# SPDX-License-Identifier: LGPL-3.0-or-later
"""Conservative helpers for loading selected MACE-OFF checkpoints.
The helpers in this module intentionally support a narrower scope than the
upstream discussion in PR #96: they load ordinary MACE-OFF checkpoints into a
DeePMD-GNN ``MaceModel`` wrapper for standard MD use, but they do *not* infer
DPRc/QM/MM atom-type semantics from the checkpoint.
"""
from __future__ import annotations
from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING, TypedDict
import torch
from deepmd.pt import model as _deepmd_pt_model # noqa: F401
from e3nn.util.jit import script as e3nn_script
from mace.modules import ScaleShiftMACE, gate_dict
from deepmd_gnn.mace import ELEMENTS, MaceModel
from deepmd_gnn.mace_off_cli import download_mace_off_model
if TYPE_CHECKING:
from collections.abc import Iterator
_ALLOWED_MISSING_STATE_DICT_SUFFIXES = ("_zeroed",)
_SUPPORTED_DISTANCE_TRANSFORMS = {
None: "None",
"AgnesiTransform": "Agnesi",
"SoftTransform": "Soft",
}
_SUPPORTED_RADIAL_BASES = {
"BesselBasis": "bessel",
}
class _InferredMaceConfig(TypedDict):
type_map: list[str]
r_max: float
num_radial_basis: int
num_cutoff_basis: int
max_ell: int
interaction: str
num_interactions: int
hidden_irreps: str
pair_repulsion: bool
distance_transform: str
correlation: int
gate: str
MLP_irreps: str
radial_type: str
radial_MLP: list[int]
std: float
avg_num_neighbors: float
@contextmanager
def _temporary_default_dtype(dtype: torch.dtype) -> Iterator[None]:
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
try:
yield
finally:
torch.set_default_dtype(old_dtype)
def _validate_atomic_numbers(atomic_numbers: list[int]) -> None:
if not atomic_numbers:
msg = "Loaded model has empty atomic_numbers"
raise ValueError(msg)
if any(z < 1 or z > len(ELEMENTS) for z in atomic_numbers):
msg = (
f"Invalid atomic numbers found: {atomic_numbers}. "
f"Must be between 1 and {len(ELEMENTS)}"
)
raise ValueError(msg)
def _infer_gate_name(mace_model: ScaleShiftMACE) -> str:
if not mace_model.readouts:
return "None"
last_readout = mace_model.readouts[-1]
if not hasattr(last_readout, "non_linearity"):
return "None"
acts = getattr(last_readout.non_linearity, "acts", None)
if acts is None or len(acts) != 1 or not hasattr(acts[0], "f"):
msg = "Unsupported MACE non-linearity structure"
raise ValueError(msg)
gate_fn = acts[0].f
for gate_name, candidate in gate_dict.items():
if candidate is not None and candidate is gate_fn:
return gate_name
msg = f"Unsupported MACE gate function: {gate_fn}"
raise ValueError(msg)
def _infer_distance_transform(mace_model: ScaleShiftMACE) -> str:
transform = getattr(mace_model.radial_embedding, "distance_transform", None)
transform_name = None if transform is None else transform.__class__.__name__
if transform_name not in _SUPPORTED_DISTANCE_TRANSFORMS:
msg = (
"Unsupported MACE distance transform for conservative loader: "
f"{transform_name}"
)
raise ValueError(msg)
return _SUPPORTED_DISTANCE_TRANSFORMS[transform_name]
def _infer_radial_type(mace_model: ScaleShiftMACE) -> str:
basis_name = mace_model.radial_embedding.bessel_fn.__class__.__name__
if basis_name not in _SUPPORTED_RADIAL_BASES:
msg = f"Unsupported radial basis for conservative loader: {basis_name}"
raise ValueError(msg)
return _SUPPORTED_RADIAL_BASES[basis_name]
def _infer_radial_mlp(mace_model: ScaleShiftMACE) -> list[int]:
layers = [
layer
for layer in mace_model.interactions[0].conv_tp_weights
if hasattr(layer, "weight")
]
if len(layers) < 2:
msg = "Unsupported radial MLP structure in MACE checkpoint"
raise ValueError(msg)
return [int(layer.weight.shape[1]) for layer in layers[:-1]]
def _infer_interaction_name(mace_model: ScaleShiftMACE) -> str:
interaction_names = [
interaction.__class__.__name__ for interaction in mace_model.interactions
]
if not interaction_names:
msg = "Loaded MACE model has no interaction blocks"
raise ValueError(msg)
if interaction_names[0] != "RealAgnosticInteractionBlock":
msg = (
"Conservative MACE-OFF loader only supports checkpoints whose first "
f"interaction block is RealAgnosticInteractionBlock, got {interaction_names[0]}"
)
raise ValueError(msg)
later_names = interaction_names[1:]
if later_names and len(set(later_names)) != 1:
msg = f"Mixed later interaction blocks are unsupported: {interaction_names}"
raise ValueError(msg)
return later_names[0] if later_names else interaction_names[0]
def _infer_hidden_irreps(mace_model: ScaleShiftMACE) -> str:
return str(mace_model.interactions[0].hidden_irreps)
def _infer_max_ell(mace_model: ScaleShiftMACE) -> int:
return max(ir.l for _, ir in mace_model.spherical_harmonics.irreps_out)
def _infer_correlation(mace_model: ScaleShiftMACE) -> int:
correlations = [
int(product.symmetric_contractions.contractions[0].correlation)
for product in mace_model.products
]
if len(set(correlations)) != 1:
msg = f"Layer-dependent correlation is unsupported: {correlations}"
raise ValueError(msg)
return correlations[0]
def _infer_mlp_irreps(mace_model: ScaleShiftMACE) -> str:
last_readout = mace_model.readouts[-1]
if not hasattr(last_readout, "hidden_irreps"):
msg = "Checkpoint does not expose non-linear readout hidden_irreps"
raise ValueError(msg)
return str(last_readout.hidden_irreps)
def _infer_scale(mace_model: ScaleShiftMACE) -> float:
scale_state = mace_model.scale_shift.state_dict()
return float(scale_state["scale"].detach().cpu().item())
def _infer_avg_num_neighbors(mace_model: ScaleShiftMACE) -> float:
return float(mace_model.interactions[0].avg_num_neighbors)
def _validate_checkpoint_scope(mace_model: ScaleShiftMACE) -> None:
atomic_numbers = mace_model.atomic_numbers.tolist()
_validate_atomic_numbers(atomic_numbers)
heads = getattr(mace_model, "heads", None)
if heads not in (None, ["Default"]):
msg = f"Multi-head checkpoints are unsupported: heads={heads}"
raise ValueError(msg)
if (
hasattr(mace_model, "embedding_specs")
and mace_model.embedding_specs is not None
):
msg = "Joint-embedding checkpoints are unsupported by the conservative loader"
raise ValueError(msg)
if bool(getattr(mace_model, "pair_repulsion", False)):
msg = "Pair-repulsion checkpoints are unsupported by the conservative loader"
raise ValueError(msg)
def _infer_deepmd_config(mace_model: ScaleShiftMACE) -> _InferredMaceConfig:
_validate_checkpoint_scope(mace_model)
atomic_numbers = mace_model.atomic_numbers.tolist()
return {
"type_map": [ELEMENTS[z - 1] for z in atomic_numbers],
"r_max": float(mace_model.r_max),
"num_radial_basis": len(mace_model.radial_embedding.bessel_fn.bessel_weights),
"num_cutoff_basis": int(mace_model.radial_embedding.cutoff_fn.p.item()),
"max_ell": _infer_max_ell(mace_model),
"interaction": _infer_interaction_name(mace_model),
"num_interactions": int(mace_model.num_interactions),
"hidden_irreps": _infer_hidden_irreps(mace_model),
"pair_repulsion": False,
"distance_transform": _infer_distance_transform(mace_model),
"correlation": _infer_correlation(mace_model),
"gate": _infer_gate_name(mace_model),
"MLP_irreps": _infer_mlp_irreps(mace_model),
"radial_type": _infer_radial_type(mace_model),
"radial_MLP": _infer_radial_mlp(mace_model),
"std": _infer_scale(mace_model),
"avg_num_neighbors": _infer_avg_num_neighbors(mace_model),
}
def _load_mace_checkpoint(model_path: Path, device: str) -> ScaleShiftMACE:
"""Load a trusted local MACE checkpoint object.
Notes
-----
This function uses ``torch.load(..., weights_only=False)`` because the
conservative loader needs access to the original ``ScaleShiftMACE`` object,
not just a plain state dict. That implies normal Python unpickling semantics:
callers should only use trusted checkpoint files, whether downloaded from the
official download helper or supplied via a trusted local path.
"""
model = torch.load(str(model_path), map_location=device, weights_only=False)
if not isinstance(model, ScaleShiftMACE):
msg = (
"Loaded checkpoint is not a ScaleShiftMACE model: "
f"{model.__class__.__module__}.{model.__class__.__name__}"
)
raise TypeError(msg)
return model
def _validate_load_result(load_result: object) -> None:
missing_keys = list(getattr(load_result, "missing_keys", []))
unexpected_keys = list(getattr(load_result, "unexpected_keys", []))
disallowed_missing_keys = [
key
for key in missing_keys
if not key.endswith(_ALLOWED_MISSING_STATE_DICT_SUFFIXES)
]
if disallowed_missing_keys or unexpected_keys:
msg = (
"Failed to load MACE checkpoint into DeePMD-GNN wrapper. "
f"missing={disallowed_missing_keys}, unexpected={unexpected_keys}"
)
raise RuntimeError(msg)
[docs]
def load_mace_off_model(
model_name: str | None = "small",
*,
sel: int,
model_path: Path | None = None,
cache_dir: Path | None = None,
device: str = "cpu",
) -> MaceModel:
"""Load a supported MACE-OFF checkpoint as a DeePMD-GNN ``MaceModel``.
Parameters
----------
model_name
Name of the official MACE-OFF checkpoint to download. Ignored when
``model_path`` is provided.
sel
Neighbor-list cap for the DeePMD-GNN wrapper. This value is *required*:
MACE-OFF checkpoints do not store DeePMD's runtime neighbor-list limit,
and guessing it can silently truncate neighbors.
model_path
Local checkpoint path. If omitted, download the selected official model.
cache_dir
Cache directory for downloaded checkpoints.
device
Device used when loading the original checkpoint.
Notes
-----
This helper intentionally supports a narrower scope than DeePMD-GNN's full
DPRc machinery. It only infers ordinary element ``type_map`` entries from
``atomic_numbers`` and does not infer ``mH`` / ``HW`` / ``OW`` or other
QM/MM-specific type semantics.
"""
if sel <= 0:
msg = f"sel must be positive, got {sel}"
raise ValueError(msg)
if model_path is None:
if model_name is None:
msg = "Either model_name or model_path must be provided"
raise ValueError(msg)
model_path = download_mace_off_model(model_name, cache_dir=cache_dir)
else:
model_path = Path(model_path)
mace_model = _load_mace_checkpoint(model_path, device=device)
config = _infer_deepmd_config(mace_model)
source_dtype = mace_model.atomic_energies_fn.atomic_energies.dtype
with _temporary_default_dtype(source_dtype):
deepmd_model = MaceModel(
type_map=config["type_map"],
sel=sel,
r_max=config["r_max"],
num_radial_basis=config["num_radial_basis"],
num_cutoff_basis=config["num_cutoff_basis"],
max_ell=config["max_ell"],
interaction=config["interaction"],
num_interactions=config["num_interactions"],
hidden_irreps=config["hidden_irreps"],
pair_repulsion=config["pair_repulsion"],
distance_transform=config["distance_transform"],
correlation=config["correlation"],
gate=config["gate"],
MLP_irreps=config["MLP_irreps"],
radial_type=config["radial_type"],
radial_MLP=config["radial_MLP"],
std=config["std"],
avg_num_neighbors=config["avg_num_neighbors"],
)
load_result = deepmd_model.model.load_state_dict(
mace_model.state_dict(),
strict=False,
)
_validate_load_result(load_result)
deepmd_model.model = mace_model
deepmd_model.eval()
return deepmd_model
[docs]
def convert_mace_off_to_deepmd(
output_file: str,
*,
sel: int,
model_name: str | None = "small",
model_path: Path | None = None,
cache_dir: Path | None = None,
device: str = "cpu",
) -> Path:
"""Serialize a loaded MACE-OFF wrapper as a TorchScript DeePMD-GNN model.
This helper scripts the DeePMD-GNN wrapper and writes it to ``output_file``.
It does not, by itself, prove end-to-end downstream deployment in external
engines such as LAMMPS or AMBER; callers should still validate the final
deployment path they care about.
"""
model = load_mace_off_model(
model_name=model_name,
sel=sel,
model_path=model_path,
cache_dir=cache_dir,
device=device,
)
output_path = Path(output_file)
model.model = e3nn_script(model.model)
scripted_model = torch.jit.script(model)
torch.jit.save(scripted_model, str(output_path))
return output_path
__all__ = [
"convert_mace_off_to_deepmd",
"download_mace_off_model",
"load_mace_off_model",
]