# SPDX-License-Identifier: LGPL-3.0-or-later
import re
from collections.abc import (
Callable,
)
from typing import (
Any,
)
import numpy as np
from deepmd.dpmodel.utils.env_mat import (
EnvMat,
)
from deepmd.dpmodel.utils.network import (
EmbeddingNet,
NetworkCollection,
)
from deepmd.tf.common import (
cast_precision,
get_activation_func,
get_precision,
)
from deepmd.tf.env import (
EMBEDDING_NET_PATTERN,
GLOBAL_NP_FLOAT_PRECISION,
GLOBAL_TF_FLOAT_PRECISION,
default_tf_session_config,
op_module,
tf,
)
from deepmd.tf.utils.graph import (
get_tensor_by_name_from_graph,
)
from deepmd.tf.utils.network import (
embedding_net,
embedding_net_rand_seed_shift,
)
from deepmd.tf.utils.sess import (
run_sess,
)
from deepmd.tf.utils.tabulate import (
DPTabulate,
)
from deepmd.utils.version import (
check_version_compatibility,
)
from .descriptor import (
Descriptor,
)
from .se import (
DescrptSe,
)
@Descriptor.register("se_e3")
@Descriptor.register("se_at")
@Descriptor.register("se_a_3be")
[docs]
class DescrptSeT(DescrptSe):
r"""DeepPot-SE constructed from all information (both angular and radial) of atomic
configurations.
The embedding takes angles between two neighboring atoms as input.
Parameters
----------
rcut
The cut-off radius
rcut_smth
From where the environment matrix should be smoothed
sel : list[int]
sel[i] specifies the maxmum number of type i atoms in the cut-off radius
neuron : list[int]
Number of neurons in each hidden layers of the embedding net
resnet_dt
Time-step `dt` in the resnet construction:
y = x + dt * \phi (Wx + b)
trainable
If the weights of embedding net are trainable.
seed
Random seed for initializing the network parameters.
set_davg_zero
Set the shift of embedding net input to zero.
activation_function
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
precision
The precision of the embedding net parameters. Supported options are |PRECISION|
uniform_seed
Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed
env_protection: float
Protection parameter to prevent division by zero errors during environment matrix calculations.
type_map: list[str], Optional
A list of strings. Give the name to each type of atoms.
"""
def __init__(
self,
rcut: float,
rcut_smth: float,
sel: list[int],
neuron: list[int] = [24, 48, 96],
resnet_dt: bool = False,
trainable: bool = True,
seed: int | None = None,
exclude_types: list[list[int]] = [],
set_davg_zero: bool = False,
activation_function: str = "tanh",
precision: str = "default",
uniform_seed: bool = False,
type_map: list[str] | None = None, # to be compat with input
env_protection: float = 0.0, # not implement!!
**kwargs: Any,
) -> None:
"""Constructor."""
if rcut < rcut_smth:
raise RuntimeError(
f"rcut_smth ({rcut_smth:f}) should be no more than rcut ({rcut:f})!"
)
if exclude_types:
raise NotImplementedError("exclude_types != [] is not supported.")
if env_protection != 0.0:
raise NotImplementedError("env_protection != 0.0 is not supported.")
[docs]
self.rcut_r_smth = rcut_smth
[docs]
self.filter_neuron = neuron
[docs]
self.filter_resnet_dt = resnet_dt
[docs]
self.seed_shift = embedding_net_rand_seed_shift(self.filter_neuron)
[docs]
self.trainable = trainable
[docs]
self.filter_activation_fn = get_activation_func(activation_function)
[docs]
self.activation_function_name = activation_function
[docs]
self.filter_precision = get_precision(precision)
[docs]
self.env_protection = env_protection
[docs]
self.orig_exclude_types = exclude_types
[docs]
self.exclude_types = set()
[docs]
self.type_map = type_map
for tt in exclude_types:
assert len(tt) == 2
self.exclude_types.add((tt[0], tt[1]))
self.exclude_types.add((tt[1], tt[0]))
[docs]
self.set_davg_zero = set_davg_zero
# descrpt config
[docs]
self.sel_r = [0 for ii in range(len(self.sel_a))]
[docs]
self.ntypes = len(self.sel_a)
assert self.ntypes == len(self.sel_r)
# numb of neighbors and numb of descriptors
[docs]
self.nnei_a = np.cumsum(self.sel_a)[-1]
[docs]
self.nnei_r = np.cumsum(self.sel_r)[-1]
[docs]
self.nnei = self.nnei_a + self.nnei_r
[docs]
self.ndescrpt_a = self.nnei_a * 4
[docs]
self.ndescrpt_r = self.nnei_r * 1
[docs]
self.ndescrpt = self.ndescrpt_a + self.ndescrpt_r
[docs]
self.embedding_net_variables = None
[docs]
self.place_holders = {}
avg_zero = np.zeros([self.ntypes, self.ndescrpt]).astype( # pylint: disable=no-explicit-dtype
GLOBAL_NP_FLOAT_PRECISION
)
std_ones = np.ones([self.ntypes, self.ndescrpt]).astype( # pylint: disable=no-explicit-dtype
GLOBAL_NP_FLOAT_PRECISION
)
sub_graph = tf.Graph()
with sub_graph.as_default():
name_pfx = "d_sea_"
for ii in ["coord", "box"]:
self.place_holders[ii] = tf.placeholder(
GLOBAL_NP_FLOAT_PRECISION, [None, None], name=name_pfx + "t_" + ii
)
self.place_holders["type"] = tf.placeholder(
tf.int32, [None, None], name=name_pfx + "t_type"
)
self.place_holders["natoms_vec"] = tf.placeholder(
tf.int32, [self.ntypes + 2], name=name_pfx + "t_natoms"
)
self.place_holders["default_mesh"] = tf.placeholder(
tf.int32, [None], name=name_pfx + "t_mesh"
)
self.stat_descrpt, descrpt_deriv, rij, nlist = op_module.prod_env_mat_a(
self.place_holders["coord"],
self.place_holders["type"],
self.place_holders["natoms_vec"],
self.place_holders["box"],
self.place_holders["default_mesh"],
tf.constant(avg_zero),
tf.constant(std_ones),
rcut_a=self.rcut_a,
rcut_r=self.rcut_r,
rcut_r_smth=self.rcut_r_smth,
sel_a=self.sel_a,
sel_r=self.sel_r,
)
[docs]
self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config)
[docs]
def get_rcut(self) -> float:
"""Returns the cut-off radius."""
return self.rcut_r
[docs]
def get_ntypes(self) -> int:
"""Returns the number of atom types."""
return self.ntypes
[docs]
def get_dim_out(self) -> int:
"""Returns the output dimension of this descriptor."""
return self.filter_neuron[-1]
[docs]
def get_nlist(self) -> tuple[tf.Tensor, tf.Tensor, list[int], list[int]]:
"""Returns neighbor information.
Returns
-------
nlist
Neighbor list
rij
The relative distance between the neighbor and the center atom.
sel_a
The number of neighbors with full information
sel_r
The number of neighbors with only radial information
"""
return self.nlist, self.rij, self.sel_a, self.sel_r
[docs]
def enable_compression(
self,
min_nbor_dist: float,
graph: tf.Graph,
graph_def: tf.GraphDef,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
suffix: str = "",
) -> None:
"""Receive the statistics (distance, max_nbor_size and env_mat_range) of the training data.
Parameters
----------
min_nbor_dist
The nearest distance between atoms
graph : tf.Graph
The graph of the model
graph_def : tf.GraphDef
The graph_def of the model
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
suffix : str, optional
The suffix of the scope
"""
assert not self.filter_resnet_dt, (
"Model compression error: descriptor resnet_dt must be false!"
)
self.compress = True
self.table = DPTabulate(
self,
self.filter_neuron,
graph,
graph_def,
activation_fn=self.filter_activation_fn,
suffix=suffix,
)
self.table_config = [
table_extrapolate,
table_stride_1 * 10,
table_stride_2 * 10,
check_frequency,
]
self.lower, self.upper = self.table.build(
min_nbor_dist, table_extrapolate, table_stride_1 * 10, table_stride_2 * 10
)
self.davg = get_tensor_by_name_from_graph(graph, f"descrpt_attr{suffix}/t_avg")
self.dstd = get_tensor_by_name_from_graph(graph, f"descrpt_attr{suffix}/t_std")
[docs]
def build(
self,
coord_: tf.Tensor,
atype_: tf.Tensor,
natoms: tf.Tensor,
box_: tf.Tensor,
mesh: tf.Tensor,
input_dict: dict,
reuse: bool | None = None,
suffix: str = "",
) -> tf.Tensor:
"""Build the computational graph for the descriptor.
Parameters
----------
coord_
The coordinate of atoms
atype_
The type of atoms
natoms
The number of atoms. This tensor has the length of Ntypes + 2
natoms[0]: number of local atoms
natoms[1]: total number of atoms held by this processor
natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
box_ : tf.Tensor
The box of the system
mesh
For historical reasons, only the length of the Tensor matters.
if size of mesh == 6, pbc is assumed.
if size of mesh == 0, no-pbc is assumed.
input_dict
Dictionary for additional inputs
reuse
The weights in the networks should be reused when get the variable.
suffix
Name suffix to identify this descriptor
Returns
-------
descriptor
The output descriptor
"""
davg = self.davg
dstd = self.dstd
with tf.variable_scope("descrpt_attr" + suffix, reuse=reuse):
if davg is None:
davg = np.zeros([self.ntypes, self.ndescrpt]) # pylint: disable=no-explicit-dtype
if dstd is None:
dstd = np.ones([self.ntypes, self.ndescrpt]) # pylint: disable=no-explicit-dtype
t_rcut = tf.constant(
np.max([self.rcut_r, self.rcut_a]),
name="rcut",
dtype=GLOBAL_TF_FLOAT_PRECISION,
)
t_ntypes = tf.constant(self.ntypes, name="ntypes", dtype=tf.int32)
t_ndescrpt = tf.constant(self.ndescrpt, name="ndescrpt", dtype=tf.int32)
t_sel = tf.constant(self.sel_a, name="sel", dtype=tf.int32)
self.t_avg = tf.get_variable(
"t_avg",
davg.shape,
dtype=GLOBAL_TF_FLOAT_PRECISION,
trainable=False,
initializer=tf.constant_initializer(davg),
)
self.t_std = tf.get_variable(
"t_std",
dstd.shape,
dtype=GLOBAL_TF_FLOAT_PRECISION,
trainable=False,
initializer=tf.constant_initializer(dstd),
)
coord = tf.reshape(coord_, [-1, natoms[1] * 3])
box = tf.reshape(box_, [-1, 9])
atype = tf.reshape(atype_, [-1, natoms[1]])
(
self.descrpt,
self.descrpt_deriv,
self.rij,
self.nlist,
) = op_module.prod_env_mat_a(
coord,
atype,
natoms,
box,
mesh,
self.t_avg,
self.t_std,
rcut_a=self.rcut_a,
rcut_r=self.rcut_r,
rcut_r_smth=self.rcut_r_smth,
sel_a=self.sel_a,
sel_r=self.sel_r,
)
self.descrpt_reshape = tf.reshape(self.descrpt, [-1, self.ndescrpt])
self._identity_tensors(suffix=suffix)
self.dout, self.qmat = self._pass_filter(
self.descrpt_reshape,
atype,
natoms,
input_dict,
suffix=suffix,
reuse=reuse,
trainable=self.trainable,
)
return self.dout
[docs]
def prod_force_virial(
self, atom_ener: tf.Tensor, natoms: tf.Tensor
) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""Compute force and virial.
Parameters
----------
atom_ener
The atomic energy
natoms
The number of atoms. This tensor has the length of Ntypes + 2
natoms[0]: number of local atoms
natoms[1]: total number of atoms held by this processor
natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
Returns
-------
force
The force on atoms
virial
The total virial
atom_virial
The atomic virial
"""
[net_deriv] = tf.gradients(atom_ener, self.descrpt_reshape)
net_deriv_reshape = tf.reshape(
net_deriv,
[
np.asarray(-1, dtype=np.int64),
natoms[0] * np.asarray(self.ndescrpt, dtype=np.int64),
],
)
force = op_module.prod_force_se_a(
net_deriv_reshape,
self.descrpt_deriv,
self.nlist,
natoms,
n_a_sel=self.nnei_a,
n_r_sel=self.nnei_r,
)
virial, atom_virial = op_module.prod_virial_se_a(
net_deriv_reshape,
self.descrpt_deriv,
self.rij,
self.nlist,
natoms,
n_a_sel=self.nnei_a,
n_r_sel=self.nnei_r,
)
return force, virial, atom_virial
[docs]
def _pass_filter(
self,
inputs: tf.Tensor,
atype: tf.Tensor,
natoms: tf.Tensor,
input_dict: dict,
reuse: bool | None = None,
suffix: str = "",
trainable: bool = True,
) -> tuple[tf.Tensor, None]:
start_index = 0
inputs = tf.reshape(inputs, [-1, natoms[0], self.ndescrpt])
output = []
output_qmat = []
inputs_i = inputs
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
type_i = -1
layer, qmat = self._filter(
inputs_i,
type_i,
name="filter_type_all" + suffix,
natoms=natoms,
reuse=reuse,
trainable=trainable,
activation_fn=self.filter_activation_fn,
)
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0], self.get_dim_out()])
# qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[0] * self.get_dim_rot_mat_1() * 3])
output.append(layer)
# output_qmat.append(qmat)
output = tf.concat(output, axis=1)
# output_qmat = tf.concat(output_qmat, axis = 1)
return output, None
[docs]
def _compute_dstats_sys_smth(
self,
data_coord: np.ndarray,
data_box: np.ndarray,
data_atype: np.ndarray,
natoms_vec: np.ndarray,
mesh: np.ndarray,
) -> tuple[list[float], list[float], list[float], list[float], list[int]]:
dd_all = run_sess(
self.sub_sess,
self.stat_descrpt,
feed_dict={
self.place_holders["coord"]: data_coord,
self.place_holders["type"]: data_atype,
self.place_holders["natoms_vec"]: natoms_vec,
self.place_holders["box"]: data_box,
self.place_holders["default_mesh"]: mesh,
},
)
natoms = natoms_vec
dd_all = np.reshape(dd_all, [-1, self.ndescrpt * natoms[0]])
start_index = 0
sysr = []
sysa = []
sysn = []
sysr2 = []
sysa2 = []
for type_i in range(self.ntypes):
end_index = start_index + self.ndescrpt * natoms[2 + type_i]
dd = dd_all[:, start_index:end_index]
dd = np.reshape(dd, [-1, self.ndescrpt])
start_index = end_index
# compute
dd = np.reshape(dd, [-1, 4])
ddr = dd[:, :1]
dda = dd[:, 1:]
sumr = np.sum(ddr)
suma = np.sum(dda) / 3.0
sumn = dd.shape[0]
sumr2 = np.sum(np.multiply(ddr, ddr))
suma2 = np.sum(np.multiply(dda, dda)) / 3.0
sysr.append(sumr)
sysa.append(suma)
sysn.append(sumn)
sysr2.append(sumr2)
sysa2.append(suma2)
return sysr, sysr2, sysa, sysa2, sysn
[docs]
def _compute_std(self, sumv2: float, sumv: float, sumn: int) -> float:
val = np.sqrt(sumv2 / sumn - np.multiply(sumv / sumn, sumv / sumn))
if np.abs(val) < 1e-2:
val = 1e-2
return val
@cast_precision
[docs]
def _filter(
self,
inputs: tf.Tensor,
type_input: int,
natoms: tf.Tensor,
activation_fn: Callable[[tf.Tensor], tf.Tensor] | None = tf.nn.tanh,
stddev: float = 1.0,
bavg: float = 0.0,
name: str = "linear",
reuse: bool | None = None,
trainable: bool = True,
) -> tuple[tf.Tensor, tf.Tensor]:
# natom x (nei x 4)
shape = inputs.get_shape().as_list()
outputs_size = [1, *self.filter_neuron]
with tf.variable_scope(name, reuse=reuse):
start_index_i = 0
result = None
for type_i in range(self.ntypes):
# cut-out inputs
# with natom x (nei_type_i x 4)
inputs_i = tf.slice(
inputs, [0, start_index_i * 4], [-1, self.sel_a[type_i] * 4]
)
start_index_j = start_index_i
start_index_i += self.sel_a[type_i]
nei_type_i = self.sel_a[type_i]
shape_i = inputs_i.get_shape().as_list()
assert shape_i[1] == nei_type_i * 4
# with natom x nei_type_i x 4
env_i = tf.reshape(inputs_i, [-1, nei_type_i, 4])
# with natom x nei_type_i x 3
env_i = tf.slice(env_i, [0, 0, 1], [-1, -1, -1])
for type_j in range(type_i, self.ntypes):
# with natom x (nei_type_j x 4)
inputs_j = tf.slice(
inputs, [0, start_index_j * 4], [-1, self.sel_a[type_j] * 4]
)
start_index_j += self.sel_a[type_j]
nei_type_j = self.sel_a[type_j]
shape_j = inputs_j.get_shape().as_list()
assert shape_j[1] == nei_type_j * 4
# with natom x nei_type_j x 4
env_j = tf.reshape(inputs_j, [-1, nei_type_j, 4])
# with natom x nei_type_i x 3
env_j = tf.slice(env_j, [0, 0, 1], [-1, -1, -1])
# with natom x nei_type_i x nei_type_j
env_ij = tf.einsum("ijm,ikm->ijk", env_i, env_j)
# with (natom x nei_type_i x nei_type_j)
ebd_env_ij = tf.reshape(env_ij, [-1, 1])
if self.compress:
net = "filter_" + str(type_i) + "_net_" + str(type_j)
info = [
self.lower[net],
self.upper[net],
self.upper[net] * self.table_config[0],
self.table_config[1],
self.table_config[2],
self.table_config[3],
]
res_ij = op_module.tabulate_fusion_se_t(
tf.cast(self.table.data[net], self.filter_precision),
info,
ebd_env_ij,
env_ij,
last_layer_size=outputs_size[-1],
)
else:
# with (natom x nei_type_i x nei_type_j) x out_size
ebd_env_ij = embedding_net(
ebd_env_ij,
self.filter_neuron,
self.filter_precision,
activation_fn=activation_fn,
resnet_dt=self.filter_resnet_dt,
name_suffix=f"_{type_i}_{type_j}",
stddev=stddev,
bavg=bavg,
seed=self.seed,
trainable=trainable,
uniform_seed=self.uniform_seed,
initial_variables=self.embedding_net_variables,
)
if (not self.uniform_seed) and (self.seed is not None):
self.seed += self.seed_shift
# with natom x nei_type_i x nei_type_j x out_size
ebd_env_ij = tf.reshape(
ebd_env_ij, [-1, nei_type_i, nei_type_j, outputs_size[-1]]
)
# with natom x out_size
res_ij = tf.einsum("ijk,ijkm->im", env_ij, ebd_env_ij)
res_ij = res_ij * (1.0 / float(nei_type_i) / float(nei_type_j))
if result is None:
result = res_ij
else:
result += res_ij
return result, None
[docs]
def serialize_network(
self,
ntypes: int,
ndim: int,
in_dim: int,
neuron: list[int],
activation_function: str,
resnet_dt: bool,
variables: dict,
excluded_types: set[tuple[int, int]] = set(),
trainable: bool = True,
suffix: str = "",
) -> dict:
"""Serialize network.
Parameters
----------
ntypes : int
The number of types
ndim : int
The dimension of elements
in_dim : int
The input dimension
neuron : list[int]
The neuron list
activation_function : str
The activation function
resnet_dt : bool
Whether to use resnet
variables : dict
The input variables
excluded_types : set[tuple[int, int]], optional
The excluded types
trainable : bool, optional
Whether the network is trainable
suffix : str, optional
The suffix of the scope
Returns
-------
dict
The converted network data
"""
assert ndim == 2, "Embeddings in descriptor 'se_e3' must have two dimensions."
embeddings = NetworkCollection(
ntypes=ntypes,
ndim=ndim,
network_type="embedding_network",
)
def clear_ij(type_i: int, type_j: int) -> None:
# initialize an empty network
embeddings[(type_i, type_j)] = EmbeddingNet(
in_dim=in_dim,
neuron=neuron,
activation_function=activation_function,
resnet_dt=resnet_dt,
precision=self.precision.name,
trainable=trainable,
)
embeddings[(type_i, type_j)].clear()
for i, j in excluded_types:
clear_ij(i, j)
clear_ij(j, i)
for i in range(ntypes):
for j in range(0, i):
clear_ij(i, j)
if suffix != "":
embedding_net_pattern = (
EMBEDDING_NET_PATTERN.replace("/(idt)", suffix + "/(idt)")
.replace("/(bias)", suffix + "/(bias)")
.replace("/(matrix)", suffix + "/(matrix)")
)
else:
embedding_net_pattern = EMBEDDING_NET_PATTERN
for key, value in variables.items():
m = re.search(embedding_net_pattern, key)
m = [mm for mm in m.groups() if mm is not None]
typei = m[3]
typej = m[4]
layer_idx = int(m[2]) - 1
weight_name = m[1]
network_idx = (int(typei), int(typej))
if embeddings[network_idx] is None:
# initialize the network if it is not initialized
embeddings[network_idx] = EmbeddingNet(
in_dim=in_dim,
neuron=neuron,
activation_function=activation_function,
resnet_dt=resnet_dt,
precision=self.precision.name,
trainable=trainable,
)
assert embeddings[network_idx] is not None
if weight_name == "idt":
value = value.ravel()
embeddings[network_idx][layer_idx][weight_name] = value
return embeddings.serialize()
@classmethod
[docs]
def deserialize_network(cls, data: dict, suffix: str = "") -> dict:
"""Deserialize network.
Parameters
----------
data : dict
The input network data
suffix : str, optional
The suffix of the scope
Returns
-------
variables : dict
The input variables
"""
embedding_net_variables = {}
embeddings = NetworkCollection.deserialize(data)
assert embeddings.ndim == 2, (
"Embeddings in descriptor 'se_e3' must have two dimensions."
)
for ii in range(embeddings.ntypes**embeddings.ndim):
net_idx = []
rest_ii = ii
for _ in range(embeddings.ndim):
net_idx.append(rest_ii % embeddings.ntypes)
rest_ii //= embeddings.ntypes
net_idx = tuple(net_idx)
key0 = "all"
key1 = f"_{net_idx[0]}"
key2 = f"_{net_idx[1]}"
network = embeddings[net_idx]
assert network is not None
for layer_idx, layer in enumerate(network.layers):
embedding_net_variables[
f"filter_type_{key0}{suffix}/matrix_{layer_idx + 1}{key1}{key2}"
] = layer.w
embedding_net_variables[
f"filter_type_{key0}{suffix}/bias_{layer_idx + 1}{key1}{key2}"
] = layer.b
if layer.idt is not None:
embedding_net_variables[
f"filter_type_{key0}{suffix}/idt_{layer_idx + 1}{key1}{key2}"
] = layer.idt.reshape(1, -1)
else:
# prevent keyError
embedding_net_variables[
f"filter_type_{key0}{suffix}/idt_{layer_idx + 1}{key1}{key2}"
] = 0.0
return embedding_net_variables
@classmethod
[docs]
def deserialize(cls, data: dict, suffix: str = "") -> "DescrptSeT":
"""Deserialize the model.
Parameters
----------
data : dict
The serialized data
Returns
-------
Model
The deserialized model
"""
if cls is not DescrptSeT:
raise NotImplementedError(f"Not implemented in class {cls.__name__}")
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 3, 1)
data.pop("@class", None)
data.pop("type", None)
embedding_net_variables = cls.deserialize_network(
data.pop("embeddings"), suffix=suffix
)
data.pop("env_mat")
variables = data.pop("@variables")
data.pop("compress", None) # tf uses frozen graph for compression
descriptor = cls(**data)
descriptor.embedding_net_variables = embedding_net_variables
descriptor.davg = variables["davg"].reshape(
descriptor.ntypes, descriptor.ndescrpt
)
descriptor.dstd = variables["dstd"].reshape(
descriptor.ntypes, descriptor.ndescrpt
)
return descriptor
[docs]
def serialize(self, suffix: str = "") -> dict:
"""Serialize the model.
Parameters
----------
suffix : str, optional
The suffix of the scope
Returns
-------
dict
The serialized data
"""
if type(self) is not DescrptSeT:
raise NotImplementedError(
f"Not implemented in class {self.__class__.__name__}"
)
if self.embedding_net_variables is None:
raise RuntimeError("init_variables must be called before serialize")
assert self.davg is not None
assert self.dstd is not None
return {
"@class": "Descriptor",
"type": "se_e3",
"@version": 2,
"rcut": self.rcut_r,
"rcut_smth": self.rcut_r_smth,
"sel": self.sel_a,
"neuron": self.filter_neuron,
"resnet_dt": self.filter_resnet_dt,
"set_davg_zero": self.set_davg_zero,
"activation_function": self.activation_function_name,
"precision": self.filter_precision.name,
"embeddings": self.serialize_network(
ntypes=self.ntypes,
ndim=2,
in_dim=1,
neuron=self.filter_neuron,
activation_function=self.activation_function_name,
resnet_dt=self.filter_resnet_dt,
variables=self.embedding_net_variables,
excluded_types=self.exclude_types,
trainable=self.trainable,
suffix=suffix,
),
"env_mat": EnvMat(self.rcut_r, self.rcut_r_smth).serialize(),
"exclude_types": list(self.orig_exclude_types),
"env_protection": self.env_protection,
"@variables": {
"davg": self.davg.reshape(self.ntypes, self.nnei_a, 4),
"dstd": self.dstd.reshape(self.ntypes, self.nnei_a, 4),
},
"type_map": self.type_map,
"trainable": self.trainable,
}