Source code for deepmd.model.ener

import numpy as np
from typing import Tuple, List

from deepmd.env import tf
from deepmd.utils.pair_tab import PairTab
from deepmd.utils.graph import load_graph_def
from deepmd.common import ClassArg
from deepmd.env import global_cvt_2_ener_float, MODEL_VERSION, GLOBAL_TF_FLOAT_PRECISION
from deepmd.env import op_module
from .model_stat import make_stat_input, merge_sys_stat

[docs]class EnerModel() : """Energy model. Parameters ---------- descrpt Descriptor fitting Fitting net type_map Mapping atom type to the name (str) of the type. For example `type_map[1]` gives the name of the type 1. data_stat_nbatch Number of frames used for data statistic data_stat_protect Protect parameter for atomic energy regression use_srtab The table for the short-range pairwise interaction added on top of DP. The table is a text data file with (N_t + 1) * N_t / 2 + 1 columes. The first colume is the distance between atoms. The second to the last columes are energies for pairs of certain types. For example we have two atom types, 0 and 1. The columes from 2nd to 4th are for 0-0, 0-1 and 1-1 correspondingly. smin_alpha The short-range tabulated interaction will be swithed according to the distance of the nearest neighbor. This distance is calculated by softmin. This parameter is the decaying parameter in the softmin. It is only required when `use_srtab` is provided. sw_rmin The lower boundary of the interpolation between short-range tabulated interaction and DP. It is only required when `use_srtab` is provided. sw_rmin The upper boundary of the interpolation between short-range tabulated interaction and DP. It is only required when `use_srtab` is provided. """ model_type = 'ener' def __init__ ( self, descrpt, fitting, typeebd = None, type_map : List[str] = None, data_stat_nbatch : int = 10, data_stat_protect : float = 1e-2, use_srtab : str = None, smin_alpha : float = None, sw_rmin : float = None, sw_rmax : float = None ) -> None: """ Constructor """ # descriptor self.descrpt = descrpt self.rcut = self.descrpt.get_rcut() self.ntypes = self.descrpt.get_ntypes() # fitting self.fitting = fitting self.numb_fparam = self.fitting.get_numb_fparam() # type embedding self.typeebd = typeebd # other inputs if type_map is None: self.type_map = [] else: self.type_map = type_map self.data_stat_nbatch = data_stat_nbatch self.data_stat_protect = data_stat_protect self.srtab_name = use_srtab if self.srtab_name is not None : self.srtab = PairTab(self.srtab_name) self.smin_alpha = smin_alpha self.sw_rmin = sw_rmin self.sw_rmax = sw_rmax else : self.srtab = None
[docs] def get_rcut (self) : return self.rcut
[docs] def get_ntypes (self) : return self.ntypes
[docs] def get_type_map (self) : return self.type_map
[docs] def data_stat(self, data): all_stat = make_stat_input(data, self.data_stat_nbatch, merge_sys = False) m_all_stat = merge_sys_stat(all_stat) self._compute_input_stat(m_all_stat, protection = self.data_stat_protect) self._compute_output_stat(all_stat)
# self.bias_atom_e = data.compute_energy_shift(self.rcond) def _compute_input_stat (self, all_stat, protection = 1e-2) : self.descrpt.compute_input_stats(all_stat['coord'], all_stat['box'], all_stat['type'], all_stat['natoms_vec'], all_stat['default_mesh'], all_stat) self.fitting.compute_input_stats(all_stat, protection = protection) def _compute_output_stat (self, all_stat) : self.fitting.compute_output_stats(all_stat)
[docs] def build (self, coord_, atype_, natoms, box, mesh, input_dict, frz_model = None, suffix = '', reuse = None): with tf.variable_scope('model_attr' + suffix, reuse = reuse) : t_tmap = tf.constant(' '.join(self.type_map), name = 'tmap', dtype = tf.string) t_mt = tf.constant(self.model_type, name = 'model_type', dtype = tf.string) t_ver = tf.constant(MODEL_VERSION, name = 'model_version', dtype = tf.string) if self.srtab is not None : tab_info, tab_data = self.srtab.get() self.tab_info = tf.get_variable('t_tab_info', tab_info.shape, dtype = tf.float64, trainable = False, initializer = tf.constant_initializer(tab_info, dtype = tf.float64)) self.tab_data = tf.get_variable('t_tab_data', tab_data.shape, dtype = tf.float64, trainable = False, initializer = tf.constant_initializer(tab_data, dtype = tf.float64)) coord = tf.reshape (coord_, [-1, natoms[1] * 3]) atype = tf.reshape (atype_, [-1, natoms[1]]) # type embedding if any if self.typeebd is not None: type_embedding = self.typeebd.build( self.ntypes, reuse = reuse, suffix = suffix, ) input_dict['type_embedding'] = type_embedding if frz_model == None: dout \ = self.descrpt.build(coord_, atype_, natoms, box, mesh, input_dict, suffix = suffix, reuse = reuse) dout = tf.identity(dout, name='o_descriptor') else: tf.constant(self.rcut, name = 'descrpt_attr/rcut', dtype = GLOBAL_TF_FLOAT_PRECISION) tf.constant(self.ntypes, name = 'descrpt_attr/ntypes', dtype = tf.int32) feed_dict = self.descrpt.get_feed_dict(coord_, atype_, natoms, box, mesh) return_elements = [*self.descrpt.get_tensor_names(), 'o_descriptor:0'] imported_tensors \ = self._import_graph_def_from_frz_model(frz_model, feed_dict, return_elements) dout = imported_tensors[-1] self.descrpt.pass_tensors_from_frz_model(*imported_tensors[:-1]) if self.srtab is not None : nlist, rij, sel_a, sel_r = self.descrpt.get_nlist() nnei_a = np.cumsum(sel_a)[-1] nnei_r = np.cumsum(sel_r)[-1] atom_ener = self.fitting.build (dout, natoms, input_dict, reuse = reuse, suffix = suffix) if self.srtab is not None : sw_lambda, sw_deriv \ = op_module.soft_min_switch(atype, rij, nlist, natoms, sel_a = sel_a, sel_r = sel_r, alpha = self.smin_alpha, rmin = self.sw_rmin, rmax = self.sw_rmax) inv_sw_lambda = 1.0 - sw_lambda # NOTICE: # atom energy is not scaled, # force and virial are scaled tab_atom_ener, tab_force, tab_atom_virial \ = op_module.pair_tab(self.tab_info, self.tab_data, atype, rij, nlist, natoms, sw_lambda, sel_a = sel_a, sel_r = sel_r) energy_diff = tab_atom_ener - tf.reshape(atom_ener, [-1, natoms[0]]) tab_atom_ener = tf.reshape(sw_lambda, [-1]) * tf.reshape(tab_atom_ener, [-1]) atom_ener = tf.reshape(inv_sw_lambda, [-1]) * atom_ener energy_raw = tab_atom_ener + atom_ener else : energy_raw = atom_ener energy_raw = tf.reshape(energy_raw, [-1, natoms[0]], name = 'o_atom_energy'+suffix) energy = tf.reduce_sum(global_cvt_2_ener_float(energy_raw), axis=1, name='o_energy'+suffix) force, virial, atom_virial \ = self.descrpt.prod_force_virial (atom_ener, natoms) if self.srtab is not None : sw_force \ = op_module.soft_min_force(energy_diff, sw_deriv, nlist, natoms, n_a_sel = nnei_a, n_r_sel = nnei_r) force = force + sw_force + tab_force force = tf.reshape (force, [-1, 3 * natoms[1]], name = "o_force"+suffix) if self.srtab is not None : sw_virial, sw_atom_virial \ = op_module.soft_min_virial (energy_diff, sw_deriv, rij, nlist, natoms, n_a_sel = nnei_a, n_r_sel = nnei_r) atom_virial = atom_virial + sw_atom_virial + tab_atom_virial virial = virial + sw_virial \ + tf.reduce_sum(tf.reshape(tab_atom_virial, [-1, natoms[1], 9]), axis = 1) virial = tf.reshape (virial, [-1, 9], name = "o_virial"+suffix) atom_virial = tf.reshape (atom_virial, [-1, 9 * natoms[1]], name = "o_atom_virial"+suffix) model_dict = {} model_dict['energy'] = energy model_dict['force'] = force model_dict['virial'] = virial model_dict['atom_ener'] = energy_raw model_dict['atom_virial'] = atom_virial model_dict['coord'] = coord model_dict['atype'] = atype return model_dict
def _import_graph_def_from_frz_model(self, frz_model, feed_dict, return_elements): graph, graph_def = load_graph_def(frz_model) return tf.import_graph_def(graph_def, input_map = feed_dict, return_elements = return_elements, name = "")