Source code for deepmd.model.frozen

# SPDX-License-Identifier: LGPL-3.0-or-later
from enum import (
    Enum,
)
from typing import (
    Optional,
    Union,
)

from deepmd.env import (
    GLOBAL_TF_FLOAT_PRECISION,
    MODEL_VERSION,
    tf,
)
from deepmd.fit.fitting import (
    Fitting,
)
from deepmd.infer import (
    DeepPotential,
)
from deepmd.loss.loss import (
    Loss,
)

from .model import (
    Model,
)


[docs]class FrozenModel(Model): """Load model from a frozen model, which cannot be trained. Parameters ---------- model_file : str The path to the frozen model """ def __init__(self, model_file: str, **kwargs): super().__init__(**kwargs) self.model_file = model_file self.model = DeepPotential(model_file) self.model_type = self.model.model_type
[docs] def build( self, coord_: tf.Tensor, atype_: tf.Tensor, natoms: tf.Tensor, box: tf.Tensor, mesh: tf.Tensor, input_dict: dict, frz_model: Optional[str] = None, ckpt_meta: Optional[str] = None, suffix: str = "", reuse: Optional[Union[bool, Enum]] = None, ) -> dict: """Build the model. Parameters ---------- coord_ : tf.Tensor The coordinates of atoms atype_ : tf.Tensor The atom types of atoms natoms : tf.Tensor The number of atoms box : tf.Tensor The box vectors mesh : tf.Tensor The mesh vectors input_dict : dict The input dict frz_model : str, optional The path to the frozen model ckpt_meta : str, optional The path prefix of the checkpoint and meta files suffix : str, optional The suffix of the scope reuse : bool or tf.AUTO_REUSE, optional Whether to reuse the variables Returns ------- dict The output dict """ # reset the model to import to the correct graph extra_feed_dict = {} if input_dict is not None: if "fparam" in input_dict: extra_feed_dict["fparam"] = input_dict["fparam"] if "aparam" in input_dict: extra_feed_dict["aparam"] = input_dict["aparam"] input_map = self.get_feed_dict( coord_, atype_, natoms, box, mesh, **extra_feed_dict ) self.model = DeepPotential( self.model_file, default_tf_graph=True, load_prefix="load" + suffix, input_map=input_map, ) with tf.variable_scope("model_attr" + suffix, reuse=reuse): t_tmap = tf.constant( " ".join(self.get_type_map()), name="tmap", dtype=tf.string ) t_mt = tf.constant(self.model_type, name="model_type", dtype=tf.string) t_ver = tf.constant(MODEL_VERSION, name="model_version", dtype=tf.string) with tf.variable_scope("descrpt_attr" + suffix, reuse=reuse): t_ntypes = tf.constant(self.get_ntypes(), name="ntypes", dtype=tf.int32) t_rcut = tf.constant( self.get_rcut(), name="rcut", dtype=GLOBAL_TF_FLOAT_PRECISION ) with tf.variable_scope("fitting_attr" + suffix, reuse=reuse): t_dfparam = tf.constant( self.model.get_dim_fparam(), name="dfparam", dtype=tf.int32 ) t_daparam = tf.constant( self.model.get_dim_aparam(), name="daparam", dtype=tf.int32 ) if self.model_type == "ener": return { "energy": tf.identity(self.model.t_energy, name="o_energy" + suffix), "force": tf.identity(self.model.t_force, name="o_force" + suffix), "virial": tf.identity(self.model.t_virial, name="o_virial" + suffix), "atom_ener": tf.identity( self.model.t_ae, name="o_atom_energy" + suffix ), "atom_virial": tf.identity( self.model.t_av, name="o_atom_virial" + suffix ), "coord": coord_, "atype": atype_, } else: raise NotImplementedError( f"Model type {self.model_type} has not been implemented. " "Contribution is welcome!" )
[docs] def get_fitting(self) -> Union[Fitting, dict]: """Get the fitting(s).""" return {}
[docs] def get_loss(self, loss: dict, lr) -> Optional[Union[Loss, dict]]: """Get the loss function(s).""" # loss should be never used for a frozen model return
[docs] def get_rcut(self): return self.model.get_rcut()
[docs] def get_ntypes(self) -> int: return self.model.get_ntypes()
[docs] def data_stat(self, data): pass
[docs] def init_variables( self, graph: tf.Graph, graph_def: tf.GraphDef, model_type: str = "original_model", suffix: str = "", ) -> None: """Init the embedding net variables with the given frozen model. Parameters ---------- graph : tf.Graph The input frozen model graph graph_def : tf.GraphDef The input frozen model graph_def model_type : str the type of the model suffix : str suffix to name scope """ pass
[docs] def enable_compression(self, suffix: str = "") -> None: """Enable compression. Parameters ---------- suffix : str suffix to name scope """ pass
[docs] def get_type_map(self) -> list: """Get the type map.""" return self.model.get_type_map()
[docs] @classmethod def update_sel(cls, global_jdata: dict, local_jdata: dict): """Update the selection and perform neighbor statistics. Parameters ---------- global_jdata : dict The global data, containing the training section local_jdata : dict The local data refer to the current class """ # we don't know how to compress it, so no neighbor statistics here return local_jdata