deepmd.dpmodel.descriptor.dpa4_nn.attention#

Attention utilities for DPA4/SeZM message passing.

This module is the dpmodel port of deepmd.pt.model.descriptor.sezm_nn.attention. It implements the destination-wise envelope-gated softmax used by the SO(2) attention path.

Padded-edge adaptation#

The pt version consumes a sparse edge list and reduces per destination node with scatter_reduce(amax) / scatter_add keyed by dst. In the dpmodel padded layout (see edge_cache.EdgeCache) the edge axis is E = n_nodes * nnei with slot (i, j) belonging to node i, so every destination-wise reduction becomes a plain reduction over the nnei axis after a (n_nodes, nnei, ...) reshape, and invalid (padded) slots are removed by folding edge_mask into the non-negative per-edge weight.

Functions#

segment_envelope_gated_softmax(→ Any)

Compute destination-wise envelope-gated softmax attention.

Module Contents#

deepmd.dpmodel.descriptor.dpa4_nn.attention.segment_envelope_gated_softmax(logits: Any, edge_env: Any, n_nodes: int, z_bias_raw: Any, eps: float, src_weight: Any = None, edge_mask: Any = None) Any[source]#

Compute destination-wise envelope-gated softmax attention.

All array arguments must live in the same array namespace.

Parameters:
logits

Attention logits with shape (E, F, H), padded-edge layout with E = n_nodes * nnei.

edge_env

Cutoff envelope weights with shape (E, 1) or (E,).

n_nodes

Number of nodes. The pt dst argument is dropped: in the padded layout the destination of edge slot (i, j) is implicitly node i.

z_bias_raw

Unconstrained denominator bias with shape (F, H). Softplus is applied to keep the bias strictly positive.

eps

Small epsilon for denominator stability.

src_weight

Optional per-edge source-side multiplier with shape (E, 1) or (E,). When provided the per-edge weight becomes edge_env**2 * src_weight and the attention reduces to edge_env**2 * src_weight * exp(logits) / (zeta + sum(edge_env**2 * src_weight * exp(logits))). src_weight = 0 therefore removes the source from both the numerator and the denominator, which is what SFPG needs so that a muted source does not even leak through the softmax normalization.

edge_mask

Optional padded-edge validity mask with shape (E,) or (E, 1); zero marks invalid slots. Folded into the non-negative per-edge weight, so invalid slots drop out of the group max, the numerator, and the denominator exactly like absent edges in the pt sparse layout.

Returns:
Array

Normalized edge weights with shape (E, F, H). Zero on invalid slots.