Source code for deepmd.descriptor.se

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    Tuple,
)

from deepmd.env import (
    tf,
)
from deepmd.utils.graph import (
    get_embedding_net_variables_from_graph_def,
    get_tensor_by_name_from_graph,
)

from .descriptor import (
    Descriptor,
)


[docs]class DescrptSe(Descriptor): """A base class for smooth version of descriptors. Notes ----- All of these descriptors have an environmental matrix and an embedding network (:meth:`deepmd.utils.network.embedding_net`), so they can share some similiar methods without defining them twice. Attributes ---------- embedding_net_variables : dict initial embedding network variables descrpt_reshape : tf.Tensor the reshaped descriptor descrpt_deriv : tf.Tensor the descriptor derivative rij : tf.Tensor distances between two atoms nlist : tf.Tensor the neighbor list """ def _identity_tensors(self, suffix: str = "") -> None: """Identify tensors which are expected to be stored and restored. Notes ----- These tensors will be indentitied: self.descrpt_reshape : o_rmat self.descrpt_deriv : o_rmat_deriv self.rij : o_rij self.nlist : o_nlist Thus, this method should be called during building the descriptor and after these tensors are initialized. Parameters ---------- suffix : str The suffix of the scope """ self.descrpt_reshape = tf.identity(self.descrpt_reshape, name="o_rmat" + suffix) self.descrpt_deriv = tf.identity( self.descrpt_deriv, name="o_rmat_deriv" + suffix ) self.rij = tf.identity(self.rij, name="o_rij" + suffix) self.nlist = tf.identity(self.nlist, name="o_nlist" + suffix)
[docs] def get_tensor_names(self, suffix: str = "") -> Tuple[str]: """Get names of tensors. Parameters ---------- suffix : str The suffix of the scope Returns ------- Tuple[str] Names of tensors """ return ( f"o_rmat{suffix}:0", f"o_rmat_deriv{suffix}:0", f"o_rij{suffix}:0", f"o_nlist{suffix}:0", )
[docs] def pass_tensors_from_frz_model( self, descrpt_reshape: tf.Tensor, descrpt_deriv: tf.Tensor, rij: tf.Tensor, nlist: tf.Tensor, ): """Pass the descrpt_reshape tensor as well as descrpt_deriv tensor from the frz graph_def. Parameters ---------- descrpt_reshape The passed descrpt_reshape tensor descrpt_deriv The passed descrpt_deriv tensor rij The passed rij tensor nlist The passed nlist tensor """ self.rij = rij self.nlist = nlist self.descrpt_deriv = descrpt_deriv self.descrpt_reshape = descrpt_reshape
[docs] def init_variables( self, graph: tf.Graph, graph_def: tf.GraphDef, suffix: str = "", ) -> None: """Init the embedding net variables with the given dict. Parameters ---------- graph : tf.Graph The input frozen model graph graph_def : tf.GraphDef The input frozen model graph_def suffix : str, optional The suffix of the scope """ self.embedding_net_variables = get_embedding_net_variables_from_graph_def( graph_def, suffix=suffix ) self.davg = get_tensor_by_name_from_graph( graph, "descrpt_attr%s/t_avg" % suffix ) self.dstd = get_tensor_by_name_from_graph( graph, "descrpt_attr%s/t_std" % suffix )
@property def precision(self) -> tf.DType: """Precision of filter network.""" return self.filter_precision
[docs] @classmethod def update_sel(cls, global_jdata: dict, local_jdata: dict): """Update the selection and perform neighbor statistics. Parameters ---------- global_jdata : dict The global data, containing the training section local_jdata : dict The local data refer to the current class """ from deepmd.entrypoints.train import ( update_one_sel, ) # default behavior is to update sel which is a list local_jdata_cpy = local_jdata.copy() return update_one_sel(global_jdata, local_jdata_cpy, False)