Source code for dpgen.util

#!/usr/bin/env python
import json
import os
from contextlib import (
    contextmanager,
)
from pathlib import Path
from typing import Union

import dpdata
import h5py
import numpy as np
from dargs import Argument
from dpdata.data_type import Axis, DataType

from dpgen import dlog

"""
some common utilities for generator, auto_test and data
"""

# constants define
MaxLength = 70


[docs] def sepline(ch="-", sp="-", screen=False): r"""Seperate the output by '-'.""" if screen: print(ch.center(MaxLength, sp)) else: dlog.info(ch.center(MaxLength, sp))
[docs] def box_center(ch="", fill=" ", sp="|"): r"""Put the string at the center of | |.""" strs = ch.center(MaxLength, fill) dlog.info(sp + strs[1 : len(strs) - 1 :] + sp)
[docs] def expand_sys_str(root_dir: Union[str, Path]) -> list[str]: """Recursively iterate over directories taking those that contain `type.raw` file. If root_dir is a file but not a directory, it will be assumed as an HDF5 file. Parameters ---------- root_dir : Union[str, Path] starting directory Returns ------- List[str] list of string pointing to system directories Raises ------ RuntimeError No system was found in the directory """ root_dir = Path(root_dir) if root_dir.is_dir(): matches = [str(d) for d in root_dir.rglob("*") if (d / "type.raw").is_file()] if (root_dir / "type.raw").is_file(): matches.append(str(root_dir)) elif root_dir.is_file(): # HDF5 file with h5py.File(root_dir, "r") as f: # list of keys in the h5 file f_keys = ["/"] f.visit(lambda x: f_keys.append("/" + x)) matches = [ f"{root_dir}#{d}" for d in f_keys if str(Path(d) / "type.raw") in f_keys ] else: raise OSError(f"{root_dir} does not exist.") if len(matches) == 0: raise RuntimeError("%s does not contain any systems!" % root_dir) return matches
[docs] def normalize(arginfo: Argument, data: dict, strict_check: bool = True) -> dict: """Normalize and check input data. Parameters ---------- arginfo : dargs.Argument argument information data : dict input data strict_check : bool, default=True strict check data or not Returns ------- dict normalized data """ data = arginfo.normalize_value(data, trim_pattern="_*") arginfo.check_value(data, strict=strict_check) return data
[docs] def convert_training_data_to_hdf5(input_files: list[str], h5_file: str): """Convert training data to HDF5 format and update the input files. Parameters ---------- input_files : list of str DeePMD-kit input file names h5_file : str HDF5 file name """ systems = [] h5_dir = Path(h5_file).parent.absolute() cwd = Path.cwd().absolute() for ii in input_files: ii = Path(ii) dd = ii.parent.absolute() with open(ii, "r+") as f: jinput = json.load(f) if "training_data" in jinput["training"]: # v2.0 p_sys = jinput["training"]["training_data"]["systems"] else: # v1.x p_sys = jinput["training"]["systems"] for ii, pp in enumerate(p_sys): if "#" in pp: # HDF5 file p1, p2 = pp.split("#") ff = os.path.normpath(str((dd / p1).absolute().relative_to(cwd))) pp = ff + "#" + p2 new_pp = os.path.normpath(os.path.relpath(ff, h5_dir)) + p2 else: pp = os.path.normpath(str((dd / pp).absolute().relative_to(cwd))) new_pp = os.path.normpath(os.path.relpath(pp, h5_dir)) p_sys[ii] = ( os.path.normpath(os.path.relpath(h5_file, dd)) + "#/" + str(new_pp) ) systems.append(pp) f.seek(0) json.dump(jinput, f, indent=4) systems = list(set(systems)) dlog.info("Combining %d training systems to %s...", len(systems), h5_file) with h5py.File(h5_file, "w") as f: for ii in systems: if "#" in ii: p1, p2 = ii.split("#") p1 = os.path.normpath(os.path.relpath(p1, h5_dir)) group = f.create_group(str(p1) + p2) s = dpdata.LabeledSystem(ii, fmt="deepmd/hdf5") s.to("deepmd/hdf5", group) else: pp = os.path.normpath(os.path.relpath(ii, h5_dir)) group = f.create_group(str(pp)) s = dpdata.LabeledSystem(ii, fmt="deepmd/npy") s.to("deepmd/hdf5", group)
[docs] @contextmanager def set_directory(path: Path): """Sets the current working path within the context. Parameters ---------- path : Path The path to the cwd Yields ------ None Examples -------- >>> with set_directory("some_path"): ... do_something() """ cwd = Path().absolute() path.mkdir(exist_ok=True, parents=True) try: os.chdir(path) yield finally: os.chdir(cwd)
[docs] def load_file(filename: Union[str, os.PathLike]) -> dict: """Load data from a JSON or YAML file. Parameters ---------- filename : str or os.PathLike The filename to load data from, whose suffix should be .json, .yaml, or .yml Returns ------- dict The data loaded from the file Raises ------ ValueError If the file format is not supported """ filename = str(filename) if filename.endswith(".json"): with open(filename) as fp: data = json.load(fp) elif filename.endswith(".yaml") or filename.endswith(".yml"): from ruamel.yaml import YAML yaml = YAML(typ="safe", pure=True) with open(filename) as fp: data = yaml.load(fp) else: raise ValueError(f"Unsupported file format: {filename}") return data
[docs] def setup_ele_temp(atomic: bool): """Set electronic temperature as required input data. Parameters ---------- atomic : bool Whether to use atomic temperature or frame temperature """ if atomic: ele_temp_data_type = DataType( "aparam", np.ndarray, shape=(Axis.NFRAMES, Axis.NATOMS, 1), required=False, ) else: ele_temp_data_type = DataType( "fparam", np.ndarray, shape=(Axis.NFRAMES, 1), required=False, ) dpdata.System.register_data_type(ele_temp_data_type) dpdata.LabeledSystem.register_data_type(ele_temp_data_type)