Source code for dpdata.formats.xyz.quip_gap_xyz

#!/usr/bin/env python3
# %%
from __future__ import annotations

import re
from collections import OrderedDict

import numpy as np

from dpdata.periodic_table import Element


[docs] class QuipGapxyzSystems: """deal with QuipGapxyzFile.""" def __init__(self, file_name): self.file_object = open(file_name) self.block_generator = self.get_block_generator() def __iter__(self): return self def __next__(self): return self.handle_single_xyz_frame(next(self.block_generator)) def __del__(self): self.file_object.close()
[docs] def get_block_generator(self): p3 = re.compile(r"^\s*(\d+)\s*") while True: line = self.file_object.readline() if not line: break if p3.match(line): atom_num = int(p3.match(line).group(1)) lines = [] lines.append(line) for ii in range(atom_num + 1): lines.append(self.file_object.readline()) if not lines[-1]: raise RuntimeError( f"this xyz file may lack of lines, should be {atom_num + 2};lines:{lines}" ) yield lines
[docs] @staticmethod def handle_single_xyz_frame(lines): atom_num = int(lines[0].strip("\n").strip()) if len(lines) != atom_num + 2: raise RuntimeError( f"format error, atom_num=={atom_num}, {len(lines)}!=atom_num+2" ) data_format_line = lines[1].strip("\n").strip() + " " field_value_pattern = re.compile( r"(?P<key>\S+)=(?P<quote>[\'\"]?)(?P<value>.*?)(?P=quote)\s+" ) prop_pattern = re.compile( r"(?P<key>\w+?):(?P<datatype>[a-zA-Z]):(?P<value>\d+)" ) data_format_list = [ kv_dict.groupdict() for kv_dict in field_value_pattern.finditer(data_format_line) ] field_dict = {} for item in data_format_list: field_dict[item["key"]] = item["value"] Properties = field_dict["Properties"] prop_list = [ kv_dict.groupdict() for kv_dict in prop_pattern.finditer(Properties) ] data_lines = [] for line in lines[2:]: data_lines.append(list(filter(bool, line.strip().split()))) data_array = np.array(data_lines) used_colomn = 0 type_array = None coords_array = None Z_array = None force_array = None virials = None for kv_dict in prop_list: if kv_dict["key"] == "species": if kv_dict["datatype"] != "S": raise RuntimeError( "datatype for species must be 'S' instead of {}".format( kv_dict["datatype"] ) ) field_length = int(kv_dict["value"]) type_array = data_array[ :, used_colomn : used_colomn + field_length ].flatten() used_colomn += field_length continue elif kv_dict["key"] == "pos": if kv_dict["datatype"] != "R": raise RuntimeError( "datatype for pos must be 'R' instead of {}".format( kv_dict["datatype"] ) ) field_length = int(kv_dict["value"]) coords_array = data_array[:, used_colomn : used_colomn + field_length] used_colomn += field_length continue elif kv_dict["key"] == "Z": if kv_dict["datatype"] != "I": raise RuntimeError( "datatype for pos must be 'R' instead of {}".format( kv_dict["datatype"] ) ) field_length = int(kv_dict["value"]) Z_array = data_array[ :, used_colomn : used_colomn + field_length ].flatten() used_colomn += field_length continue elif kv_dict["key"] == "force": if kv_dict["datatype"] != "R": raise RuntimeError( "datatype for pos must be 'R' instead of {}".format( kv_dict["datatype"] ) ) field_length = int(kv_dict["value"]) force_array = data_array[:, used_colomn : used_colomn + field_length] used_colomn += field_length continue else: raise RuntimeError("unknown field {}".format(kv_dict["key"])) type_num_dict = OrderedDict() atom_type_list = [] type_map = {} temp_atom_max_index = 0 if type_array is None: raise RuntimeError("type_array can't be None type, check .xyz file") for ii in type_array: if ii not in type_map: type_map[ii] = temp_atom_max_index temp_atom_max_index += 1 temp_atom_index = type_map[ii] atom_type_list.append(temp_atom_index) type_num_dict[ii] = 1 else: temp_atom_index = type_map[ii] atom_type_list.append(temp_atom_index) type_num_dict[ii] += 1 type_num_list = [] for atom_type, atom_num in type_num_dict.items(): type_num_list.append((atom_type, atom_num)) type_num_array = np.array(type_num_list) if field_dict.get("virial", None): virials = np.array( [ np.array( list(filter(bool, field_dict["virial"].split(" "))) ).reshape(3, 3) ] ).astype(np.float64) else: virials = None info_dict = {} info_dict["atom_names"] = list(type_num_array[:, 0]) info_dict["atom_numbs"] = list(type_num_array[:, 1].astype(int)) info_dict["atom_types"] = np.array(atom_type_list).astype(int) info_dict["cells"] = np.array( [ np.array(list(filter(bool, field_dict["Lattice"].split(" ")))).reshape( 3, 3 ) ] ).astype(np.float64) info_dict["coords"] = np.array([coords_array]).astype(np.float64) info_dict["energies"] = np.array([field_dict["energy"]]).astype(np.float64) info_dict["forces"] = np.array([force_array]).astype(np.float64) if virials is not None: info_dict["virials"] = virials info_dict["orig"] = np.zeros(3) return info_dict
[docs] def format_single_frame(data, frame_idx): """Format a single frame of system data into QUIP/GAP XYZ format lines. Parameters ---------- data : dict system data frame_idx : int frame index Returns ------- list[str] lines for the frame """ # Number of atoms natoms = len(data["atom_types"]) # Build header line with metadata header_parts = [] # Energy energy = data["energies"][frame_idx] header_parts.append(f"energy={energy:.12e}") # Virial (if present) if "virials" in data: virial = data["virials"][frame_idx] virial_str = " ".join(f"{v:.12e}" for v in virial.flatten()) header_parts.append(f'virial="{virial_str}"') # Lattice cell = data["cells"][frame_idx] lattice_str = " ".join(f"{c:.12e}" for c in cell.flatten()) header_parts.append(f'Lattice="{lattice_str}"') # Properties header_parts.append("Properties=species:S:1:pos:R:3:Z:I:1:force:R:3") header_line = " ".join(header_parts) # Format atom lines atom_lines = [] coords = data["coords"][frame_idx] forces = data["forces"][frame_idx] atom_names = np.array(data["atom_names"]) atom_types = data["atom_types"] for i in range(natoms): atom_type_idx = atom_types[i] species = atom_names[atom_type_idx] x, y, z = coords[i] fx, fy, fz = forces[i] atomic_number = Element(species).Z atom_line = f"{species} {x:.11e} {y:.11e} {z:.11e} {atomic_number} {fx:.11e} {fy:.11e} {fz:.11e}" atom_lines.append(atom_line) # Combine all lines for this frame frame_lines = [str(natoms), header_line] + atom_lines return frame_lines