deepmd.dpmodel.descriptor.dpa4_nn.edge_cache#
Edge cache construction for the dpmodel DPA4/SeZM descriptor.
This module defines the EdgeCache dataclass (the dpmodel counterpart of the pt EdgeFeatureCache NamedTuple from deepmd.pt.model.descriptor.sezm_nn.edge_cache) and build_edge_cache(), the padded-layout counterpart of pt’s sparse build_edge_cache.
Padded-edge layout#
The pt implementation extracts a sparse edge list with torch.nonzero: only valid neighbor slots become edges, and per-edge tensors have a data-dependent length E. The dpmodel implementation instead uses a padded and frame-explicit edge layout: every neighbor slot of the DeePMD neighbor list contributes one edge, so
E = nf * nloc * nnei
with per-edge tensors flattened from (nf, nloc, nnei, ...) in row-major order. Invalid slots (nlist == -1 padding, excluded type pairs) stay in the arrays and are marked by edge_mask == 0. Edge slot (f, i, j) always belongs to destination node f * nloc + i, so destination aggregation is a masked sum over the nnei axis instead of a scatter.
Classes#
Global edge feature cache created once per forward(). |
Functions#
| Build the padded edge validity mask and safe source-local indices. |
| Build the global padded edge cache from a DeePMD padded neighbor list. |
Module Contents#
- class deepmd.dpmodel.descriptor.dpa4_nn.edge_cache.EdgeCache[source]#
Global edge feature cache created once per forward().
All per-edge arrays are aligned on the same padded edge axis (
E = nf * nloc * nnei); see the module docstring for the layout contract. Node-level arrays use the local node axisN = nf * nloc.An
EdgeCachemust not be reused across forward passes:D_to_m_cache/Dt_from_m_cacheare keyed only by"lmax:mmax", not by the contents ofD_full, so reuse with different Wigner blocks would silently return stale projections.- Parameters:
- src
Source (neighbor) node indices with shape (E,), pointing into the local node axis
N = nf * nloc. Invalid slots must hold a safe in-range index (their contribution is masked out byedge_mask).- dst
Destination (center) node indices with shape (E,). In the padded layout this is slot-implicit and MUST equal
arange(nf * nloc)with each index repeatednneiconsecutive times (i.e.np.repeat(np.arange(nf * nloc), nnei); node-contiguous order); aggregation code relies on this ordering.- edge_type_feat
Per-edge type embeddings with shape (E, C), computed as src+dst.
- edge_vec
Edge vectors with shape (E, 3) in Å.
- edge_rbf
Radial basis with shape (E, n_radial). The C^3 cutoff envelope is already baked in.
- edge_env
C^3 cutoff envelope weights with shape (E, 1). Zero on invalid slots.
- deg
Envelope-squared smooth degree with shape (N,), computed as the masked
sum(edge_env**2)over each node’snneislots.- inv_sqrt_deg
Inverse square root smooth degree normalization with shape (N, 1, 1).
- D_full
Block-diagonal Wigner-D matrix with shape (E, D, D) where D=(lmax+1)^2. Used for efficient batched rotation. None if not available.
- Dt_full
Transpose of D_full with shape (E, D, D). None if not available.
- D_to_m_cache
Lazy cache for projected D matrices keyed by a normalized
"lmax:mmax"identifier. The key does not capture the contents ofD_full, so the cache is only valid for the forward pass that created thisEdgeCache(see the class docstring).- Dt_from_m_cache
Lazy cache for projected Dt matrices keyed by a normalized
"lmax:mmax"identifier. Same single-forward-pass validity caveat asD_to_m_cache.- edge_src_gate
Optional per-edge Source Freeze Propagation Gate (SFPG) weight with shape (E, 1). Present only in bridging mode;
Noneotherwise.- edge_quat
Per-edge global-to-local quaternion used to build
D_fullandDt_fullwith shape (E, 4). None if not available.- edge_mask
Validity mask for the padded-edge layout with shape (E,) or (E, 1); nonzero (1) marks a real edge, zero marks a padded/invalid slot.
Nonemeans all slots are valid. This field has no pt counterpart: pt’s sparse edge list contains valid edges only, while dpmodel keeps the paddednf * nloc * nneislots and masks the invalid ones.
- deepmd.dpmodel.descriptor.dpa4_nn.edge_cache._build_edge_mask_and_src(xp: Any, nlist: Any, mapping: Any, pair_keep_mask: Any, nall: int) tuple[Any, Any, Any][source]#
Build the padded edge validity mask and safe source-local indices.
Mirrors the pt edge-keep semantics of
sezm_nn.edge_cache._build_standard_edge_indexexactly:padding slots (
nlist == -1) are invalid;excluded type pairs (
pair_keep_mask == False) are invalid;after mapping the neighbor’s extended index to a local index, slots whose source falls outside
[0, nloc)are invalid (pt’ssrc_okfilter; e.g. broken mapping or ghost-only neighbors);no distance-based filtering: edges beyond
rcutstay valid and are zeroed naturally by the smooth envelope.
Instead of dropping invalid slots (pt’s
torch.nonzero), they are kept withmask == Falseand safe (index 0) placeholder indices.- Parameters:
- xp
Array namespace.
- nlist
Neighbor list with shape (nf, nloc, nnei); -1 marks padding.
- mapping
Extended-to-local mapping with shape (nf, nall), or None if the neighbor indices are already local.
- pair_keep_mask
Pair exclusion keep mask with shape (nf, nloc, nnei). True means keep.
- nall
Number of atoms on the extended axis per frame.
- Returns:
- deepmd.dpmodel.descriptor.dpa4_nn.edge_cache.build_edge_cache(*, type_ebed: Any, extended_coord: Any, nlist: Any, mapping: Any, pair_keep_mask: Any, eps: float, deg_norm_floor: float, edge_envelope: Any, radial_basis: Any, n_radial: int, random_gamma: bool, wigner_calc: Any, gamma: Any = None) EdgeCache[source]#
Build the global padded edge cache from a DeePMD padded neighbor list.
Padded counterpart of pt
sezm_nn.edge_cache.build_edge_cache. Instead of extracting a sparse edge list withtorch.nonzero(data-dependent length), every neighbor slot becomes one edge slot:E = nf * nloc * nneiflattened row-major, with invalid slots marked byedge_mask == 0(see theEdgeCachelayout contract). In particulardst == np.repeat(arange(nf * nloc), nnei)always, and there is no empty-cache special case (E is shape-determined).Masked-slot safety: gathered edge vectors on invalid slots are garbage (placeholder index 0), and could even be exactly zero (self-difference), which would produce a 0/0 in the normalization inside the quaternion construction. Although the forward contribution of such slots is masked out downstream, a NaN there would still poison the backward pass (
wherepropagates NaN gradients from the unselected branch). Invalid slots are therefore rewritten to the safe dummy unit vector+zBEFORE any norm/quaternion/Wigner evaluation, and their envelope, radial basis, and type features are multiplied by the mask so they are exactly zero.- Parameters:
- type_ebed
Per-node type embedding with shape (N, C), where N = nf * nloc.
- extended_coord
Extended coordinates with shape (nf, nall, 3).
- nlist
Neighbor list with shape (nf, nloc, nnei); -1 marks padding.
- mapping
Mapping from extended to local indices with shape (nf, nall), or None when the neighbor indices are already local.
- pair_keep_mask
Pair keep mask from
PairExcludeMaskwith shape (nf, nloc, nnei). True means keep.- eps
Small positive epsilon for safe norm / quaternion construction.
- deg_norm_floor
Floor added to the envelope-squared degree before the inverse-sqrt normalization.
- edge_envelope
C^3 edge envelope callable
(E, 1) -> (E, 1).- radial_basis
Radial basis callable
(E, 1) -> (E, n_radial)(envelope baked in).- n_radial
Number of radial basis channels. Unused in the padded layout (kept for signature parity with pt, where it sizes the empty cache).
- random_gamma
Whether to apply a random roll around the local +Z axis before constructing Wigner-D blocks.
- wigner_calc
Callable converting edge quaternions (E, 4) into packed Wigner-D blocks
(D_full, Dt_full).- gamma
Optional per-edge roll angles with shape (E,), used only when
random_gammais True. pt draws gamma internally withtorch.randand the draw cannot be reproduced here, so callers needing determinism (e.g. tests) inject the angles explicitly. When None, angles are drawn fromnumpy.random.default_rng()uniformly in[0, 2*pi), matching pt’s distribution.
- Returns:
EdgeCachePadded per-edge cache with
edge_maskset.