deepmd.dpmodel.descriptor.dpa4_nn.lora#
LoRA low-rank fine-tuning support for DPA4/SeZM.
This module adds two things:
LoRASO3andLoRASO2subclasses that wrap the corresponding base equivariant linear operators (SO3Linear/SO2Linear). Each one freezes the pre-trained weights and registers rank-Radapter parametersA/Bwhose shapes share the base’s batch layout (per-lfor SO(3), per-|m|-group for SO(2)). The LoRA delta is folded into the effective weight before the single large einsum that already exists in the base module; forward FLOPs are therefore identical to the base, and the overhead comes only from anO(R)weight-side matmul that does not depend on the number of edges or nodes.apply_lora_to_sezm,merge_lora_into_baseand a few helpers that drive the fine-tune policy (which submodules stay trainable, which ones remain frozen) and the merged-checkpoint export used byTrainer.save_model_merged.
Naming convention: the LoRA parameter names – A_by_l, B_by_l, A_m0, B_m0, A_m, B_m – intentionally do not start with adam_ / adamw_ and do not contain bias. HybridMuon.get_adam_route therefore classifies them as muon and, because the tensors have the same rank structure as the corresponding base weights, the slice-mode matrix view gives per-l / per-|m|-group Newton-Schulz updates that match the base training recipe.
This module is the dpmodel (array-API) port of deepmd.pt.model.descriptor.sezm_nn.lora.
Attributes#
Classes#
Functions#
| Yield |
Yield | |
| Flat dotted |
| Return the trailing non-numeric segment of a parameter name. |
| |
| No-op retained for parity with the PyTorch backend. |
| Replace |
| Return |
| Inject LoRA adapters into every |
| Fold LoRA adapter keys into base weight keys in state_dict (in-place). |
| Produce a plain (LoRA-free) state dict from a LoRA-augmented module. |
| Drop any |
| Destructively replace every |
Module Contents#
- class deepmd.dpmodel.descriptor.dpa4_nn.lora.LoRASO3(*, lmax: int, in_channels: int, out_channels: int, n_focus: int = 1, precision: str = DEFAULT_PRECISION, mlp_bias: bool = False, trainable: bool = False, seed: int | list[int] | None = None, lora_rank: int, lora_alpha: float | None = None)[source]#
Bases:
deepmd.dpmodel.descriptor.dpa4_nn.so3.SO3LinearPer-l ELoRA adapter for
SO3Linear.The pre-trained weight
self.weight((lmax+1, C_in, F*C_out)) is frozen. Two new 3D parametersA_by_l((lmax+1, rank, C_in)) andB_by_l((lmax+1, F*C_out, rank)) share the samelmax+1batch axis as the base so thatmuon_mode="slice"updates everyl-block independently. SO(3) equivariance is preserved because the per-ldelta only rotates within eachl-block (no cross-lmixing).- Parameters:
- lmax, in_channels, out_channels, n_focus, precision, mlp_bias, trainable, seed
Forwarded to
SO3Linearto build the frozen base weight.- lora_rank
LoRA rank. Must satisfy
lora_rank >= 1.- lora_alpha
Scaling numerator; the effective scaling is
lora_alpha / lora_rank.Nonedefaults tolora_alpha = lora_rank(scaling1.0).
- _compute_delta_weight(xp: Any, device: Any) deepmd.dpmodel.array_api.Array[source]#
Return
ΔWwith shape(lmax+1, C_in, F*C_out).
- call(x: deepmd.dpmodel.array_api.Array) deepmd.dpmodel.array_api.Array[source]#
- Parameters:
- x
Input features with shape
(N, D, F, C_in)whereD=(lmax+1)^2.
- Returns:
ArrayOutput features with shape
(N, D, F, C_out).
- merge_into_base() deepmd.dpmodel.descriptor.dpa4_nn.so3.SO3Linear[source]#
Build a plain
SO3Linearwhose weight has absorbed the LoRA delta.
- class deepmd.dpmodel.descriptor.dpa4_nn.lora.LoRASO2(*, lmax: int, mmax: int | None = None, in_channels: int, out_channels: int, n_focus: int = 1, precision: str = DEFAULT_PRECISION, mlp_bias: bool = False, seed: int | list[int] | None = None, trainable: bool = False, lora_rank: int, lora_alpha: float | None = None)[source]#
Bases:
deepmd.dpmodel.descriptor.dpa4_nn.so2.SO2LinearPer-
|m|-group LoRA adapter forSO2Linear.weight_m0((num_in_m0, F*num_out_m0)) and eachweight_m[i]((num_in_m, F*2*num_out_m)) get an independent 2D LoRA pairA/B. SO(2) equivariance is preserved because the|m|>02x2 complex block[[W_u, -W_v], [W_v, W_u]]stays intact whenΔW_mis absorbed into the concatenated[W_u | W_v]layout before_build_so2_weightsplits it (the shared input basisAsplits naturally intoΔW_u = B_u AandΔW_v = B_v A).The base
calllogic is inherited unchanged; only_build_so2_weightis overridden to fold the LoRA delta into each base block prior to assembling the block-diagonal weight. TheΔW_mconstruction does not depend on the edge countE, so the forward FLOPs remain identical to the base.- Parameters:
- lmax, mmax, in_channels, out_channels, n_focus, precision, mlp_bias, trainable, seed
Forwarded to
SO2Linearto build the frozen base weights.- lora_rank
LoRA rank.
- lora_alpha
Scaling numerator; scaling is
lora_alpha / lora_rank.Nonedefaults tolora_alpha = lora_rank(scaling1.0).
- A_m: list[numpy.ndarray] = [][source]#
- B_m: list[numpy.ndarray] = [][source]#
- _compute_delta_m0(xp: Any, device: Any) deepmd.dpmodel.array_api.Array[source]#
Return
ΔW_m0with shape(num_in_m0, F*num_out_m0).
- _compute_delta_m(m_idx: int, xp: Any, device: Any) deepmd.dpmodel.array_api.Array[source]#
Return
ΔW_m[m_idx]with the same shape asweight_m[m_idx].
- _build_so2_weight(xp: Any, device: Any) deepmd.dpmodel.array_api.Array[source]#
Assemble the block-diagonal weight with LoRA delta folded in.
- merge_into_base() deepmd.dpmodel.descriptor.dpa4_nn.so2.SO2Linear[source]#
Build a plain
SO2Linearwhose weights have absorbed every LoRA delta.
- deepmd.dpmodel.descriptor.dpa4_nn.lora._UNFREEZE_SUBMODULE_PATHS: tuple[str, Ellipsis] = ('atomic_model.fitting_net', 'atomic_model.dens_fitting_net',...[source]#
- deepmd.dpmodel.descriptor.dpa4_nn.lora._UNFREEZE_PER_BLOCK_SUBPATHS: tuple[str, Ellipsis] = ('full_attn_res_so2', 'full_attn_res_ffns', 'block_attn_res_so2', 'block_attn_res_ffns',...[source]#
- deepmd.dpmodel.descriptor.dpa4_nn.lora._BLOCKS_PATH: str = 'atomic_model.descriptor.blocks'[source]#
- deepmd.dpmodel.descriptor.dpa4_nn.lora._iter_named_modules(root: deepmd.dpmodel.NativeOP, prefix: str = '', memo: set[int] | None = None) collections.abc.Iterator[tuple[str, deepmd.dpmodel.NativeOP]][source]#
Yield
(dotted_name, module)for root and every nestedNativeOP.rootis yielded first under prefix, then the walk descends into every attribute value that is aNativeOPand into everyNativeOPelement of alist/tuple, building dotted paths (attrandattr.{i}). A shared-module memo de-duplicates repeated references.
- deepmd.dpmodel.descriptor.dpa4_nn.lora._iter_named_parameters(root: deepmd.dpmodel.NativeOP) collections.abc.Iterator[tuple[str, deepmd.dpmodel.NativeOP, numpy.ndarray]][source]#
Yield
(dotted_name, owner, array)for every numpy-array parameter.A dpmodel “parameter” is a
numpyarray stored as a module attribute (or anumpyelement of alist/tupleattribute, the equivalent of annn.ParameterList).owneris the module holding the array; because the dpmodel tracks trainability per module (module.trainable) rather than per tensor, callers toggleowner.trainablewhere the PyTorch code togglesparam.requires_grad.
- deepmd.dpmodel.descriptor.dpa4_nn.lora._module_state_dict(root: deepmd.dpmodel.NativeOP) dict[str, numpy.ndarray][source]#
Flat dotted
{name: array}dict over the whole module tree.
- deepmd.dpmodel.descriptor.dpa4_nn.lora._leaf_name(param_name: str) str[source]#
Return the trailing non-numeric segment of a parameter name.
nn.ParameterListchildren show up asfoo.0,foo.1, …;get_adam_routestrips those numeric indices before routing, so this helper keeps the policy in sync.
- deepmd.dpmodel.descriptor.dpa4_nn.lora._get_submodule_or_none(root: deepmd.dpmodel.NativeOP, path: str) Any[source]#
- deepmd.dpmodel.descriptor.dpa4_nn.lora._clear_sezm_compile_cache(model: deepmd.dpmodel.NativeOP) None[source]#
No-op retained for parity with the PyTorch backend.
In PyTorch, LoRA injection or merge replaces submodules and therefore invalidates any
torch.compile/ inductor callable captured on the module graph, which must be cleared before the next forward. The dpmodel (array-API) backend compiles nothing, so there is no cache to clear and this function intentionally does nothing.
- deepmd.dpmodel.descriptor.dpa4_nn.lora._swap_submodule(parent: Any, attr: str, new_module: deepmd.dpmodel.NativeOP) None[source]#
Replace
parent.attrwithnew_module.Numeric attribute names address
list/tuplechildren (the dpmodel analogue ofnn.ModuleList/nn.ParameterListelements) and are assigned by index; every other name is a plain attribute assignment.
- deepmd.dpmodel.descriptor.dpa4_nn.lora.has_lora(module: deepmd.dpmodel.NativeOP) bool[source]#
Return
Trueiff any submodule is a LoRA adapter.
- deepmd.dpmodel.descriptor.dpa4_nn.lora.apply_lora_to_sezm(model: deepmd.dpmodel.NativeOP, *, rank: int, alpha: float | None = None) deepmd.dpmodel.NativeOP[source]#
Inject LoRA adapters into every
SO3Linear/SO2Linearof a SeZM model and apply the SeZM fine-tune freeze/unfreeze policy in place.This function is idempotent-safe: the
type(mod) is SO3Linear(exact type) test prevents re-wrapping a LoRASO3 that is already present.- Parameters:
- model
A
SeZMModelinstance (or anyNativeOPcontaining SeZMSO3Linear/SO2Linearsubmodules).- rank
LoRA rank applied uniformly to every adapter.
- alpha
LoRA scaling numerator; scaling is
alpha / rank.Nonedefaults toalpha = rank(scaling1.0).
- Returns:
NativeOPThe same
modelafter injection (returned for chaining).
- deepmd.dpmodel.descriptor.dpa4_nn.lora.fold_lora_state_dict_keys(state_dict: dict[str, numpy.ndarray], prefix: str) None[source]#
Fold LoRA adapter keys into base weight keys in state_dict (in-place).
Scans for SO3-style
A_by_l/B_by_lpairs and SO2-styleA_m0/B_m0/A_m.*/B_m.*groups under prefix. For each pair whose corresponding base weight key also exists, the deltaeinsum(B, A) * scalingis added to the weight and the adapter keys are popped.lora_scalingis read from state_dict when present; otherwise1.0is assumed (the default whenalpha == rank).Called by
DescrptSeZM._load_from_state_dictso that a LoRA-trained checkpoint can be loaded into a plain (non-LoRA) descriptor transparently.- Parameters:
- state_dict
Flat state dict to mutate in place.
- prefix
Key prefix that scopes the scan (e.g.
"model.Default.atomic_model.descriptor.").
- deepmd.dpmodel.descriptor.dpa4_nn.lora.build_merged_state_dict(module: deepmd.dpmodel.NativeOP, state_dict: dict[str, numpy.ndarray] | None = None, *, prefix: str = '') dict[str, numpy.ndarray][source]#
Produce a plain (LoRA-free) state dict from a LoRA-augmented module.
Walks
module.named_modules()and, for everyLoRASO3/LoRASO2submodule, foldsΔW = BA·scalinginto the base weight key and removes theA/Bkeys. The returned dict has the same key set as a same-topology SeZM that has never been LoRA-wrapped, and is suitable for loading into a plain SeZM model withstrict=True.Non-destructive: when
state_dictisNonea deep copy ofmodule.state_dict()is taken; when the caller provides astate_dictit is assumed to already be a detached copy (e.g. the full-gathered state dict from FSDP2) and is mutated in place for efficiency.- Parameters:
- module
The LoRA-augmented module tree. Only used for structural information (LoRA submodule prefixes,
scaling,weight_mlength); its parameters are not modified.- state_dict
Optional pre-collected state dict (e.g. gathered from FSDP2). If
None,deepcopy(module.state_dict())is used.- prefix
Prefix to prepend to every LoRA submodule name when looking keys up in
state_dict. Use this when the caller has state keyed under an outer wrapper (for example"model.Default.").
- Returns:
dictFlat state dict with LoRA adapters folded into base weights.
- deepmd.dpmodel.descriptor.dpa4_nn.lora.strip_lora_from_extra_state(extra_state: dict[str, Any]) dict[str, Any][source]#
Drop any
loraentry from_extra_state["model_params"].Handles both single-task (
model_paramsis the model config) and multi-task (model_params["model_dict"][<branch>]is each branch’s config). Returns a deep-copied dict; the input is not mutated.
- deepmd.dpmodel.descriptor.dpa4_nn.lora.merge_lora_into_base(model: deepmd.dpmodel.NativeOP) deepmd.dpmodel.NativeOP[source]#
Destructively replace every
LoRASO3/LoRASO2with its merged plain base module.After this call the model no longer contains LoRA submodules: the optimizer, EMA state, and any compiled callables that reference the old submodules become invalid. Prefer
build_merged_state_dict()for non-destructive checkpoint export during or after training; this function is primarily useful in tests and offline scripts.