Source code for deepmd.model.linear

# SPDX-License-Identifier: LGPL-3.0-or-later
from enum import (
    Enum,
)
from functools import (
    lru_cache,
)
from typing import (
    List,
    Optional,
    Union,
)

from deepmd.env import (
    GLOBAL_TF_FLOAT_PRECISION,
    MODEL_VERSION,
    tf,
)
from deepmd.fit.fitting import (
    Fitting,
)
from deepmd.loss.loss import (
    Loss,
)

from .model import (
    Model,
)


[docs]class LinearModel(Model): """Linear model make linear combinations of several existing models. Parameters ---------- models : list[dict] A list of models to be combined. weights : list[float] or str If the type is list[float], a list of weights for each model. If "mean", the weights are set to be 1 / len(models). If "sum", the weights are set to be 1. """ def __init__(self, models: List[dict], weights: List[float], **kwargs): super().__init__(**kwargs) self.models = [Model(**model) for model in models] if isinstance(weights, list): if len(weights) != len(models): raise ValueError( "The length of weights is not equal to the number of models" ) self.weights = weights elif weights == "mean": self.weights = [1 / len(models) for _ in range(len(models))] elif weights == "sum": self.weights = [1 for _ in range(len(models))] # TODO: add more weights, for example, so-called committee models else: raise ValueError(f"Invalid weights {weights}")
[docs] def get_fitting(self) -> Union[Fitting, dict]: """Get the fitting(s).""" return { f"model{ii}": model.get_fitting() for ii, model in enumerate(self.models) }
[docs] def get_loss(self, loss: dict, lr) -> Optional[Union[Loss, dict]]: """Get the loss function(s).""" # the first model that is not None, or None if all models are None for model in self.models: loss = model.get_loss(loss, lr) if loss is not None: return loss return None
[docs] def get_rcut(self): return max([model.get_rcut() for model in self.models])
[docs] @lru_cache(maxsize=1) def get_ntypes(self) -> int: # check if all models have the same ntypes for model in self.models: if model.get_ntypes() != self.models[0].get_ntypes(): raise ValueError("Models have different ntypes") return self.models[0].get_ntypes()
[docs] def data_stat(self, data): for model in self.models: model.data_stat(data)
[docs] def init_variables( self, graph: tf.Graph, graph_def: tf.GraphDef, model_type: str = "original_model", suffix: str = "", ) -> None: """Init the embedding net variables with the given frozen model. Parameters ---------- graph : tf.Graph The input frozen model graph graph_def : tf.GraphDef The input frozen model graph_def model_type : str the type of the model suffix : str suffix to name scope """ for ii, model in enumerate(self.models): model.init_variables( graph, graph_def, model_type, suffix=f"_model{ii}{suffix}" )
[docs] def enable_compression(self, suffix: str = "") -> None: """Enable compression. Parameters ---------- suffix : str suffix to name scope """ for ii, model in enumerate(self.models): model.enable_compression(suffix=f"_model{ii}{suffix}")
[docs] def get_type_map(self) -> list: """Get the type map.""" return self.models[0].get_type_map()
[docs] @classmethod def update_sel(cls, global_jdata: dict, local_jdata: dict): """Update the selection and perform neighbor statistics. Parameters ---------- global_jdata : dict The global data, containing the training section local_jdata : dict The local data refer to the current class """ local_jdata_cpy = local_jdata.copy() local_jdata_cpy["models"] = [ Model.update_sel(global_jdata, sub_jdata) for sub_jdata in local_jdata["models"] ] return local_jdata_cpy
[docs]class LinearEnergyModel(LinearModel): """Linear energy model make linear combinations of several existing energy models.""" model_type = "ener"
[docs] def build( self, coord_: tf.Tensor, atype_: tf.Tensor, natoms: tf.Tensor, box: tf.Tensor, mesh: tf.Tensor, input_dict: dict, frz_model: Optional[str] = None, ckpt_meta: Optional[str] = None, suffix: str = "", reuse: Optional[Union[bool, Enum]] = None, ) -> dict: """Build the model. Parameters ---------- coord_ : tf.Tensor The coordinates of atoms atype_ : tf.Tensor The atom types of atoms natoms : tf.Tensor The number of atoms box : tf.Tensor The box vectors mesh : tf.Tensor The mesh vectors input_dict : dict The input dict frz_model : str, optional The path to the frozen model ckpt_meta : str, optional The path prefix of the checkpoint and meta files suffix : str, optional The suffix of the scope reuse : bool or tf.AUTO_REUSE, optional Whether to reuse the variables Returns ------- dict The output dict """ with tf.variable_scope("model_attr" + suffix, reuse=reuse): t_tmap = tf.constant( " ".join(self.get_type_map()), name="tmap", dtype=tf.string ) t_mt = tf.constant(self.model_type, name="model_type", dtype=tf.string) t_ver = tf.constant(MODEL_VERSION, name="model_version", dtype=tf.string) with tf.variable_scope("fitting_attr" + suffix, reuse=reuse): # non zero not supported t_dfparam = tf.constant(0, name="dfparam", dtype=tf.int32) t_daparam = tf.constant(0, name="daparam", dtype=tf.int32) with tf.variable_scope("descrpt_attr" + suffix, reuse=reuse): t_ntypes = tf.constant(self.get_ntypes(), name="ntypes", dtype=tf.int32) t_rcut = tf.constant( self.get_rcut(), name="rcut", dtype=GLOBAL_TF_FLOAT_PRECISION ) subdicts = [] for ii, model in enumerate(self.models): subdict = model.build( coord_, atype_, natoms, box, mesh, input_dict, frz_model=frz_model, ckpt_meta=ckpt_meta, suffix=f"_model{ii}{suffix}", reuse=reuse, ) subdicts.append(subdict) t_weight = tf.constant(self.weights, dtype=GLOBAL_TF_FLOAT_PRECISION) model_dict = {} # energy shape is (n_batch,), other shapes are (n_batch, -1) energy = tf.reduce_sum( tf.stack([mm["energy"] for mm in subdicts], axis=0) * t_weight[:, None], axis=0, ) force = tf.reduce_sum( tf.stack([mm["force"] for mm in subdicts], axis=0) * t_weight[:, None, None], axis=0, ) virial = tf.reduce_sum( tf.stack([mm["virial"] for mm in subdicts], axis=0) * t_weight[:, None, None], axis=0, ) atom_ener = tf.reduce_sum( tf.stack([mm["atom_ener"] for mm in subdicts], axis=0) * t_weight[:, None, None], axis=0, ) atom_virial = tf.reduce_sum( tf.stack([mm["atom_virial"] for mm in subdicts], axis=0) * t_weight[:, None, None], axis=0, ) model_dict["energy"] = tf.identity(energy, name="o_energy" + suffix) model_dict["force"] = tf.identity(force, name="o_force" + suffix) model_dict["virial"] = tf.identity(virial, name="o_virial" + suffix) model_dict["atom_ener"] = tf.identity(atom_ener, name="o_atom_energy" + suffix) model_dict["atom_virial"] = tf.identity( atom_virial, name="o_atom_virial" + suffix ) model_dict["coord"] = coord_ model_dict["atype"] = atype_ return model_dict