Source code for dpdata.plugins.abacus
from __future__ import annotations
import os
from typing import TYPE_CHECKING
import numpy as np
import dpdata.formats.abacus.md
import dpdata.formats.abacus.relax
import dpdata.formats.abacus.scf
from dpdata.data_type import Axis, DataType
from dpdata.format import Format
from dpdata.formats.abacus.stru import get_frame_from_stru, make_unlabeled_stru
from dpdata.utils import open_file
if TYPE_CHECKING:
from dpdata.utils import FileType
[docs]
@Format.register("abacus/stru")
@Format.register("stru")
class AbacusSTRUFormat(Format):
[docs]
def from_system(self, file_name, **kwargs):
data = get_frame_from_stru(file_name)
register_mag_data(data)
return data
[docs]
def to_system(self, data, file_name: FileType, frame_idx=0, **kwargs):
"""Dump the system into ABACUS STRU format file.
Parameters
----------
data : dict
System data
file_name : str
The output file name
frame_idx : int
The index of the frame to dump
**kwargs : dict
other parameters
"""
stru_string = make_unlabeled_stru(
data=data,
frame_idx=frame_idx,
dest_dir=os.path.dirname(file_name),
**kwargs,
)
with open_file(file_name, "w") as fp:
fp.write(stru_string)
[docs]
def register_mag_data(data):
if "spins" in data:
dt = DataType(
"spins",
np.ndarray,
(Axis.NFRAMES, Axis.NATOMS, 3),
required=False,
deepmd_name="spin",
)
dpdata.System.register_data_type(dt)
dpdata.LabeledSystem.register_data_type(dt)
if "force_mags" in data:
dt = DataType(
"force_mags",
np.ndarray,
(Axis.NFRAMES, Axis.NATOMS, 3),
required=False,
deepmd_name="force_mag",
)
dpdata.System.register_data_type(dt)
dpdata.LabeledSystem.register_data_type(dt)
[docs]
def register_move_data(data):
if "move" in data:
dt = DataType(
"move",
np.ndarray,
(Axis.NFRAMES, Axis.NATOMS, 3),
required=False,
deepmd_name="move",
)
dpdata.System.register_data_type(dt)
[docs]
@Format.register("abacus/scf")
@Format.register("abacus/pw/scf")
@Format.register("abacus/lcao/scf")
class AbacusSCFFormat(Format):
# @Format.post("rot_lower_triangular")
[docs]
def from_labeled_system(self, file_name, **kwargs):
data = dpdata.formats.abacus.scf.get_frame(file_name)
register_mag_data(data)
register_move_data(data)
return data
[docs]
@Format.register("abacus/md")
@Format.register("abacus/pw/md")
@Format.register("abacus/lcao/md")
class AbacusMDFormat(Format):
# @Format.post("rot_lower_triangular")
[docs]
def from_labeled_system(self, file_name, **kwargs):
data = dpdata.formats.abacus.md.get_frame(file_name)
register_mag_data(data)
register_move_data(data)
return data
[docs]
@Format.register("abacus/relax")
@Format.register("abacus/pw/relax")
@Format.register("abacus/lcao/relax")
class AbacusRelaxFormat(Format):
# @Format.post("rot_lower_triangular")
[docs]
def from_labeled_system(self, file_name, **kwargs):
data = dpdata.formats.abacus.relax.get_frame(file_name)
register_mag_data(data)
register_move_data(data)
return data