Source code for deepmd.pt.model.descriptor.repformer_layer

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    List,
    Optional,
)

import torch
import torch.nn as nn

from deepmd.pt.model.network.layernorm import (
    LayerNorm,
)
from deepmd.pt.model.network.mlp import (
    MLPLayer,
)
from deepmd.pt.utils import (
    env,
)
from deepmd.pt.utils.env import (
    PRECISION_DICT,
)
from deepmd.pt.utils.utils import (
    ActivationFn,
    to_numpy_array,
    to_torch_tensor,
)
from deepmd.utils.version import (
    check_version_compatibility,
)


[docs] def get_residual( _dim: int, _scale: float, _mode: str = "norm", trainable: bool = True, precision: str = "float64", ) -> torch.Tensor: r""" Get residual tensor for one update vector. Parameters ---------- _dim : int The dimension of the update vector. _scale The initial scale of the residual tensor. See `_mode` for details. _mode The mode of residual initialization for the residual tensor. - "norm" (default): init residual using normal with `_scale` std. - "const": init residual using element-wise constants of `_scale`. trainable Whether the residual tensor is trainable. precision The precision of the residual tensor. """ residual = nn.Parameter( data=torch.zeros(_dim, dtype=PRECISION_DICT[precision], device=env.DEVICE), requires_grad=trainable, ) if _mode == "norm": nn.init.normal_(residual.data, std=_scale) elif _mode == "const": nn.init.constant_(residual.data, val=_scale) else: raise RuntimeError(f"Unsupported initialization mode '{_mode}'!") return residual
# common ops
[docs] def _make_nei_g1( g1_ext: torch.Tensor, nlist: torch.Tensor, ) -> torch.Tensor: """ Make neighbor-wise atomic invariant rep. Parameters ---------- g1_ext Extended atomic invariant rep, with shape nb x nall x ng1. nlist Neighbor list, with shape nb x nloc x nnei. Returns ------- gg1: torch.Tensor Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. """ # nlist: nb x nloc x nnei nb, nloc, nnei = nlist.shape # g1_ext: nb x nall x ng1 ng1 = g1_ext.shape[-1] # index: nb x (nloc x nnei) x ng1 index = nlist.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, ng1) # gg1 : nb x (nloc x nnei) x ng1 gg1 = torch.gather(g1_ext, dim=1, index=index) # gg1 : nb x nloc x nnei x ng1 gg1 = gg1.view(nb, nloc, nnei, ng1) return gg1
[docs] def _apply_nlist_mask( gg: torch.Tensor, nlist_mask: torch.Tensor, ) -> torch.Tensor: """ Apply nlist mask to neighbor-wise rep tensors. Parameters ---------- gg Neighbor-wise rep tensors, with shape nf x nloc x nnei x d. nlist_mask Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. """ # gg: nf x nloc x nnei x d # msk: nf x nloc x nnei return gg.masked_fill(~nlist_mask.unsqueeze(-1), 0.0)
[docs] def _apply_switch(gg: torch.Tensor, sw: torch.Tensor) -> torch.Tensor: """ Apply switch function to neighbor-wise rep tensors. Parameters ---------- gg Neighbor-wise rep tensors, with shape nf x nloc x nnei x d. sw The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, and remains 0 beyond rcut, with shape nf x nloc x nnei. """ # gg: nf x nloc x nnei x d # sw: nf x nloc x nnei return gg * sw.unsqueeze(-1)
[docs] class Atten2Map(torch.nn.Module): def __init__( self, input_dim: int, hidden_dim: int, head_num: int, has_gate: bool = False, # apply gate to attn map smooth: bool = True, attnw_shift: float = 20.0, precision: str = "float64", ): """Return neighbor-wise multi-head self-attention maps, with gate mechanism.""" super().__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.head_num = head_num self.mapqk = MLPLayer( input_dim, hidden_dim * 2 * head_num, bias=False, precision=precision ) self.has_gate = has_gate self.smooth = smooth self.attnw_shift = attnw_shift self.precision = precision
[docs] def forward( self, g2: torch.Tensor, # nb x nloc x nnei x ng2 h2: torch.Tensor, # nb x nloc x nnei x 3 nlist_mask: torch.Tensor, # nb x nloc x nnei sw: torch.Tensor, # nb x nloc x nnei ) -> torch.Tensor: ( nb, nloc, nnei, _, ) = g2.shape nd, nh = self.hidden_dim, self.head_num # nb x nloc x nnei x nd x (nh x 2) g2qk = self.mapqk(g2).view(nb, nloc, nnei, nd, nh * 2) # nb x nloc x (nh x 2) x nnei x nd g2qk = torch.permute(g2qk, (0, 1, 4, 2, 3)) # nb x nloc x nh x nnei x nd g2q, g2k = torch.split(g2qk, nh, dim=2) # g2q = torch.nn.functional.normalize(g2q, dim=-1) # g2k = torch.nn.functional.normalize(g2k, dim=-1) # nb x nloc x nh x nnei x nnei attnw = torch.matmul(g2q, torch.transpose(g2k, -1, -2)) / nd**0.5 if self.has_gate: gate = torch.matmul(h2, torch.transpose(h2, -1, -2)).unsqueeze(-3) attnw = attnw * gate # mask the attenmap, nb x nloc x 1 x 1 x nnei attnw_mask = ~nlist_mask.unsqueeze(2).unsqueeze(2) # mask the attenmap, nb x nloc x 1 x nnei x 1 attnw_mask_c = ~nlist_mask.unsqueeze(2).unsqueeze(-1) if self.smooth: attnw = (attnw + self.attnw_shift) * sw[:, :, None, :, None] * sw[ :, :, None, None, : ] - self.attnw_shift else: attnw = attnw.masked_fill( attnw_mask, float("-inf"), ) attnw = torch.softmax(attnw, dim=-1) attnw = attnw.masked_fill( attnw_mask, 0.0, ) # nb x nloc x nh x nnei x nnei attnw = attnw.masked_fill( attnw_mask_c, 0.0, ) if self.smooth: attnw = attnw * sw[:, :, None, :, None] * sw[:, :, None, None, :] # nb x nloc x nnei x nnei h2h2t = torch.matmul(h2, torch.transpose(h2, -1, -2)) / 3.0**0.5 # nb x nloc x nh x nnei x nnei ret = attnw * h2h2t[:, :, None, :, :] # ret = torch.softmax(g2qk, dim=-1) # nb x nloc x nnei x nnei x nh ret = torch.permute(ret, (0, 1, 3, 4, 2)) return ret
[docs] def serialize(self) -> dict: """Serialize the networks to a dict. Returns ------- dict The serialized networks. """ return { "@class": "Atten2Map", "@version": 1, "input_dim": self.input_dim, "hidden_dim": self.hidden_dim, "head_num": self.head_num, "has_gate": self.has_gate, "smooth": self.smooth, "attnw_shift": self.attnw_shift, "precision": self.precision, "mapqk": self.mapqk.serialize(), }
@classmethod
[docs] def deserialize(cls, data: dict) -> "Atten2Map": """Deserialize the networks from a dict. Parameters ---------- data : dict The dict to deserialize from. """ data = data.copy() check_version_compatibility(data.pop("@version"), 1, 1) data.pop("@class") mapqk = data.pop("mapqk") obj = cls(**data) obj.mapqk = MLPLayer.deserialize(mapqk) return obj
[docs] class Atten2MultiHeadApply(torch.nn.Module): def __init__( self, input_dim: int, head_num: int, precision: str = "float64", ): super().__init__() self.input_dim = input_dim self.head_num = head_num self.mapv = MLPLayer( input_dim, input_dim * head_num, bias=False, precision=precision ) self.head_map = MLPLayer(input_dim * head_num, input_dim, precision=precision) self.precision = precision
[docs] def forward( self, AA: torch.Tensor, # nf x nloc x nnei x nnei x nh g2: torch.Tensor, # nf x nloc x nnei x ng2 ) -> torch.Tensor: nf, nloc, nnei, ng2 = g2.shape nh = self.head_num # nf x nloc x nnei x ng2 x nh g2v = self.mapv(g2).view(nf, nloc, nnei, ng2, nh) # nf x nloc x nh x nnei x ng2 g2v = torch.permute(g2v, (0, 1, 4, 2, 3)) # g2v = torch.nn.functional.normalize(g2v, dim=-1) # nf x nloc x nh x nnei x nnei AA = torch.permute(AA, (0, 1, 4, 2, 3)) # nf x nloc x nh x nnei x ng2 ret = torch.matmul(AA, g2v) # nf x nloc x nnei x ng2 x nh ret = torch.permute(ret, (0, 1, 3, 4, 2)).reshape(nf, nloc, nnei, (ng2 * nh)) # nf x nloc x nnei x ng2 return self.head_map(ret)
[docs] def serialize(self) -> dict: """Serialize the networks to a dict. Returns ------- dict The serialized networks. """ return { "@class": "Atten2MultiHeadApply", "@version": 1, "input_dim": self.input_dim, "head_num": self.head_num, "precision": self.precision, "mapv": self.mapv.serialize(), "head_map": self.head_map.serialize(), }
@classmethod
[docs] def deserialize(cls, data: dict) -> "Atten2MultiHeadApply": """Deserialize the networks from a dict. Parameters ---------- data : dict The dict to deserialize from. """ data = data.copy() check_version_compatibility(data.pop("@version"), 1, 1) data.pop("@class") mapv = data.pop("mapv") head_map = data.pop("head_map") obj = cls(**data) obj.mapv = MLPLayer.deserialize(mapv) obj.head_map = MLPLayer.deserialize(head_map) return obj
[docs] class Atten2EquiVarApply(torch.nn.Module): def __init__( self, input_dim: int, head_num: int, precision: str = "float64", ): super().__init__() self.input_dim = input_dim self.head_num = head_num self.head_map = MLPLayer(head_num, 1, bias=False, precision=precision) self.precision = precision
[docs] def forward( self, AA: torch.Tensor, # nf x nloc x nnei x nnei x nh h2: torch.Tensor, # nf x nloc x nnei x 3 ) -> torch.Tensor: nf, nloc, nnei, _ = h2.shape nh = self.head_num # nf x nloc x nh x nnei x nnei AA = torch.permute(AA, (0, 1, 4, 2, 3)) h2m = torch.unsqueeze(h2, dim=2) # nf x nloc x nh x nnei x 3 h2m = torch.tile(h2m, [1, 1, nh, 1, 1]) # nf x nloc x nh x nnei x 3 ret = torch.matmul(AA, h2m) # nf x nloc x nnei x 3 x nh ret = torch.permute(ret, (0, 1, 3, 4, 2)).view(nf, nloc, nnei, 3, nh) # nf x nloc x nnei x 3 return torch.squeeze(self.head_map(ret), dim=-1)
[docs] def serialize(self) -> dict: """Serialize the networks to a dict. Returns ------- dict The serialized networks. """ return { "@class": "Atten2EquiVarApply", "@version": 1, "input_dim": self.input_dim, "head_num": self.head_num, "precision": self.precision, "head_map": self.head_map.serialize(), }
@classmethod
[docs] def deserialize(cls, data: dict) -> "Atten2EquiVarApply": """Deserialize the networks from a dict. Parameters ---------- data : dict The dict to deserialize from. """ data = data.copy() check_version_compatibility(data.pop("@version"), 1, 1) data.pop("@class") head_map = data.pop("head_map") obj = cls(**data) obj.head_map = MLPLayer.deserialize(head_map) return obj
[docs] class LocalAtten(torch.nn.Module): def __init__( self, input_dim: int, hidden_dim: int, head_num: int, smooth: bool = True, attnw_shift: float = 20.0, precision: str = "float64", ): super().__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.head_num = head_num self.mapq = MLPLayer( input_dim, hidden_dim * 1 * head_num, bias=False, precision=precision ) self.mapkv = MLPLayer( input_dim, (hidden_dim + input_dim) * head_num, bias=False, precision=precision, ) self.head_map = MLPLayer(input_dim * head_num, input_dim, precision=precision) self.smooth = smooth self.attnw_shift = attnw_shift self.precision = precision
[docs] def forward( self, g1: torch.Tensor, # nb x nloc x ng1 gg1: torch.Tensor, # nb x nloc x nnei x ng1 nlist_mask: torch.Tensor, # nb x nloc x nnei sw: torch.Tensor, # nb x nloc x nnei ) -> torch.Tensor: nb, nloc, nnei = nlist_mask.shape ni, nd, nh = self.input_dim, self.hidden_dim, self.head_num assert ni == g1.shape[-1] assert ni == gg1.shape[-1] # nb x nloc x nd x nh g1q = self.mapq(g1).view(nb, nloc, nd, nh) # nb x nloc x nh x nd g1q = torch.permute(g1q, (0, 1, 3, 2)) # nb x nloc x nnei x (nd+ni) x nh gg1kv = self.mapkv(gg1).view(nb, nloc, nnei, nd + ni, nh) gg1kv = torch.permute(gg1kv, (0, 1, 4, 2, 3)) # nb x nloc x nh x nnei x nd, nb x nloc x nh x nnei x ng1 gg1k, gg1v = torch.split(gg1kv, [nd, ni], dim=-1) # nb x nloc x nh x 1 x nnei attnw = torch.matmul(g1q.unsqueeze(-2), torch.transpose(gg1k, -1, -2)) / nd**0.5 # nb x nloc x nh x nnei attnw = attnw.squeeze(-2) # mask the attenmap, nb x nloc x 1 x nnei attnw_mask = ~nlist_mask.unsqueeze(-2) # nb x nloc x nh x nnei if self.smooth: attnw = (attnw + self.attnw_shift) * sw.unsqueeze(-2) - self.attnw_shift else: attnw = attnw.masked_fill( attnw_mask, float("-inf"), ) attnw = torch.softmax(attnw, dim=-1) attnw = attnw.masked_fill( attnw_mask, 0.0, ) if self.smooth: attnw = attnw * sw.unsqueeze(-2) # nb x nloc x nh x ng1 ret = ( torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nb, nloc, nh * ni) ) # nb x nloc x ng1 ret = self.head_map(ret) return ret
[docs] def serialize(self) -> dict: """Serialize the networks to a dict. Returns ------- dict The serialized networks. """ return { "@class": "LocalAtten", "@version": 1, "input_dim": self.input_dim, "hidden_dim": self.hidden_dim, "head_num": self.head_num, "smooth": self.smooth, "attnw_shift": self.attnw_shift, "precision": self.precision, "mapq": self.mapq.serialize(), "mapkv": self.mapkv.serialize(), "head_map": self.head_map.serialize(), }
@classmethod
[docs] def deserialize(cls, data: dict) -> "LocalAtten": """Deserialize the networks from a dict. Parameters ---------- data : dict The dict to deserialize from. """ data = data.copy() check_version_compatibility(data.pop("@version"), 1, 1) data.pop("@class") mapq = data.pop("mapq") mapkv = data.pop("mapkv") head_map = data.pop("head_map") obj = cls(**data) obj.mapq = MLPLayer.deserialize(mapq) obj.mapkv = MLPLayer.deserialize(mapkv) obj.head_map = MLPLayer.deserialize(head_map) return obj
[docs] class RepformerLayer(torch.nn.Module): def __init__( self, rcut, rcut_smth, sel: int, ntypes: int, g1_dim=128, g2_dim=16, axis_neuron: int = 4, update_chnnl_2: bool = True, update_g1_has_conv: bool = True, update_g1_has_drrd: bool = True, update_g1_has_grrg: bool = True, update_g1_has_attn: bool = True, update_g2_has_g1g1: bool = True, update_g2_has_attn: bool = True, update_h2: bool = False, attn1_hidden: int = 64, attn1_nhead: int = 4, attn2_hidden: int = 16, attn2_nhead: int = 4, attn2_has_gate: bool = False, activation_function: str = "tanh", update_style: str = "res_avg", update_residual: float = 0.001, update_residual_init: str = "norm", smooth: bool = True, precision: str = "float64", trainable_ln: bool = True, ln_eps: Optional[float] = 1e-5, ): super().__init__() self.epsilon = 1e-4 # protection of 1./nnei self.rcut = rcut self.rcut_smth = rcut_smth self.ntypes = ntypes sel = [sel] if isinstance(sel, int) else sel self.nnei = sum(sel) assert len(sel) == 1 self.sel = sel self.sec = self.sel self.axis_neuron = axis_neuron self.activation_function = activation_function self.act = ActivationFn(activation_function) self.update_g1_has_grrg = update_g1_has_grrg self.update_g1_has_drrd = update_g1_has_drrd self.update_g1_has_conv = update_g1_has_conv self.update_g1_has_attn = update_g1_has_attn self.update_chnnl_2 = update_chnnl_2 self.update_g2_has_g1g1 = update_g2_has_g1g1 if self.update_chnnl_2 else False self.update_g2_has_attn = update_g2_has_attn if self.update_chnnl_2 else False self.update_h2 = update_h2 if self.update_chnnl_2 else False del update_g2_has_g1g1, update_g2_has_attn, update_h2 self.attn1_hidden = attn1_hidden self.attn1_nhead = attn1_nhead self.attn2_hidden = attn2_hidden self.attn2_nhead = attn2_nhead self.attn2_has_gate = attn2_has_gate self.update_style = update_style self.update_residual = update_residual self.update_residual_init = update_residual_init self.smooth = smooth self.g1_dim = g1_dim self.g2_dim = g2_dim self.trainable_ln = trainable_ln self.ln_eps = ln_eps self.precision = precision assert update_residual_init in [ "norm", "const", ], "'update_residual_init' only support 'norm' or 'const'!" self.update_residual = update_residual self.update_residual_init = update_residual_init self.g1_residual = [] self.g2_residual = [] self.h2_residual = [] if self.update_style == "res_residual": self.g1_residual.append( get_residual( g1_dim, self.update_residual, self.update_residual_init, precision=precision, ) ) g1_in_dim = self.cal_1_dim(g1_dim, g2_dim, self.axis_neuron) self.linear1 = MLPLayer(g1_in_dim, g1_dim, precision=precision) self.linear2 = None self.proj_g1g2 = None self.proj_g1g1g2 = None self.attn2g_map = None self.attn2_mh_apply = None self.attn2_lm = None self.attn2_ev_apply = None self.loc_attn = None if self.update_chnnl_2: self.linear2 = MLPLayer(g2_dim, g2_dim, precision=precision) if self.update_style == "res_residual": self.g2_residual.append( get_residual( g2_dim, self.update_residual, self.update_residual_init, precision=precision, ) ) if self.update_g1_has_conv: self.proj_g1g2 = MLPLayer(g1_dim, g2_dim, bias=False, precision=precision) if self.update_g2_has_g1g1: self.proj_g1g1g2 = MLPLayer(g1_dim, g2_dim, bias=False, precision=precision) if self.update_style == "res_residual": self.g2_residual.append( get_residual( g2_dim, self.update_residual, self.update_residual_init, precision=precision, ) ) if self.update_g2_has_attn or self.update_h2: self.attn2g_map = Atten2Map( g2_dim, attn2_hidden, attn2_nhead, attn2_has_gate, self.smooth, precision=precision, ) if self.update_g2_has_attn: self.attn2_mh_apply = Atten2MultiHeadApply( g2_dim, attn2_nhead, precision=precision ) self.attn2_lm = LayerNorm( g2_dim, eps=ln_eps, trainable=trainable_ln, precision=precision ) if self.update_style == "res_residual": self.g2_residual.append( get_residual( g2_dim, self.update_residual, self.update_residual_init, precision=precision, ) ) if self.update_h2: self.attn2_ev_apply = Atten2EquiVarApply( g2_dim, attn2_nhead, precision=precision ) if self.update_style == "res_residual": self.h2_residual.append( get_residual( 1, self.update_residual, self.update_residual_init, precision=precision, ) ) if self.update_g1_has_attn: self.loc_attn = LocalAtten( g1_dim, attn1_hidden, attn1_nhead, self.smooth, precision=precision ) if self.update_style == "res_residual": self.g1_residual.append( get_residual( g1_dim, self.update_residual, self.update_residual_init, precision=precision, ) ) self.g1_residual = nn.ParameterList(self.g1_residual) self.g2_residual = nn.ParameterList(self.g2_residual) self.h2_residual = nn.ParameterList(self.h2_residual)
[docs] def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: ret = g1d if self.update_g1_has_grrg: ret += g2d * ax if self.update_g1_has_drrd: ret += g1d * ax if self.update_g1_has_conv: ret += g2d return ret
[docs] def _update_h2( self, h2: torch.Tensor, attn: torch.Tensor, ) -> torch.Tensor: """ Calculate the attention weights update for pair-wise equivariant rep. Parameters ---------- h2 Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3. attn Attention weights from g2 attention, with shape nf x nloc x nnei x nnei x nh2. """ assert self.attn2_ev_apply is not None # nf x nloc x nnei x nh2 h2_1 = self.attn2_ev_apply(attn, h2) return h2_1
[docs] def _update_g1_conv( self, gg1: torch.Tensor, g2: torch.Tensor, nlist_mask: torch.Tensor, sw: torch.Tensor, ) -> torch.Tensor: """ Calculate the convolution update for atomic invariant rep. Parameters ---------- gg1 Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. g2 Pair invariant rep, with shape nb x nloc x nnei x ng2. nlist_mask Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. sw The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, and remains 0 beyond rcut, with shape nb x nloc x nnei. """ assert self.proj_g1g2 is not None nb, nloc, nnei, _ = g2.shape ng1 = gg1.shape[-1] ng2 = g2.shape[-1] # gg1 : nb x nloc x nnei x ng2 gg1 = self.proj_g1g2(gg1).view(nb, nloc, nnei, ng2) # nb x nloc x nnei x ng2 gg1 = _apply_nlist_mask(gg1, nlist_mask) if not self.smooth: # normalized by number of neighbors, not smooth # nb x nloc x 1 # must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy invnnei = 1.0 / ( self.epsilon + torch.sum(nlist_mask.type_as(gg1), dim=-1) ).unsqueeze(-1) else: gg1 = _apply_switch(gg1, sw) invnnei = (1.0 / float(nnei)) * torch.ones( (nb, nloc, 1), dtype=gg1.dtype, device=gg1.device ) # nb x nloc x ng2 g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei return g1_11
@staticmethod
[docs] def _cal_hg( g2: torch.Tensor, h2: torch.Tensor, nlist_mask: torch.Tensor, sw: torch.Tensor, smooth: bool = True, epsilon: float = 1e-4, ) -> torch.Tensor: """ Calculate the transposed rotation matrix. Parameters ---------- g2 Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng2. h2 Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3. nlist_mask Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. sw The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, and remains 0 beyond rcut, with shape nb x nloc x nnei. smooth Whether to use smoothness in processes such as attention weights calculation. epsilon Protection of 1./nnei. Returns ------- hg The transposed rotation matrix, with shape nb x nloc x 3 x ng2. """ # g2: nb x nloc x nnei x ng2 # h2: nb x nloc x nnei x 3 # msk: nb x nloc x nnei nb, nloc, nnei, _ = g2.shape ng2 = g2.shape[-1] # nb x nloc x nnei x ng2 g2 = _apply_nlist_mask(g2, nlist_mask) if not smooth: # nb x nloc # must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy invnnei = 1.0 / (epsilon + torch.sum(nlist_mask.type_as(g2), dim=-1)) # nb x nloc x 1 x 1 invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) else: g2 = _apply_switch(g2, sw) invnnei = (1.0 / float(nnei)) * torch.ones( (nb, nloc, 1, 1), dtype=g2.dtype, device=g2.device ) # nb x nloc x 3 x ng2 h2g2 = torch.matmul(torch.transpose(h2, -1, -2), g2) * invnnei return h2g2
@staticmethod
[docs] def _cal_grrg(h2g2: torch.Tensor, axis_neuron: int) -> torch.Tensor: """ Calculate the atomic invariant rep. Parameters ---------- h2g2 The transposed rotation matrix, with shape nb x nloc x 3 x ng2. axis_neuron Size of the submatrix. Returns ------- grrg Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng2) """ # nb x nloc x 3 x ng2 nb, nloc, _, ng2 = h2g2.shape # nb x nloc x 3 x axis h2g2m = torch.split(h2g2, axis_neuron, dim=-1)[0] # nb x nloc x axis x ng2 g1_13 = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1) # nb x nloc x (axisxng2) g1_13 = g1_13.view(nb, nloc, axis_neuron * ng2) return g1_13
[docs] def symmetrization_op( self, g2: torch.Tensor, h2: torch.Tensor, nlist_mask: torch.Tensor, sw: torch.Tensor, axis_neuron: int, smooth: bool = True, epsilon: float = 1e-4, ) -> torch.Tensor: """ Symmetrization operator to obtain atomic invariant rep. Parameters ---------- g2 Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng2. h2 Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3. nlist_mask Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. sw The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, and remains 0 beyond rcut, with shape nb x nloc x nnei. axis_neuron Size of the submatrix. smooth Whether to use smoothness in processes such as attention weights calculation. epsilon Protection of 1./nnei. Returns ------- grrg Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng2) """ # g2: nb x nloc x nnei x ng2 # h2: nb x nloc x nnei x 3 # msk: nb x nloc x nnei nb, nloc, nnei, _ = g2.shape # nb x nloc x 3 x ng2 h2g2 = self._cal_hg(g2, h2, nlist_mask, sw, smooth=smooth, epsilon=epsilon) # nb x nloc x (axisxng2) g1_13 = self._cal_grrg(h2g2, axis_neuron) return g1_13
[docs] def _update_g2_g1g1( self, g1: torch.Tensor, # nb x nloc x ng1 gg1: torch.Tensor, # nb x nloc x nnei x ng1 nlist_mask: torch.Tensor, # nb x nloc x nnei sw: torch.Tensor, # nb x nloc x nnei ) -> torch.Tensor: """ Update the g2 using element-wise dot g1_i * g1_j. Parameters ---------- g1 Atomic invariant rep, with shape nb x nloc x ng1. gg1 Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. nlist_mask Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. sw The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, and remains 0 beyond rcut, with shape nb x nloc x nnei. """ ret = g1.unsqueeze(-2) * gg1 # nb x nloc x nnei x ng1 ret = _apply_nlist_mask(ret, nlist_mask) if self.smooth: ret = _apply_switch(ret, sw) return ret
[docs] def forward( self, g1_ext: torch.Tensor, # nf x nall x ng1 g2: torch.Tensor, # nf x nloc x nnei x ng2 h2: torch.Tensor, # nf x nloc x nnei x 3 nlist: torch.Tensor, # nf x nloc x nnei nlist_mask: torch.Tensor, # nf x nloc x nnei sw: torch.Tensor, # switch func, nf x nloc x nnei ): """ Parameters ---------- g1_ext : nf x nall x ng1 extended single-atom chanel g2 : nf x nloc x nnei x ng2 pair-atom channel, invariant h2 : nf x nloc x nnei x 3 pair-atom channel, equivariant nlist : nf x nloc x nnei neighbor list (padded neis are set to 0) nlist_mask : nf x nloc x nnei masks of the neighbor list. real nei 1 otherwise 0 sw : nf x nloc x nnei switch function Returns ------- g1: nf x nloc x ng1 updated single-atom chanel g2: nf x nloc x nnei x ng2 updated pair-atom channel, invariant h2: nf x nloc x nnei x 3 updated pair-atom channel, equivariant """ cal_gg1 = ( self.update_g1_has_drrd or self.update_g1_has_conv or self.update_g1_has_attn or self.update_g2_has_g1g1 ) nb, nloc, nnei, _ = g2.shape nall = g1_ext.shape[1] g1, _ = torch.split(g1_ext, [nloc, nall - nloc], dim=1) assert (nb, nloc) == g1.shape[:2] assert (nb, nloc, nnei) == h2.shape[:3] g2_update: List[torch.Tensor] = [g2] h2_update: List[torch.Tensor] = [h2] g1_update: List[torch.Tensor] = [g1] g1_mlp: List[torch.Tensor] = [g1] if cal_gg1: gg1 = _make_nei_g1(g1_ext, nlist) else: gg1 = None if self.update_chnnl_2: # mlp(g2) assert self.linear2 is not None # nb x nloc x nnei x ng2 g2_1 = self.act(self.linear2(g2)) g2_update.append(g2_1) if self.update_g2_has_g1g1: # linear(g1_i * g1_j) assert gg1 is not None assert self.proj_g1g1g2 is not None g2_update.append( self.proj_g1g1g2(self._update_g2_g1g1(g1, gg1, nlist_mask, sw)) ) if self.update_g2_has_attn or self.update_h2: # gated_attention(g2, h2) assert self.attn2g_map is not None # nb x nloc x nnei x nnei x nh AAg = self.attn2g_map(g2, h2, nlist_mask, sw) if self.update_g2_has_attn: assert self.attn2_mh_apply is not None assert self.attn2_lm is not None # nb x nloc x nnei x ng2 g2_2 = self.attn2_mh_apply(AAg, g2) g2_2 = self.attn2_lm(g2_2) g2_update.append(g2_2) if self.update_h2: # linear_head(attention_weights * h2) h2_update.append(self._update_h2(h2, AAg)) if self.update_g1_has_conv: assert gg1 is not None g1_mlp.append(self._update_g1_conv(gg1, g2, nlist_mask, sw)) if self.update_g1_has_grrg: g1_mlp.append( self.symmetrization_op( g2, h2, nlist_mask, sw, self.axis_neuron, smooth=self.smooth, epsilon=self.epsilon, ) ) if self.update_g1_has_drrd: assert gg1 is not None g1_mlp.append( self.symmetrization_op( gg1, h2, nlist_mask, sw, self.axis_neuron, smooth=self.smooth, epsilon=self.epsilon, ) ) # nb x nloc x [ng1+ng2+(axisxng2)+(axisxng1)] # conv grrg drrd g1_1 = self.act(self.linear1(torch.cat(g1_mlp, dim=-1))) g1_update.append(g1_1) if self.update_g1_has_attn: assert gg1 is not None assert self.loc_attn is not None g1_update.append(self.loc_attn(g1, gg1, nlist_mask, sw)) # update if self.update_chnnl_2: g2_new = self.list_update(g2_update, "g2") h2_new = self.list_update(h2_update, "h2") else: g2_new, h2_new = g2, h2 g1_new = self.list_update(g1_update, "g1") return g1_new, g2_new, h2_new
@torch.jit.export
[docs] def list_update_res_avg( self, update_list: List[torch.Tensor], ) -> torch.Tensor: nitem = len(update_list) uu = update_list[0] for ii in range(1, nitem): uu = uu + update_list[ii] return uu / (float(nitem) ** 0.5)
@torch.jit.export
[docs] def list_update_res_incr(self, update_list: List[torch.Tensor]) -> torch.Tensor: nitem = len(update_list) uu = update_list[0] scale = 1.0 / (float(nitem - 1) ** 0.5) if nitem > 1 else 0.0 for ii in range(1, nitem): uu = uu + scale * update_list[ii] return uu
@torch.jit.export
[docs] def list_update_res_residual( self, update_list: List[torch.Tensor], update_name: str = "g1" ) -> torch.Tensor: nitem = len(update_list) uu = update_list[0] # make jit happy if update_name == "g1": for ii, vv in enumerate(self.g1_residual): uu = uu + vv * update_list[ii + 1] elif update_name == "g2": for ii, vv in enumerate(self.g2_residual): uu = uu + vv * update_list[ii + 1] elif update_name == "h2": for ii, vv in enumerate(self.h2_residual): uu = uu + vv * update_list[ii + 1] else: raise NotImplementedError return uu
@torch.jit.export
[docs] def list_update( self, update_list: List[torch.Tensor], update_name: str = "g1" ) -> torch.Tensor: if self.update_style == "res_avg": return self.list_update_res_avg(update_list) elif self.update_style == "res_incr": return self.list_update_res_incr(update_list) elif self.update_style == "res_residual": return self.list_update_res_residual(update_list, update_name=update_name) else: raise RuntimeError(f"unknown update style {self.update_style}")
[docs] def serialize(self) -> dict: """Serialize the networks to a dict. Returns ------- dict The serialized networks. """ data = { "@class": "RepformerLayer", "@version": 1, "rcut": self.rcut, "rcut_smth": self.rcut_smth, "sel": self.sel, "ntypes": self.ntypes, "g1_dim": self.g1_dim, "g2_dim": self.g2_dim, "axis_neuron": self.axis_neuron, "update_chnnl_2": self.update_chnnl_2, "update_g1_has_conv": self.update_g1_has_conv, "update_g1_has_drrd": self.update_g1_has_drrd, "update_g1_has_grrg": self.update_g1_has_grrg, "update_g1_has_attn": self.update_g1_has_attn, "update_g2_has_g1g1": self.update_g2_has_g1g1, "update_g2_has_attn": self.update_g2_has_attn, "update_h2": self.update_h2, "attn1_hidden": self.attn1_hidden, "attn1_nhead": self.attn1_nhead, "attn2_hidden": self.attn2_hidden, "attn2_nhead": self.attn2_nhead, "attn2_has_gate": self.attn2_has_gate, "activation_function": self.activation_function, "update_style": self.update_style, "smooth": self.smooth, "precision": self.precision, "trainable_ln": self.trainable_ln, "ln_eps": self.ln_eps, "linear1": self.linear1.serialize(), } if self.update_chnnl_2: data.update( { "linear2": self.linear2.serialize(), } ) if self.update_g1_has_conv: data.update( { "proj_g1g2": self.proj_g1g2.serialize(), } ) if self.update_g2_has_g1g1: data.update( { "proj_g1g1g2": self.proj_g1g1g2.serialize(), } ) if self.update_g2_has_attn or self.update_h2: data.update( { "attn2g_map": self.attn2g_map.serialize(), } ) if self.update_g2_has_attn: data.update( { "attn2_mh_apply": self.attn2_mh_apply.serialize(), "attn2_lm": self.attn2_lm.serialize(), } ) if self.update_h2: data.update( { "attn2_ev_apply": self.attn2_ev_apply.serialize(), } ) if self.update_g1_has_attn: data.update( { "loc_attn": self.loc_attn.serialize(), } ) if self.update_style == "res_residual": data.update( { "g1_residual": [to_numpy_array(t) for t in self.g1_residual], "g2_residual": [to_numpy_array(t) for t in self.g2_residual], "h2_residual": [to_numpy_array(t) for t in self.h2_residual], } ) return data
@classmethod
[docs] def deserialize(cls, data: dict) -> "RepformerLayer": """Deserialize the networks from a dict. Parameters ---------- data : dict The dict to deserialize from. """ data = data.copy() check_version_compatibility(data.pop("@version"), 1, 1) data.pop("@class") linear1 = data.pop("linear1") update_chnnl_2 = data["update_chnnl_2"] update_g1_has_conv = data["update_g1_has_conv"] update_g2_has_g1g1 = data["update_g2_has_g1g1"] update_g2_has_attn = data["update_g2_has_attn"] update_h2 = data["update_h2"] update_g1_has_attn = data["update_g1_has_attn"] update_style = data["update_style"] linear2 = data.pop("linear2", None) proj_g1g2 = data.pop("proj_g1g2", None) proj_g1g1g2 = data.pop("proj_g1g1g2", None) attn2g_map = data.pop("attn2g_map", None) attn2_mh_apply = data.pop("attn2_mh_apply", None) attn2_lm = data.pop("attn2_lm", None) attn2_ev_apply = data.pop("attn2_ev_apply", None) loc_attn = data.pop("loc_attn", None) g1_residual = data.pop("g1_residual", []) g2_residual = data.pop("g2_residual", []) h2_residual = data.pop("h2_residual", []) obj = cls(**data) obj.linear1 = MLPLayer.deserialize(linear1) if update_chnnl_2: assert isinstance(linear2, dict) obj.linear2 = MLPLayer.deserialize(linear2) if update_g1_has_conv: assert isinstance(proj_g1g2, dict) obj.proj_g1g2 = MLPLayer.deserialize(proj_g1g2) if update_g2_has_g1g1: assert isinstance(proj_g1g1g2, dict) obj.proj_g1g1g2 = MLPLayer.deserialize(proj_g1g1g2) if update_g2_has_attn or update_h2: assert isinstance(attn2g_map, dict) obj.attn2g_map = Atten2Map.deserialize(attn2g_map) if update_g2_has_attn: assert isinstance(attn2_mh_apply, dict) assert isinstance(attn2_lm, dict) obj.attn2_mh_apply = Atten2MultiHeadApply.deserialize(attn2_mh_apply) obj.attn2_lm = LayerNorm.deserialize(attn2_lm) if update_h2: assert isinstance(attn2_ev_apply, dict) obj.attn2_ev_apply = Atten2EquiVarApply.deserialize(attn2_ev_apply) if update_g1_has_attn: assert isinstance(loc_attn, dict) obj.loc_attn = LocalAtten.deserialize(loc_attn) if update_style == "res_residual": for ii, t in enumerate(obj.g1_residual): t.data = to_torch_tensor(g1_residual[ii]) for ii, t in enumerate(obj.g2_residual): t.data = to_torch_tensor(g2_residual[ii]) for ii, t in enumerate(obj.h2_residual): t.data = to_torch_tensor(h2_residual[ii]) return obj