deepmd.pt.optimizer.adamuon#
AdaMuon optimizer for DeePMD-kit PyTorch backend.
AdaMuon combines Newton-Schulz orthogonalization with adaptive per-element second-moment normalization and RMS-aligned global scaling. It applies sign-stabilized orthogonal direction for improved training stability.
Key improvements over vanilla Muon: - Sign-stabilized orthogonal direction - Per-element second-moment (v_buffer) normalization - RMS-aligned global scaling
References#
Ethan Smith et al., “AdaMuon: Adaptive Muon Optimizer,” arXiv:2507.11005, 2025. https://arxiv.org/abs/2507.11005
AdaMuon GitHub Repository. ethansmith2000/AdaMuon
Classes#
AdaMuon optimizer with adaptive second-moment normalization and auxiliary Adam. |
Functions#
| Compute the zeroth power (orthogonalization) of a matrix via Newton-Schulz iteration. |
| Prepare momentum update and reshape for batched Newton-Schulz. |
Module Contents#
- deepmd.pt.optimizer.adamuon.zeropower_via_newtonschulz5(G: torch.Tensor, steps: int = 5, eps: float = 1e-08) torch.Tensor[source]#
Compute the zeroth power (orthogonalization) of a matrix via Newton-Schulz iteration.
Uses quintic Newton-Schulz iteration to compute the orthogonal component of the input matrix. This is equivalent to computing U from the SVD decomposition G = USV^T.
This implementation always performs Newton-Schulz in bfloat16 and returns a bfloat16 tensor.
- Parameters:
- G
torch.Tensor Input matrix to orthogonalize with shape (…, M, N).
- steps
int Number of Newton-Schulz iterations with default 5.
- eps
float Numerical stability epsilon for norm clamping with default 1e-8.
- G
- Returns:
torch.TensorOrthogonalized matrix in bfloat16 with same shape as input.
- Raises:
ValueErrorIf G has fewer than 2 dimensions.
ValueErrorIf steps >= 100 (guard for efficiency).
- deepmd.pt.optimizer.adamuon._prepare_muon_momentum(grad: torch.Tensor, momentum_buffer: torch.Tensor, beta: float, nesterov: bool) tuple[torch.Tensor, tuple[int, Ellipsis]][source]#
Prepare momentum update and reshape for batched Newton-Schulz.
- Parameters:
- grad
torch.Tensor Gradient tensor.
- momentum_buffer
torch.Tensor Momentum buffer (will be updated in-place).
- beta
float Momentum coefficient.
- nesterovbool
Whether to use Nesterov momentum.
- grad
- Returns:
- update
torch.Tensor Reshaped update tensor with shape (M, N).
- original_shape
tuple[int, …] Original shape before reshape.
- update
- class deepmd.pt.optimizer.adamuon.AdaMuonOptimizer(params: collections.abc.Iterable[torch.Tensor] | collections.abc.Iterable[dict[str, Any]], lr: float = 0.001, momentum: float = 0.95, weight_decay: float = 0.001, ns_steps: int = 5, adam_betas: tuple[float, float] = (0.9, 0.95), adam_eps: float = 1e-08, nesterov: bool = True, lr_adjust: float = 10.0, lr_adjust_coeff: float = 0.2, eps: float = 1e-08)[source]#
Bases:
torch.optim.optimizer.OptimizerAdaMuon optimizer with adaptive second-moment normalization and auxiliary Adam.
This optimizer applies different update rules based on parameter dimensionality: - For 2D+ parameters (weight matrices): AdaMuon update with sign-stabilized
Newton-Schulz orthogonalization and per-element v_buffer normalization.
For 1D parameters (biases, layer norms): Standard Adam update.
Key AdaMuon features: - Sign-stabilized orthogonal direction: Applies sign() before orthogonalization. - Per-element second-moment normalization using momentum coefficient. - RMS-aligned global scaling: 0.2 * sqrt(m * n) / norm.
- Parameters:
- paramsiterable
Iterable of parameters to optimize.
- lr
float Learning rate with default 1e-3.
- momentum
float Momentum coefficient for AdaMuon with default 0.95.
- weight_decay
float Weight decay coefficient (applied only to >=2D params) with default 0.001.
- ns_steps
int Number of Newton-Schulz iterations with default 5.
- adam_betas
tuple[float,float] Adam beta coefficients with default (0.9, 0.95).
- adam_eps
float Adam epsilon with default 1e-8.
- nesterovbool
Whether to use Nesterov momentum for AdaMuon with default True.
- lr_adjust
float Learning rate adjustment factor for Adam (1D params). - If lr_adjust <= 0: use match-RMS scaling for AdaMuon update,
scale = lr_adjust_coeff * sqrt(max(m, n)). Adam uses lr directly.
If lr_adjust > 0: use rectangular correction for AdaMuon update, scale = sqrt(max(1.0, m/n)). Adam uses lr/lr_adjust as learning rate.
Default is 10.0 (Adam lr = lr/10).
- lr_adjust_coeff
float Coefficient for match-RMS scaling with default 0.2. Only effective when lr_adjust <= 0.
- eps
float Epsilon for v_buffer sqrt and global scaling normalization with default 1e-8.
Examples
>>> optimizer = AdaMuonOptimizer(model.parameters(), lr=1e-3) >>> for epoch in range(epochs): ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
- step(closure: collections.abc.Callable[[], torch.Tensor] | None = None) torch.Tensor | None[source]#
Perform a single optimization step.
- Parameters:
- closure
callable(),optional A closure that reevaluates the model and returns the loss.
- closure
- Returns:
- loss
torch.Tensor,optional The loss value if closure is provided.
- loss