deepmd.pt.optimizer#
Submodules#
Classes#
AdaMuon optimizer with adaptive second-moment normalization and auxiliary Adam. | |
HybridMuon optimizer with 1D Adam path and matrix Muon path. | |
Package Contents#
- class deepmd.pt.optimizer.AdaMuonOptimizer(params: collections.abc.Iterable[torch.Tensor] | collections.abc.Iterable[dict[str, Any]], lr: float = 0.001, momentum: float = 0.95, weight_decay: float = 0.001, ns_steps: int = 5, adam_betas: tuple[float, float] = (0.9, 0.95), adam_eps: float = 1e-08, nesterov: bool = True, lr_adjust: float = 10.0, lr_adjust_coeff: float = 0.2, eps: float = 1e-08)[source]#
Bases:
torch.optim.optimizer.OptimizerAdaMuon optimizer with adaptive second-moment normalization and auxiliary Adam.
This optimizer applies different update rules based on parameter dimensionality: - For 2D+ parameters (weight matrices): AdaMuon update with sign-stabilized
Newton-Schulz orthogonalization and per-element v_buffer normalization.
For 1D parameters (biases, layer norms): Standard Adam update.
Key AdaMuon features: - Sign-stabilized orthogonal direction: Applies sign() before orthogonalization. - Per-element second-moment normalization using momentum coefficient. - RMS-aligned global scaling: 0.2 * sqrt(m * n) / norm.
- Parameters:
- paramsiterable
Iterable of parameters to optimize.
- lr
float Learning rate with default 1e-3.
- momentum
float Momentum coefficient for AdaMuon with default 0.95.
- weight_decay
float Weight decay coefficient (applied only to >=2D params) with default 0.001.
- ns_steps
int Number of Newton-Schulz iterations with default 5.
- adam_betas
tuple[float,float] Adam beta coefficients with default (0.9, 0.95).
- adam_eps
float Adam epsilon with default 1e-8.
- nesterovbool
Whether to use Nesterov momentum for AdaMuon with default True.
- lr_adjust
float Learning rate adjustment factor for Adam (1D params). - If lr_adjust <= 0: use match-RMS scaling for AdaMuon update,
scale = lr_adjust_coeff * sqrt(max(m, n)). Adam uses lr directly.
If lr_adjust > 0: use rectangular correction for AdaMuon update, scale = sqrt(max(1.0, m/n)). Adam uses lr/lr_adjust as learning rate.
Default is 10.0 (Adam lr = lr/10).
- lr_adjust_coeff
float Coefficient for match-RMS scaling with default 0.2. Only effective when lr_adjust <= 0.
- eps
float Epsilon for v_buffer sqrt and global scaling normalization with default 1e-8.
Examples
>>> optimizer = AdaMuonOptimizer(model.parameters(), lr=1e-3) >>> for epoch in range(epochs): ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
- step(closure: collections.abc.Callable[[], torch.Tensor] | None = None) torch.Tensor | None[source]#
Perform a single optimization step.
- Parameters:
- closure
callable(),optional A closure that reevaluates the model and returns the loss.
- closure
- Returns:
- loss
torch.Tensor,optional The loss value if closure is provided.
- loss
- class deepmd.pt.optimizer.HybridMuonOptimizer(params: collections.abc.Iterable[torch.Tensor] | collections.abc.Iterable[dict[str, Any]], lr: float = 0.0005, momentum: float = 0.95, weight_decay: float = 0.001, adam_betas: tuple[float, float] = (0.9, 0.95), lr_adjust: float = 0.0, lr_adjust_coeff: float = 0.18, muon_mode: str = 'slice', named_parameters: collections.abc.Iterable[tuple[str, torch.Tensor]] | None = None, enable_gram: bool = True, flash_muon: bool = True, magma_muon: bool = True, use_foreach: bool | None = None)[source]#
Bases:
torch.optim.optimizer.OptimizerHybridMuon optimizer with 1D Adam path and matrix Muon path.
This optimizer applies different update rules based on parameter dimensionality, parameter names, and
muon_mode: - Parameters with final effective name segment containingbias(case-insensitive), or starting with
adam_(case-insensitive): standard Adam update.Parameters with final effective name segment starting with
adamw_(case-insensitive): Adam with decoupled weight decay (AdamW-style).1D parameters: standard Adam update.
Parameters are routed by effective shape (singleton dimensions removed).
muon_mode="2d": - effective rank 2 parameters use Muon. - effective rank >2 parameters use Adam.muon_mode="flat": - effective rank >=2 parameters use flattened matrix-view Muon.muon_mode="slice": - effective rank 2 parameters use Muon. - effective rank >=3 parameters apply Muon independently on each trailing(m, n)slice.
Naming convention for explicit Adam routing: - Parameters representing bias terms should include
biasin theirfinal effective name segment (case-insensitive).
Parameters that are not semantic bias but should still use Adam should use an
adam_prefix in their final effective name segment (case-insensitive).Parameters that should use Adam with decoupled weight decay should use an
adamw_prefix in their final effective name segment (case-insensitive).
This hybrid approach is effective because Muon’s orthogonalization is designed for weight matrices, while Adam is more suitable for biases and normalization params.
- Parameters:
- paramsiterable
Iterable of parameters to optimize.
- lr
float Learning rate.
- momentum
float Momentum coefficient for Muon with default 0.95.
- weight_decay
float Weight decay coefficient with default 0.001. Applied to Muon-routed parameters and >=2D Adam-routed parameters with AdamW-style decoupled decay. Not applied to 1D Adam parameters.
- adam_betas
tuple[float,float] Adam beta coefficients with default (0.9, 0.95).
- lr_adjust
float Learning rate adjustment mode for Muon scaling and Adam learning rate. - If lr_adjust <= 0: use match-RMS scaling for Muon,
scale = lr_adjust_coeff * sqrt(max(m, n)). Adam uses lr directly.
If lr_adjust > 0: use rectangular correction for Muon, scale = sqrt(max(1.0, m/n)). Adam uses lr/lr_adjust.
Default is 0.0 (match-RMS scaling).
- lr_adjust_coeff
float Coefficient with default 0.18 for match-RMS scaling when
lr_adjust <= 0:scale = lr_adjust_coeff * sqrt(max(m, n)). 0.18 is the value calibrated by DeepSeek-V4 so that Muon’s per-element update RMS matches AdamW’s typical RMS, enabling reuse of AdamW learning rates across both paths. The Moonlight reference uses 0.2; both are empirically viable.- muon_mode
str Muon routing mode with default
"slice". -"2d": only 2D parameters are Muon candidates. -"flat": >=2D parameters use flattened matrix-view routing. -"slice": >=3D parameters use per-slice Muon routing on last two dims.- named_parametersiterable[
tuple[str,torch.Tensor]] |None Optional named parameter iterable used for name-based routing. Parameters with final effective name segment containing
bias(case-insensitive), or starting withadam_(case-insensitive), are forced to Adam (no weight decay). Parameters starting withadamw_are forced to AdamW-style decoupled decay path.- enable_grambool
Enable the compiled Gram Newton-Schulz path for rectangular Muon matrices. Square matrices continue to use the current standard Newton-Schulz implementation. Default is True.
- flash_muonbool
Enable triton-accelerated Newton-Schulz orthogonalization. Requires triton and CUDA. Falls back to PyTorch implementation when triton is unavailable or running on CPU. Ignored when
enable_gram=True. Default is True.- magma_muonbool
Enable Magma-lite damping on Muon updates with default True. This computes momentum-gradient cosine alignment per Muon block, applies EMA smoothing, and rescales Muon updates in [0.1, 1.0]. Adam/AdamW paths are unchanged. Empirically beneficial for MLIP / SeZM training under heavy-tailed gradient noise from conservative-force (second-order) autograd.
Examples
>>> optimizer = HybridMuonOptimizer(model.parameters(), lr=5e-4) >>> for epoch in range(epochs): ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
- _routing_built = False#
- _use_flash = False#
- _ns_buffers: dict[tuple[int, torch.device], tuple[torch.Tensor, torch.Tensor]]#
- _gram_orthogonalizer: _GramNewtonSchulzOrthogonalizer | None = None#
- _use_foreach#
- set_param_names(named_parameters: collections.abc.Iterable[tuple[str, torch.Tensor]]) None[source]#
Set runtime-only parameter names used for name-based routing.
The mapping intentionally stays outside optimizer defaults and
param_groupsso optimizer checkpoints do not persist full(name, Parameter)tuples. Under ZeRO-1 this avoids gathering a duplicate model-sized object graph duringconsolidate_state_dict.
- static _resolve_foreach(use_foreach: bool | None) bool[source]#
Resolve the
use_foreachflag fortorch._foreach_*kernels.Foreach fuses per-parameter loops into single kernel launches, eliminating Python overhead. When
use_foreachisNonethe default isTruebecause plaintorch.Tensor(single-GPU, DDP, ZeRO-1) always supports these ops; callers that hit DTensor dispatch errors under FSDP2 must passuse_foreach=Falseexplicitly.
- _compute_magma_scales_merged(bucket_entries: list[tuple[dict[str, Any], torch.Tensor, torch.Tensor, torch.Tensor]], rows: int, cols: int) list[torch.Tensor][source]#
Compute Magma-lite scales for a merged bucket with variable batch_sizes.
Like
_compute_magma_scales_for_bucketbut handles entries whosebatch_sizemay differ (produced by the merged-bucket strategy that keys on(rows, cols)instead of(batch_size, rows, cols)).
- _compute_magma_scale(param: torch.Tensor, grad: torch.Tensor, momentum_buffer: torch.Tensor, batch_size: int, rows: int, cols: int) torch.Tensor[source]#
Compute Magma-lite Muon damping scales from momentum-gradient alignment.
Implements a stabilized version of Magma (Momentum-Aligned Gradient Masking) adapted for MLIP force-field training. Computes block-wise alignment scores between Muon momentum and current gradients, applies EMA smoothing, and rescales Muon updates to improve stability under heavy-tailed gradient noise.
- Parameters:
- param
torch.Tensor Parameter updated by Muon.
- grad
torch.Tensor Current gradient tensor with shape compatible with
(batch_size, rows, cols).- momentum_buffer
torch.Tensor Muon momentum buffer (updated m_t) with same shape as
grad.- batch_size
int Number of Muon blocks (1 for 2d/flat mode, >1 for slice mode).
- rows
int Matrix row count per block.
- cols
int Matrix column count per block.
- param
- Returns:
torch.TensorDamping scales with shape (batch_size,) in [MAGMA_MIN_SCALE, 1.0].
Notes
For each Muon block b:
Compute cosine similarity between momentum and gradient:
cos(b) = <μ_t^(b), g_t^(b)> / (||μ_t^(b)|| * ||g_t^(b)||)
Apply sigmoid with range stretching to [0, 1]:
s_raw^(b) = (sigmoid(cos(b) / τ) - s_min) / (s_max - s_min)
where τ=2.0, s_min=sigmoid(-1/τ), s_max=sigmoid(1/τ). This stretches the narrow sigmoid range [0.38, 0.62] to [0, 1].
Apply EMA smoothing:
s̃_t^(b) = a * s̃_{t-1}^(b) + (1-a) * s_raw^(b)
where a=0.9 (MAGMA_EMA_DECAY).
Map to damping scale in [s_min_scale, 1.0]:
scale^(b) = s_min_scale + (1 - s_min_scale) * s̃_t^(b)
where s_min_scale=0.1 (MAGMA_MIN_SCALE).
Apply damping to Muon update:
Δ̃^(b) = scale^(b) * Δ^(b) (soft scaling, no Bernoulli masking)
Key differences from the original Magma paper:
Sigmoid range stretching: Paper uses raw sigmoid with narrow range [0.38, 0.62]. We stretch to [0, 1] for better discrimination between aligned/misaligned blocks.
Soft scaling: Paper uses Bernoulli masking (50% skip probability). We use continuous soft scaling [0.1, 1.0] for stability in MLIP training.
Minimum scale: Paper allows scale=0 (complete skip). We enforce scale >= 0.1 to guarantee minimum learning rate.
- _compute_magma_scales_for_bucket(bucket_entries: list[tuple[dict[str, Any], torch.Tensor, torch.Tensor, torch.Tensor]], batch_size: int, rows: int, cols: int) list[torch.Tensor][source]#
Compute Magma-lite damping scales for one Muon bucket in a batched way.
- Parameters:
- bucket_entries
list[tuple[dict[str,Any],torch.Tensor,torch.Tensor,torch.Tensor]] Bucket entries as
(entry, update_tensor, grad, momentum_buffer).- batch_size
int Number of Muon blocks per parameter in this bucket.
- rows
int Matrix row count for this bucket.
- cols
int Matrix column count for this bucket.
- bucket_entries
- Returns:
list[torch.Tensor]Magma scales for each bucket entry. Each tensor has shape (batch_size,).
- _get_ns_buffers(M: int, device: torch.device) tuple[torch.Tensor, torch.Tensor][source]#
Get or lazily allocate pre-allocated buffers for flash Newton-Schulz.
- Parameters:
- M
int Square buffer dimension (= min(rows, cols) of the update matrix).
- device
torch.device Target CUDA device.
- M
- Returns:
tuple[torch.Tensor,torch.Tensor](buf1, buf2), each with shape (M, M) in bfloat16.
- _get_gram_orthogonalizer() _GramNewtonSchulzOrthogonalizer[source]#
Lazily initialize the compiled Gram orthogonalizer.
- Returns:
_GramNewtonSchulzOrthogonalizerShared Gram orthogonalizer instance for the optimizer.
- _process_merged_gram_buckets(gram_buckets: dict[tuple[int, int, torch.device, torch.dtype], list[tuple[dict[str, Any], torch.Tensor, torch.Tensor, torch.Tensor]]], lr: float, lr_adjust: float, lr_adjust_coeff: float, magma_scales_map: dict[int, torch.Tensor]) None[source]#
Column-pad merge across rectangular buckets sharing the same min_dim.
Rectangular Muon matrices with the same
min(rows, cols)can be fused into a single Gram Newton-Schulz call by zero-padding the column (large) dimension to the group maximum. This reduces the number of compiled Gram NS dispatches and improves GPU occupancy.Mathematical equivalence proof for column-padding: Both Standard NS and Gram NS operate on the wide orientation
X (m x n),m <= n. The Gram matrix isR = X @ X^T (m x m).Let
X_pad = [X | 0] (m x (n+p))where the last p columns are zero. Then:Frobenius norm is unchanged:
||X_pad||_F = ||X||_Fbecause the zero columns contribute nothing.Gram matrix is unchanged:
R_pad = X_pad @ X_pad^T = X @ X^T + 0 @ 0^T = RSince all NS iterations (both standard quintic and Gram/Polar- Express) depend only on R (which is m x m regardless of n), every intermediate
Q_kis identical.The restart step
X_new = Q @ X_pad = [Q @ X | 0]also preserves the invariantR_new = Q @ R @ Q^T, so subsequent iterations remain identical.The final output is
Q_last @ X_pad = [Q_last @ X | 0]. Truncating to the first n columns exactly recovers the unpadded result.
Constraint: Only the column (large) dimension may be padded. Padding rows would change the size of R and break equivalence.
Per-entry
scaleand Magma damping are applied after unpadding, since different original shapes have differentmax(rows, cols).
- _build_param_routing() None[source]#
Classify parameters into Muon, Adam, and AdamW routes (static routing).
Routing logic: - name-based
adam_prefix or containsbias→ Adam (no decay) - name-basedadamw_prefix → AdamW (decoupled weight decay) - effective shape rank <2 → Adam (no decay) - non-matrix effective shape for current muon_mode → AdamW (decoupled) - remaining eligible matrix params → Muon path
- _adam_update_moments(exp_avgs: list[torch.Tensor], exp_avg_sqs: list[torch.Tensor], grads_fp32: list[torch.Tensor], beta1: float, beta2: float) None[source]#
Update Adam first/second moment estimates, foreach-accelerated when safe.
exp_avg = beta1 * exp_avg + (1 - beta1) * grad exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2
- _weight_decay_inplace(params: list[torch.Tensor], factor: float) None[source]#
Apply multiplicative weight decay, foreach-accelerated when safe.
- step(closure: collections.abc.Callable[[], torch.Tensor] | None = None) torch.Tensor | None[source]#
Perform a single optimization step.
- Parameters:
- closure
callable(),optional A closure that reevaluates the model and returns the loss.
- closure
- Returns:
torch.Tensor|NoneThe loss value if closure is provided, otherwise None.
- class deepmd.pt.optimizer.KFOptimizerWrapper(model: torch.nn.Module, optimizer: torch.optim.optimizer.Optimizer, atoms_selected: int, atoms_per_group: int, is_distributed: bool = False)[source]#
- model#
- optimizer#
- atoms_selected#
- atoms_per_group#
- is_distributed = False#
- update_denoise_coord(inputs: dict, clean_coord: torch.Tensor, update_prefactor: float = 1, mask_loss_coord: bool = True, coord_mask: torch.Tensor = None) None[source]#
- class deepmd.pt.optimizer.LKFOptimizer(params: Any, kalman_lambda: float = 0.98, kalman_nue: float = 0.9987, block_size: int = 5120)[source]#
Bases:
torch.optim.optimizer.Optimizer- _params#
- _state#
- dist_init#
- rank#
- dindex = []#
- remainder = 0#
- __split_weights(weight: torch.Tensor) list[torch.Tensor][source]#
- __update(H: torch.Tensor, error: torch.Tensor, weights: torch.Tensor) None[source]#
- step(error: torch.Tensor) None[source]#