deepmd.dpmodel.descriptor.dpa4_nn.attention#
Attention utilities for DPA4/SeZM message passing.
This module is the dpmodel port of deepmd.pt.model.descriptor.sezm_nn.attention. It implements the destination-wise envelope-gated softmax used by the SO(2) attention path.
Padded-edge adaptation#
The pt version consumes a sparse edge list and reduces per destination node with scatter_reduce(amax) / scatter_add keyed by dst. In the dpmodel padded layout (see edge_cache.EdgeCache) the edge axis is E = n_nodes * nnei with slot (i, j) belonging to node i, so every destination-wise reduction becomes a plain reduction over the nnei axis after a (n_nodes, nnei, ...) reshape, and invalid (padded) slots are removed by folding edge_mask into the non-negative per-edge weight.
Functions#
Compute destination-wise envelope-gated softmax attention. |
Module Contents#
- deepmd.dpmodel.descriptor.dpa4_nn.attention.segment_envelope_gated_softmax(logits: Any, edge_env: Any, n_nodes: int, z_bias_raw: Any, eps: float, src_weight: Any = None, edge_mask: Any = None) Any[source]#
Compute destination-wise envelope-gated softmax attention.
All array arguments must live in the same array namespace.
- Parameters:
- logits
Attention logits with shape (E, F, H), padded-edge layout with
E = n_nodes * nnei.- edge_env
Cutoff envelope weights with shape (E, 1) or (E,).
- n_nodes
Number of nodes. The pt
dstargument is dropped: in the padded layout the destination of edge slot(i, j)is implicitly nodei.- z_bias_raw
Unconstrained denominator bias with shape (F, H). Softplus is applied to keep the bias strictly positive.
- eps
Small epsilon for denominator stability.
- src_weight
Optional per-edge source-side multiplier with shape (E, 1) or (E,). When provided the per-edge weight becomes
edge_env**2 * src_weightand the attention reduces toedge_env**2 * src_weight * exp(logits) / (zeta + sum(edge_env**2 * src_weight * exp(logits))).src_weight = 0therefore removes the source from both the numerator and the denominator, which is what SFPG needs so that a muted source does not even leak through the softmax normalization.- edge_mask
Optional padded-edge validity mask with shape (E,) or (E, 1); zero marks invalid slots. Folded into the non-negative per-edge weight, so invalid slots drop out of the group max, the numerator, and the denominator exactly like absent edges in the pt sparse layout.
- Returns:
ArrayNormalized edge weights with shape (E, F, H). Zero on invalid slots.