# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Optional,
)
import torch
from deepmd.dpmodel import (
FittingOutputDef,
OutputVariableDef,
fitting_check_output,
)
from deepmd.pt.model.network.network import (
MaskLMHead,
NonLinearHead,
)
from deepmd.pt.model.task.fitting import (
Fitting,
)
from deepmd.pt.utils import (
env,
)
@fitting_check_output
[docs]
class DenoiseNet(Fitting):
def __init__(
self,
feature_dim,
ntypes,
attn_head=8,
prefactor=[0.5, 0.5],
activation_function="gelu",
**kwargs,
) -> None:
"""Construct a denoise net.
Args:
- ntypes: Element count.
- embedding_width: Embedding width per atom.
- neuron: Number of neurons in each hidden layers of the fitting net.
- bias_atom_e: Average energy per atom for each element.
- resnet_dt: Using time-step in the ResNet construction.
"""
super().__init__()
[docs]
self.feature_dim = feature_dim
[docs]
self.attn_head = attn_head
[docs]
self.prefactor = torch.tensor(
prefactor, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
[docs]
self.lm_head = MaskLMHead(
embed_dim=self.feature_dim,
output_dim=ntypes,
activation_fn=activation_function,
weight=None,
)
if not isinstance(self.attn_head, list):
self.pair2coord_proj = NonLinearHead(
self.attn_head, 1, activation_fn=activation_function
)
else:
self.pair2coord_proj = []
self.ndescriptor = len(self.attn_head)
for ii in range(self.ndescriptor):
_pair2coord_proj = NonLinearHead(
self.attn_head[ii], 1, activation_fn=activation_function
)
self.pair2coord_proj.append(_pair2coord_proj)
self.pair2coord_proj = torch.nn.ModuleList(self.pair2coord_proj)
[docs]
def output_def(self):
return FittingOutputDef(
[
OutputVariableDef(
"updated_coord",
[3],
reducible=False,
r_differentiable=False,
c_differentiable=False,
),
OutputVariableDef(
"logits",
[-1],
reducible=False,
r_differentiable=False,
c_differentiable=False,
),
]
)
[docs]
def forward(
self,
pair_weights,
diff,
nlist_mask,
features,
sw,
masked_tokens: Optional[torch.Tensor] = None,
):
"""Calculate the updated coord.
Args:
- coord: Input noisy coord with shape [nframes, nloc, 3].
- pair_weights: Input pair weights with shape [nframes, nloc, nnei, head].
- diff: Input pair relative coord list with shape [nframes, nloc, nnei, 3].
- nlist_mask: Input nlist mask with shape [nframes, nloc, nnei].
Returns
-------
- denoised_coord: Denoised updated coord with shape [nframes, nloc, 3].
"""
# [nframes, nloc, nnei, 1]
logits = self.lm_head(features, masked_tokens=masked_tokens)
if not isinstance(self.attn_head, list):
attn_probs = self.pair2coord_proj(pair_weights)
out_coord = (attn_probs * diff).sum(dim=-2) / (
sw.sum(dim=-1).unsqueeze(-1) + 1e-6
)
else:
assert len(self.prefactor) == self.ndescriptor
all_coord_update = []
assert len(pair_weights) == len(diff) == len(nlist_mask) == self.ndescriptor
for ii in range(self.ndescriptor):
_attn_probs = self.pair2coord_proj[ii](pair_weights[ii])
_coord_update = (_attn_probs * diff[ii]).sum(dim=-2) / (
nlist_mask[ii].sum(dim=-1).unsqueeze(-1) + 1e-6
)
all_coord_update.append(_coord_update)
out_coord = self.prefactor[0] * all_coord_update[0]
for ii in range(self.ndescriptor - 1):
out_coord += self.prefactor[ii + 1] * all_coord_update[ii + 1]
return {
"updated_coord": out_coord,
"logits": logits,
}