import numpy as np
from typing import Tuple, List
from deepmd.env import tf
from deepmd.utils.pair_tab import PairTab
from deepmd.utils.graph import load_graph_def
from deepmd.common import ClassArg
from deepmd.env import global_cvt_2_ener_float, MODEL_VERSION, GLOBAL_TF_FLOAT_PRECISION
from deepmd.env import op_module
from .model_stat import make_stat_input, merge_sys_stat
[docs]class EnerModel() :
"""Energy model.
Parameters
----------
descrpt
Descriptor
fitting
Fitting net
type_map
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
data_stat_nbatch
Number of frames used for data statistic
data_stat_protect
Protect parameter for atomic energy regression
use_srtab
The table for the short-range pairwise interaction added on top of DP. The table is a text data file with (N_t + 1) * N_t / 2 + 1 columes. The first colume is the distance between atoms. The second to the last columes are energies for pairs of certain types. For example we have two atom types, 0 and 1. The columes from 2nd to 4th are for 0-0, 0-1 and 1-1 correspondingly.
smin_alpha
The short-range tabulated interaction will be swithed according to the distance of the nearest neighbor. This distance is calculated by softmin. This parameter is the decaying parameter in the softmin. It is only required when `use_srtab` is provided.
sw_rmin
The lower boundary of the interpolation between short-range tabulated interaction and DP. It is only required when `use_srtab` is provided.
sw_rmin
The upper boundary of the interpolation between short-range tabulated interaction and DP. It is only required when `use_srtab` is provided.
"""
model_type = 'ener'
def __init__ (
self,
descrpt,
fitting,
typeebd = None,
type_map : List[str] = None,
data_stat_nbatch : int = 10,
data_stat_protect : float = 1e-2,
use_srtab : str = None,
smin_alpha : float = None,
sw_rmin : float = None,
sw_rmax : float = None
) -> None:
"""
Constructor
"""
# descriptor
self.descrpt = descrpt
self.rcut = self.descrpt.get_rcut()
self.ntypes = self.descrpt.get_ntypes()
# fitting
self.fitting = fitting
self.numb_fparam = self.fitting.get_numb_fparam()
# type embedding
self.typeebd = typeebd
# other inputs
if type_map is None:
self.type_map = []
else:
self.type_map = type_map
self.data_stat_nbatch = data_stat_nbatch
self.data_stat_protect = data_stat_protect
self.srtab_name = use_srtab
if self.srtab_name is not None :
self.srtab = PairTab(self.srtab_name)
self.smin_alpha = smin_alpha
self.sw_rmin = sw_rmin
self.sw_rmax = sw_rmax
else :
self.srtab = None
[docs] def get_rcut (self) :
return self.rcut
[docs] def get_ntypes (self) :
return self.ntypes
[docs] def get_type_map (self) :
return self.type_map
[docs] def data_stat(self, data):
all_stat = make_stat_input(data, self.data_stat_nbatch, merge_sys = False)
m_all_stat = merge_sys_stat(all_stat)
self._compute_input_stat(m_all_stat, protection = self.data_stat_protect)
self._compute_output_stat(all_stat)
# self.bias_atom_e = data.compute_energy_shift(self.rcond)
def _compute_input_stat (self, all_stat, protection = 1e-2) :
self.descrpt.compute_input_stats(all_stat['coord'],
all_stat['box'],
all_stat['type'],
all_stat['natoms_vec'],
all_stat['default_mesh'],
all_stat)
self.fitting.compute_input_stats(all_stat, protection = protection)
def _compute_output_stat (self, all_stat) :
self.fitting.compute_output_stats(all_stat)
[docs] def build (self,
coord_,
atype_,
natoms,
box,
mesh,
input_dict,
frz_model = None,
suffix = '',
reuse = None):
with tf.variable_scope('model_attr' + suffix, reuse = reuse) :
t_tmap = tf.constant(' '.join(self.type_map),
name = 'tmap',
dtype = tf.string)
t_mt = tf.constant(self.model_type,
name = 'model_type',
dtype = tf.string)
t_ver = tf.constant(MODEL_VERSION,
name = 'model_version',
dtype = tf.string)
if self.srtab is not None :
tab_info, tab_data = self.srtab.get()
self.tab_info = tf.get_variable('t_tab_info',
tab_info.shape,
dtype = tf.float64,
trainable = False,
initializer = tf.constant_initializer(tab_info, dtype = tf.float64))
self.tab_data = tf.get_variable('t_tab_data',
tab_data.shape,
dtype = tf.float64,
trainable = False,
initializer = tf.constant_initializer(tab_data, dtype = tf.float64))
coord = tf.reshape (coord_, [-1, natoms[1] * 3])
atype = tf.reshape (atype_, [-1, natoms[1]])
# type embedding if any
if self.typeebd is not None:
type_embedding = self.typeebd.build(
self.ntypes,
reuse = reuse,
suffix = suffix,
)
input_dict['type_embedding'] = type_embedding
if frz_model == None:
dout \
= self.descrpt.build(coord_,
atype_,
natoms,
box,
mesh,
input_dict,
suffix = suffix,
reuse = reuse)
dout = tf.identity(dout, name='o_descriptor')
else:
tf.constant(self.rcut,
name = 'descrpt_attr/rcut',
dtype = GLOBAL_TF_FLOAT_PRECISION)
tf.constant(self.ntypes,
name = 'descrpt_attr/ntypes',
dtype = tf.int32)
feed_dict = self.descrpt.get_feed_dict(coord_, atype_, natoms, box, mesh)
return_elements = [*self.descrpt.get_tensor_names(), 'o_descriptor:0']
imported_tensors \
= self._import_graph_def_from_frz_model(frz_model, feed_dict, return_elements)
dout = imported_tensors[-1]
self.descrpt.pass_tensors_from_frz_model(*imported_tensors[:-1])
if self.srtab is not None :
nlist, rij, sel_a, sel_r = self.descrpt.get_nlist()
nnei_a = np.cumsum(sel_a)[-1]
nnei_r = np.cumsum(sel_r)[-1]
atom_ener = self.fitting.build (dout,
natoms,
input_dict,
reuse = reuse,
suffix = suffix)
if self.srtab is not None :
sw_lambda, sw_deriv \
= op_module.soft_min_switch(atype,
rij,
nlist,
natoms,
sel_a = sel_a,
sel_r = sel_r,
alpha = self.smin_alpha,
rmin = self.sw_rmin,
rmax = self.sw_rmax)
inv_sw_lambda = 1.0 - sw_lambda
# NOTICE:
# atom energy is not scaled,
# force and virial are scaled
tab_atom_ener, tab_force, tab_atom_virial \
= op_module.pair_tab(self.tab_info,
self.tab_data,
atype,
rij,
nlist,
natoms,
sw_lambda,
sel_a = sel_a,
sel_r = sel_r)
energy_diff = tab_atom_ener - tf.reshape(atom_ener, [-1, natoms[0]])
tab_atom_ener = tf.reshape(sw_lambda, [-1]) * tf.reshape(tab_atom_ener, [-1])
atom_ener = tf.reshape(inv_sw_lambda, [-1]) * atom_ener
energy_raw = tab_atom_ener + atom_ener
else :
energy_raw = atom_ener
energy_raw = tf.reshape(energy_raw, [-1, natoms[0]], name = 'o_atom_energy'+suffix)
energy = tf.reduce_sum(global_cvt_2_ener_float(energy_raw), axis=1, name='o_energy'+suffix)
force, virial, atom_virial \
= self.descrpt.prod_force_virial (atom_ener, natoms)
if self.srtab is not None :
sw_force \
= op_module.soft_min_force(energy_diff,
sw_deriv,
nlist,
natoms,
n_a_sel = nnei_a,
n_r_sel = nnei_r)
force = force + sw_force + tab_force
force = tf.reshape (force, [-1, 3 * natoms[1]], name = "o_force"+suffix)
if self.srtab is not None :
sw_virial, sw_atom_virial \
= op_module.soft_min_virial (energy_diff,
sw_deriv,
rij,
nlist,
natoms,
n_a_sel = nnei_a,
n_r_sel = nnei_r)
atom_virial = atom_virial + sw_atom_virial + tab_atom_virial
virial = virial + sw_virial \
+ tf.reduce_sum(tf.reshape(tab_atom_virial, [-1, natoms[1], 9]), axis = 1)
virial = tf.reshape (virial, [-1, 9], name = "o_virial"+suffix)
atom_virial = tf.reshape (atom_virial, [-1, 9 * natoms[1]], name = "o_atom_virial"+suffix)
model_dict = {}
model_dict['energy'] = energy
model_dict['force'] = force
model_dict['virial'] = virial
model_dict['atom_ener'] = energy_raw
model_dict['atom_virial'] = atom_virial
model_dict['coord'] = coord
model_dict['atype'] = atype
return model_dict
def _import_graph_def_from_frz_model(self, frz_model, feed_dict, return_elements):
graph, graph_def = load_graph_def(frz_model)
return tf.import_graph_def(graph_def, input_map = feed_dict, return_elements = return_elements, name = "")