Source code for deepmd.pt.entrypoints.freeze_pt2

# SPDX-License-Identifier: LGPL-3.0-or-later
"""DPA4 / SeZM → AOTInductor ``.pt2`` freeze path for the pt backend.

SeZM relies on a nested ``autograd.grad(create_graph=True)`` inside
``fit_output_to_model_output``; TorchScript cannot represent that
graph, so DPA4 / SeZM checkpoints are routed through AOTInductor instead.
The output archive layout follows the ``pt_expt`` convention, including the
metadata consumed by ``DeepPotPTExpt.cc`` and ``DeepSpinPTExpt.cc``.

Tracing runs on CPU (``make_fx`` with ``_allow_non_fake_inputs=True``
is brittle on CUDA because the proxy-tensor dispatcher does not set
up CUDA streams for the captured parameters).  The compiled package
is moved to the target device via ``move_to_device_pass`` before
``aoti_compile_and_package``.

``.pt2`` I/O is always float64, matching the C++ contract in
``DeepPotPTExpt::compute`` where LAMMPS coordinates are unconditionally
cast to ``torch::kFloat64``.  SeZM's own ``_input_type_cast`` bridges
fp64 inputs to whatever internal compute dtype the checkpoint uses.
"""

from __future__ import (
    annotations,
)

import json
import logging
import zipfile
from copy import (
    deepcopy,
)
from typing import (
    Any,
)

import numpy as np
import torch

from deepmd.dpmodel.utils.nlist import (
    build_neighbor_list,
    extend_coord_with_ghosts,
)
from deepmd.dpmodel.utils.region import (
    normalize_coord,
)
from deepmd.pt.model.model import (
    get_model,
)
from deepmd.pt.train.wrapper import (
    ModelWrapper,
)
from deepmd.pt.utils.env import (
    DEVICE,
)
from deepmd.utils.model_branch_dict import (
    get_model_dict,
)

log = logging.getLogger(__name__)


def _model_has_spin(model: torch.nn.Module) -> bool:
    """Return whether ``model`` uses the spin lower interface."""
    has_spin = getattr(model, "has_spin", False)
    return bool(has_spin() if callable(has_spin) else has_spin)


def _get_model_ntypes(model: torch.nn.Module) -> int:
    """Return atom type count even when the exported type map is empty."""
    type_map = list(model.get_type_map())
    if type_map:
        return len(type_map)
    descriptor = model.get_descriptor()
    return int(descriptor.get_ntypes())


def _model_has_message_passing(model: torch.nn.Module) -> bool:
    """Return whether the regular .pt2 graph requires a real atom mapping."""
    for obj in (
        model,
        getattr(model, "atomic_model", None),
        model.get_descriptor() if hasattr(model, "get_descriptor") else None,
    ):
        if obj is None or not hasattr(obj, "has_message_passing"):
            continue
        try:
            return bool(obj.has_message_passing())
        except (AttributeError, NotImplementedError):
            continue
    return False


def _strip_shape_assertions(graph_module: torch.nn.Module) -> None:
    """Remove deferred shape assertions from spin export graphs.

    The spin lower path slices tensors using both ``nall`` and ``nloc`` after
    virtual atom expansion. ``torch.export`` may turn valid dynamic cases into
    deferred ``Ne(nall, nloc)`` assertions, even though the graph works for both
    NoPBC and ghost-atom inputs. The generic pt_expt spin exporter applies the
    same cleanup.
    """
    graph = graph_module.graph
    for node in list(graph.nodes):
        if (
            node.op == "call_function"
            and node.target is torch.ops.aten._assert_scalar.default
        ):
            graph.erase_node(node)
    graph.eliminate_dead_code()
    graph_module.recompile()


