Source code for deepmd.dpmodel.utils.env_mat

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    Any,
    Optional,
)

import array_api_compat

from deepmd.dpmodel import (
    NativeOP,
)
from deepmd.dpmodel.array_api import (
    Array,
    support_array_api,
    xp_take_along_axis,
)
from deepmd.dpmodel.utils.safe_gradient import (
    safe_for_vector_norm,
)


@support_array_api(version="2023.12")
[docs] def compute_smooth_weight( distance: Array, rmin: float, rmax: float, ) -> Array: """Compute smooth weight for descriptor elements.""" if rmin >= rmax: raise ValueError("rmin should be less than rmax.") xp = array_api_compat.array_namespace(distance) distance = xp.clip(distance, min=rmin, max=rmax) uu = (distance - rmin) / (rmax - rmin) uu2 = uu * uu vv = uu2 * uu * (-6.0 * uu2 + 15.0 * uu - 10.0) + 1.0 return vv
@support_array_api(version="2023.12")
[docs] def compute_exp_sw( distance: Array, rmin: float, rmax: float, ) -> Array: """Compute the exponential switch function for neighbor update.""" if rmin >= rmax: raise ValueError("rmin should be less than rmax.") xp = array_api_compat.array_namespace(distance) distance = xp.clip(distance, min=0.0, max=rmax) C = 20 a = C / rmin b = rmin exp_sw = xp.exp(-xp.exp(a * (distance - b))) return exp_sw
[docs] def _make_env_mat( nlist: Any, coord: Any, rcut: float, ruct_smth: float, radial_only: bool = False, protection: float = 0.0, use_exp_switch: bool = False, ) -> tuple[Any, Any, Any]: """Make smooth environment matrix.""" xp = array_api_compat.array_namespace(nlist) nf, nloc, nnei = nlist.shape # nf x nall x 3 coord = xp.reshape(coord, (nf, -1, 3)) mask = nlist >= 0 nlist = nlist * xp.astype(mask, nlist.dtype) # nf x (nloc x nnei) x 3 index = xp.tile(xp.reshape(nlist, (nf, -1, 1)), (1, 1, 3)) coord_r = xp_take_along_axis(coord, index, 1) # nf x nloc x nnei x 3 coord_r = xp.reshape(coord_r, (nf, nloc, nnei, 3)) # nf x nloc x 1 x 3 coord_l = xp.reshape(coord[:, :nloc, ...], (nf, -1, 1, 3)) # nf x nloc x nnei x 3 diff = coord_r - coord_l # nf x nloc x nnei # the grad of JAX vector_norm is NaN at x=0 length = safe_for_vector_norm(diff, axis=-1, keepdims=True) # for index 0 nloc atom length = length + xp.astype(~xp.expand_dims(mask, axis=-1), length.dtype) t0 = 1 / (length + protection) t1 = diff / (length + protection) ** 2 weight = ( compute_smooth_weight(length, ruct_smth, rcut) if not use_exp_switch else compute_exp_sw(length, ruct_smth, rcut) ) weight = weight * xp.astype(xp.expand_dims(mask, axis=-1), weight.dtype) if radial_only: env_mat = t0 * weight else: env_mat = xp.concat([t0, t1], axis=-1) * weight return env_mat, diff * xp.astype(xp.expand_dims(mask, axis=-1), diff.dtype), weight
[docs] class EnvMat(NativeOP): def __init__( self, rcut: float, rcut_smth: float, protection: float = 0.0, use_exp_switch: bool = False, ) -> None:
[docs] self.rcut = rcut
[docs] self.rcut_smth = rcut_smth
[docs] self.protection = protection
[docs] self.use_exp_switch = use_exp_switch
[docs] def call( self, coord_ext: Array, atype_ext: Array, nlist: Array, davg: Optional[Array] = None, dstd: Optional[Array] = None, radial_only: bool = False, ) -> tuple[Array, Array, Array]: """Compute the environment matrix. Parameters ---------- nlist The neighbor list. shape: nf x nloc x nnei coord_ext The extended coordinates of atoms. shape: nf x (nallx3) atype_ext The extended aotm types. shape: nf x nall davg The data avg. shape: nt x nnei x (4 or 1) dstd The inverse of data std. shape: nt x nnei x (4 or 1) radial_only Whether to only compute radial part of the environment matrix. If True, the output will be of shape nf x nloc x nnei x 1. Otherwise, the output will be of shape nf x nloc x nnei x 4. Default: False. Returns ------- env_mat The environment matrix. shape: nf x nloc x nnei x (4 or 1) diff The relative coordinate of neighbors. shape: nf x nloc x nnei x 3 switch The value of switch function. shape: nf x nloc x nnei """ xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) em, diff, sw = self._call(nlist, coord_ext, radial_only) nf, nloc, nnei = nlist.shape atype = atype_ext[:, :nloc] if davg is not None: em -= xp.reshape(xp.take(davg, xp.reshape(atype, (-1,)), axis=0), em.shape) if dstd is not None: em /= xp.reshape(xp.take(dstd, xp.reshape(atype, (-1,)), axis=0), em.shape) return em, diff, sw
[docs] def _call( self, nlist: Any, coord_ext: Any, radial_only: bool ) -> tuple[Any, Any, Any]: em, diff, ww = _make_env_mat( nlist, coord_ext, self.rcut, self.rcut_smth, radial_only=radial_only, protection=self.protection, use_exp_switch=self.use_exp_switch, ) return em, diff, ww
[docs] def serialize( self, ) -> dict: return { "rcut": self.rcut, "rcut_smth": self.rcut_smth, "protection": self.protection, "use_exp_switch": self.use_exp_switch, }
@classmethod
[docs] def deserialize( cls, data: dict, ) -> "EnvMat": return cls(**data)