deepmd.pt.optimizer.adamuon#

AdaMuon optimizer for DeePMD-kit PyTorch backend.

AdaMuon combines Newton-Schulz orthogonalization with adaptive per-element second-moment normalization and RMS-aligned global scaling. It applies sign-stabilized orthogonal direction for improved training stability.

Key improvements over vanilla Muon: - Sign-stabilized orthogonal direction - Per-element second-moment (v_buffer) normalization - RMS-aligned global scaling

References#

[1]

Ethan Smith et al., “AdaMuon: Adaptive Muon Optimizer,” arXiv:2507.11005, 2025. https://arxiv.org/abs/2507.11005

[2]

AdaMuon GitHub Repository. ethansmith2000/AdaMuon

Classes#

AdaMuonOptimizer

AdaMuon optimizer with adaptive second-moment normalization and auxiliary Adam.

Functions#

zeropower_via_newtonschulz5(→ torch.Tensor)

Compute the zeroth power (orthogonalization) of a matrix via Newton-Schulz iteration.

_prepare_muon_momentum(→ tuple[torch.Tensor, ...)

Prepare momentum update and reshape for batched Newton-Schulz.

Module Contents#

deepmd.pt.optimizer.adamuon.zeropower_via_newtonschulz5(G: torch.Tensor, steps: int = 5, eps: float = 1e-08) torch.Tensor[source]#

Compute the zeroth power (orthogonalization) of a matrix via Newton-Schulz iteration.

Uses quintic Newton-Schulz iteration to compute the orthogonal component of the input matrix. This is equivalent to computing U from the SVD decomposition G = USV^T.

This implementation always performs Newton-Schulz in bfloat16 and returns a bfloat16 tensor.

Parameters:
Gtorch.Tensor

Input matrix to orthogonalize with shape (…, M, N).

stepsint

Number of Newton-Schulz iterations with default 5.

epsfloat

Numerical stability epsilon for norm clamping with default 1e-8.

Returns:
torch.Tensor

Orthogonalized matrix in bfloat16 with same shape as input.

Raises:
ValueError

If G has fewer than 2 dimensions.

ValueError

If steps >= 100 (guard for efficiency).

deepmd.pt.optimizer.adamuon._prepare_muon_momentum(grad: torch.Tensor, momentum_buffer: torch.Tensor, beta: float, nesterov: bool) tuple[torch.Tensor, tuple[int, Ellipsis]][source]#

Prepare momentum update and reshape for batched Newton-Schulz.

Parameters:
gradtorch.Tensor

Gradient tensor.

momentum_buffertorch.Tensor

Momentum buffer (will be updated in-place).

betafloat

Momentum coefficient.

nesterovbool

Whether to use Nesterov momentum.

Returns:
updatetorch.Tensor

Reshaped update tensor with shape (M, N).

original_shapetuple[int, …]

Original shape before reshape.

class deepmd.pt.optimizer.adamuon.AdaMuonOptimizer(params: collections.abc.Iterable[torch.Tensor] | collections.abc.Iterable[dict[str, Any]], lr: float = 0.001, momentum: float = 0.95, weight_decay: float = 0.001, ns_steps: int = 5, adam_betas: tuple[float, float] = (0.9, 0.95), adam_eps: float = 1e-08, nesterov: bool = True, lr_adjust: float = 10.0, lr_adjust_coeff: float = 0.2, eps: float = 1e-08)[source]#

Bases: torch.optim.optimizer.Optimizer

AdaMuon optimizer with adaptive second-moment normalization and auxiliary Adam.

This optimizer applies different update rules based on parameter dimensionality: - For 2D+ parameters (weight matrices): AdaMuon update with sign-stabilized

Newton-Schulz orthogonalization and per-element v_buffer normalization.

  • For 1D parameters (biases, layer norms): Standard Adam update.

Key AdaMuon features: - Sign-stabilized orthogonal direction: Applies sign() before orthogonalization. - Per-element second-moment normalization using momentum coefficient. - RMS-aligned global scaling: 0.2 * sqrt(m * n) / norm.

Parameters:
paramsiterable

Iterable of parameters to optimize.

lrfloat

Learning rate with default 1e-3.

momentumfloat

Momentum coefficient for AdaMuon with default 0.95.

weight_decayfloat

Weight decay coefficient (applied only to >=2D params) with default 0.001.

ns_stepsint

Number of Newton-Schulz iterations with default 5.

adam_betastuple[float, float]

Adam beta coefficients with default (0.9, 0.95).

adam_epsfloat

Adam epsilon with default 1e-8.

nesterovbool

Whether to use Nesterov momentum for AdaMuon with default True.

lr_adjustfloat

Learning rate adjustment factor for Adam (1D params). - If lr_adjust <= 0: use match-RMS scaling for AdaMuon update,

scale = lr_adjust_coeff * sqrt(max(m, n)). Adam uses lr directly.

  • If lr_adjust > 0: use rectangular correction for AdaMuon update, scale = sqrt(max(1.0, m/n)). Adam uses lr/lr_adjust as learning rate.

Default is 10.0 (Adam lr = lr/10).

lr_adjust_coefffloat

Coefficient for match-RMS scaling with default 0.2. Only effective when lr_adjust <= 0.

epsfloat

Epsilon for v_buffer sqrt and global scaling normalization with default 1e-8.

Examples

>>> optimizer = AdaMuonOptimizer(model.parameters(), lr=1e-3)
>>> for epoch in range(epochs):
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
step(closure: collections.abc.Callable[[], torch.Tensor] | None = None) torch.Tensor | None[source]#

Perform a single optimization step.

Parameters:
closurecallable(), optional

A closure that reevaluates the model and returns the loss.

Returns:
losstorch.Tensor, optional

The loss value if closure is provided.