def _extract_state_and_params(
    ckpt: Any,
) -> tuple[dict[str, Any], dict[str, Any]]:
    """Unwrap a ``torch.load`` result into ``(state_dict, model_params)``.

    Accepts both the training-wrapper layout (weights under a top-level
    ``"model"`` key) and a bare state dict.
    """
    inner = ckpt.get("model", ckpt) if isinstance(ckpt, dict) else ckpt
    if not isinstance(inner, dict):
        raise ValueError("Unsupported checkpoint: expected a dict-like state dict.")
    extra = inner.get("_extra_state") or {}
    params = extra.get("model_params")
    if not isinstance(params, dict):
        raise ValueError("Unsupported checkpoint: missing '_extra_state.model_params'.")
    return inner, params


[docs] def is_sezm_checkpoint(ckpt_path: str) -> bool: """Best-effort detection used by the CLI to route DPA4 / SeZM checkpoints. Returns ``False`` for unreadable files or non-SeZM checkpoints; no exception leaks out so the caller can treat this as a pure routing signal. """ try: raw = torch.load(ckpt_path, map_location="cpu", weights_only=False) except Exception: return False try: _, params = _extract_state_and_params(raw) except ValueError: return False if "model_dict" in params: return any( str(branch_params.get("type", "")).lower() in ("sezm", "dpa4") for branch_params in params["model_dict"].values() ) return str(params.get("type", "")).lower() in ("sezm", "dpa4")
def _select_model_head( state_dict: dict[str, Any], params: dict[str, Any], head: str | None, ) -> tuple[dict[str, Any], dict[str, Any]]: """Extract a single selected model branch from a checkpoint.""" if "model_dict" not in params: if head is not None: raise NotImplementedError( "SeZM .pt2 freeze does not yet support head selection for single-task checkpoints; pass head=None." ) return state_dict, params model_alias_dict, _ = get_model_dict(params["model_dict"]) model_keys = list(params["model_dict"]) if head is None and "Default" in model_alias_dict: head = "Default" log.info( "Using default head %s for multitask SeZM freeze.", model_alias_dict[head] ) if head is None: raise ValueError( "Head must be set for multitask SeZM/DPA4 freeze. " f"Available heads are: {model_keys}." ) if head not in model_alias_dict: head_lower = head.lower() for key in model_alias_dict: if key.lower() == head_lower: head = key break if head not in model_alias_dict: raise ValueError( f"No head or alias named {head!r} in model. Available heads are: {model_keys}." ) branch = model_alias_dict[head] branch_params = deepcopy(params["model_dict"][branch]) branch_state: dict[str, Any] = { "_extra_state": deepcopy(state_dict.get("_extra_state", {})), } branch_state["_extra_state"]["model_params"] = branch_params prefix = f"model.{branch}." for key, value in state_dict.items(): if key.startswith(prefix): branch_state[key.replace(prefix, "model.Default.")] = value return branch_state, branch_params def _to_py_list(value: Any) -> Any: """Coerce torch / numpy scalars into JSON-friendly Python values.""" if value is None: return None if isinstance(value, torch.Tensor): return value.detach().cpu().tolist() if isinstance(value, np.ndarray): return value.tolist() if isinstance(value, (list, tuple)): return list(value) if isinstance(value, (int, float, bool, str)): return value raise TypeError(f"Cannot JSON-serialize value of type {type(value)!r}") def _collect_metadata( model: torch.nn.Module, output_keys: list[str], is_spin: bool | None = None, ) -> dict: """Assemble the flat metadata dict expected by :class:`DeepPotPTExpt`. Mirrors the reader contract at ``source/api_cc/src/DeepPotPTExpt.cc`` and the metadata-only load path in ``deepmd.pt_expt.infer.deep_eval.DeepEval``: every field consumed by C++ LAMMPS inference **and** every field consumed by ``DeepEval._init_from_metadata`` must be present here. ``output_keys`` is the insertion order that the loader zips with ``AOTIModelPackageLoader::run``'s flat output vector. """ if is_spin is None: is_spin = _model_has_spin(model) fitting_output_defs: list[dict[str, Any]] = [] for vdef in model.atomic_output_def().get_data().values(): fitting_output_defs.append( { "name": vdef.name, "shape": list(vdef.shape), "reducible": vdef.reducible, "r_differentiable": vdef.r_differentiable, "c_differentiable": vdef.c_differentiable, "atomic": vdef.atomic, # OutputVariableCategory is an IntEnum; force plain int for # deterministic JSON serialisation across Python versions. "category": int(vdef.category), "r_hessian": vdef.r_hessian, "magnetic": bool(vdef.magnetic or (is_spin and vdef.name == "energy")), "intensive": vdef.intensive, } ) metadata = { "type_map": list(model.get_type_map()), "ntypes": _get_model_ntypes(model), "rcut": float(model.get_rcut()), "sel": [int(s) for s in model.get_sel()], "dim_fparam": int(model.get_dim_fparam()), "dim_aparam": int(model.get_dim_aparam()), "dim_chg_spin": int(model.get_dim_chg_spin()), "mixed_types": bool(model.mixed_types()), "has_message_passing": _model_has_message_passing(model), "has_comm_artifact": False, "has_default_fparam": bool(model.has_default_fparam()), "default_fparam": _to_py_list(model.get_default_fparam()), "default_chg_spin": _to_py_list(model.get_default_chg_spin()), "output_keys": list(output_keys), "fitting_output_defs": fitting_output_defs, # sel_type feeds DeepEval.get_sel_type() in metadata-only mode. # SeZM energy models return [] (every type selected). "sel_type": [int(t) for t in model.get_sel_type()], "is_spin": bool(is_spin), } if is_spin: metadata["ntypes_spin"] = int(model.spin.get_ntypes_spin()) metadata["use_spin"] = [bool(v) for v in model.spin.use_spin] return metadata def _make_sample_inputs( model: torch.nn.Module, nframes: int, nloc: int, device: torch.device, has_spin: bool = False, ) -> tuple[torch.Tensor | None, ...]: """Build representative ``forward_common_lower`` inputs for tracing. Tensors are float64 / int64 (matching the ``.pt2`` I/O contract). """ rcut = float(model.get_rcut()) sel = list(model.get_sel()) ntypes = len(model.get_type_map()) if ntypes == 0: ntypes = int(model.get_descriptor().get_ntypes()) if ntypes <= 0: raise ValueError("SeZM .pt2 freeze requires at least one atom type.") dim_fparam = int(model.get_dim_fparam()) dim_aparam = int(model.get_dim_aparam()) dim_chg_spin = int(model.get_dim_chg_spin()) mixed_types = bool(model.mixed_types()) box_size = rcut * 3.0 box = np.eye(3, dtype=np.float64) * box_size box_np = box.reshape(1, 9) rng = np.random.default_rng(42) coord_np = rng.random((nframes, nloc, 3), dtype=np.float64) * box_size * 0.5 coord_np += box_size * 0.25 # centre roughly in the middle of the cell atype_np = np.zeros((nframes, nloc), dtype=np.int32) for i in range(nloc): atype_np[:, i] = i % ntypes spin_np = np.zeros((nframes, nloc, 3), dtype=np.float64) if has_spin: atom_idx = np.arange(nloc, dtype=np.float64).reshape(1, nloc) spin_np[:, :, 0] = 0.10 + 0.01 * atom_idx spin_np[:, :, 1] = 0.20 + 0.02 * atom_idx spin_np[:, :, 2] = 0.05 coord_normalized = normalize_coord( coord_np.reshape(nframes, nloc, 3), np.tile(box.reshape(1, 3, 3), (nframes, 1, 1)), ) extended_coord, extended_atype, mapping = extend_coord_with_ghosts( coord_normalized, atype_np, np.tile(box_np, (nframes, 1)), rcut ) nlist = build_neighbor_list( extended_coord, extended_atype, nloc, rcut, sel, distinguish_types=not mixed_types, ) extended_coord = extended_coord.reshape(nframes, -1, 3) ext_coord = torch.tensor(extended_coord, dtype=torch.float64, device=device) ext_atype = torch.tensor(extended_atype, dtype=torch.int64, device=device) nlist_t = torch.tensor(nlist, dtype=torch.int64, device=device) mapping_t = torch.tensor(mapping, dtype=torch.int64, device=device) if has_spin: extended_spin = np.take_along_axis(spin_np, mapping[..., None], axis=1) ext_spin = torch.tensor(extended_spin, dtype=torch.float64, device=device) fparam = ( torch.zeros(nframes, dim_fparam, dtype=torch.float64, device=device) if dim_fparam > 0 else None ) aparam = ( torch.zeros(nframes, nloc, dim_aparam, dtype=torch.float64, device=device) if dim_aparam > 0 else None ) charge_spin = None if dim_chg_spin > 0: charge_spin = torch.zeros( nframes, dim_chg_spin, dtype=torch.float64, device=device ) if has_spin: return ( ext_coord, ext_atype, ext_spin, nlist_t, mapping_t, fparam, aparam, charge_spin, ) return ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin def _resolve_nframes( model: torch.nn.Module, nloc: int, device: torch.device, start: int = 2, has_spin: bool = False, ) -> tuple[int, tuple[torch.Tensor | None, ...]]: """Pick an ``nframes`` that does not collide with any other dim size. ``torch.export``'s duck-sizing unifies symbolic dims whose concrete sample values match; if ``nframes`` happens to equal, say, the spatial ``3`` or the virial ``9``, the ExportedProgram rejects later calls whose ``nframes`` differs. Bumping ``nframes`` until no collision is left keeps the export safe. """ nframes = start sample = _make_sample_inputs( model, nframes=nframes, nloc=nloc, device=device, has_spin=has_spin, ) other_dims: set[int] = set() for t in sample: if t is not None: other_dims.update(t.shape[1:]) while nframes in other_dims: nframes += 1 if nframes != start: sample = _make_sample_inputs( model, nframes=nframes, nloc=nloc, device=device, has_spin=has_spin, ) return nframes, sample def _build_dynamic_shapes( sample_inputs: tuple[torch.Tensor | None, ...], ) -> tuple: """Positional ``dynamic_shapes`` for the traced ``(ext_coord, ext_atype, nlist, mapping, fparam, aparam)`` signature. """ nframes_dim = torch.export.Dim("nframes", min=1) has_spin = ( len(sample_inputs) >= 7 and sample_inputs[2] is not None and sample_inputs[2].is_floating_point() ) has_charge_spin = (has_spin and len(sample_inputs) == 8) or ( not has_spin and len(sample_inputs) == 7 ) # Spin export currently generates a valid lower-bound guard from its # virtual-atom split/concat pattern. Matching the bound keeps export strict, # while `_strip_shape_assertions` removes the spurious deferred guards later. nall_dim = torch.export.Dim("nall", min=4 if has_spin else 1) nloc_dim = torch.export.Dim("nloc", min=1) fparam = sample_inputs[5] if has_spin else sample_inputs[4] aparam = sample_inputs[6] if has_spin else sample_inputs[5] charge_spin = None if has_charge_spin: charge_spin = sample_inputs[7] if has_spin else sample_inputs[6] if has_spin: shapes = ( {0: nframes_dim, 1: nall_dim}, # extended_coord {0: nframes_dim, 1: nall_dim}, # extended_atype {0: nframes_dim, 1: nall_dim}, # extended_spin {0: nframes_dim, 1: nloc_dim}, # nlist {0: nframes_dim, 1: nall_dim}, # mapping {0: nframes_dim} if fparam is not None else None, {0: nframes_dim, 1: nloc_dim} if aparam is not None else None, ) if has_charge_spin: shapes = (*shapes, {0: nframes_dim} if charge_spin is not None else None) return shapes shapes = ( {0: nframes_dim, 1: nall_dim}, # extended_coord: (nframes, nall, 3) {0: nframes_dim, 1: nall_dim}, # extended_atype: (nframes, nall) {0: nframes_dim, 1: nloc_dim}, # nlist: (nframes, nloc, nnei) {0: nframes_dim, 1: nall_dim}, # mapping: (nframes, nall) {0: nframes_dim} if fparam is not None else None, {0: nframes_dim, 1: nloc_dim} if aparam is not None else None, ) if has_charge_spin: shapes = (*shapes, {0: nframes_dim} if charge_spin is not None else None) return shapes
[docs] def freeze_sezm_to_pt2( ckpt_path: str, out_path: str, *, device: torch.device | None = None, head: str | None = None, ) -> None: """Freeze a SeZM checkpoint into an AOTInductor ``.pt2`` archive. Parameters ---------- ckpt_path Path to the SeZM training checkpoint (``.pt``). out_path Destination file. A ``.pt2`` suffix is expected. device Target device for the compiled shared library. Defaults to :data:`DEVICE`. Tracing itself always runs on CPU. head Model head to export from a multi-task checkpoint. If omitted, the ``Default`` head is used when present; otherwise multi-task checkpoints must pass an explicit head. Single-task checkpoints must pass ``None``. """ from torch._inductor import ( aoti_compile_and_package, ) target_device = device if device is not None else DEVICE raw = torch.load(ckpt_path, map_location="cpu", weights_only=False) state_dict, params = _extract_state_and_params(raw) state_dict, params = _select_model_head(state_dict, params, head) model_type = str(params.get("type", "")).lower() if model_type not in ("sezm", "dpa4"): raise ValueError( f"freeze_sezm_to_pt2 expects a SeZM/DPA4 checkpoint, got type={params.get('type')!r}." ) model = get_model(params) is_spin = _model_has_spin(model) ModelWrapper(model).load_state_dict(state_dict) model.eval() model.to("cpu") _, sample_inputs_cpu = _resolve_nframes( model, nloc=7, device=torch.device("cpu"), has_spin=is_spin, ) # do_atomic_virial=True pulls every key that DeepPotPTExpt may read # (energy, energy_redu, energy_derv_r, energy_derv_c, energy_derv_c_redu) # into the traced graph. if is_spin: ( ext_coord, ext_atype, ext_spin, nlist_t, mapping_t, fparam, aparam, charge_spin, ) = sample_inputs_cpu traced = model.forward_common_lower_exportable( ext_coord, ext_atype, ext_spin, nlist_t, mapping_t, fparam=fparam, aparam=aparam, charge_spin=charge_spin, do_atomic_virial=True, ) else: ( ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin, ) = sample_inputs_cpu traced = model.forward_common_lower_exportable( ext_coord, ext_atype, nlist_t, mapping_t, fparam=fparam, aparam=aparam, charge_spin=charge_spin, do_atomic_virial=True, ) # Output key order is taken from a concrete run; Python dict order # is stable and matches what DeepPotPTExpt::extract_outputs zips # against AOTIModelPackageLoader::run's output vector. with torch.no_grad(): sample_out = traced(*sample_inputs_cpu) output_keys = list(sample_out.keys()) exported = torch.export.export( traced, sample_inputs_cpu, dynamic_shapes=_build_dynamic_shapes(sample_inputs_cpu), strict=False, prefer_deferred_runtime_asserts_over_guards=True, ) if is_spin: _strip_shape_assertions(exported.graph_module) # move_to_device_pass handles FakeTensor device propagation cleanly; # a naive .to(device) on the exported program does not. if target_device.type != "cpu": from torch.export.passes import ( move_to_device_pass, ) exported = move_to_device_pass(exported, target_device) out_path_str = str(out_path) aoti_compile_and_package(exported, package_path=out_path_str) metadata = _collect_metadata(model, output_keys=output_keys, is_spin=is_spin) with zipfile.ZipFile(out_path_str, "a") as zf: zf.writestr("model/extra/metadata.json", json.dumps(metadata)) # The raw training params are preserved so `dp change-bias` and # other downstream tooling can recover the exact training config. # ``default=str`` is a safety net for exotic nested values. zf.writestr( "model/extra/model_def_script.json", json.dumps(params, default=str), ) log.info( "Saved SeZM .pt2 to %s (device=%s, output_keys=%s)", out_path_str, target_device, output_keys, )
__all__ = [ "freeze_sezm_to_pt2", "is_sezm_checkpoint", ]