# SPDX-License-Identifier: LGPL-3.0-or-later
"""DPA4 / SeZM → AOTInductor ``.pt2`` freeze path for the pt backend.
SeZM relies on a nested ``autograd.grad(create_graph=True)`` inside
``fit_output_to_model_output``; TorchScript cannot represent that
graph, so DPA4 / SeZM checkpoints are routed through AOTInductor instead.
The output archive layout follows the ``pt_expt`` convention, including the
metadata consumed by ``DeepPotPTExpt.cc`` and ``DeepSpinPTExpt.cc``.
Tracing runs on CPU (``make_fx`` with ``_allow_non_fake_inputs=True``
is brittle on CUDA because the proxy-tensor dispatcher does not set
up CUDA streams for the captured parameters). The compiled package
is moved to the target device via ``move_to_device_pass`` before
``aoti_compile_and_package``.
``.pt2`` I/O is always float64, matching the C++ contract in
``DeepPotPTExpt::compute`` where LAMMPS coordinates are unconditionally
cast to ``torch::kFloat64``. SeZM's own ``_input_type_cast`` bridges
fp64 inputs to whatever internal compute dtype the checkpoint uses.
"""
from __future__ import (
annotations,
)
import json
import logging
import zipfile
from copy import (
deepcopy,
)
from typing import (
Any,
)
import numpy as np
import torch
from deepmd.dpmodel.utils.nlist import (
build_neighbor_list,
extend_coord_with_ghosts,
)
from deepmd.dpmodel.utils.region import (
normalize_coord,
)
from deepmd.pt.model.model import (
get_model,
)
from deepmd.pt.train.wrapper import (
ModelWrapper,
)
from deepmd.pt.utils.env import (
DEVICE,
)
from deepmd.utils.model_branch_dict import (
get_model_dict,
)
log = logging.getLogger(__name__)
def _model_has_spin(model: torch.nn.Module) -> bool:
"""Return whether ``model`` uses the spin lower interface."""
has_spin = getattr(model, "has_spin", False)
return bool(has_spin() if callable(has_spin) else has_spin)
def _get_model_ntypes(model: torch.nn.Module) -> int:
"""Return atom type count even when the exported type map is empty."""
type_map = list(model.get_type_map())
if type_map:
return len(type_map)
descriptor = model.get_descriptor()
return int(descriptor.get_ntypes())
def _model_has_message_passing(model: torch.nn.Module) -> bool:
"""Return whether the regular .pt2 graph requires a real atom mapping."""
for obj in (
model,
getattr(model, "atomic_model", None),
model.get_descriptor() if hasattr(model, "get_descriptor") else None,
):
if obj is None or not hasattr(obj, "has_message_passing"):
continue
try:
return bool(obj.has_message_passing())
except (AttributeError, NotImplementedError):
continue
return False
def _strip_shape_assertions(graph_module: torch.nn.Module) -> None:
"""Remove deferred shape assertions from spin export graphs.
The spin lower path slices tensors using both ``nall`` and ``nloc`` after
virtual atom expansion. ``torch.export`` may turn valid dynamic cases into
deferred ``Ne(nall, nloc)`` assertions, even though the graph works for both
NoPBC and ghost-atom inputs. The generic pt_expt spin exporter applies the
same cleanup.
"""
graph = graph_module.graph
for node in list(graph.nodes):
if (
node.op == "call_function"
and node.target is torch.ops.aten._assert_scalar.default
):
graph.erase_node(node)
graph.eliminate_dead_code()
graph_module.recompile()
def _extract_state_and_params(
ckpt: Any,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Unwrap a ``torch.load`` result into ``(state_dict, model_params)``.
Accepts both the training-wrapper layout (weights under a top-level
``"model"`` key) and a bare state dict.
"""
inner = ckpt.get("model", ckpt) if isinstance(ckpt, dict) else ckpt
if not isinstance(inner, dict):
raise ValueError("Unsupported checkpoint: expected a dict-like state dict.")
extra = inner.get("_extra_state") or {}
params = extra.get("model_params")
if not isinstance(params, dict):
raise ValueError("Unsupported checkpoint: missing '_extra_state.model_params'.")
return inner, params
[docs]
def is_sezm_checkpoint(ckpt_path: str) -> bool:
"""Best-effort detection used by the CLI to route DPA4 / SeZM checkpoints.
Returns ``False`` for unreadable files or non-SeZM checkpoints; no
exception leaks out so the caller can treat this as a pure routing
signal.
"""
try:
raw = torch.load(ckpt_path, map_location="cpu", weights_only=False)
except Exception:
return False
try:
_, params = _extract_state_and_params(raw)
except ValueError:
return False
if "model_dict" in params:
return any(
str(branch_params.get("type", "")).lower() in ("sezm", "dpa4")
for branch_params in params["model_dict"].values()
)
return str(params.get("type", "")).lower() in ("sezm", "dpa4")
def _select_model_head(
state_dict: dict[str, Any],
params: dict[str, Any],
head: str | None,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Extract a single selected model branch from a checkpoint."""
if "model_dict" not in params:
if head is not None:
raise NotImplementedError(
"SeZM .pt2 freeze does not yet support head selection for single-task checkpoints; pass head=None."
)
return state_dict, params
model_alias_dict, _ = get_model_dict(params["model_dict"])
model_keys = list(params["model_dict"])
if head is None and "Default" in model_alias_dict:
head = "Default"
log.info(
"Using default head %s for multitask SeZM freeze.", model_alias_dict[head]
)
if head is None:
raise ValueError(
"Head must be set for multitask SeZM/DPA4 freeze. "
f"Available heads are: {model_keys}."
)
if head not in model_alias_dict:
head_lower = head.lower()
for key in model_alias_dict:
if key.lower() == head_lower:
head = key
break
if head not in model_alias_dict:
raise ValueError(
f"No head or alias named {head!r} in model. Available heads are: {model_keys}."
)
branch = model_alias_dict[head]
branch_params = deepcopy(params["model_dict"][branch])
branch_state: dict[str, Any] = {
"_extra_state": deepcopy(state_dict.get("_extra_state", {})),
}
branch_state["_extra_state"]["model_params"] = branch_params
prefix = f"model.{branch}."
for key, value in state_dict.items():
if key.startswith(prefix):
branch_state[key.replace(prefix, "model.Default.")] = value
return branch_state, branch_params
def _to_py_list(value: Any) -> Any:
"""Coerce torch / numpy scalars into JSON-friendly Python values."""
if value is None:
return None
if isinstance(value, torch.Tensor):
return value.detach().cpu().tolist()
if isinstance(value, np.ndarray):
return value.tolist()
if isinstance(value, (list, tuple)):
return list(value)
if isinstance(value, (int, float, bool, str)):
return value
raise TypeError(f"Cannot JSON-serialize value of type {type(value)!r}")
def _collect_metadata(
model: torch.nn.Module,
output_keys: list[str],
is_spin: bool | None = None,
) -> dict:
"""Assemble the flat metadata dict expected by :class:`DeepPotPTExpt`.
Mirrors the reader contract at ``source/api_cc/src/DeepPotPTExpt.cc`` and
the metadata-only load path in ``deepmd.pt_expt.infer.deep_eval.DeepEval``:
every field consumed by C++ LAMMPS inference **and** every field
consumed by ``DeepEval._init_from_metadata`` must be present here.
``output_keys`` is the insertion order that the loader zips with
``AOTIModelPackageLoader::run``'s flat output vector.
"""
if is_spin is None:
is_spin = _model_has_spin(model)
fitting_output_defs: list[dict[str, Any]] = []
for vdef in model.atomic_output_def().get_data().values():
fitting_output_defs.append(
{
"name": vdef.name,
"shape": list(vdef.shape),
"reducible": vdef.reducible,
"r_differentiable": vdef.r_differentiable,
"c_differentiable": vdef.c_differentiable,
"atomic": vdef.atomic,
# OutputVariableCategory is an IntEnum; force plain int for
# deterministic JSON serialisation across Python versions.
"category": int(vdef.category),
"r_hessian": vdef.r_hessian,
"magnetic": bool(vdef.magnetic or (is_spin and vdef.name == "energy")),
"intensive": vdef.intensive,
}
)
metadata = {
"type_map": list(model.get_type_map()),
"ntypes": _get_model_ntypes(model),
"rcut": float(model.get_rcut()),
"sel": [int(s) for s in model.get_sel()],
"dim_fparam": int(model.get_dim_fparam()),
"dim_aparam": int(model.get_dim_aparam()),
"dim_chg_spin": int(model.get_dim_chg_spin()),
"mixed_types": bool(model.mixed_types()),
"has_message_passing": _model_has_message_passing(model),
"has_comm_artifact": False,
"has_default_fparam": bool(model.has_default_fparam()),
"default_fparam": _to_py_list(model.get_default_fparam()),
"default_chg_spin": _to_py_list(model.get_default_chg_spin()),
"output_keys": list(output_keys),
"fitting_output_defs": fitting_output_defs,
# sel_type feeds DeepEval.get_sel_type() in metadata-only mode.
# SeZM energy models return [] (every type selected).
"sel_type": [int(t) for t in model.get_sel_type()],
"is_spin": bool(is_spin),
}
if is_spin:
metadata["ntypes_spin"] = int(model.spin.get_ntypes_spin())
metadata["use_spin"] = [bool(v) for v in model.spin.use_spin]
return metadata
def _make_sample_inputs(
model: torch.nn.Module,
nframes: int,
nloc: int,
device: torch.device,
has_spin: bool = False,
) -> tuple[torch.Tensor | None, ...]:
"""Build representative ``forward_common_lower`` inputs for tracing.
Tensors are float64 / int64 (matching the ``.pt2`` I/O contract).
"""
rcut = float(model.get_rcut())
sel = list(model.get_sel())
ntypes = len(model.get_type_map())
if ntypes == 0:
ntypes = int(model.get_descriptor().get_ntypes())
if ntypes <= 0:
raise ValueError("SeZM .pt2 freeze requires at least one atom type.")
dim_fparam = int(model.get_dim_fparam())
dim_aparam = int(model.get_dim_aparam())
dim_chg_spin = int(model.get_dim_chg_spin())
mixed_types = bool(model.mixed_types())
box_size = rcut * 3.0
box = np.eye(3, dtype=np.float64) * box_size
box_np = box.reshape(1, 9)
rng = np.random.default_rng(42)
coord_np = rng.random((nframes, nloc, 3), dtype=np.float64) * box_size * 0.5
coord_np += box_size * 0.25 # centre roughly in the middle of the cell
atype_np = np.zeros((nframes, nloc), dtype=np.int32)
for i in range(nloc):
atype_np[:, i] = i % ntypes
spin_np = np.zeros((nframes, nloc, 3), dtype=np.float64)
if has_spin:
atom_idx = np.arange(nloc, dtype=np.float64).reshape(1, nloc)
spin_np[:, :, 0] = 0.10 + 0.01 * atom_idx
spin_np[:, :, 1] = 0.20 + 0.02 * atom_idx
spin_np[:, :, 2] = 0.05
coord_normalized = normalize_coord(
coord_np.reshape(nframes, nloc, 3),
np.tile(box.reshape(1, 3, 3), (nframes, 1, 1)),
)
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
coord_normalized, atype_np, np.tile(box_np, (nframes, 1)), rcut
)
nlist = build_neighbor_list(
extended_coord,
extended_atype,
nloc,
rcut,
sel,
distinguish_types=not mixed_types,
)
extended_coord = extended_coord.reshape(nframes, -1, 3)
ext_coord = torch.tensor(extended_coord, dtype=torch.float64, device=device)
ext_atype = torch.tensor(extended_atype, dtype=torch.int64, device=device)
nlist_t = torch.tensor(nlist, dtype=torch.int64, device=device)
mapping_t = torch.tensor(mapping, dtype=torch.int64, device=device)
if has_spin:
extended_spin = np.take_along_axis(spin_np, mapping[..., None], axis=1)
ext_spin = torch.tensor(extended_spin, dtype=torch.float64, device=device)
fparam = (
torch.zeros(nframes, dim_fparam, dtype=torch.float64, device=device)
if dim_fparam > 0
else None
)
aparam = (
torch.zeros(nframes, nloc, dim_aparam, dtype=torch.float64, device=device)
if dim_aparam > 0
else None
)
charge_spin = None
if dim_chg_spin > 0:
charge_spin = torch.zeros(
nframes, dim_chg_spin, dtype=torch.float64, device=device
)
if has_spin:
return (
ext_coord,
ext_atype,
ext_spin,
nlist_t,
mapping_t,
fparam,
aparam,
charge_spin,
)
return ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin
def _resolve_nframes(
model: torch.nn.Module,
nloc: int,
device: torch.device,
start: int = 2,
has_spin: bool = False,
) -> tuple[int, tuple[torch.Tensor | None, ...]]:
"""Pick an ``nframes`` that does not collide with any other dim size.
``torch.export``'s duck-sizing unifies symbolic dims whose concrete
sample values match; if ``nframes`` happens to equal, say, the
spatial ``3`` or the virial ``9``, the ExportedProgram rejects
later calls whose ``nframes`` differs. Bumping ``nframes`` until
no collision is left keeps the export safe.
"""
nframes = start
sample = _make_sample_inputs(
model,
nframes=nframes,
nloc=nloc,
device=device,
has_spin=has_spin,
)
other_dims: set[int] = set()
for t in sample:
if t is not None:
other_dims.update(t.shape[1:])
while nframes in other_dims:
nframes += 1
if nframes != start:
sample = _make_sample_inputs(
model,
nframes=nframes,
nloc=nloc,
device=device,
has_spin=has_spin,
)
return nframes, sample
def _build_dynamic_shapes(
sample_inputs: tuple[torch.Tensor | None, ...],
) -> tuple:
"""Positional ``dynamic_shapes`` for the traced
``(ext_coord, ext_atype, nlist, mapping, fparam, aparam)`` signature.
"""
nframes_dim = torch.export.Dim("nframes", min=1)
has_spin = (
len(sample_inputs) >= 7
and sample_inputs[2] is not None
and sample_inputs[2].is_floating_point()
)
has_charge_spin = (has_spin and len(sample_inputs) == 8) or (
not has_spin and len(sample_inputs) == 7
)
# Spin export currently generates a valid lower-bound guard from its
# virtual-atom split/concat pattern. Matching the bound keeps export strict,
# while `_strip_shape_assertions` removes the spurious deferred guards later.
nall_dim = torch.export.Dim("nall", min=4 if has_spin else 1)
nloc_dim = torch.export.Dim("nloc", min=1)
fparam = sample_inputs[5] if has_spin else sample_inputs[4]
aparam = sample_inputs[6] if has_spin else sample_inputs[5]
charge_spin = None
if has_charge_spin:
charge_spin = sample_inputs[7] if has_spin else sample_inputs[6]
if has_spin:
shapes = (
{0: nframes_dim, 1: nall_dim}, # extended_coord
{0: nframes_dim, 1: nall_dim}, # extended_atype
{0: nframes_dim, 1: nall_dim}, # extended_spin
{0: nframes_dim, 1: nloc_dim}, # nlist
{0: nframes_dim, 1: nall_dim}, # mapping
{0: nframes_dim} if fparam is not None else None,
{0: nframes_dim, 1: nloc_dim} if aparam is not None else None,
)
if has_charge_spin:
shapes = (*shapes, {0: nframes_dim} if charge_spin is not None else None)
return shapes
shapes = (
{0: nframes_dim, 1: nall_dim}, # extended_coord: (nframes, nall, 3)
{0: nframes_dim, 1: nall_dim}, # extended_atype: (nframes, nall)
{0: nframes_dim, 1: nloc_dim}, # nlist: (nframes, nloc, nnei)
{0: nframes_dim, 1: nall_dim}, # mapping: (nframes, nall)
{0: nframes_dim} if fparam is not None else None,
{0: nframes_dim, 1: nloc_dim} if aparam is not None else None,
)
if has_charge_spin:
shapes = (*shapes, {0: nframes_dim} if charge_spin is not None else None)
return shapes
[docs]
def freeze_sezm_to_pt2(
ckpt_path: str,
out_path: str,
*,
device: torch.device | None = None,
head: str | None = None,
) -> None:
"""Freeze a SeZM checkpoint into an AOTInductor ``.pt2`` archive.
Parameters
----------
ckpt_path
Path to the SeZM training checkpoint (``.pt``).
out_path
Destination file. A ``.pt2`` suffix is expected.
device
Target device for the compiled shared library. Defaults to
:data:`DEVICE`. Tracing itself always runs on CPU.
head
Model head to export from a multi-task checkpoint. If omitted, the
``Default`` head is used when present; otherwise multi-task checkpoints
must pass an explicit head. Single-task checkpoints must pass ``None``.
"""
from torch._inductor import (
aoti_compile_and_package,
)
target_device = device if device is not None else DEVICE
raw = torch.load(ckpt_path, map_location="cpu", weights_only=False)
state_dict, params = _extract_state_and_params(raw)
state_dict, params = _select_model_head(state_dict, params, head)
model_type = str(params.get("type", "")).lower()
if model_type not in ("sezm", "dpa4"):
raise ValueError(
f"freeze_sezm_to_pt2 expects a SeZM/DPA4 checkpoint, got type={params.get('type')!r}."
)
model = get_model(params)
is_spin = _model_has_spin(model)
ModelWrapper(model).load_state_dict(state_dict)
model.eval()
model.to("cpu")
_, sample_inputs_cpu = _resolve_nframes(
model,
nloc=7,
device=torch.device("cpu"),
has_spin=is_spin,
)
# do_atomic_virial=True pulls every key that DeepPotPTExpt may read
# (energy, energy_redu, energy_derv_r, energy_derv_c, energy_derv_c_redu)
# into the traced graph.
if is_spin:
(
ext_coord,
ext_atype,
ext_spin,
nlist_t,
mapping_t,
fparam,
aparam,
charge_spin,
) = sample_inputs_cpu
traced = model.forward_common_lower_exportable(
ext_coord,
ext_atype,
ext_spin,
nlist_t,
mapping_t,
fparam=fparam,
aparam=aparam,
charge_spin=charge_spin,
do_atomic_virial=True,
)
else:
(
ext_coord,
ext_atype,
nlist_t,
mapping_t,
fparam,
aparam,
charge_spin,
) = sample_inputs_cpu
traced = model.forward_common_lower_exportable(
ext_coord,
ext_atype,
nlist_t,
mapping_t,
fparam=fparam,
aparam=aparam,
charge_spin=charge_spin,
do_atomic_virial=True,
)
# Output key order is taken from a concrete run; Python dict order
# is stable and matches what DeepPotPTExpt::extract_outputs zips
# against AOTIModelPackageLoader::run's output vector.
with torch.no_grad():
sample_out = traced(*sample_inputs_cpu)
output_keys = list(sample_out.keys())
exported = torch.export.export(
traced,
sample_inputs_cpu,
dynamic_shapes=_build_dynamic_shapes(sample_inputs_cpu),
strict=False,
prefer_deferred_runtime_asserts_over_guards=True,
)
if is_spin:
_strip_shape_assertions(exported.graph_module)
# move_to_device_pass handles FakeTensor device propagation cleanly;
# a naive .to(device) on the exported program does not.
if target_device.type != "cpu":
from torch.export.passes import (
move_to_device_pass,
)
exported = move_to_device_pass(exported, target_device)
out_path_str = str(out_path)
aoti_compile_and_package(exported, package_path=out_path_str)
metadata = _collect_metadata(model, output_keys=output_keys, is_spin=is_spin)
with zipfile.ZipFile(out_path_str, "a") as zf:
zf.writestr("model/extra/metadata.json", json.dumps(metadata))
# The raw training params are preserved so `dp change-bias` and
# other downstream tooling can recover the exact training config.
# ``default=str`` is a safety net for exotic nested values.
zf.writestr(
"model/extra/model_def_script.json",
json.dumps(params, default=str),
)
log.info(
"Saved SeZM .pt2 to %s (device=%s, output_keys=%s)",
out_path_str,
target_device,
output_keys,
)
__all__ = [
"freeze_sezm_to_pt2",
"is_sezm_checkpoint",
]