deepmd.pt.utils.exclude_mask#

Classes#

AtomExcludeMask

Computes the type exclusion mask for atoms.

PairExcludeMask

Computes the type exclusion mask for atom pairs.

Module Contents#

class deepmd.pt.utils.exclude_mask.AtomExcludeMask(ntypes: int, exclude_types: list[int] = [])[source]#

Bases: torch.nn.Module

Computes the type exclusion mask for atoms.

reinit(ntypes: int, exclude_types: list[int] = []) None[source]#
get_exclude_types()[source]#
get_type_mask()[source]#
forward(atype: torch.Tensor) torch.Tensor[source]#

Compute type exclusion mask for atoms.

Parameters:
atype

The extended atom types. shape: nf x natom

Returns:
mask

The type exclusion mask for atoms. shape: nf x natom Element [ff,ii] being 0 if type(ii) is excluded, otherwise being 1.

class deepmd.pt.utils.exclude_mask.PairExcludeMask(ntypes: int, exclude_types: list[tuple[int, int]] = [])[source]#

Bases: torch.nn.Module

Computes the type exclusion mask for atom pairs.

reinit(ntypes: int, exclude_types: list[tuple[int, int]] = []) None[source]#
get_exclude_types()[source]#
forward(nlist: torch.Tensor, atype_ext: torch.Tensor) torch.Tensor[source]#

Compute type exclusion mask.

Parameters:
nlist

The neighbor list. shape: nf x nloc x nnei

atype_ext

The extended aotm types. shape: nf x nall

Returns:
mask

The type exclusion mask of shape: nf x nloc x nnei. Element [ff,ii,jj] being 0 if type(ii), type(nlist[ff,ii,jj]) is excluded, otherwise being 1.