deepmd.pt.utils.exclude_mask

Module Contents

Classes

AtomExcludeMask

Computes the type exclusion mask for atoms.

PairExcludeMask

Computes the type exclusion mask for atom pairs.

class deepmd.pt.utils.exclude_mask.AtomExcludeMask(ntypes: int, exclude_types: List[int] = [])[source]

Bases: torch.nn.Module

Computes the type exclusion mask for atoms.

reinit(ntypes: int, exclude_types: List[int] = [])[source]
get_exclude_types()[source]
get_type_mask()[source]
forward(atype: torch.Tensor) torch.Tensor[source]

Compute type exclusion mask for atoms.

Parameters:
atype

The extended atom types. shape: nf x natom

Returns:
mask

The type exclusion mask for atoms. shape: nf x natom Element [ff,ii] being 0 if type(ii) is excluded, otherwise being 1.

class deepmd.pt.utils.exclude_mask.PairExcludeMask(ntypes: int, exclude_types: List[Tuple[int, int]] = [])[source]

Bases: torch.nn.Module

Computes the type exclusion mask for atom pairs.

reinit(ntypes: int, exclude_types: List[Tuple[int, int]] = [])[source]
get_exclude_types()[source]
forward(nlist: torch.Tensor, atype_ext: torch.Tensor) torch.Tensor[source]

Compute type exclusion mask.

Parameters:
nlist

The neighbor list. shape: nf x nloc x nnei

atype_ext

The extended aotm types. shape: nf x nall

Returns:
mask

The type exclusion mask of shape: nf x nloc x nnei. Element [ff,ii,jj] being 0 if type(ii), type(nlist[ff,ii,jj]) is excluded, otherwise being 1.