Source code for deepmd.model.pairwise_dprc

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    Dict,
    List,
    Optional,
    Union,
)

from deepmd.common import (
    add_data_requirement,
    make_default_mesh,
)
from deepmd.env import (
    GLOBAL_TF_FLOAT_PRECISION,
    MODEL_VERSION,
    op_module,
    tf,
)
from deepmd.loss.loss import (
    Loss,
)
from deepmd.model.model import (
    Model,
)
from deepmd.utils.graph import (
    load_graph_def,
)
from deepmd.utils.spin import (
    Spin,
)
from deepmd.utils.type_embed import (
    TypeEmbedNet,
)


[docs]class PairwiseDPRc(Model): """Pairwise Deep Potential - Range Correction.""" model_type = "ener" def __init__( self, qm_model: dict, qmmm_model: dict, type_embedding: Union[dict, TypeEmbedNet], type_map: List[str], data_stat_nbatch: int = 10, data_stat_nsample: int = 10, data_stat_protect: float = 1e-2, use_srtab: Optional[str] = None, smin_alpha: Optional[float] = None, sw_rmin: Optional[float] = None, sw_rmax: Optional[float] = None, spin: Optional[Spin] = None, compress: Optional[dict] = None, **kwargs, ) -> None: # internal variable to compare old and new behavior # expect they give the same results self.merge_frames = True super().__init__( type_embedding=type_embedding, type_map=type_map, data_stat_nbatch=data_stat_nbatch, data_stat_nsample=data_stat_nsample, data_stat_protect=data_stat_protect, use_srtab=use_srtab, smin_alpha=smin_alpha, sw_rmin=sw_rmin, sw_rmax=sw_rmax, spin=spin, compress=compress, **kwargs, ) # type embedding if isinstance(type_embedding, TypeEmbedNet): self.typeebd = type_embedding else: self.typeebd = TypeEmbedNet( **type_embedding, # must use se_atten, so it must be True padding=True, ) self.qm_model = Model( **qm_model, type_map=type_map, type_embedding=self.typeebd, compress=compress, ) self.qmmm_model = Model( **qmmm_model, type_map=type_map, type_embedding=self.typeebd, compress=compress, ) add_data_requirement("aparam", 1, atomic=True, must=True, high_prec=False) self.ntypes = len(type_map) self.rcut = max(self.qm_model.get_rcut(), self.qmmm_model.get_rcut())
[docs] def build( self, coord_: tf.Tensor, atype_: tf.Tensor, natoms: tf.Tensor, box_: tf.Tensor, mesh: tf.Tensor, input_dict: dict, frz_model=None, ckpt_meta: Optional[str] = None, suffix: str = "", reuse: Optional[bool] = None, ): feed_dict = self.get_feed_dict( coord_, atype_, natoms, box_, mesh, aparam=input_dict["aparam"] ) input_dict_qm = {"global_feed_dict": feed_dict} input_dict_qmmm = {"global_feed_dict": feed_dict} with tf.variable_scope("model_attr" + suffix, reuse=reuse): t_tmap = tf.constant(" ".join(self.type_map), name="tmap", dtype=tf.string) t_mt = tf.constant(self.model_type, name="model_type", dtype=tf.string) t_ver = tf.constant(MODEL_VERSION, name="model_version", dtype=tf.string) with tf.variable_scope("fitting_attr" + suffix, reuse=reuse): t_dfparam = tf.constant(0, name="dfparam", dtype=tf.int32) t_daparam = tf.constant(1, name="daparam", dtype=tf.int32) t_aparam_nall = tf.constant(True, name="aparam_nall", dtype=tf.bool) with tf.variable_scope("descrpt_attr" + suffix, reuse=reuse): t_ntypes = tf.constant(self.ntypes, name="ntypes", dtype=tf.int32) t_rcut = tf.constant( self.rcut, name="rcut", dtype=GLOBAL_TF_FLOAT_PRECISION ) # convert X-frame to X-Y-frame coordinates box = tf.reshape(box_, [-1, 9]) nframes = tf.shape(box)[0] idxs = tf.cast(input_dict["aparam"], tf.int32) idxs = tf.reshape(idxs, (nframes, natoms[1])) ( forward_qm_map, backward_qm_map, forward_qmmm_map, backward_qmmm_map, natoms_qm, natoms_qmmm, qmmm_frame_idx, ) = op_module.dprc_pairwise_idx(idxs, natoms) coord = tf.reshape(coord_, [nframes, natoms[1], 3]) atype = tf.reshape(atype_, [nframes, natoms[1], 1]) nframes_qmmm = tf.shape(qmmm_frame_idx)[0] if self.merge_frames: ( forward_qmmm_map, backward_qmmm_map, natoms_qmmm, mesh_qmmm, ) = op_module.convert_forward_map(forward_qmmm_map, natoms_qmmm, natoms) coord_qmmm = tf.reshape(coord, [1, -1, 3]) atype_qmmm = tf.reshape(atype, [1, -1, 1]) box_qmmm = tf.reshape(box[0], [1, 9]) else: mesh_qmmm = make_default_mesh(False, True) coord_qmmm = tf.gather(coord, qmmm_frame_idx) atype_qmmm = tf.gather(atype, qmmm_frame_idx) box_qmmm = tf.gather(box, qmmm_frame_idx) coord_qm = gather_placeholder(coord, forward_qm_map) atype_qm = gather_placeholder(atype, forward_qm_map, placeholder=-1) coord_qmmm = gather_placeholder(coord_qmmm, forward_qmmm_map) atype_qmmm = gather_placeholder(atype_qmmm, forward_qmmm_map, placeholder=-1) box_qm = box type_embedding = self.build_type_embedding( self.ntypes, reuse=reuse, suffix=suffix, frz_model=frz_model, ckpt_meta=ckpt_meta, ) input_dict_qm["type_embedding"] = type_embedding input_dict_qmmm["type_embedding"] = type_embedding mesh_mixed_type = make_default_mesh(False, True) # allow loading a frozen QM model that has only QM types # Note: here we don't map the type between models, so # the type of the frozen model must be the same as # the first Ntypes of the current model if self.get_ntypes() > self.qm_model.get_ntypes(): natoms_qm = tf.slice(natoms_qm, [0], [self.qm_model.get_ntypes() + 2]) assert self.get_ntypes() == self.qmmm_model.get_ntypes() qm_dict = self.qm_model.build( coord_qm, atype_qm, natoms_qm, box_qm, mesh_mixed_type, input_dict_qm, frz_model=frz_model, ckpt_meta=ckpt_meta, suffix="_qm" + suffix, reuse=reuse, ) qmmm_dict = self.qmmm_model.build( coord_qmmm, atype_qmmm, natoms_qmmm, box_qmmm, mesh_qmmm, input_dict_qmmm, frz_model=frz_model, ckpt_meta=ckpt_meta, suffix="_qmmm" + suffix, reuse=reuse, ) if self.merge_frames: qmmm_dict = qmmm_dict.copy() sub_nframes = tf.shape(backward_qmmm_map)[0] qmmm_dict["force"] = tf.tile(qmmm_dict["force"], [sub_nframes, 1]) qmmm_dict["atom_ener"] = tf.tile(qmmm_dict["atom_ener"], [sub_nframes, 1]) qmmm_dict["atom_virial"] = tf.tile( qmmm_dict["atom_virial"], [sub_nframes, 1] ) force_qm = gather_placeholder( tf.reshape(qm_dict["force"], (nframes, natoms_qm[1], 3)), backward_qm_map, placeholder=0.0, ) force_qmmm = tf.math.segment_sum( gather_placeholder( tf.reshape(qmmm_dict["force"], (nframes_qmmm, natoms_qmmm[1], 3)), backward_qmmm_map, placeholder=0.0, ), qmmm_frame_idx, ) force = force_qm + force_qmmm force = tf.reshape(force, (nframes, 3 * natoms[1]), name="o_force" + suffix) backward_qm_map_nloc = tf.slice(backward_qm_map, [0, 0], [-1, natoms[0]]) backward_qmmm_map_nloc = tf.slice(backward_qmmm_map, [0, 0], [-1, natoms[0]]) atom_ener_qm = gather_placeholder( qm_dict["atom_ener"], backward_qm_map_nloc, placeholder=0.0 ) atom_ener_qmmm = tf.math.segment_sum( gather_placeholder( qmmm_dict["atom_ener"], backward_qmmm_map_nloc, placeholder=0.0 ), qmmm_frame_idx, ) atom_ener = atom_ener_qm + atom_ener_qmmm atom_ener = tf.identity(atom_ener, name="o_atom_energy" + suffix) atom_virial_qm = gather_placeholder( tf.reshape(qm_dict["atom_virial"], (nframes, natoms_qm[1], 9)), backward_qm_map, placeholder=0.0, ) atom_virial_qmmm = tf.math.segment_sum( gather_placeholder( tf.reshape(qmmm_dict["atom_virial"], (nframes_qmmm, natoms_qmmm[1], 9)), backward_qmmm_map, placeholder=0.0, ), qmmm_frame_idx, ) atom_virial = atom_virial_qm + atom_virial_qmmm atom_virial = tf.reshape( atom_virial, (nframes, 9 * natoms[1]), name="o_atom_virial" + suffix ) energy = tf.reduce_sum(atom_ener, axis=1, name="o_energy" + suffix) virial = tf.reduce_sum( tf.reshape(atom_virial, (nframes, natoms[1], 9)), axis=1, name="o_virial" + suffix, ) model_dict = {} model_dict["energy"] = energy model_dict["force"] = force model_dict["virial"] = virial model_dict["atom_ener"] = atom_ener model_dict["atom_virial"] = atom_virial model_dict["coord"] = coord_ model_dict["atype"] = atype_ return model_dict
[docs] def get_fitting(self) -> Union[str, dict]: """Get the fitting(s).""" return { "qm": self.qm_model.get_fitting(), "qmmm": self.qmmm_model.get_fitting(), }
[docs] def get_loss(self, loss: dict, lr) -> Union[Loss, dict]: """Get the loss function(s).""" return self.qm_model.get_loss(loss, lr)
[docs] def get_rcut(self): return max(self.qm_model.get_rcut(), self.qmmm_model.get_rcut())
[docs] def get_ntypes(self) -> int: return self.ntypes
[docs] def data_stat(self, data): self.qm_model.data_stat(data) self.qmmm_model.data_stat(data)
[docs] def init_variables( self, graph: tf.Graph, graph_def: tf.GraphDef, model_type: str = "original_model", suffix: str = "", ) -> None: """Init the embedding net variables with the given frozen model. Parameters ---------- graph : tf.Graph The input frozen model graph graph_def : tf.GraphDef The input frozen model graph_def model_type : str the type of the model suffix : str suffix to name scope """ self.typeebd.init_variables(graph, graph_def, model_type=model_type) self.qm_model.init_variables( graph, graph_def, model_type=model_type, suffix="_qm" + suffix ) self.qmmm_model.init_variables( graph, graph_def, model_type=model_type, suffix="_qmmm" + suffix )
[docs] def enable_compression(self, suffix: str = "") -> None: """Enable compression. Parameters ---------- suffix : str suffix to name scope """ graph, graph_def = load_graph_def(self.compress["model_file"]) self.typeebd.init_variables(graph, graph_def) self.qm_model.enable_compression(suffix="_qm" + suffix) self.qmmm_model.enable_compression(suffix="_qmmm" + suffix)
[docs] def get_feed_dict( self, coord_: tf.Tensor, atype_: tf.Tensor, natoms: tf.Tensor, box: tf.Tensor, mesh: tf.Tensor, **kwargs, ) -> Dict[str, tf.Tensor]: """Generate the feed_dict for current descriptor. Parameters ---------- coord_ : tf.Tensor The coordinate of atoms atype_ : tf.Tensor The type of atoms natoms : tf.Tensor The number of atoms. This tensor has the length of Ntypes + 2 natoms[0]: number of local atoms natoms[1]: total number of atoms held by this processor natoms[i]: 2 <= i < Ntypes+2, number of type i atoms box : tf.Tensor The box. Can be generated by deepmd.model.make_stat_input mesh : tf.Tensor For historical reasons, only the length of the Tensor matters. if size of mesh == 6, pbc is assumed. if size of mesh == 0, no-pbc is assumed. aparam : tf.Tensor The parameters of the descriptor **kwargs : dict The keyword arguments Returns ------- feed_dict : dict[str, tf.Tensor] The output feed_dict of current descriptor """ feed_dict = { "t_coord:0": coord_, "t_type:0": atype_, "t_natoms:0": natoms, "t_box:0": box, "t_aparam:0": kwargs["aparam"], } return feed_dict
[docs] @classmethod def update_sel(cls, global_jdata: dict, local_jdata: dict): """Update the selection and perform neighbor statistics. Parameters ---------- global_jdata : dict The global data, containing the training section local_jdata : dict The local data refer to the current class """ from deepmd.entrypoints.train import ( get_min_nbor_dist, ) # do not update sel; only find min distance # rcut is not important here get_min_nbor_dist(global_jdata, 6.0) return local_jdata
[docs]def gather_placeholder( params: tf.Tensor, indices: tf.Tensor, placeholder: float = 0.0, **kwargs ) -> tf.Tensor: """Call tf.gather but allow indices to contain placeholders (-1).""" # (nframes, x, 2, 3) -> (nframes, 1, 2, 3) placeholder_shape = tf.concat( [[tf.shape(params)[0], 1], tf.shape(params)[2:]], axis=0 ) params = tf.concat( [tf.cast(tf.fill(placeholder_shape, placeholder), params.dtype), params], axis=1 ) return tf.gather(params, indices + 1, batch_dims=1, **kwargs)