Source code for deepmd.descriptor.se

from typing import Tuple, List

from deepmd.env import tf
from deepmd.utils.graph import get_embedding_net_variables
from .descriptor import Descriptor


[docs]class DescrptSe (Descriptor): """A base class for smooth version of descriptors. Notes ----- All of these descriptors have an environmental matrix and an embedding network (:meth:`deepmd.utils.network.embedding_net`), so they can share some similiar methods without defining them twice. Attributes ---------- embedding_net_variables : dict initial embedding network variables descrpt_reshape : tf.Tensor the reshaped descriptor descrpt_deriv : tf.Tensor the descriptor derivative rij : tf.Tensor distances between two atoms nlist : tf.Tensor the neighbor list """ def _identity_tensors(self, suffix : str = "") -> None: """Identify tensors which are expected to be stored and restored. Notes ----- These tensors will be indentitied: self.descrpt_reshape : o_rmat self.descrpt_deriv : o_rmat_deriv self.rij : o_rij self.nlist : o_nlist Thus, this method should be called during building the descriptor and after these tensors are initialized. Parameters ---------- suffix : str The suffix of the scope """ self.descrpt_reshape = tf.identity(self.descrpt_reshape, name = 'o_rmat' + suffix) self.descrpt_deriv = tf.identity(self.descrpt_deriv, name = 'o_rmat_deriv' + suffix) self.rij = tf.identity(self.rij, name = 'o_rij' + suffix) self.nlist = tf.identity(self.nlist, name = 'o_nlist' + suffix)
[docs] def get_tensor_names(self, suffix : str = "") -> Tuple[str]: """Get names of tensors. Parameters ---------- suffix : str The suffix of the scope Returns ------- Tuple[str] Names of tensors """ return (f'o_rmat{suffix}:0', f'o_rmat_deriv{suffix}:0', f'o_rij{suffix}:0', f'o_nlist{suffix}:0')
[docs] def pass_tensors_from_frz_model(self, descrpt_reshape : tf.Tensor, descrpt_deriv : tf.Tensor, rij : tf.Tensor, nlist : tf.Tensor ): """ Pass the descrpt_reshape tensor as well as descrpt_deriv tensor from the frz graph_def Parameters ---------- descrpt_reshape The passed descrpt_reshape tensor descrpt_deriv The passed descrpt_deriv tensor rij The passed rij tensor nlist The passed nlist tensor """ self.rij = rij self.nlist = nlist self.descrpt_deriv = descrpt_deriv self.descrpt_reshape = descrpt_reshape
[docs] def init_variables(self, model_file : str, suffix : str = "", ) -> None: """ Init the embedding net variables with the given dict Parameters ---------- model_file : str The input frozen model file suffix : str, optional The suffix of the scope """ self.embedding_net_variables = get_embedding_net_variables(model_file, suffix = suffix)