# SPDX-License-Identifier: LGPL-3.0-or-later
import warnings
from typing import (
Optional,
)
import numpy as np
from deepmd.tf.common import (
cast_precision,
get_activation_func,
get_precision,
)
from deepmd.tf.descriptor import (
DescrptSeA,
)
from deepmd.tf.env import (
GLOBAL_TF_FLOAT_PRECISION,
tf,
)
from deepmd.tf.fit.fitting import (
Fitting,
)
from deepmd.tf.loss.loss import (
Loss,
)
from deepmd.tf.loss.tensor import (
TensorLoss,
)
from deepmd.tf.utils.errors import (
GraphWithoutTensorError,
)
from deepmd.tf.utils.graph import (
get_fitting_net_variables_from_graph_def,
get_tensor_by_name_from_graph,
)
from deepmd.tf.utils.network import (
one_layer,
one_layer_rand_seed_shift,
)
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.version import (
check_version_compatibility,
)
@Fitting.register("polar")
[docs]
class PolarFittingSeA(Fitting):
r"""Fit the atomic polarizability with descriptor se_a.
Parameters
----------
ntypes
The ntypes of the descriptor :math:`\mathcal{D}`
dim_descrpt
The dimension of the descriptor :math:`\mathcal{D}`
embedding_width
The rotation matrix dimension of the descriptor :math:`\mathcal{D}`
neuron : list[int]
Number of neurons in each hidden layer of the fitting net
resnet_dt : bool
Time-step `dt` in the resnet construction:
y = x + dt * \phi (Wx + b)
numb_fparam
Number of frame parameters
numb_aparam
Number of atomic parameters
dim_case_embd
Dimension of case specific embedding.
sel_type : list[int]
The atom types selected to have an atomic polarizability prediction. If is None, all atoms are selected.
fit_diag : bool
Fit the diagonal part of the rotational invariant polarizability matrix, which will be converted to normal polarizability matrix by contracting with the rotation matrix.
scale : list[float]
The output of the fitting net (polarizability matrix) for type i atom will be scaled by scale[i]
diag_shift : list[float]
The diagonal part of the polarizability matrix of type i will be shifted by diag_shift[i]. The shift operation is carried out after scale.
seed : int
Random seed for initializing the network parameters.
activation_function : str
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
precision : str
The precision of the embedding net parameters. Supported options are |PRECISION|
uniform_seed
Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed
mixed_types : bool
If true, use a uniform fitting net for all atom types, otherwise use
different fitting nets for different atom types.
type_map: list[str], Optional
A list of strings. Give the name to each type of atoms.
"""
def __init__(
self,
ntypes: int,
dim_descrpt: int,
embedding_width: int,
neuron: list[int] = [120, 120, 120],
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
dim_case_embd: int = 0,
sel_type: Optional[list[int]] = None,
fit_diag: bool = True,
scale: Optional[list[float]] = None,
shift_diag: bool = True, # YWolfeee: will support the user to decide whether to use this function
# diag_shift : list[float] = None, YWolfeee: will not support the user to assign a shift
seed: Optional[int] = None,
activation_function: str = "tanh",
precision: str = "default",
uniform_seed: bool = False,
mixed_types: bool = False,
type_map: Optional[list[str]] = None, # to be compat with input
**kwargs,
) -> None:
"""Constructor."""
[docs]
self.dim_descrpt = dim_descrpt
[docs]
self.resnet_dt = resnet_dt
[docs]
self.sel_type = sel_type
[docs]
self.fit_diag = fit_diag
[docs]
self.seed_shift = one_layer_rand_seed_shift()
# self.diag_shift = diag_shift
[docs]
self.shift_diag = shift_diag
[docs]
self.activation_function_name = activation_function
[docs]
self.fitting_activation_fn = get_activation_func(activation_function)
[docs]
self.fitting_precision = get_precision(precision)
if self.sel_type is None:
self.sel_type = list(range(self.ntypes))
[docs]
self.sel_mask = np.array(
[ii in self.sel_type for ii in range(self.ntypes)], dtype=bool
)
if self.scale is None:
self.scale = np.array([1.0 for ii in range(self.ntypes)])
else:
if isinstance(self.scale, list):
assert len(self.scale) == ntypes, (
"Scale should be a list of length ntypes."
)
elif isinstance(self.scale, float):
self.scale = [self.scale for _ in range(ntypes)]
else:
raise ValueError(
"Scale must be a list of float of length ntypes or a float."
)
self.scale = np.array(self.scale)
# if self.diag_shift is None:
# self.diag_shift = [0.0 for ii in range(self.ntypes)]
if not isinstance(self.sel_type, list):
self.sel_type = [self.sel_type]
self.sel_type = sorted(self.sel_type)
[docs]
self.constant_matrix = np.zeros( # pylint: disable=no-explicit-dtype
self.ntypes
) # self.ntypes x 1, store the average diagonal value
# if type(self.diag_shift) is not list:
# self.diag_shift = [self.diag_shift]
[docs]
self.dim_rot_mat_1 = embedding_width
[docs]
self.dim_rot_mat = self.dim_rot_mat_1 * 3
[docs]
self.fitting_net_variables = None
[docs]
self.mixed_types = mixed_types
[docs]
self.type_map = type_map
[docs]
self.numb_fparam = numb_fparam
[docs]
self.numb_aparam = numb_aparam
[docs]
self.dim_case_embd = dim_case_embd
if numb_fparam > 0:
raise ValueError("numb_fparam is not supported in the dipole fitting")
if numb_aparam > 0:
raise ValueError("numb_aparam is not supported in the dipole fitting")
if dim_case_embd > 0:
raise ValueError("dim_case_embd is not supported in TensorFlow.")
[docs]
self.fparam_inv_std = None
[docs]
self.aparam_inv_std = None
[docs]
def get_sel_type(self) -> list[int]:
"""Get selected atom types."""
return self.sel_type
[docs]
def get_out_size(self) -> int:
"""Get the output size. Should be 9."""
return 9
[docs]
def compute_output_stats(self, all_stat) -> None:
"""Compute the output statistics.
Parameters
----------
all_stat
Dictionary of inputs.
can be prepared by model.make_stat_input
"""
if "polarizability" not in all_stat.keys():
self.avgeig = np.zeros([9]) # pylint: disable=no-explicit-dtype
warnings.warn(
"no polarizability data, cannot do data stat. use zeros as guess"
)
return
data = all_stat["polarizability"]
all_tmp = []
for ss in range(len(data)):
tmp = np.concatenate(data[ss], axis=0)
tmp = np.reshape(tmp, [-1, 3, 3])
tmp, _ = np.linalg.eig(tmp)
tmp = np.absolute(tmp)
tmp = np.sort(tmp, axis=1)
all_tmp.append(tmp)
all_tmp = np.concatenate(all_tmp, axis=1)
self.avgeig = np.average(all_tmp, axis=0)
# YWolfeee: support polar normalization, initialize to a more appropriate point
if self.shift_diag:
mean_polar = np.zeros([len(self.sel_type), 9]) # pylint: disable=no-explicit-dtype
sys_matrix, polar_bias = [], []
for ss in range(len(all_stat["type"])):
nframes = all_stat["type"][ss].shape[0]
atom_has_polar = [
w for w in all_stat["type"][ss][0] if (w in self.sel_type)
] # select atom with polar
if all_stat["find_atom_polarizability"][ss] > 0.0:
for itype in range(
len(self.sel_type)
): # Atomic polar mode, should specify the atoms
index_lis = [
index
for index, w in enumerate(atom_has_polar)
if w == self.sel_type[itype]
] # select index in this type
sys_matrix.append(np.zeros((1, len(self.sel_type)))) # pylint: disable=no-explicit-dtype
sys_matrix[-1][0, itype] = len(index_lis)
polar_bias.append(
np.sum(
all_stat["atom_polarizability"][ss].reshape(
nframes, len(atom_has_polar), -1
)[:, index_lis, :]
/ nframes,
axis=(0, 1),
).reshape((1, 9))
)
else: # No atomic polar in this system, so it should have global polar
if (
not all_stat["find_polarizability"][ss] > 0.0
): # This system is just a joke?
continue
# Till here, we have global polar
sys_matrix.append(
np.zeros((1, len(self.sel_type))) # pylint: disable=no-explicit-dtype
) # add a line in the equations
for itype in range(
len(self.sel_type)
): # Atomic polar mode, should specify the atoms
index_lis = [
index
for index, w in enumerate(atom_has_polar)
if atom_has_polar[index] == self.sel_type[itype]
] # select index in this type
sys_matrix[-1][0, itype] = len(index_lis)
# add polar_bias
polar_bias.append(
np.mean(all_stat["polarizability"][ss], axis=0).reshape((1, 9))
)
matrix, bias = (
np.concatenate(sys_matrix, axis=0),
np.concatenate(polar_bias, axis=0),
)
atom_polar, _, _, _ = np.linalg.lstsq(matrix, bias, rcond=None)
for itype in range(len(self.sel_type)):
self.constant_matrix[self.sel_type[itype]] = np.mean(
np.diagonal(atom_polar[itype].reshape((3, 3)))
)
@cast_precision
[docs]
def _build_lower(self, start_index, natoms, inputs, rot_mat, suffix="", reuse=None):
# cut-out inputs
inputs_i = tf.slice(
inputs, [0, start_index * self.dim_descrpt], [-1, natoms * self.dim_descrpt]
)
inputs_i = tf.reshape(inputs_i, [-1, self.dim_descrpt])
rot_mat_i = tf.slice(
rot_mat,
[0, start_index * self.dim_rot_mat],
[-1, natoms * self.dim_rot_mat],
)
rot_mat_i = tf.reshape(rot_mat_i, [-1, self.dim_rot_mat_1, 3])
layer = inputs_i
for ii in range(0, len(self.n_neuron)):
if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii - 1]:
layer += one_layer(
layer,
self.n_neuron[ii],
name="layer_" + str(ii) + suffix,
reuse=reuse,
seed=self.seed,
use_timestep=self.resnet_dt,
activation_fn=self.fitting_activation_fn,
precision=self.fitting_precision,
uniform_seed=self.uniform_seed,
initial_variables=self.fitting_net_variables,
mixed_prec=self.mixed_prec,
)
else:
layer = one_layer(
layer,
self.n_neuron[ii],
name="layer_" + str(ii) + suffix,
reuse=reuse,
seed=self.seed,
activation_fn=self.fitting_activation_fn,
precision=self.fitting_precision,
uniform_seed=self.uniform_seed,
initial_variables=self.fitting_net_variables,
mixed_prec=self.mixed_prec,
)
if (not self.uniform_seed) and (self.seed is not None):
self.seed += self.seed_shift
if self.fit_diag:
bavg = np.zeros(self.dim_rot_mat_1) # pylint: disable=no-explicit-dtype
# bavg[0] = self.avgeig[0]
# bavg[1] = self.avgeig[1]
# bavg[2] = self.avgeig[2]
# (nframes x natoms) x naxis
final_layer = one_layer(
layer,
self.dim_rot_mat_1,
activation_fn=None,
name="final_layer" + suffix,
reuse=reuse,
seed=self.seed,
bavg=bavg,
precision=self.fitting_precision,
uniform_seed=self.uniform_seed,
initial_variables=self.fitting_net_variables,
mixed_prec=self.mixed_prec,
final_layer=True,
)
if (not self.uniform_seed) and (self.seed is not None):
self.seed += self.seed_shift
# (nframes x natoms) x naxis
final_layer = tf.reshape(
final_layer, [tf.shape(inputs)[0] * natoms, self.dim_rot_mat_1]
)
# (nframes x natoms) x naxis x naxis
final_layer = tf.matrix_diag(final_layer)
else:
bavg = np.zeros(self.dim_rot_mat_1 * self.dim_rot_mat_1) # pylint: disable=no-explicit-dtype
# bavg[0*self.dim_rot_mat_1+0] = self.avgeig[0]
# bavg[1*self.dim_rot_mat_1+1] = self.avgeig[1]
# bavg[2*self.dim_rot_mat_1+2] = self.avgeig[2]
# (nframes x natoms) x (naxis x naxis)
final_layer = one_layer(
layer,
self.dim_rot_mat_1 * self.dim_rot_mat_1,
activation_fn=None,
name="final_layer" + suffix,
reuse=reuse,
seed=self.seed,
bavg=bavg,
precision=self.fitting_precision,
uniform_seed=self.uniform_seed,
initial_variables=self.fitting_net_variables,
mixed_prec=self.mixed_prec,
final_layer=True,
)
if (not self.uniform_seed) and (self.seed is not None):
self.seed += self.seed_shift
# (nframes x natoms) x naxis x naxis
final_layer = tf.reshape(
final_layer,
[tf.shape(inputs)[0] * natoms, self.dim_rot_mat_1, self.dim_rot_mat_1],
)
# (nframes x natoms) x naxis x naxis
final_layer = final_layer + tf.transpose(final_layer, perm=[0, 2, 1])
# (nframes x natoms) x naxis x 3(coord)
final_layer = tf.matmul(final_layer, rot_mat_i)
# (nframes x natoms) x 3(coord) x 3(coord)
final_layer = tf.matmul(rot_mat_i, final_layer, transpose_a=True)
# nframes x natoms x 3 x 3
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms, 3, 3])
return final_layer
[docs]
def build(
self,
input_d: tf.Tensor,
rot_mat: tf.Tensor,
natoms: tf.Tensor,
input_dict: Optional[dict] = None,
reuse: Optional[bool] = None,
suffix: str = "",
):
"""Build the computational graph for fitting net.
Parameters
----------
input_d
The input descriptor
rot_mat
The rotation matrix from the descriptor.
natoms
The number of atoms. This tensor has the length of Ntypes + 2
natoms[0]: number of local atoms
natoms[1]: total number of atoms held by this processor
natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
input_dict
Additional dict for inputs.
reuse
The weights in the networks should be reused when get the variable.
suffix
Name suffix to identify this descriptor
Returns
-------
atomic_polar
The atomic polarizability
"""
if input_dict is None:
input_dict = {}
type_embedding = input_dict.get("type_embedding", None)
atype = input_dict.get("atype", None)
nframes = input_dict.get("nframes")
start_index = 0
with tf.variable_scope("fitting_attr" + suffix, reuse=reuse):
self.t_constant_matrix = tf.get_variable(
"t_constant_matrix",
self.constant_matrix.shape,
dtype=GLOBAL_TF_FLOAT_PRECISION,
trainable=False,
initializer=tf.constant_initializer(self.constant_matrix),
)
inputs = tf.reshape(input_d, [-1, self.dim_descrpt * natoms[0]])
rot_mat = tf.reshape(rot_mat, [-1, self.dim_rot_mat * natoms[0]])
if nframes is None:
nframes = tf.shape(inputs)[0]
if self.mixed_types or type_embedding is not None:
# keep old behavior
self.mixed_types = True
# nframes x nloc
nloc_mask = tf.reshape(
tf.tile(tf.repeat(self.sel_mask, natoms[2:]), [nframes]), [nframes, -1]
)
# nframes x nloc_masked
scale = tf.reshape(
tf.reshape(
tf.tile(tf.repeat(self.scale, natoms[2:]), [nframes]), [nframes, -1]
)[nloc_mask],
[nframes, -1],
)
if self.shift_diag:
# nframes x nloc_masked
constant_matrix = tf.reshape(
tf.reshape(
tf.tile(
tf.repeat(self.t_constant_matrix, natoms[2:]), [nframes]
),
[nframes, -1],
)[nloc_mask],
[nframes, -1],
)
atype_nall = tf.reshape(atype, [-1, natoms[1]])
# (nframes x nloc_masked)
self.atype_nloc_masked = tf.reshape(
tf.slice(atype_nall, [0, 0], [-1, natoms[0]])[nloc_mask], [-1]
) ## lammps will make error
self.nloc_masked = tf.shape(
tf.reshape(self.atype_nloc_masked, [nframes, -1])
)[1]
if type_embedding is not None:
atype_embed = tf.nn.embedding_lookup(type_embedding, self.atype_nloc_masked)
else:
atype_embed = None
self.atype_embed = atype_embed
if atype_embed is not None:
inputs = tf.reshape(
tf.reshape(inputs, [nframes, natoms[0], self.dim_descrpt])[nloc_mask],
[-1, self.dim_descrpt],
)
rot_mat = tf.reshape(
tf.reshape(rot_mat, [nframes, natoms[0], self.dim_rot_mat])[nloc_mask],
[-1, self.dim_rot_mat * self.nloc_masked],
)
atype_embed = tf.cast(atype_embed, self.fitting_precision)
type_shape = atype_embed.get_shape().as_list()
inputs = tf.concat([inputs, atype_embed], axis=1)
self.dim_descrpt = self.dim_descrpt + type_shape[1]
if not self.mixed_types:
count = 0
outs_list = []
for type_i in range(self.ntypes):
if type_i not in self.sel_type:
start_index += natoms[2 + type_i]
continue
final_layer = self._build_lower(
start_index,
natoms[2 + type_i],
inputs,
rot_mat,
suffix="_type_" + str(type_i) + suffix,
reuse=reuse,
)
# shift and scale
sel_type_idx = self.sel_type.index(type_i)
final_layer = final_layer * self.scale[sel_type_idx]
final_layer = final_layer + tf.slice(
self.t_constant_matrix, [sel_type_idx], [1]
) * tf.eye(
3,
batch_shape=[tf.shape(inputs)[0], natoms[2 + type_i]],
dtype=GLOBAL_TF_FLOAT_PRECISION,
)
start_index += natoms[2 + type_i]
# concat the results
outs_list.append(final_layer)
count += 1
outs = tf.concat(outs_list, axis=1)
else:
inputs = tf.reshape(inputs, [-1, self.dim_descrpt * self.nloc_masked])
final_layer = self._build_lower(
0, self.nloc_masked, inputs, rot_mat, suffix=suffix, reuse=reuse
)
# shift and scale
final_layer *= tf.expand_dims(tf.expand_dims(scale, -1), -1)
if self.shift_diag:
final_layer += tf.expand_dims(
tf.expand_dims(constant_matrix, -1), -1
) * tf.eye(3, batch_shape=[1, 1], dtype=GLOBAL_TF_FLOAT_PRECISION)
outs = final_layer
tf.summary.histogram("fitting_net_output", outs)
return tf.reshape(outs, [-1])
[docs]
def init_variables(
self,
graph: tf.Graph,
graph_def: tf.GraphDef,
suffix: str = "",
) -> None:
"""Init the fitting net variables with the given dict.
Parameters
----------
graph : tf.Graph
The input frozen model graph
graph_def : tf.GraphDef
The input frozen model graph_def
suffix : str
suffix to name scope
"""
self.fitting_net_variables = get_fitting_net_variables_from_graph_def(
graph_def, suffix=suffix
)
if self.shift_diag:
try:
self.constant_matrix = get_tensor_by_name_from_graph(
graph, f"fitting_attr{suffix}/t_constant_matrix"
)
except GraphWithoutTensorError:
warnings.warn(
"You are trying to read a model trained with shift_diag=True, but the mean of the diagonal terms of the polarizability is not stored in the graph. This will lead to wrong inference results. You may train your model with the latest DeePMD-kit to avoid this issue.",
stacklevel=2,
)
[docs]
def enable_mixed_precision(self, mixed_prec: Optional[dict] = None) -> None:
"""Receive the mixed precision setting.
Parameters
----------
mixed_prec
The mixed precision setting used in the embedding net
"""
self.mixed_prec = mixed_prec
self.fitting_precision = get_precision(mixed_prec["output_prec"])
[docs]
def get_loss(self, loss: dict, lr) -> Loss:
"""Get the loss function."""
return TensorLoss(
loss,
model=self,
tensor_name="polar",
tensor_size=9,
label_name="polarizability",
)
[docs]
def serialize(self, suffix: str) -> dict:
"""Serialize the model.
Returns
-------
dict
The serialized data
"""
data = {
"@class": "Fitting",
"type": "polar",
"@version": 4,
"ntypes": self.ntypes,
"dim_descrpt": self.dim_descrpt,
"embedding_width": self.dim_rot_mat_1,
"mixed_types": self.mixed_types,
"dim_out": 3,
"neuron": self.n_neuron,
"resnet_dt": self.resnet_dt,
"numb_fparam": self.numb_fparam,
"numb_aparam": self.numb_aparam,
"dim_case_embd": self.dim_case_embd,
"activation_function": self.activation_function_name,
"precision": self.fitting_precision.name,
"exclude_types": [],
"fit_diag": self.fit_diag,
"scale": list(self.scale),
"shift_diag": self.shift_diag,
"nets": self.serialize_network(
ntypes=self.ntypes,
ndim=0 if self.mixed_types else 1,
in_dim=self.dim_descrpt,
out_dim=self.dim_rot_mat_1,
neuron=self.n_neuron,
activation_function=self.activation_function_name,
resnet_dt=self.resnet_dt,
variables=self.fitting_net_variables,
suffix=suffix,
),
"@variables": {
"fparam_avg": None,
"fparam_inv_std": None,
"aparam_avg": None,
"aparam_inv_std": None,
"case_embd": None,
"scale": self.scale.reshape(-1, 1),
"constant_matrix": self.constant_matrix.reshape(-1),
},
"type_map": self.type_map,
}
return data
@classmethod
[docs]
def deserialize(cls, data: dict, suffix: str):
"""Deserialize the model.
Parameters
----------
data : dict
The serialized data
Returns
-------
Model
The deserialized model
"""
data = data.copy()
check_version_compatibility(
data.pop("@version", 1), 4, 1
) # to allow PT version.
fitting = cls(**data)
fitting.fitting_net_variables = cls.deserialize_network(
data["nets"],
suffix=suffix,
)
fitting.constant_matrix = data["@variables"]["constant_matrix"].ravel()
return fitting
[docs]
class GlobalPolarFittingSeA:
r"""Fit the system polarizability with descriptor se_a.
Parameters
----------
descrpt : tf.Tensor
The descriptor
neuron : list[int]
Number of neurons in each hidden layer of the fitting net
resnet_dt : bool
Time-step `dt` in the resnet construction:
y = x + dt * \phi (Wx + b)
sel_type : list[int]
The atom types selected to have an atomic polarizability prediction
fit_diag : bool
Fit the diagonal part of the rotational invariant polarizability matrix, which will be converted to normal polarizability matrix by contracting with the rotation matrix.
scale : list[float]
The output of the fitting net (polarizability matrix) for type i atom will be scaled by scale[i]
diag_shift : list[float]
The diagonal part of the polarizability matrix of type i will be shifted by diag_shift[i]. The shift operation is carried out after scale.
seed : int
Random seed for initializing the network parameters.
activation_function : str
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
precision : str
The precision of the embedding net parameters. Supported options are |PRECISION|
"""
def __init__(
self,
descrpt: tf.Tensor,
neuron: list[int] = [120, 120, 120],
resnet_dt: bool = True,
sel_type: Optional[list[int]] = None,
fit_diag: bool = True,
scale: Optional[list[float]] = None,
diag_shift: Optional[list[float]] = None,
seed: Optional[int] = None,
activation_function: str = "tanh",
precision: str = "default",
) -> None:
"""Constructor."""
if not isinstance(descrpt, DescrptSeA):
raise RuntimeError("GlobalPolarFittingSeA only supports DescrptSeA")
[docs]
self.ntypes = descrpt.get_ntypes()
[docs]
self.dim_descrpt = descrpt.get_dim_out()
[docs]
self.polar_fitting = PolarFittingSeA(
descrpt,
neuron,
resnet_dt,
sel_type,
fit_diag,
scale,
diag_shift,
seed,
activation_function,
precision,
)
[docs]
def get_sel_type(self) -> int:
"""Get selected atom types."""
return self.polar_fitting.get_sel_type()
[docs]
def get_out_size(self) -> int:
"""Get the output size. Should be 9."""
return self.polar_fitting.get_out_size()
[docs]
def build(
self,
input_d,
rot_mat,
natoms,
input_dict: Optional[dict] = None,
reuse=None,
suffix="",
) -> tf.Tensor:
"""Build the computational graph for fitting net.
Parameters
----------
input_d
The input descriptor
rot_mat
The rotation matrix from the descriptor.
natoms
The number of atoms. This tensor has the length of Ntypes + 2
natoms[0]: number of local atoms
natoms[1]: total number of atoms held by this processor
natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
input_dict
Additional dict for inputs.
reuse
The weights in the networks should be reused when get the variable.
suffix
Name suffix to identify this descriptor
Returns
-------
polar
The system polarizability
"""
inputs = tf.reshape(input_d, [-1, self.dim_descrpt * natoms[0]])
outs = self.polar_fitting.build(
input_d, rot_mat, natoms, input_dict, reuse, suffix
)
# nframes x natoms x 9
outs = tf.reshape(outs, [tf.shape(inputs)[0], -1, 9])
outs = tf.reduce_sum(outs, axis=1)
tf.summary.histogram("fitting_net_output", outs)
return tf.reshape(outs, [-1])
[docs]
def init_variables(
self,
graph: tf.Graph,
graph_def: tf.GraphDef,
suffix: str = "",
) -> None:
"""Init the fitting net variables with the given dict.
Parameters
----------
graph : tf.Graph
The input frozen model graph
graph_def : tf.GraphDef
The input frozen model graph_def
suffix : str
suffix to name scope
"""
self.polar_fitting.init_variables(
graph=graph, graph_def=graph_def, suffix=suffix
)
[docs]
def enable_mixed_precision(self, mixed_prec: Optional[dict] = None) -> None:
"""Receive the mixed precision setting.
Parameters
----------
mixed_prec
The mixed precision setting used in the embedding net
"""
self.polar_fitting.enable_mixed_precision(mixed_prec)
[docs]
def get_loss(self, loss: dict, lr) -> Loss:
"""Get the loss function.
Parameters
----------
loss : dict
the loss dict
lr : LearningRateExp
the learning rate
Returns
-------
Loss
the loss function
"""
return TensorLoss(
loss,
model=self,
tensor_name="global_polar",
tensor_size=9,
atomic=False,
label_name="polarizability",
)
@property
[docs]
def get_numb_fparam(self) -> int:
"""Get the number of frame parameters."""
return self.numb_fparam
[docs]
def get_numb_aparam(self) -> int:
"""Get the number of atomic parameters."""
return self.numb_aparam