# SPDX-License-Identifier: LGPL-3.0-or-later
from collections.abc import (
Callable,
)
from typing import (
Any,
)
import paddle
from deepmd.dpmodel.utils import EnvMat as DPEnvMat
from deepmd.dpmodel.utils.seed import (
child_seed,
)
from deepmd.pd.model.descriptor import (
DescriptorBlock,
)
from deepmd.pd.model.descriptor.env_mat import (
prod_env_mat,
)
from deepmd.pd.model.network.mlp import (
EmbeddingNet,
NetworkCollection,
)
from deepmd.pd.model.network.network import (
TypeEmbedNet,
TypeEmbedNetConsistent,
)
from deepmd.pd.utils import (
env,
)
from deepmd.pd.utils.env import (
PRECISION_DICT,
RESERVED_PRECISION_DICT,
)
from deepmd.pd.utils.env_mat_stat import (
EnvMatStatSe,
)
from deepmd.pd.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.pd.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.env_mat_stat import (
StatItem,
)
from deepmd.utils.finetune import (
get_index_between_two_maps,
map_pair_exclude_types,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
from .base_descriptor import (
BaseDescriptor,
)
from .descriptor import (
extend_descrpt_stat,
)
@BaseDescriptor.register("se_e3_tebd")
[docs]
class DescrptSeTTebd(BaseDescriptor, paddle.nn.Layer):
r"""Construct an embedding net that takes angles between two neighboring atoms and type embeddings as input.
Parameters
----------
rcut
The cut-off radius
rcut_smth
From where the environment matrix should be smoothed
sel : Union[list[int], int]
list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius
int: the total maxmum number of atoms in the cut-off radius
ntypes : int
Number of element types
neuron : list[int]
Number of neurons in each hidden layers of the embedding net
tebd_dim : int
Dimension of the type embedding
tebd_input_mode : str
The input mode of the type embedding. Supported modes are ["concat", "strip"].
- "concat": Concatenate the type embedding with the smoothed angular information as the union input for the embedding network.
- "strip": Use a separated embedding network for the type embedding and combine the output with the angular embedding network output.
resnet_dt
Time-step `dt` in the resnet construction:
y = x + dt * \phi (Wx + b)
set_davg_zero
Set the shift of embedding net input to zero.
activation_function
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
env_protection: float
Protection parameter to prevent division by zero errors during environment matrix calculations.
exclude_types : list[tuple[int, int]]
The excluded pairs of types which have no interaction with each other.
For example, `[[0, 1]]` means no interaction between type 0 and type 1.
precision
The precision of the embedding net parameters. Supported options are |PRECISION|
trainable
If the weights of embedding net are trainable.
seed
Random seed for initializing the network parameters.
type_map: list[str], Optional
A list of strings. Give the name to each type of atoms.
concat_output_tebd: bool
Whether to concat type embedding at the output of the descriptor.
use_econf_tebd: bool, Optional
Whether to use electronic configuration type embedding.
use_tebd_bias : bool, Optional
Whether to use bias in the type embedding layer.
smooth: bool
Whether to use smooth process in calculation.
"""
def __init__(
self,
rcut: float,
rcut_smth: float,
sel: list[int] | int,
ntypes: int,
neuron: list = [2, 4, 8],
tebd_dim: int = 8,
tebd_input_mode: str = "concat",
resnet_dt: bool = False,
set_davg_zero: bool = True,
activation_function: str = "tanh",
env_protection: float = 0.0,
exclude_types: list[tuple[int, int]] = [],
precision: str = "float64",
trainable: bool = True,
seed: int | list[int] | None = None,
type_map: list[str] | None = None,
concat_output_tebd: bool = True,
use_econf_tebd: bool = False,
use_tebd_bias: bool = False,
smooth: bool = True,
) -> None:
super().__init__()
[docs]
self.se_ttebd = DescrptBlockSeTTebd(
rcut,
rcut_smth,
sel,
ntypes,
neuron=neuron,
tebd_dim=tebd_dim,
tebd_input_mode=tebd_input_mode,
set_davg_zero=set_davg_zero,
activation_function=activation_function,
precision=precision,
resnet_dt=resnet_dt,
exclude_types=exclude_types,
env_protection=env_protection,
smooth=smooth,
seed=child_seed(seed, 1),
trainable=trainable,
)
[docs]
self.prec = PRECISION_DICT[precision]
[docs]
self.use_econf_tebd = use_econf_tebd
[docs]
self.type_map = type_map
if type_map is not None:
self.register_buffer(
"buffer_type_map",
paddle.to_tensor([ord(c) for c in " ".join(type_map)]),
)
[docs]
self.type_embedding = TypeEmbedNet(
ntypes,
tebd_dim,
precision=precision,
seed=child_seed(seed, 2),
use_econf_tebd=use_econf_tebd,
type_map=type_map,
use_tebd_bias=use_tebd_bias,
trainable=trainable,
)
[docs]
self.tebd_dim = tebd_dim
[docs]
self.concat_output_tebd = concat_output_tebd
[docs]
self.trainable = trainable
# set trainable
for param in self.parameters():
param.stop_gradient = not trainable
[docs]
def get_rcut(self) -> float:
"""Returns the cut-off radius."""
return self.se_ttebd.get_rcut()
[docs]
def get_rcut_smth(self) -> float:
"""Returns the radius where the neighbor information starts to smoothly decay to 0."""
return self.se_ttebd.get_rcut_smth()
[docs]
def get_nsel(self) -> int:
"""Returns the number of selected atoms in the cut-off radius."""
return self.se_ttebd.get_nsel()
[docs]
def get_sel(self) -> list[int]:
"""Returns the number of selected atoms for each type."""
return self.se_ttebd.get_sel()
[docs]
def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.se_ttebd.get_ntypes()
[docs]
def get_type_map(self) -> list[str]:
"""Get the name to each type of atoms."""
return self.type_map
[docs]
def get_buffer_type_map(self) -> paddle.Tensor:
"""
Return the type map as a buffer-style Tensor for JIT saving.
The original type map (e.g., ['Ni', 'O']) is first joined into a single space-separated string
(e.g., "Ni O"). Each character in this string is then converted to its ASCII code using `ord()`,
and the resulting integer sequence is stored as a 1D paddle.Tensor of dtype int.
This format allows the type map to be serialized as a raw byte buffer during JIT model saving.
"""
return self.buffer_type_map
[docs]
def get_dim_out(self) -> int:
"""Returns the output dimension."""
ret = self.se_ttebd.get_dim_out()
if self.concat_output_tebd:
ret += self.tebd_dim
return ret
[docs]
def get_dim_emb(self) -> int:
return self.se_ttebd.dim_emb
[docs]
def mixed_types(self) -> bool:
"""If true, the descriptor
1. assumes total number of atoms aligned across frames;
2. requires a neighbor list that does not distinguish different atomic types.
If false, the descriptor
1. assumes total number of atoms of each atom type aligned across frames;
2. requires a neighbor list that distinguishes different atomic types.
"""
return self.se_ttebd.mixed_types()
[docs]
def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.se_ttebd.has_message_passing()
[docs]
def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return self.se_ttebd.need_sorted_nlist_for_lower()
[docs]
def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.se_ttebd.get_env_protection()
[docs]
def share_params(
self, base_class: object, shared_level: int, resume: bool = False
) -> None:
"""
Share the parameters of self to the base_class with shared_level during multitask training.
If not start from checkpoint (resume is False),
some separated parameters (e.g. mean and stddev) will be re-calculated across different classes.
"""
assert self.__class__ == base_class.__class__, (
"Only descriptors of the same type can share params!"
)
# For DPA1 descriptors, the user-defined share-level
# shared_level: 0
# share all parameters in both type_embedding and se_ttebd
if shared_level == 0:
self._sub_layers["type_embedding"] = base_class._sub_layers[
"type_embedding"
]
self.se_ttebd.share_params(base_class.se_ttebd, 0, resume=resume)
# shared_level: 1
# share all parameters in type_embedding
elif shared_level == 1:
self._sub_layers["type_embedding"] = base_class._sub_layers[
"type_embedding"
]
# Other shared levels
else:
raise NotImplementedError
@property
[docs]
def dim_out(self) -> int:
return self.get_dim_out()
@property
[docs]
def dim_emb(self) -> int:
return self.get_dim_emb()
[docs]
def set_stat_mean_and_stddev(
self,
mean: paddle.Tensor,
stddev: paddle.Tensor,
) -> None:
"""Update mean and stddev for descriptor."""
self.se_ttebd.mean = mean
self.se_ttebd.stddev = stddev
[docs]
def get_stat_mean_and_stddev(self) -> tuple[paddle.Tensor, paddle.Tensor]:
"""Get mean and stddev for descriptor."""
return self.se_ttebd.mean, self.se_ttebd.stddev
[docs]
def change_type_map(
self, type_map: list[str], model_with_new_type_stat: Any | None = None
) -> None:
"""Change the type related params to new ones, according to `type_map` and the original one in the model.
If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types.
"""
assert self.type_map is not None, (
"'type_map' must be defined when performing type changing!"
)
remap_index, has_new_type = get_index_between_two_maps(self.type_map, type_map)
obj = self.se_ttebd
obj.ntypes = len(type_map)
self.type_map = type_map
self.type_embedding.change_type_map(type_map=type_map)
obj.reinit_exclude(map_pair_exclude_types(obj.exclude_types, remap_index))
if has_new_type:
# the avg and std of new types need to be updated
extend_descrpt_stat(
obj,
type_map,
des_with_stat=model_with_new_type_stat.se_ttebd
if model_with_new_type_stat is not None
else None,
)
obj["davg"] = obj["davg"][remap_index]
obj["dstd"] = obj["dstd"][remap_index]
[docs]
def serialize(self) -> dict:
obj = self.se_ttebd
data = {
"@class": "Descriptor",
"type": "se_e3_tebd",
"@version": 1,
"rcut": obj.rcut,
"rcut_smth": obj.rcut_smth,
"sel": obj.sel,
"ntypes": obj.ntypes,
"neuron": obj.neuron,
"tebd_dim": obj.tebd_dim,
"tebd_input_mode": obj.tebd_input_mode,
"set_davg_zero": obj.set_davg_zero,
"activation_function": obj.activation_function,
"resnet_dt": obj.resnet_dt,
"concat_output_tebd": self.concat_output_tebd,
"use_econf_tebd": self.use_econf_tebd,
"type_map": self.type_map,
# make deterministic
"precision": RESERVED_PRECISION_DICT[obj.prec],
"embeddings": obj.filter_layers.serialize(),
"env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(),
"type_embedding": self.type_embedding.embedding.serialize(),
"exclude_types": obj.exclude_types,
"env_protection": obj.env_protection,
"smooth": self.smooth,
"@variables": {
"davg": obj["davg"].numpy(),
"dstd": obj["dstd"].numpy(),
},
"trainable": self.trainable,
}
if obj.tebd_input_mode in ["strip"]:
data.update({"embeddings_strip": obj.filter_layers_strip.serialize()})
return data
@classmethod
[docs]
def deserialize(cls, data: dict) -> "DescrptSeTTebd":
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
data.pop("@class")
data.pop("type")
variables = data.pop("@variables")
embeddings = data.pop("embeddings")
type_embedding = data.pop("type_embedding")
env_mat = data.pop("env_mat")
tebd_input_mode = data["tebd_input_mode"]
if tebd_input_mode in ["strip"]:
embeddings_strip = data.pop("embeddings_strip")
else:
embeddings_strip = None
obj = cls(**data)
def t_cvt(xx: paddle.Tensor) -> paddle.Tensor:
return paddle.to_tensor(xx, dtype=obj.se_ttebd.prec).to(device=env.DEVICE)
obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize(
type_embedding
)
obj.se_ttebd["davg"] = t_cvt(variables["davg"])
obj.se_ttebd["dstd"] = t_cvt(variables["dstd"])
obj.se_ttebd.filter_layers = NetworkCollection.deserialize(embeddings)
if tebd_input_mode in ["strip"]:
obj.se_ttebd.filter_layers_strip = NetworkCollection.deserialize(
embeddings_strip
)
return obj
[docs]
def forward(
self,
extended_coord: paddle.Tensor,
extended_atype: paddle.Tensor,
nlist: paddle.Tensor,
mapping: paddle.Tensor | None = None,
comm_dict: list[paddle.Tensor] | None = None,
) -> paddle.Tensor:
"""Compute the descriptor.
Parameters
----------
extended_coord
The extended coordinates of atoms. shape: nf x (nallx3)
extended_atype
The extended aotm types. shape: nf x nall
nlist
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, not required by this descriptor.
comm_dict
The data needed for communication for parallel inference.
Returns
-------
descriptor
The descriptor. shape: nf x nloc x (ng x axis_neuron)
gr
The rotationally equivariant and permutationally invariant single particle
representation. shape: nf x nloc x ng x 3
g2
The rotationally invariant pair-partical representation.
shape: nf x nloc x nnei x ng
h2
The rotationally equivariant pair-partical representation.
shape: nf x nloc x nnei x 3
sw
The smooth switch function. shape: nf x nloc x nnei
"""
# cast the input to internal precsion
extended_coord = extended_coord.to(dtype=self.prec)
del mapping
nframes, nloc, nnei = nlist.shape
nall = extended_coord.reshape([nframes, -1]).shape[1] // 3
g1_ext = self.type_embedding(extended_atype)
g1_inp = g1_ext[:, :nloc, :]
if self.tebd_input_mode in ["strip"]:
type_embedding = self.type_embedding.get_full_embedding(g1_ext.place)
else:
type_embedding = None
g1, _, _, _, sw = self.se_ttebd(
nlist,
extended_coord,
extended_atype,
g1_ext,
mapping=None,
type_embedding=type_embedding,
)
if self.concat_output_tebd:
g1 = paddle.concat([g1, g1_inp], axis=-1)
return (
g1.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
None,
None,
None,
sw.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION),
)
@classmethod
[docs]
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: list[str] | None,
local_jdata: dict,
) -> tuple[dict, float | None]:
"""Update the selection and perform neighbor statistics.
Parameters
----------
train_data : DeepmdDataSystem
data used to do neighbor statistics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class
Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
min_nbor_dist, sel = UpdateSel().update_one_sel(
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], True
)
local_jdata_cpy["sel"] = sel[0]
return local_jdata_cpy, min_nbor_dist
@DescriptorBlock.register("se_ttebd")
[docs]
class DescrptBlockSeTTebd(DescriptorBlock):
def __init__(
self,
rcut: float,
rcut_smth: float,
sel: list[int] | int,
ntypes: int,
neuron: list = [25, 50, 100],
tebd_dim: int = 8,
tebd_input_mode: str = "concat",
set_davg_zero: bool = True,
activation_function: str = "tanh",
precision: str = "float64",
resnet_dt: bool = False,
exclude_types: list[tuple[int, int]] = [],
env_protection: float = 0.0,
smooth: bool = True,
seed: int | list[int] | None = None,
trainable: bool = True,
) -> None:
super().__init__()
[docs]
self.rcut = float(rcut)
self.register_buffer("buffer_rcut", paddle.to_tensor(self.rcut))
[docs]
self.rcut_smth = float(rcut_smth)
self.register_buffer("buffer_rcut_smth", paddle.to_tensor(self.rcut_smth))
[docs]
self.filter_neuron = self.neuron
[docs]
self.tebd_dim = tebd_dim
[docs]
self.set_davg_zero = set_davg_zero
[docs]
self.activation_function = activation_function
[docs]
self.precision = precision
[docs]
self.prec = PRECISION_DICT[self.precision]
[docs]
self.resnet_dt = resnet_dt
[docs]
self.env_protection = env_protection
if isinstance(sel, int):
sel = [sel]
self.register_buffer(
"buffer_ntypes", paddle.to_tensor(self.ntypes, dtype="int64")
)
[docs]
self.split_sel = self.sel
[docs]
self.ndescrpt = self.nnei * 4
# order matters, placed after the assignment of self.ntypes
self.reinit_exclude(exclude_types)
wanted_shape = (self.ntypes, self.nnei, 4)
mean = paddle.zeros(wanted_shape, dtype=env.GLOBAL_PD_FLOAT_PRECISION).to(
device=env.DEVICE
)
stddev = paddle.ones(wanted_shape, dtype=env.GLOBAL_PD_FLOAT_PRECISION).to(
device=env.DEVICE
)
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)
if self.tebd_input_mode in ["concat"]:
self.embd_input_dim = 1 + self.tebd_dim_input
else:
self.embd_input_dim = 1
[docs]
self.filter_layers = None
[docs]
self.filter_layers_strip = None
filter_layers = NetworkCollection(
ndim=0, ntypes=self.ntypes, network_type="embedding_network"
)
filter_layers[0] = EmbeddingNet(
self.embd_input_dim,
self.filter_neuron,
activation_function=self.activation_function,
precision=self.precision,
resnet_dt=self.resnet_dt,
seed=child_seed(self.seed, 1),
trainable=trainable,
)
self.filter_layers = filter_layers
if self.tebd_input_mode in ["strip"]:
filter_layers_strip = NetworkCollection(
ndim=0, ntypes=self.ntypes, network_type="embedding_network"
)
filter_layers_strip[0] = EmbeddingNet(
self.tebd_dim_input,
self.filter_neuron,
activation_function=self.activation_function,
precision=self.precision,
resnet_dt=self.resnet_dt,
seed=child_seed(self.seed, 2),
trainable=trainable,
)
self.filter_layers_strip = filter_layers_strip
[docs]
def get_rcut(self) -> float:
"""Returns the cut-off radius."""
return self.rcut
[docs]
def get_rcut_smth(self) -> float:
"""Returns the radius where the neighbor information starts to smoothly decay to 0."""
return self.rcut_smth
[docs]
def get_buffer_rcut(self) -> paddle.Tensor:
"""Returns the cut-off radius as a buffer-style Tensor."""
return self.buffer_rcut
[docs]
def get_buffer_rcut_smth(self) -> paddle.Tensor:
"""Returns the radius where the neighbor information starts to smoothly decay to 0 as a buffer-style Tensor."""
return self.buffer_rcut_smth
[docs]
def get_nsel(self) -> int:
"""Returns the number of selected atoms in the cut-off radius."""
return sum(self.sel)
[docs]
def get_sel(self) -> list[int]:
"""Returns the number of selected atoms for each type."""
return self.sel
[docs]
def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.ntypes if paddle.in_dynamic_mode() else self.buffer_ntypes
[docs]
def get_dim_in(self) -> int:
"""Returns the input dimension."""
return self.dim_in
[docs]
def get_dim_out(self) -> int:
"""Returns the output dimension."""
return self.dim_out
[docs]
def get_dim_emb(self) -> int:
"""Returns the output dimension of embedding."""
return self.filter_neuron[-1]
[docs]
def __setitem__(self, key: str, value: paddle.Tensor) -> None:
if key in ("avg", "data_avg", "davg"):
self.mean = value
elif key in ("std", "data_std", "dstd"):
self.stddev = value
else:
raise KeyError(key)
[docs]
def __getitem__(self, key: str) -> paddle.Tensor:
if key in ("avg", "data_avg", "davg"):
return self.mean
elif key in ("std", "data_std", "dstd"):
return self.stddev
else:
raise KeyError(key)
[docs]
def mixed_types(self) -> bool:
"""If true, the descriptor
1. assumes total number of atoms aligned across frames;
2. requires a neighbor list that does not distinguish different atomic types.
If false, the descriptor
1. assumes total number of atoms of each atom type aligned across frames;
2. requires a neighbor list that distinguishes different atomic types.
"""
return True
[docs]
def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
@property
[docs]
def dim_out(self) -> int:
"""Returns the output dimension of this descriptor."""
return self.filter_neuron[-1]
@property
[docs]
def dim_in(self) -> int:
"""Returns the atomic input dimension of this descriptor."""
return self.tebd_dim
@property
[docs]
def dim_emb(self) -> int:
"""Returns the output dimension of embedding."""
return self.get_dim_emb()
[docs]
def get_stats(self) -> dict[str, StatItem]:
"""Get the statistics of the descriptor."""
if self.stats is None:
raise RuntimeError(
"The statistics of the descriptor has not been computed."
)
return self.stats
[docs]
def reinit_exclude(
self,
exclude_types: list[tuple[int, int]] = [],
) -> None:
self.exclude_types = exclude_types
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)
[docs]
def forward(
self,
nlist: paddle.Tensor,
extended_coord: paddle.Tensor,
extended_atype: paddle.Tensor,
extended_atype_embd: paddle.Tensor | None = None,
mapping: paddle.Tensor | None = None,
type_embedding: paddle.Tensor | None = None,
) -> paddle.Tensor:
"""Compute the descriptor.
Parameters
----------
nlist
The neighbor list. shape: nf x nloc x nnei
extended_coord
The extended coordinates of atoms. shape: nf x (nallx3)
extended_atype
The extended aotm types. shape: nf x nall x nt
extended_atype_embd
The extended type embedding of atoms. shape: nf x nall
mapping
The index mapping, not required by this descriptor.
type_embedding
Full type embeddings. shape: (ntypes+1) x nt
Required for stripped type embeddings.
Returns
-------
result
The descriptor. shape: nf x nloc x (ng x axis_neuron)
g2
The rotationally invariant pair-partical representation.
shape: nf x nloc x nnei x ng
h2
The rotationally equivariant pair-partical representation.
shape: nf x nloc x nnei x 3
gr
The rotationally equivariant and permutationally invariant single particle
representation. shape: nf x nloc x ng x 3
sw
The smooth switch function. shape: nf x nloc x nnei
"""
del mapping
assert extended_atype_embd is not None
nframes, nloc, nnei = nlist.shape
atype = extended_atype[:, :nloc]
nb = nframes
nall = extended_coord.reshape([nb, -1, 3]).shape[1]
dmatrix, diff, sw = prod_env_mat(
extended_coord,
nlist,
atype,
self.mean,
self.stddev,
self.rcut,
self.rcut_smth,
protection=self.env_protection,
)
# nb x nloc x nnei
exclude_mask = self.emask(nlist, extended_atype)
nlist = paddle.where(exclude_mask != 0, nlist, paddle.full_like(nlist, -1))
nlist_mask = nlist != -1
nlist = paddle.where(nlist == -1, paddle.zeros_like(nlist), nlist)
sw = paddle.squeeze(sw, -1)
# nf x nall x nt
nt = extended_atype_embd.shape[-1]
# beyond the cutoff sw should be 0.0
sw = sw.masked_fill(~nlist_mask, 0.0)
# (nb x nloc) x nnei
exclude_mask = exclude_mask.reshape([nb * nloc, nnei])
assert self.filter_layers is not None
# nfnl x nnei x 4
dmatrix = dmatrix.reshape([-1, self.nnei, 4])
nfnl = dmatrix.shape[0]
# nfnl x nnei x 4
rr = dmatrix
rr = rr * exclude_mask[:, :, None].astype(rr.dtype)
# nfnl x nt_i x 3
rr_i = rr[:, :, 1:]
# nfnl x nt_j x 3
rr_j = rr[:, :, 1:]
# nfnl x nt_i x nt_j
# env_ij = paddle.einsum("ijm,ikm->ijk", rr_i, rr_j)
env_ij = (
# ij1m x i1km -> ijkm -> ijk
rr_i.unsqueeze(2) * rr_j.unsqueeze(1)
).sum(-1)
# nfnl x nt_i x nt_j x 1
ss = env_ij.unsqueeze(-1)
if self.tebd_input_mode in ["concat"]:
atype_tebd_ext = extended_atype_embd
# nb x (nloc x nnei) x nt
index = nlist.reshape([nb, nloc * nnei]).unsqueeze(-1).expand([-1, -1, nt])
# nb x (nloc x nnei) x nt
# atype_tebd_nlist = paddle.take_along_axis(atype_tebd_ext, axis=1, index=index)
atype_tebd_nlist = paddle.take_along_axis(
atype_tebd_ext, axis=1, indices=index, broadcast=False
)
# nb x nloc x nnei x nt
atype_tebd_nlist = atype_tebd_nlist.reshape([nb, nloc, nnei, nt])
# nfnl x nnei x tebd_dim
nlist_tebd = atype_tebd_nlist.reshape([nfnl, nnei, self.tebd_dim])
# nfnl x nt_i x nt_j x tebd_dim
nlist_tebd_i = nlist_tebd.unsqueeze(2).expand([-1, -1, self.nnei, -1])
nlist_tebd_j = nlist_tebd.unsqueeze(1).expand([-1, self.nnei, -1, -1])
# nfnl x nt_i x nt_j x (1 + tebd_dim * 2)
ss = paddle.concat([ss, nlist_tebd_i, nlist_tebd_j], axis=-1)
# nfnl x nt_i x nt_j x ng
gg = self.filter_layers.networks[0](ss)
elif self.tebd_input_mode in ["strip"]:
# nfnl x nt_i x nt_j x ng
gg_s = self.filter_layers.networks[0](ss)
assert self.filter_layers_strip is not None
assert type_embedding is not None
ng = self.filter_neuron[-1]
ntypes_with_padding = type_embedding.shape[0]
# nf x (nl x nnei)
nlist_index = nlist.reshape([nb, nloc * nnei])
# nf x (nl x nnei)
nei_type = paddle.take_along_axis(
extended_atype, indices=nlist_index, axis=1, broadcast=False
)
# nfnl x nnei
nei_type = nei_type.reshape([nfnl, nnei])
# nfnl x nnei x nnei
nei_type_i = nei_type.unsqueeze(2).expand([-1, -1, nnei])
nei_type_j = nei_type.unsqueeze(1).expand([-1, nnei, -1])
idx_i = nei_type_i * ntypes_with_padding
idx_j = nei_type_j
# (nf x nl x nt_i x nt_j) x ng
idx = (
(idx_i + idx_j)
.reshape([-1, 1])
.expand([-1, ng])
.astype(paddle.int64)
.to(paddle.int64)
)
# ntypes * (ntypes) * nt
type_embedding_i = paddle.tile(
type_embedding.reshape([ntypes_with_padding, 1, nt]),
[1, ntypes_with_padding, 1],
)
# (ntypes) * ntypes * nt
type_embedding_j = paddle.tile(
type_embedding.reshape([1, ntypes_with_padding, nt]),
[ntypes_with_padding, 1, 1],
)
# (ntypes * ntypes) * (nt+nt)
two_side_type_embedding = paddle.concat(
[type_embedding_i, type_embedding_j], -1
).reshape([-1, nt * 2])
tt_full = self.filter_layers_strip.networks[0](two_side_type_embedding)
# (nfnl x nt_i x nt_j) x ng
gg_t = paddle.take_along_axis(tt_full, indices=idx, axis=0, broadcast=False)
# (nfnl x nt_i x nt_j) x ng
gg_t = gg_t.reshape([nfnl, nnei, nnei, ng])
if self.smooth:
gg_t = (
gg_t
* sw.reshape([nfnl, self.nnei, 1, 1])
* sw.reshape([nfnl, 1, self.nnei, 1])
)
# nfnl x nt_i x nt_j x ng
gg = gg_s * gg_t + gg_s
else:
raise NotImplementedError
# nfnl x ng
# res_ij = paddle.einsum("ijk,ijkm->im", env_ij, gg)
res_ij = (
# ijk1 x ijkm -> ijkm -> im
env_ij.unsqueeze(-1) * gg
).sum([1, 2])
res_ij = res_ij * (1.0 / float(self.nnei) / float(self.nnei))
# nf x nl x ng
result = res_ij.reshape([nframes, nloc, self.filter_neuron[-1]])
return (
result,
None,
None,
None,
sw,
)
[docs]
def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False
[docs]
def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""
return False