Source code for deepmd.pt.model.atomic_model.base_atomic_model

# SPDX-License-Identifier: LGPL-3.0-or-later


from typing import (
    Dict,
    List,
    Optional,
    Tuple,
)

import torch

from deepmd.dpmodel.atomic_model import (
    make_base_atomic_model,
)
from deepmd.dpmodel.output_def import (
    FittingOutputDef,
    OutputVariableDef,
)
from deepmd.pt.utils import (
    AtomExcludeMask,
    PairExcludeMask,
)

BaseAtomicModel_ = make_base_atomic_model(torch.Tensor)


[docs]class BaseAtomicModel(BaseAtomicModel_): def __init__( self, atom_exclude_types: List[int] = [], pair_exclude_types: List[Tuple[int, int]] = [], ): super().__init__() self.reinit_atom_exclude(atom_exclude_types) self.reinit_pair_exclude(pair_exclude_types)
[docs] def reinit_atom_exclude( self, exclude_types: List[int] = [], ): self.atom_exclude_types = exclude_types if exclude_types == []: self.atom_excl = None else: self.atom_excl = AtomExcludeMask(self.get_ntypes(), self.atom_exclude_types)
[docs] def reinit_pair_exclude( self, exclude_types: List[Tuple[int, int]] = [], ): self.pair_exclude_types = exclude_types if exclude_types == []: self.pair_excl = None else: self.pair_excl = PairExcludeMask(self.get_ntypes(), self.pair_exclude_types)
# export public methods that are not abstract get_nsel = torch.jit.export(BaseAtomicModel_.get_nsel) get_nnei = torch.jit.export(BaseAtomicModel_.get_nnei) get_ntypes = torch.jit.export(BaseAtomicModel_.get_ntypes)
[docs] @torch.jit.export def get_model_def_script(self) -> str: return self.model_def_script
[docs] def atomic_output_def(self) -> FittingOutputDef: old_def = self.fitting_output_def() if self.atom_excl is None: return old_def else: old_list = list(old_def.get_data().values()) return FittingOutputDef( old_list # noqa:RUF005 + [ OutputVariableDef( name="mask", shape=[1], reduciable=False, r_differentiable=False, c_differentiable=False, ) ] )
[docs] def forward_common_atomic( self, extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, mapping: Optional[torch.Tensor] = None, fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: _, nloc, _ = nlist.shape atype = extended_atype[:, :nloc] if self.pair_excl is not None: pair_mask = self.pair_excl(nlist, extended_atype) # exclude neighbors in the nlist nlist = torch.where(pair_mask == 1, nlist, -1) ret_dict = self.forward_atomic( extended_coord, extended_atype, nlist, mapping=mapping, fparam=fparam, aparam=aparam, ) if self.atom_excl is not None: atom_mask = self.atom_excl(atype) for kk in ret_dict.keys(): out_shape = ret_dict[kk].shape ret_dict[kk] = ( ret_dict[kk].reshape([out_shape[0], out_shape[1], -1]) * atom_mask[:, :, None] ).reshape(out_shape) ret_dict["mask"] = atom_mask return ret_dict
[docs] def serialize(self) -> dict: return { "atom_exclude_types": self.atom_exclude_types, "pair_exclude_types": self.pair_exclude_types, }