Source code for deepmd.model.pairtab

# SPDX-License-Identifier: LGPL-3.0-or-later
from enum import (
    Enum,
)
from typing import (
    List,
    Optional,
    Union,
)

import numpy as np

from deepmd.env import (
    GLOBAL_TF_FLOAT_PRECISION,
    MODEL_VERSION,
    global_cvt_2_ener_float,
    op_module,
    tf,
)
from deepmd.fit.fitting import (
    Fitting,
)
from deepmd.loss.loss import (
    Loss,
)
from deepmd.model.model import (
    Model,
)
from deepmd.utils.pair_tab import (
    PairTab,
)


[docs]class PairTabModel(Model): """Pairwise tabulation energy model. This model can be used to tabulate the pairwise energy between atoms for either short-range or long-range interactions, such as D3, LJ, ZBL, etc. It should not be used alone, but rather as one submodel of a linear (sum) model, such as DP+D3. Do not put the model on the first model of a linear model, since the linear model fetches the type map from the first model. At this moment, the model does not smooth the energy at the cutoff radius, so one needs to make sure the energy has been smoothed to zero. Parameters ---------- tab_file : str The path to the tabulation file. rcut : float The cutoff radius sel : int or list[int] The maxmum number of atoms in the cut-off radius """ model_type = "ener" def __init__( self, tab_file: str, rcut: float, sel: Union[int, List[int]], **kwargs ): super().__init__() self.tab_file = tab_file self.tab = PairTab(self.tab_file) self.ntypes = self.tab.ntypes self.rcut = rcut if isinstance(sel, int): self.sel = sel elif isinstance(sel, list): self.sel = sum(sel) else: raise TypeError("sel must be int or list[int]")
[docs] def build( self, coord_: tf.Tensor, atype_: tf.Tensor, natoms: tf.Tensor, box: tf.Tensor, mesh: tf.Tensor, input_dict: dict, frz_model: Optional[str] = None, ckpt_meta: Optional[str] = None, suffix: str = "", reuse: Optional[Union[bool, Enum]] = None, ): """Build the model. Parameters ---------- coord_ : tf.Tensor The coordinates of atoms atype_ : tf.Tensor The atom types of atoms natoms : tf.Tensor The number of atoms box : tf.Tensor The box vectors mesh : tf.Tensor The mesh vectors input_dict : dict The input dict frz_model : str, optional The path to the frozen model ckpt_meta : str, optional The path prefix of the checkpoint and meta files suffix : str, optional The suffix of the scope reuse : bool or tf.AUTO_REUSE, optional Whether to reuse the variables Returns ------- dict The output dict """ tab_info, tab_data = self.tab.get() with tf.variable_scope("model_attr" + suffix, reuse=reuse): self.tab_info = tf.get_variable( "t_tab_info", tab_info.shape, dtype=tf.float64, trainable=False, initializer=tf.constant_initializer(tab_info, dtype=tf.float64), ) self.tab_data = tf.get_variable( "t_tab_data", tab_data.shape, dtype=tf.float64, trainable=False, initializer=tf.constant_initializer(tab_data, dtype=tf.float64), ) t_tmap = tf.constant(" ".join(self.type_map), name="tmap", dtype=tf.string) t_mt = tf.constant(self.model_type, name="model_type", dtype=tf.string) t_ver = tf.constant(MODEL_VERSION, name="model_version", dtype=tf.string) with tf.variable_scope("fitting_attr" + suffix, reuse=reuse): t_dfparam = tf.constant(0, name="dfparam", dtype=tf.int32) t_daparam = tf.constant(0, name="daparam", dtype=tf.int32) with tf.variable_scope("descrpt_attr" + suffix, reuse=reuse): t_ntypes = tf.constant(self.ntypes, name="ntypes", dtype=tf.int32) t_rcut = tf.constant( self.rcut, name="rcut", dtype=GLOBAL_TF_FLOAT_PRECISION ) coord = tf.reshape(coord_, [-1, natoms[1] * 3]) atype = tf.reshape(atype_, [-1, natoms[1]]) box = tf.reshape(box, [-1, 9]) # perhaps we need a OP that only outputs rij and nlist ( _, _, rij, nlist, _, _, ) = op_module.prod_env_mat_a_mix( coord, atype, natoms, box, mesh, np.zeros([self.ntypes, self.sel * 4]), np.ones([self.ntypes, self.sel * 4]), rcut_a=-1, rcut_r=self.rcut, rcut_r_smth=self.rcut, sel_a=[self.sel], sel_r=[0], ) scale = tf.ones([tf.shape(coord)[0], natoms[0]], dtype=tf.float64) tab_atom_ener, tab_force, tab_atom_virial = op_module.pair_tab( self.tab_info, self.tab_data, atype, rij, nlist, natoms, scale, sel_a=[self.sel], sel_r=[0], ) energy_raw = tf.reshape( tab_atom_ener, [-1, natoms[0]], name="o_atom_energy" + suffix ) energy = tf.reduce_sum( global_cvt_2_ener_float(energy_raw), axis=1, name="o_energy" + suffix ) force = tf.reshape(tab_force, [-1, 3 * natoms[1]], name="o_force" + suffix) virial = tf.reshape( tf.reduce_sum(tf.reshape(tab_atom_virial, [-1, natoms[1], 9]), axis=1), [-1, 9], name="o_virial" + suffix, ) atom_virial = tf.reshape( tab_atom_virial, [-1, 9 * natoms[1]], name="o_atom_virial" + suffix ) model_dict = {} model_dict["energy"] = energy model_dict["force"] = force model_dict["virial"] = virial model_dict["atom_ener"] = energy_raw model_dict["atom_virial"] = atom_virial model_dict["coord"] = coord model_dict["atype"] = atype return model_dict
[docs] def init_variables( self, graph: tf.Graph, graph_def: tf.GraphDef, model_type: str = "original_model", suffix: str = "", ) -> None: """Init the embedding net variables with the given frozen model. Parameters ---------- graph : tf.Graph The input frozen model graph graph_def : tf.GraphDef The input frozen model graph_def model_type : str the type of the model suffix : str suffix to name scope """
# skip. table can be initialized from the file
[docs] def get_fitting(self) -> Union[Fitting, dict]: """Get the fitting(s).""" # nothing needs to do return {}
[docs] def get_loss(self, loss: dict, lr) -> Optional[Union[Loss, dict]]: """Get the loss function(s).""" # nothing nees to do return
[docs] def get_rcut(self) -> float: """Get cutoff radius of the model.""" return self.rcut
[docs] def get_ntypes(self) -> int: """Get the number of types.""" return self.ntypes
[docs] def data_stat(self, data: dict): """Data staticis."""
# nothing needs to do
[docs] def enable_compression(self, suffix: str = "") -> None: """Enable compression. Parameters ---------- suffix : str suffix to name scope """
# nothing needs to do
[docs] @classmethod def update_sel(cls, global_jdata: dict, local_jdata: dict) -> dict: """Update the selection and perform neighbor statistics. Notes ----- Do not modify the input data without copying it. Parameters ---------- global_jdata : dict The global data, containing the training section local_jdata : dict The local data refer to the current class Returns ------- dict The updated local data """ from deepmd.entrypoints.train import ( update_one_sel, ) local_jdata_cpy = local_jdata.copy() return update_one_sel(global_jdata, local_jdata_cpy, True)