deepmd_gnn.mace_network
=======================

.. py:module:: deepmd_gnn.mace_network

.. autoapi-nested-parse::

   MACE network construction and freeze-time conversion helpers.

   This module keeps MACE-package glue out of the DeePMD ``MaceModel`` wrapper.
   The functions here build raw/scripted ``ScaleShiftMACE`` instances, preserve
   parameter sharing for compile traces, and convert cuEquivariance checkpoints
   back to e3nn weights when DeePMD freezes an exportable model.



Functions
---------

.. autoapisummary::

   deepmd_gnn.mace_network.make_cueq_config
   deepmd_gnn.mace_network.disable_cueq_in_model_params
   deepmd_gnn.mace_network.transfer_cueq_to_e3nn
   deepmd_gnn.mace_network.make_mace_network
   deepmd_gnn.mace_network._set_module_state_tensor
   deepmd_gnn.mace_network.link_module_state


Module Contents
---------------

.. py:function:: make_cueq_config(enable_cueq: bool) -> object | None

   Build the MACE cuEquivariance config when requested.


.. py:function:: disable_cueq_in_model_params(model_params: collections.abc.MutableMapping[str, Any]) -> bool

   Disable MACE cuEquivariance flags in model definition metadata.


.. py:function:: transfer_cueq_to_e3nn(source_model: torch.nn.Module, target_model: torch.nn.Module, *, hidden_irreps: str, correlation: int, num_interactions: int) -> None

   Transfer MACE cuEquivariance weights to an e3nn MACE model.


.. py:function:: make_mace_network(*, r_max: float, num_radial_basis: int, num_cutoff_basis: int, max_ell: int, interaction: str, num_interactions: int, num_elements: int, hidden_irreps: str, atomic_numbers: list[int], avg_num_neighbors: float, pair_repulsion: bool, distance_transform: str, correlation: int, gate: str, MLP_irreps: str, std: float, radial_MLP: list[int], radial_type: str, enable_cueq: bool, script_model: bool) -> torch.nn.Module

   Create a configured MACE subnetwork for DeePMD-GNN.


.. py:function:: _set_module_state_tensor(root: torch.nn.Module, name: str, tensor: torch.Tensor) -> None

.. py:function:: link_module_state(target: torch.nn.Module, source: torch.nn.Module) -> None

   Point ``target`` parameters and buffers at the tensors owned by ``source``.


