Source code for deepmd.pt.model.network.layernorm

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    List,
    Optional,
    Union,
)

import numpy as np
import torch
import torch.nn as nn

from deepmd.dpmodel.utils.network import LayerNorm as DPLayerNorm
from deepmd.pt.model.network.init import (
    normal_,
    ones_,
    zeros_,
)
from deepmd.pt.utils import (
    env,
)
from deepmd.pt.utils.env import (
    DEFAULT_PRECISION,
    PRECISION_DICT,
)
from deepmd.pt.utils.utils import (
    get_generator,
    to_numpy_array,
    to_torch_tensor,
)

[docs] device = env.DEVICE
[docs] def empty_t(shape, precision): return torch.empty(shape, dtype=precision, device=device)
[docs] class LayerNorm(nn.Module): def __init__( self, num_in, eps: float = 1e-5, uni_init: bool = True, bavg: float = 0.0, stddev: float = 1.0, precision: str = DEFAULT_PRECISION, trainable: bool = True, seed: Optional[Union[int, List[int]]] = None, ): super().__init__()
[docs] self.eps = eps
[docs] self.uni_init = uni_init
[docs] self.num_in = num_in
[docs] self.precision = precision
[docs] self.prec = PRECISION_DICT[self.precision]
[docs] self.matrix = nn.Parameter(data=empty_t((num_in,), self.prec))
[docs] self.bias = nn.Parameter( data=empty_t([num_in], self.prec), )
[docs] random_generator = get_generator(seed)
if self.uni_init: ones_(self.matrix.data) zeros_(self.bias.data) else: normal_(self.bias.data, mean=bavg, std=stddev, generator=random_generator) normal_( self.matrix.data, std=stddev / np.sqrt(self.num_in), generator=random_generator, )
[docs] self.trainable = trainable
if not self.trainable: self.matrix.requires_grad = False self.bias.requires_grad = False
[docs] def dim_out(self) -> int: return self.matrix.shape[0]
[docs] def forward( self, xx: torch.Tensor, ) -> torch.Tensor: """One Layer Norm used by DP model. Parameters ---------- xx : torch.Tensor The input of index. Returns ------- yy: torch.Tensor The output. """ # mean = xx.mean(dim=-1, keepdim=True) # variance = xx.var(dim=-1, unbiased=False, keepdim=True) # The following operation is the same as above, but will not raise error when using jit model to inference. # See https://github.com/pytorch/pytorch/issues/85792 if xx.numel() > 0: variance, mean = torch.var_mean(xx, dim=-1, unbiased=False, keepdim=True) yy = (xx - mean) / torch.sqrt(variance + self.eps) else: yy = xx if self.matrix is not None and self.bias is not None: yy = yy * self.matrix + self.bias return yy
[docs] def serialize(self) -> dict: """Serialize the layer to a dict. Returns ------- dict The serialized layer. """ nl = DPLayerNorm( self.matrix.shape[0], eps=self.eps, trainable=self.trainable, precision=self.precision, ) nl.w = to_numpy_array(self.matrix) nl.b = to_numpy_array(self.bias) data = nl.serialize() return data
@classmethod
[docs] def deserialize(cls, data: dict) -> "LayerNorm": """Deserialize the layer from a dict. Parameters ---------- data : dict The dict to deserialize from. """ nl = DPLayerNorm.deserialize(data) obj = cls( nl["matrix"].shape[0], eps=nl["eps"], trainable=nl["trainable"], precision=nl["precision"], ) prec = PRECISION_DICT[obj.precision] def check_load_param(ss): return ( nn.Parameter(data=to_torch_tensor(nl[ss])) if nl[ss] is not None else None ) obj.matrix = check_load_param("matrix") obj.bias = check_load_param("bias") return obj