deepmd.dpmodel.descriptor.dpa4_nn.edge_cache#

Edge cache construction for the dpmodel DPA4/SeZM descriptor.

This module defines the EdgeCache dataclass (the dpmodel counterpart of the pt EdgeFeatureCache NamedTuple from deepmd.pt.model.descriptor.sezm_nn.edge_cache) and build_edge_cache(), the padded-layout counterpart of pt’s sparse build_edge_cache.

Padded-edge layout#

The pt implementation extracts a sparse edge list with torch.nonzero: only valid neighbor slots become edges, and per-edge tensors have a data-dependent length E. The dpmodel implementation instead uses a padded and frame-explicit edge layout: every neighbor slot of the DeePMD neighbor list contributes one edge, so

E = nf * nloc * nnei

with per-edge tensors flattened from (nf, nloc, nnei, ...) in row-major order. Invalid slots (nlist == -1 padding, excluded type pairs) stay in the arrays and are marked by edge_mask == 0. Edge slot (f, i, j) always belongs to destination node f * nloc + i, so destination aggregation is a masked sum over the nnei axis instead of a scatter.

Classes#

EdgeCache

Global edge feature cache created once per forward().

Functions#

_build_edge_mask_and_src(→ tuple[Any, Any, Any])

Build the padded edge validity mask and safe source-local indices.

build_edge_cache(→ EdgeCache)

Build the global padded edge cache from a DeePMD padded neighbor list.

Module Contents#

class deepmd.dpmodel.descriptor.dpa4_nn.edge_cache.EdgeCache[source]#

Global edge feature cache created once per forward().

All per-edge arrays are aligned on the same padded edge axis (E = nf * nloc * nnei); see the module docstring for the layout contract. Node-level arrays use the local node axis N = nf * nloc.

An EdgeCache must not be reused across forward passes: D_to_m_cache/Dt_from_m_cache are keyed only by "lmax:mmax", not by the contents of D_full, so reuse with different Wigner blocks would silently return stale projections.

Parameters:
src

Source (neighbor) node indices with shape (E,), pointing into the local node axis N = nf * nloc. Invalid slots must hold a safe in-range index (their contribution is masked out by edge_mask).

dst

Destination (center) node indices with shape (E,). In the padded layout this is slot-implicit and MUST equal arange(nf * nloc) with each index repeated nnei consecutive times (i.e. np.repeat(np.arange(nf * nloc), nnei); node-contiguous order); aggregation code relies on this ordering.

edge_type_feat

Per-edge type embeddings with shape (E, C), computed as src+dst.

edge_vec

Edge vectors with shape (E, 3) in Å.

edge_rbf

Radial basis with shape (E, n_radial). The C^3 cutoff envelope is already baked in.

edge_env

C^3 cutoff envelope weights with shape (E, 1). Zero on invalid slots.

deg

Envelope-squared smooth degree with shape (N,), computed as the masked sum(edge_env**2) over each node’s nnei slots.

inv_sqrt_deg

Inverse square root smooth degree normalization with shape (N, 1, 1).

D_full

Block-diagonal Wigner-D matrix with shape (E, D, D) where D=(lmax+1)^2. Used for efficient batched rotation. None if not available.

Dt_full

Transpose of D_full with shape (E, D, D). None if not available.

D_to_m_cache

Lazy cache for projected D matrices keyed by a normalized "lmax:mmax" identifier. The key does not capture the contents of D_full, so the cache is only valid for the forward pass that created this EdgeCache (see the class docstring).

Dt_from_m_cache

Lazy cache for projected Dt matrices keyed by a normalized "lmax:mmax" identifier. Same single-forward-pass validity caveat as D_to_m_cache.

edge_src_gate

Optional per-edge Source Freeze Propagation Gate (SFPG) weight with shape (E, 1). Present only in bridging mode; None otherwise.

edge_quat

Per-edge global-to-local quaternion used to build D_full and Dt_full with shape (E, 4). None if not available.

edge_mask

Validity mask for the padded-edge layout with shape (E,) or (E, 1); nonzero (1) marks a real edge, zero marks a padded/invalid slot. None means all slots are valid. This field has no pt counterpart: pt’s sparse edge list contains valid edges only, while dpmodel keeps the padded nf * nloc * nnei slots and masks the invalid ones.

src: Any[source]#
dst: Any[source]#
edge_type_feat: Any[source]#
edge_vec: Any[source]#
edge_rbf: Any[source]#
edge_env: Any[source]#
deg: Any[source]#
inv_sqrt_deg: Any[source]#
D_full: Any = None[source]#
Dt_full: Any = None[source]#
D_to_m_cache: dict[str, Any][source]#
Dt_from_m_cache: dict[str, Any][source]#
edge_src_gate: Any = None[source]#
edge_quat: Any = None[source]#
edge_mask: Any = None[source]#
deepmd.dpmodel.descriptor.dpa4_nn.edge_cache._build_edge_mask_and_src(xp: Any, nlist: Any, mapping: Any, pair_keep_mask: Any, nall: int) tuple[Any, Any, Any][source]#

