deepmd.dpmodel.descriptor.dpa4_nn.so3#
SO(3)-equivariant linear layers for DPA4/SeZM.
This module is the dpmodel port of deepmd.pt.model.descriptor.sezm_nn.so3. It defines the channel-only and focus-aware linear maps used by the DPA4 SO(3) feature transformations. All three pt classes are ported: FocusLinear (used by so2, grid_net, activation), ChannelLinear (used by so2, grid_net), and SO3Linear (used by so2, ffn).
Serialization contract: SO3Linear mirrors the pt serialize() format exactly (same config and @variables keys), so pt serialize() output deserializes directly. The pt FocusLinear and ChannelLinear define no serialize() (they only appear nested inside larger modules’ state_dicts); their dpmodel serialize()/deserialize() use @variables keys equal to the pt state_dict key names (weight, bias) so that pt state-dict fragments load directly.
Weight initialization is distribution-equivalent to the pt version (drawn from np.random.default_rng instead of the torch generator stream), the same convention as utils.init_trunc_normal_fan_in_out.
Classes#
Per-focus linear projection on the last feature axis. | |
Channel-only linear projection on the last feature axis. | |
Focus-aware degree-wise linear self-interaction. |
Module Contents#
- class deepmd.dpmodel.descriptor.dpa4_nn.so3.FocusLinear(*, in_channels: int, out_channels: int, n_focus: int, precision: str = DEFAULT_PRECISION, bias: bool = True, trainable: bool = True, seed: int | list[int] | None = None, init_std: float | None = None)[source]#
Bases:
deepmd.dpmodel.NativeOPPer-focus linear projection on the last feature axis.
- Parameters:
- in_channels
int Input feature dimension.
- out_channels
int Output feature dimension.
- n_focus
int Number of focus streams.
- precision
str Parameter precision.
- biasbool
Whether to use bias.
- trainablebool
Whether parameters are trainable.
- seed
int|list[int] |None Random seed for initialization.
- init_std
float|None If given, use normal(0, init_std) instead of default uniform init. Useful for gate projections where small initial logits are desired.
- in_channels
Notes
Parameters are stored in (in, out) convention to match Muon’s rectangular correction assumption (rows=fan_in, cols=fan_out): - weight: (in_channels, n_focus * out_channels) - bias: (n_focus * out_channels,)
- call(x: Any) Any[source]#
Apply the per-focus linear projection.
- Parameters:
- x
Array Input array with shape (B, F, Cin).
- x
- Returns:
ArrayProjected array with shape (B, F, Cout).
- class deepmd.dpmodel.descriptor.dpa4_nn.so3.ChannelLinear(*, in_channels: int, out_channels: int, precision: str = DEFAULT_PRECISION, bias: bool = True, trainable: bool = True, seed: int | list[int] | None = None, init_std: float | None = None)[source]#
Bases:
deepmd.dpmodel.NativeOPChannel-only linear projection on the last feature axis.
- Parameters:
- in_channels
int Input feature dimension.
- out_channels
int Output feature dimension.
- precision
str Parameter precision.
- biasbool
Whether to use bias.
- trainablebool
Whether parameters are trainable.
- seed
int|list[int] |None Random seed for initialization.
- init_std
float|None If given, use normal(0, init_std) instead of default uniform init. Useful for gate projections where small initial logits are desired.
- in_channels
Notes
Parameters are stored in (in, out) convention to match Muon’s rectangular correction assumption (rows=fan_in, cols=fan_out): - weight: (in_channels, out_channels) - bias: (out_channels,)
- call(x: Any) Any[source]#
Apply the channel-only linear projection.
- Parameters:
- x
Array Input array with shape
(..., C_in).
- x
- Returns:
ArrayProjected array with shape
(..., C_out).
- class deepmd.dpmodel.descriptor.dpa4_nn.so3.SO3Linear(*, lmax: int, in_channels: int, out_channels: int, n_focus: int = 1, precision: str = DEFAULT_PRECISION, mlp_bias: bool = False, trainable: bool = True, seed: int | list[int] | None = None, init_std: float | None = None)[source]#
Bases:
deepmd.dpmodel.NativeOPFocus-aware degree-wise linear self-interaction.
The key insight is that weights are shared across all
mcomponents within eachlblock.- Parameters:
- lmax
int Maximum spherical harmonic degree.
- in_channels
int Number of input channels per (l, m) coefficient.
- out_channels
int Number of output channels per (l, m) coefficient.
- n_focus
int Number of focus streams.
- precision
str Parameter precision.
- mlp_biasbool
Whether to use bias for l=0 (scalar) components.
- trainablebool
Whether parameters are trainable.
- seed
int|list[int] |None Random seed for weight initialization.
- init_std
float|None If given, use normal(0, init_std) for all weights instead of default trunc-normal fan-in/fan-out init. Use 0.0 for zero initialization.
- lmax
Notes
Weight storage:
(lmax+1, C_in, F*C_out).Bias storage:
(F*C_out,), only applied tol=0scalar components.Runtime view restores weights to
(lmax+1, C_in, F, C_out)via reshape.expand_indexmaps each packed(l,m)position to itslvalue.The pt einsum
ndfi,difo->ndfois expressed as a broadcast batched matmul, which keeps the whole multi-focus path vectorized.