deepmd.pt.model.descriptor.sezm_nn.ffn#
Equivariant feed-forward layers for SeZM.
This module defines the full SO(3)-equivariant feed-forward network used inside SeZM interaction blocks and the descriptor output head.
Classes#
Full equivariant FFN operating on all spherical harmonic degrees. |
Module Contents#
- class deepmd.pt.model.descriptor.sezm_nn.ffn.EquivariantFFN(*, lmax: int, channels: int, hidden_channels: int, kmax: int = 1, grid_mlp: bool = False, grid_branch: int = 0, dtype: torch.dtype, s2_activation: bool = False, ffn_so3_grid: bool = False, lebedev_quadrature: bool = False, activation_function: str = 'silu', glu_activation: bool = True, mlp_bias: bool = False, trainable: bool, seed: int | list[int] | None = None)[source]#
Bases:
torch.nn.ModuleFull equivariant FFN operating on all spherical harmonic degrees.
- Default structure (glu_activation=False):
SO3 linear (in -> hidden) -> GatedActivation -> SO3 linear (hidden -> out)
- Default structure (glu_activation=True):
SO3 linear (in -> 2*hidden) -> split -> GatedActivation(val, gate) -> SO3 linear (hidden -> out)
- Optional grid-FFN structure (s2_activation=True or ffn_so3_grid=True):
SO3 linear (in -> hidden) -> project packed SO(3) coefficients to the S2 or SO3 grid -> grid GLU, polynomial MLP, or scalar-routed attention on hidden features -> project grid features back to packed SO(3) coefficients -> add scalar LinearSwiGLU branch to l=0 -> SO3 linear (hidden -> out)
GatedActivation serves as the unified “activation” for equivariant networks, analogous to SiLU in standard MLPs, but respecting SO(3) equivariance: - l=0: Uses the specified activation function (or GLU variant when glu_activation=True) - l>0: sigmoid gate from l=0 scalar features
When glu_activation=True, the first linear outputs 2*hidden_channels, then splits into value and gate branches. This transforms activations like silu->swiglu, gelu->geglu. The split approach is more efficient than two separate linear layers.
- Parameters:
- lmax
Maximum degree.
- channels
Number of channels per (l, m) coefficient.
- hidden_channels
Hidden dimension for the FFN.
- kmax
Maximum Wigner-D frame order (|k|) used by the SO3 Wigner-D FFN grid.
- grid_mlp
If True, select the polynomial grid MLP operation when the block-internal FFN grid path is enabled.
- grid_branch
Number of scalar-routed polynomial product branches used when the block-internal FFN grid path is enabled.
0disables this branch mixer. Positive values take precedence overgrid_mlp.- dtype
Parameter dtype.
- s2_activation
If True, enable the S2 FFN grid path.
- ffn_so3_grid
If True, enable the SO3 Wigner-D FFN grid path.
- lebedev_quadrature
If True, use Lebedev quadrature for the S2 projector in this FFN.
- activation_function
Activation function for l=0 components (e.g., “silu”, “tanh”, “gelu”).
- glu_activation
If True, use GLU-style gating (e.g., silu -> swiglu, gelu -> geglu).
- mlp_bias
Whether to use bias in SO3Linear (l=0 bias), GatedActivation (gate linear bias), and the scalar point-wise projection when
grid_mlp=True.- trainable
Whether parameters are trainable.
- seed
Random seed for weight initialization.
- forward(x: torch.Tensor) torch.Tensor[source]#
- Parameters:
- x
Input with shape (N, D, F, C) where D=(lmax+1)^2.
- Returns:
torch.TensorOutput with shape (N, D, F, C).