# SPDX-License-Identifier: LGPL-3.0-or-later
import json
from typing import (
Dict,
List,
Optional,
)
import numpy as np
from deepmd.descriptor.descriptor import (
Descriptor,
)
from deepmd.env import (
MODEL_VERSION,
global_cvt_2_ener_float,
op_module,
tf,
)
from deepmd.fit import (
DipoleFittingSeA,
DOSFitting,
EnerFitting,
GlobalPolarFittingSeA,
PolarFittingSeA,
)
from deepmd.fit.fitting import (
Fitting,
)
from deepmd.loss.loss import (
Loss,
)
from deepmd.utils.argcheck import (
type_embedding_args,
)
from deepmd.utils.graph import (
get_tensor_by_name_from_graph,
)
from deepmd.utils.pair_tab import (
PairTab,
)
from deepmd.utils.spin import (
Spin,
)
from deepmd.utils.type_embed import (
TypeEmbedNet,
)
from .model import (
Model,
)
from .model_stat import (
make_stat_input,
merge_sys_stat,
)
[docs]class MultiModel(Model):
"""Multi-task model.
Parameters
----------
descriptor
Descriptor
fitting_net_dict
Dictionary of fitting nets
fitting_type_dict
deprecated argument
type_embedding
Type embedding net
type_map
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
data_stat_nbatch
Number of frames used for data statistic
data_stat_protect
Protect parameter for atomic energy regression
use_srtab
The table for the short-range pairwise interaction added on top of DP. The table is a text data file with (N_t + 1) * N_t / 2 + 1 columes. The first colume is the distance between atoms. The second to the last columes are energies for pairs of certain types. For example we have two atom types, 0 and 1. The columes from 2nd to 4th are for 0-0, 0-1 and 1-1 correspondingly.
smin_alpha
The short-range tabulated interaction will be swithed according to the distance of the nearest neighbor. This distance is calculated by softmin. This parameter is the decaying parameter in the softmin. It is only required when `use_srtab` is provided.
sw_rmin
The lower boundary of the interpolation between short-range tabulated interaction and DP. It is only required when `use_srtab` is provided.
sw_rmin
The upper boundary of the interpolation between short-range tabulated interaction and DP. It is only required when `use_srtab` is provided.
"""
model_type = "multi_task"
def __init__(
self,
descriptor: dict,
fitting_net_dict: dict,
fitting_type_dict: Optional[dict] = None, # deprecated
type_embedding=None,
type_map: Optional[List[str]] = None,
data_stat_nbatch: int = 10,
data_stat_protect: float = 1e-2,
use_srtab: Optional[str] = None, # all the ener fitting will do this
smin_alpha: Optional[float] = None,
sw_rmin: Optional[float] = None,
sw_rmax: Optional[float] = None,
**kwargs,
) -> None:
"""Constructor."""
super().__init__(
descriptor=descriptor,
fitting_net_dict=fitting_net_dict,
type_embedding=type_embedding,
type_map=type_map,
data_stat_nbatch=data_stat_nbatch,
data_stat_protect=data_stat_protect,
use_srtab=use_srtab,
smin_alpha=smin_alpha,
sw_rmin=sw_rmin,
sw_rmax=sw_rmax,
)
if self.spin is not None and not isinstance(self.spin, Spin):
self.spin = Spin(**self.spin)
if isinstance(descriptor, Descriptor):
self.descrpt = descriptor
else:
self.descrpt = Descriptor(
**descriptor,
ntypes=len(self.get_type_map()),
multi_task=True,
spin=self.spin,
)
fitting_dict = {}
for item in fitting_net_dict:
item_fitting_param = fitting_net_dict[item]
if isinstance(item_fitting_param, Fitting):
fitting_dict[item] = item_fitting_param
else:
fitting_dict[item] = Fitting(
**item_fitting_param, descrpt=self.descrpt, spin=self.spin
)
# type embedding
if type_embedding is not None and isinstance(type_embedding, TypeEmbedNet):
self.typeebd = type_embedding
elif type_embedding is not None:
self.typeebd = TypeEmbedNet(
**type_embedding,
padding=self.descrpt.explicit_ntypes,
)
elif self.descrpt.explicit_ntypes:
default_args = type_embedding_args()
default_args_dict = {i.name: i.default for i in default_args}
default_args_dict["activation_function"] = None
self.typeebd = TypeEmbedNet(
**default_args_dict,
padding=True,
)
else:
self.typeebd = None
# descriptor
self.rcut = self.descrpt.get_rcut()
self.ntypes = self.descrpt.get_ntypes()
# fitting
self.fitting_dict = fitting_dict
self.numb_fparam_dict = {
item: self.fitting_dict[item].get_numb_fparam()
for item in self.fitting_dict
if isinstance(self.fitting_dict[item], EnerFitting)
}
# other inputs
if type_map is None:
self.type_map = []
else:
self.type_map = type_map
self.data_stat_nbatch = data_stat_nbatch
self.data_stat_protect = data_stat_protect
self.srtab_name = use_srtab
if self.srtab_name is not None:
self.srtab = PairTab(self.srtab_name)
self.smin_alpha = smin_alpha
self.sw_rmin = sw_rmin
self.sw_rmax = sw_rmax
else:
self.srtab = None
[docs] def get_rcut(self):
return self.rcut
[docs] def get_ntypes(self):
return self.ntypes
[docs] def get_type_map(self):
return self.type_map
[docs] def data_stat(self, data):
for fitting_key in data:
all_stat = make_stat_input(
data[fitting_key], self.data_stat_nbatch, merge_sys=False
)
m_all_stat = merge_sys_stat(all_stat)
self._compute_input_stat(
m_all_stat,
protection=self.data_stat_protect,
mixed_type=data[fitting_key].mixed_type,
fitting_key=fitting_key,
)
self._compute_output_stat(
all_stat,
mixed_type=data[fitting_key].mixed_type,
fitting_key=fitting_key,
)
self.descrpt.merge_input_stats(self.descrpt.stat_dict)
def _compute_input_stat(
self, all_stat, protection=1e-2, mixed_type=False, fitting_key=""
):
if mixed_type:
self.descrpt.compute_input_stats(
all_stat["coord"],
all_stat["box"],
all_stat["type"],
all_stat["natoms_vec"],
all_stat["default_mesh"],
all_stat,
mixed_type,
all_stat["real_natoms_vec"],
)
else:
self.descrpt.compute_input_stats(
all_stat["coord"],
all_stat["box"],
all_stat["type"],
all_stat["natoms_vec"],
all_stat["default_mesh"],
all_stat,
)
if hasattr(self.fitting_dict[fitting_key], "compute_input_stats"):
self.fitting_dict[fitting_key].compute_input_stats(
all_stat, protection=protection
)
def _compute_output_stat(self, all_stat, mixed_type=False, fitting_key=""):
if hasattr(self.fitting_dict[fitting_key], "compute_output_stats"):
if mixed_type:
self.fitting_dict[fitting_key].compute_output_stats(
all_stat, mixed_type=mixed_type
)
else:
self.fitting_dict[fitting_key].compute_output_stats(all_stat)
[docs] def build(
self,
coord_,
atype_,
natoms,
box,
mesh,
input_dict,
frz_model=None,
ckpt_meta: Optional[str] = None,
suffix="",
reuse=None,
):
if input_dict is None:
input_dict = {}
with tf.variable_scope("model_attr" + suffix, reuse=reuse):
t_tmap = tf.constant(" ".join(self.type_map), name="tmap", dtype=tf.string)
t_mt = tf.constant(self.model_type, name="model_type", dtype=tf.string)
t_ver = tf.constant(MODEL_VERSION, name="model_version", dtype=tf.string)
t_st = {}
t_od = {}
sel_type = {}
natomsel = {}
nout = {}
for fitting_key in self.fitting_dict:
if isinstance(
self.fitting_dict[fitting_key],
(DipoleFittingSeA, PolarFittingSeA, GlobalPolarFittingSeA),
):
sel_type[fitting_key] = self.fitting_dict[
fitting_key
].get_sel_type()
natomsel[fitting_key] = sum(
natoms[2 + type_i] for type_i in sel_type[fitting_key]
)
nout[fitting_key] = self.fitting_dict[fitting_key].get_out_size()
t_st[fitting_key] = tf.constant(
sel_type[fitting_key],
name=f"sel_type_{fitting_key}",
dtype=tf.int32,
)
t_od[fitting_key] = tf.constant(
nout[fitting_key],
name=f"output_dim_{fitting_key}",
dtype=tf.int32,
)
if self.srtab is not None:
tab_info, tab_data = self.srtab.get()
self.tab_info = tf.get_variable(
"t_tab_info",
tab_info.shape,
dtype=tf.float64,
trainable=False,
initializer=tf.constant_initializer(tab_info, dtype=tf.float64),
)
self.tab_data = tf.get_variable(
"t_tab_data",
tab_data.shape,
dtype=tf.float64,
trainable=False,
initializer=tf.constant_initializer(tab_data, dtype=tf.float64),
)
coord = tf.reshape(coord_, [-1, natoms[1] * 3])
atype = tf.reshape(atype_, [-1, natoms[1]])
input_dict["nframes"] = tf.shape(coord)[0]
# type embedding if any
if self.typeebd is not None:
type_embedding = self.build_type_embedding(
self.ntypes,
reuse=reuse,
suffix=suffix,
frz_model=frz_model,
ckpt_meta=ckpt_meta,
)
input_dict["type_embedding"] = type_embedding
input_dict["atype"] = atype_
dout = self.build_descrpt(
coord,
atype,
natoms,
box,
mesh,
input_dict,
frz_model=frz_model,
ckpt_meta=ckpt_meta,
suffix=suffix,
reuse=reuse,
)
dout = tf.identity(dout, name="o_descriptor")
if self.srtab is not None:
nlist, rij, sel_a, sel_r = self.descrpt.get_nlist()
nnei_a = np.cumsum(sel_a)[-1]
nnei_r = np.cumsum(sel_r)[-1]
sw_lambda, sw_deriv = op_module.soft_min_switch(
atype,
rij,
nlist,
natoms,
sel_a=sel_a,
sel_r=sel_r,
alpha=self.smin_alpha,
rmin=self.sw_rmin,
rmax=self.sw_rmax,
)
inv_sw_lambda = 1.0 - sw_lambda
# NOTICE:
# atom energy is not scaled,
# force and virial are scaled
tab_atom_ener, tab_force, tab_atom_virial = op_module.pair_tab(
self.tab_info,
self.tab_data,
atype,
rij,
nlist,
natoms,
sw_lambda,
sel_a=sel_a,
sel_r=sel_r,
)
rot_mat = self.descrpt.get_rot_mat()
rot_mat = tf.identity(rot_mat, name="o_rot_mat" + suffix)
self.atom_ener = {}
model_dict = {}
for fitting_key in self.fitting_dict:
if isinstance(self.fitting_dict[fitting_key], EnerFitting):
atom_ener = self.fitting_dict[fitting_key].build(
dout,
natoms,
input_dict,
reuse=reuse,
suffix=f"_{fitting_key}" + suffix,
)
self.atom_ener[fitting_key] = atom_ener
if self.srtab is not None:
energy_diff = tab_atom_ener - tf.reshape(atom_ener, [-1, natoms[0]])
tab_atom_ener = tf.reshape(sw_lambda, [-1]) * tf.reshape(
tab_atom_ener, [-1]
)
atom_ener = tf.reshape(inv_sw_lambda, [-1]) * atom_ener
energy_raw = tab_atom_ener + atom_ener
else:
energy_raw = atom_ener
energy_raw = tf.reshape(
energy_raw,
[-1, natoms[0]],
name=f"o_atom_energy_{fitting_key}" + suffix,
)
energy = tf.reduce_sum(
global_cvt_2_ener_float(energy_raw),
axis=1,
name=f"o_energy_{fitting_key}" + suffix,
)
force, virial, atom_virial = self.descrpt.prod_force_virial(
atom_ener, natoms
)
if self.srtab is not None:
sw_force = op_module.soft_min_force(
energy_diff,
sw_deriv,
nlist,
natoms,
n_a_sel=nnei_a,
n_r_sel=nnei_r,
)
force = force + sw_force + tab_force
force = tf.reshape(
force,
[-1, 3 * natoms[1]],
name=f"o_force_{fitting_key}" + suffix,
)
if self.srtab is not None:
sw_virial, sw_atom_virial = op_module.soft_min_virial(
energy_diff,
sw_deriv,
rij,
nlist,
natoms,
n_a_sel=nnei_a,
n_r_sel=nnei_r,
)
atom_virial = atom_virial + sw_atom_virial + tab_atom_virial
virial = (
virial
+ sw_virial
+ tf.reduce_sum(
tf.reshape(tab_atom_virial, [-1, natoms[1], 9]), axis=1
)
)
virial = tf.reshape(
virial, [-1, 9], name=f"o_virial_{fitting_key}" + suffix
)
atom_virial = tf.reshape(
atom_virial,
[-1, 9 * natoms[1]],
name=f"o_atom_virial_{fitting_key}" + suffix,
)
model_dict[fitting_key] = {}
model_dict[fitting_key]["energy"] = energy
model_dict[fitting_key]["force"] = force
model_dict[fitting_key]["virial"] = virial
model_dict[fitting_key]["atom_ener"] = energy_raw
model_dict[fitting_key]["atom_virial"] = atom_virial
model_dict[fitting_key]["coord"] = coord
model_dict[fitting_key]["atype"] = atype
elif isinstance(
self.fitting_dict[fitting_key],
(DipoleFittingSeA, PolarFittingSeA, GlobalPolarFittingSeA),
):
tensor_name = {
DipoleFittingSeA: "dipole",
PolarFittingSeA: "polar",
GlobalPolarFittingSeA: "global_polar",
}[type(self.fitting_dict[fitting_key])]
output = self.fitting_dict[fitting_key].build(
dout,
rot_mat,
natoms,
input_dict,
reuse=reuse,
suffix=f"_{fitting_key}" + suffix,
)
framesize = (
nout
if "global" in tensor_name
else natomsel[fitting_key] * nout[fitting_key]
)
output = tf.reshape(
output,
[-1, framesize],
name=f"o_{tensor_name}_{fitting_key}" + suffix,
)
model_dict[fitting_key] = {}
model_dict[fitting_key][tensor_name] = output
if "global" not in tensor_name:
gname = "global_" + tensor_name
atom_out = tf.reshape(
output, [-1, natomsel[fitting_key], nout[fitting_key]]
)
global_out = tf.reduce_sum(atom_out, axis=1)
global_out = tf.reshape(
global_out,
[-1, nout[fitting_key]],
name=f"o_{gname}_{fitting_key}" + suffix,
)
out_cpnts = tf.split(atom_out, nout[fitting_key], axis=-1)
force_cpnts = []
virial_cpnts = []
atom_virial_cpnts = []
for out_i in out_cpnts:
(
force_i,
virial_i,
atom_virial_i,
) = self.descrpt.prod_force_virial(out_i, natoms)
force_cpnts.append(tf.reshape(force_i, [-1, 3 * natoms[1]]))
virial_cpnts.append(tf.reshape(virial_i, [-1, 9]))
atom_virial_cpnts.append(
tf.reshape(atom_virial_i, [-1, 9 * natoms[1]])
)
# [nframe x nout x (natom x 3)]
force = tf.concat(
force_cpnts,
axis=1,
name=f"o_force_{fitting_key}" + suffix,
)
# [nframe x nout x 9]
virial = tf.concat(
virial_cpnts,
axis=1,
name=f"o_virial_{fitting_key}" + suffix,
)
# [nframe x nout x (natom x 9)]
atom_virial = tf.concat(
atom_virial_cpnts,
axis=1,
name=f"o_atom_virial_{fitting_key}" + suffix,
)
model_dict[fitting_key][gname] = global_out
model_dict[fitting_key]["force"] = force
model_dict[fitting_key]["virial"] = virial
model_dict[fitting_key]["atom_virial"] = atom_virial
return model_dict
[docs] def init_variables(
self,
graph: tf.Graph,
graph_def: tf.GraphDef,
model_type: str = "original_model",
suffix: str = "",
) -> None:
"""Init the embedding net variables with the given frozen model.
Parameters
----------
graph : tf.Graph
The input frozen model graph
graph_def : tf.GraphDef
The input frozen model graph_def
model_type : str
the type of the model
suffix : str
suffix to name scope
"""
# self.frz_model will control the self.model to import the descriptor from the given frozen model instead of building from scratch...
# initialize fitting net with the given compressed frozen model
assert (
model_type == "original_model"
), "Initialization in multi-task mode does not support compressed model!"
self.descrpt.init_variables(graph, graph_def, suffix=suffix)
old_jdata = json.loads(
get_tensor_by_name_from_graph(graph, "train_attr/training_script")
)
old_fitting_keys = list(old_jdata["model"]["fitting_net_dict"].keys())
newly_added_fittings = set(self.fitting_dict.keys()) - set(old_fitting_keys)
reused_fittings = set(self.fitting_dict.keys()) - newly_added_fittings
for fitting_key in reused_fittings:
self.fitting_dict[fitting_key].init_variables(
graph, graph_def, suffix=f"_{fitting_key}" + suffix
)
tf.constant("original_model", name="model_type", dtype=tf.string)
if self.typeebd is not None:
self.typeebd.init_variables(graph, graph_def, suffix=suffix)
[docs] def enable_mixed_precision(self, mixed_prec: dict):
"""Enable mixed precision for the model.
Parameters
----------
mixed_prec : dict
The mixed precision config
"""
self.descrpt.enable_mixed_precision(mixed_prec)
for fitting_key in self.fitting_dict:
self.fitting_dict[fitting_key].enable_mixed_precision(self.mixed_prec)
[docs] def get_numb_fparam(self) -> dict:
"""Get the number of frame parameters."""
numb_fparam_dict = {}
for fitting_key in self.fitting_dict:
if isinstance(self.fitting_dict[fitting_key], (EnerFitting, DOSFitting)):
numb_fparam_dict[fitting_key] = self.fitting_dict[
fitting_key
].get_numb_fparam()
else:
numb_fparam_dict[fitting_key] = 0
return numb_fparam_dict
[docs] def get_numb_aparam(self) -> dict:
"""Get the number of atomic parameters."""
numb_aparam_dict = {}
for fitting_key in self.fitting_dict:
if isinstance(self.fitting_dict[fitting_key], (EnerFitting, DOSFitting)):
numb_aparam_dict[fitting_key] = self.fitting_dict[
fitting_key
].get_numb_aparam()
else:
numb_aparam_dict[fitting_key] = 0
return numb_aparam_dict
[docs] def get_numb_dos(self) -> dict:
"""Get the number of gridpoints in energy space."""
numb_dos_dict = {}
for fitting_key in self.fitting_dict:
if isinstance(self.fitting_dict[fitting_key], DOSFitting):
numb_dos_dict[fitting_key] = self.fitting_dict[
fitting_key
].get_numb_dos()
else:
numb_dos_dict[fitting_key] = 0
return numb_dos_dict
[docs] def get_fitting(self) -> dict:
"""Get the fitting(s)."""
return self.fitting_dict.copy()
[docs] def get_loss(self, loss: dict, lr: dict) -> Dict[str, Loss]:
loss_dict = {}
for fitting_key in self.fitting_dict:
loss_param = loss.get(fitting_key, {})
loss_dict[fitting_key] = self.fitting_dict[fitting_key].get_loss(
loss_param, lr[fitting_key]
)
return loss_dict
[docs] @classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["descriptor"] = Descriptor.update_sel(
global_jdata, local_jdata["descriptor"]
)
return local_jdata_cpy