Build the padded edge validity mask and safe source-local indices.

Mirrors the pt edge-keep semantics of sezm_nn.edge_cache._build_standard_edge_index exactly:

  • padding slots (nlist == -1) are invalid;

  • excluded type pairs (pair_keep_mask == False) are invalid;

  • after mapping the neighbor’s extended index to a local index, slots whose source falls outside [0, nloc) are invalid (pt’s src_ok filter; e.g. broken mapping or ghost-only neighbors);

  • no distance-based filtering: edges beyond rcut stay valid and are zeroed naturally by the smooth envelope.

Instead of dropping invalid slots (pt’s torch.nonzero), they are kept with mask == False and safe (index 0) placeholder indices.

Parameters:
xp

Array namespace.

nlist

Neighbor list with shape (nf, nloc, nnei); -1 marks padding.

mapping

Extended-to-local mapping with shape (nf, nall), or None if the neighbor indices are already local.

pair_keep_mask

Pair exclusion keep mask with shape (nf, nloc, nnei). True means keep.

nall

Number of atoms on the extended axis per frame.

Returns:
tuple[Any, Any, Any]

(mask, nlist_safe, src_local_safe), all with shape (nf, nloc, nnei). mask is boolean; the two index arrays are int64 with 0 substituted on invalid slots.

deepmd.dpmodel.descriptor.dpa4_nn.edge_cache.build_edge_cache(*, type_ebed: Any, extended_coord: Any, nlist: Any, mapping: Any, pair_keep_mask: Any, eps: float, deg_norm_floor: float, edge_envelope: Any, radial_basis: Any, n_radial: int, random_gamma: bool, wigner_calc: Any, gamma: Any = None) EdgeCache[source]#

Build the global padded edge cache from a DeePMD padded neighbor list.

Padded counterpart of pt sezm_nn.edge_cache.build_edge_cache. Instead of extracting a sparse edge list with torch.nonzero (data-dependent length), every neighbor slot becomes one edge slot: E = nf * nloc * nnei flattened row-major, with invalid slots marked by edge_mask == 0 (see the EdgeCache layout contract). In particular dst == np.repeat(arange(nf * nloc), nnei) always, and there is no empty-cache special case (E is shape-determined).

Masked-slot safety: gathered edge vectors on invalid slots are garbage (placeholder index 0), and could even be exactly zero (self-difference), which would produce a 0/0 in the normalization inside the quaternion construction. Although the forward contribution of such slots is masked out downstream, a NaN there would still poison the backward pass (where propagates NaN gradients from the unselected branch). Invalid slots are therefore rewritten to the safe dummy unit vector +z BEFORE any norm/quaternion/Wigner evaluation, and their envelope, radial basis, and type features are multiplied by the mask so they are exactly zero.

Parameters:
type_ebed

Per-node type embedding with shape (N, C), where N = nf * nloc.

extended_coord

Extended coordinates with shape (nf, nall, 3).

nlist

Neighbor list with shape (nf, nloc, nnei); -1 marks padding.

mapping

Mapping from extended to local indices with shape (nf, nall), or None when the neighbor indices are already local.

pair_keep_mask

Pair keep mask from PairExcludeMask with shape (nf, nloc, nnei). True means keep.

eps

Small positive epsilon for safe norm / quaternion construction.

deg_norm_floor

Floor added to the envelope-squared degree before the inverse-sqrt normalization.

edge_envelope

C^3 edge envelope callable (E, 1) -> (E, 1).

radial_basis

Radial basis callable (E, 1) -> (E, n_radial) (envelope baked in).

n_radial

Number of radial basis channels. Unused in the padded layout (kept for signature parity with pt, where it sizes the empty cache).

random_gamma

Whether to apply a random roll around the local +Z axis before constructing Wigner-D blocks.

wigner_calc

Callable converting edge quaternions (E, 4) into packed Wigner-D blocks (D_full, Dt_full).

gamma

Optional per-edge roll angles with shape (E,), used only when random_gamma is True. pt draws gamma internally with torch.rand and the draw cannot be reproduced here, so callers needing determinism (e.g. tests) inject the angles explicitly. When None, angles are drawn from numpy.random.default_rng() uniformly in [0, 2*pi), matching pt’s distribution.

Returns:
EdgeCache

Padded per-edge cache with edge_mask set.