#!/usr/bin/env python3
"""Quickly create a configuration file for smooth model."""
import json
import yaml
from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np
__all__ = ["config"]
DEFAULT_DATA: Dict[str, Any] = {
"use_smooth": True,
"sel_a": [],
"rcut_smth": -1,
"rcut": -1,
"filter_neuron": [20, 40, 80],
"filter_resnet_dt": False,
"axis_neuron": 8,
"fitting_neuron": [240, 240, 240],
"fitting_resnet_dt": True,
"coord_norm": True,
"type_fitting_net": False,
"systems": [],
"set_prefix": "set",
"stop_batch": -1,
"batch_size": -1,
"start_lr": 0.001,
"decay_steps": -1,
"decay_rate": 0.95,
"start_pref_e": 0.02,
"limit_pref_e": 1,
"start_pref_f": 1000,
"limit_pref_f": 1,
"start_pref_v": 0,
"limit_pref_v": 0,
"seed": 1,
"disp_file": "lcurve.out",
"disp_freq": 1000,
"numb_test": 10,
"save_freq": 10000,
"save_ckpt": "model.ckpt",
"disp_training": True,
"time_training": True,
}
def valid_dir(path: Path):
"""Check if directory is a valid deepmd system directory.
Parameters
----------
path : Path
path to directory
Raises
------
OSError
if `type.raw` is missing on dir or `box.npy` or `coord.npy` are missing in one
of the sets subdirs
"""
if not (path / "type.raw").is_file():
raise OSError
for ii in path.glob("set.*"):
if not (ii / "box.npy").is_file():
raise OSError
if not (ii / "coord.npy").is_file():
raise OSError
def load_systems(dirs: List[Path]) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""Load systems to memory for disk.
Parameters
----------
dirs : List[Path]
list of system directories paths
Returns
-------
Tuple[List[np.ndarray], List[np.ndarray]]
atoms types and structure cells formated as Nx9 array
"""
all_type = []
all_box = []
for d in dirs:
sys_type = np.loadtxt(d / "type.raw", dtype=int)
sys_box = np.vstack([np.load(s / "box.npy") for s in d.glob("set.*")])
all_type.append(sys_type)
all_box.append(sys_box)
return all_type, all_box
def get_system_names() -> List[Path]:
"""Get system directory paths from stdin.
Returns
-------
List[Path]
list of system directories paths
"""
dirs = input("Enter system path(s) (seperated by space, wild card supported): \n")
system_dirs = []
for dir_str in dirs.split():
found_dirs = Path.cwd().glob(dir_str)
for d in found_dirs:
valid_dir(d)
system_dirs.append(d)
return system_dirs
def get_rcut() -> float:
"""Get rcut from stdin from user.
Returns
-------
float
input rcut lenght converted to float
Raises
------
ValueError
if rcut is smaller than 0.0
"""
dv = 6.0
rcut_input = input(f"Enter rcut (default {dv:.1f} A): \n")
try:
rcut = float(rcut_input)
except ValueError as e:
print(f"invalid rcut: {e} setting to default: {dv:.1f}")
rcut = dv
if rcut <= 0:
raise ValueError("rcut should be > 0")
return rcut
def get_batch_size_rule() -> int:
"""Get minimal batch size from user from stdin.
Returns
-------
int
size of the batch
Raises
------
ValueError
if batch size is <= 0
"""
dv = 32
matom_input = input(
f"Enter the minimal number of atoms in a batch (default {dv:d}: \n"
)
try:
matom = int(matom_input)
except ValueError as e:
print(f"invalid batch size: {e} setting to default: {dv:d}")
matom = dv
if matom <= 0:
raise ValueError("the number should be > 0")
return matom
def get_stop_batch() -> int:
"""Get stop batch from user from stdin.
Returns
-------
int
size of the batch
Raises
------
ValueError
if stop batch is <= 0
"""
dv = 1000000
sb_input = input(f"Enter the stop batch (default {dv:d}): \n")
try:
sb = int(sb_input)
except ValueError as e:
print(f"invalid stop batch: {e} setting to default: {dv:d}")
sb = dv
if sb <= 0:
raise ValueError("the number should be > 0")
return sb
def get_ntypes(all_type: List[np.ndarray]) -> int:
"""Count number of unique elements.
Parameters
----------
all_type : List[np.ndarray]
list with arrays specifying elements of structures
Returns
-------
int
number of unique elements
"""
return len(np.unique(all_type))
def get_max_density(
all_type: List[np.ndarray], all_box: List[np.ndarray]
) -> np.ndarray:
"""Compute maximum density in suppliedd cells.
Parameters
----------
all_type : List[np.ndarray]
list with arrays specifying elements of structures
all_box : List[np.ndarray]
list with arrays specifying cells for all structures
Returns
-------
float
maximum atom density in all supplies structures for each element individually
"""
ntypes = get_ntypes(all_type)
all_max = []
for tt, bb in zip(all_type, all_box):
vv = np.reshape(bb, [-1, 3, 3])
vv = np.linalg.det(vv)
min_v = np.min(vv)
type_count = []
for ii in range(ntypes):
type_count.append(sum(tt == ii))
max_den = type_count / min_v
all_max.append(max_den)
all_max = np.max(all_max, axis=0)
return all_max
def suggest_sel(
all_type: List[np.ndarray],
all_box: List[np.ndarray],
rcut: float,
ratio: float = 1.5,
) -> List[int]:
"""Suggest selection parameter.
Parameters
----------
all_type : List[np.ndarray]
list with arrays specifying elements of structures
all_box : List[np.ndarray]
list with arrays specifying cells for all structures
rcut : float
cutoff radius
ratio : float, optional
safety margin to add to estimated value, by default 1.5
Returns
-------
List[int]
[description]
"""
max_den = get_max_density(all_type, all_box)
return [int(ii) for ii in max_den * 4.0 / 3.0 * np.pi * rcut ** 3 * ratio]
def suggest_batch_size(all_type: List[np.ndarray], min_atom: int) -> List[int]:
"""Get suggestion for batch size.
Parameters
----------
all_type : List[np.ndarray]
list with arrays specifying elements of structures
min_atom : int
minimal number of atoms in batch
Returns
-------
List[int]
suggested batch sizes for each system
"""
bs = []
for ii in all_type:
natoms = len(ii)
tbs = min_atom // natoms
if (min_atom // natoms) * natoms != min_atom:
tbs += 1
bs.append(tbs)
return bs
def suggest_decay(stop_batch: int) -> Tuple[int, float]:
"""Suggest number of decay steps and decay rate.
Parameters
----------
stop_batch : int
stop batch number
Returns
-------
Tuple[int, float]
number of decay steps and decay rate
"""
decay_steps = int(stop_batch // 200)
decay_rate = 0.95
return decay_steps, decay_rate
[docs]def config(*, output: str, **kwargs):
"""Auto config file generator.
Parameters
----------
output: str
file to write config file
Raises
------
RuntimeError
if user does not input any systems
ValueError
if output file is of wrong type
"""
all_sys = get_system_names()
if len(all_sys) == 0:
raise RuntimeError("no system specified")
rcut = get_rcut()
matom = get_batch_size_rule()
stop_batch = get_stop_batch()
all_type, all_box = load_systems(all_sys)
sel = suggest_sel(all_type, all_box, rcut, ratio=1.5)
bs = suggest_batch_size(all_type, matom)
decay_steps, decay_rate = suggest_decay(stop_batch)
jdata = DEFAULT_DATA.copy()
jdata["systems"] = [str(ii) for ii in all_sys]
jdata["sel_a"] = sel
jdata["rcut"] = rcut
jdata["rcut_smth"] = rcut - 0.2
jdata["stop_batch"] = stop_batch
jdata["batch_size"] = bs
jdata["decay_steps"] = decay_steps
jdata["decay_rate"] = decay_rate
with open(output, "w") as fp:
if output.endswith("json"):
json.dump(jdata, fp, indent=4)
elif output.endswith(("yml", "yaml")):
yaml.safe_dump(jdata, fp, default_flow_style=False)
else:
raise ValueError("output file must be of type json or yaml")