Source code for deepmd.model.model

from abc import (
    ABC,
    abstractmethod,
)
from enum import (
    Enum,
)
from typing import (
    List,
    Optional,
    Union,
)

from deepmd.env import (
    GLOBAL_TF_FLOAT_PRECISION,
    tf,
)
from deepmd.utils.graph import (
    load_graph_def,
)


[docs]class Model(ABC):
[docs] @abstractmethod def build( self, coord_: tf.Tensor, atype_: tf.Tensor, natoms: tf.Tensor, box: tf.Tensor, mesh: tf.Tensor, input_dict: dict, frz_model: Optional[str] = None, ckpt_meta: Optional[str] = None, suffix: str = "", reuse: Optional[Union[bool, Enum]] = None, ): """Build the model. Parameters ---------- coord_ : tf.Tensor The coordinates of atoms atype_ : tf.Tensor The atom types of atoms natoms : tf.Tensor The number of atoms box : tf.Tensor The box vectors mesh : tf.Tensor The mesh vectors input_dict : dict The input dict frz_model : str, optional The path to the frozen model ckpt_meta : str, optional The path to the checkpoint and meta file suffix : str, optional The suffix of the scope reuse : bool or tf.AUTO_REUSE, optional Whether to reuse the variables Returns ------- dict The output dict """
[docs] def init_variables( self, graph: tf.Graph, graph_def: tf.GraphDef, model_type: str = "original_model", suffix: str = "", ) -> None: """Init the embedding net variables with the given frozen model. Parameters ---------- graph : tf.Graph The input frozen model graph graph_def : tf.GraphDef The input frozen model graph_def model_type : str the type of the model suffix : str suffix to name scope """ raise RuntimeError( "The 'dp train init-frz-model' command do not support this model!" )
[docs] def build_descrpt( self, coord_: tf.Tensor, atype_: tf.Tensor, natoms: tf.Tensor, box: tf.Tensor, mesh: tf.Tensor, input_dict: dict, frz_model: Optional[str] = None, ckpt_meta: Optional[str] = None, suffix: str = "", reuse: Optional[Union[bool, Enum]] = None, ): """Build the descriptor part of the model. Parameters ---------- coord_ : tf.Tensor The coordinates of atoms atype_ : tf.Tensor The atom types of atoms natoms : tf.Tensor The number of atoms box : tf.Tensor The box vectors mesh : tf.Tensor The mesh vectors input_dict : dict The input dict frz_model : str, optional The path to the frozen model ckpt_meta : str, optional The path to the checkpoint and meta file suffix : str, optional The suffix of the scope reuse : bool or tf.AUTO_REUSE, optional Whether to reuse the variables Returns ------- tf.Tensor The descriptor tensor """ if frz_model is None and ckpt_meta is None: dout = self.descrpt.build( coord_, atype_, natoms, box, mesh, input_dict, suffix=suffix, reuse=reuse, ) dout = tf.identity(dout, name="o_descriptor") else: tf.constant( self.rcut, name="descrpt_attr/rcut", dtype=GLOBAL_TF_FLOAT_PRECISION ) tf.constant(self.ntypes, name="descrpt_attr/ntypes", dtype=tf.int32) feed_dict = self.descrpt.get_feed_dict(coord_, atype_, natoms, box, mesh) return_elements = [*self.descrpt.get_tensor_names(), "o_descriptor:0"] if frz_model is not None: imported_tensors = self._import_graph_def_from_frz_model( frz_model, feed_dict, return_elements ) elif ckpt_meta is not None: imported_tensors = self._import_graph_def_from_ckpt_meta( ckpt_meta, feed_dict, return_elements ) else: raise RuntimeError("should not reach here") # pragma: no cover dout = imported_tensors[-1] self.descrpt.pass_tensors_from_frz_model(*imported_tensors[:-1]) return dout
def _import_graph_def_from_frz_model( self, frz_model: str, feed_dict: dict, return_elements: List[str] ): return_nodes = [x[:-2] for x in return_elements] graph, graph_def = load_graph_def(frz_model) sub_graph_def = tf.graph_util.extract_sub_graph(graph_def, return_nodes) return tf.import_graph_def( sub_graph_def, input_map=feed_dict, return_elements=return_elements, name="" ) def _import_graph_def_from_ckpt_meta( self, ckpt_meta: str, feed_dict: dict, return_elements: List[str] ): return_nodes = [x[:-2] for x in return_elements] with tf.Graph().as_default() as graph: tf.train.import_meta_graph(f"{ckpt_meta}.meta", clear_devices=True) graph_def = graph.as_graph_def() sub_graph_def = tf.graph_util.extract_sub_graph(graph_def, return_nodes) return tf.import_graph_def( sub_graph_def, input_map=feed_dict, return_elements=return_elements, name="" )