deepmd_gnn.mace_network#

MACE network construction and freeze-time conversion helpers.

This module keeps MACE-package glue out of the DeePMD MaceModel wrapper. The functions here build raw/scripted ScaleShiftMACE instances, preserve parameter sharing for compile traces, and convert cuEquivariance checkpoints back to e3nn weights when DeePMD freezes an exportable model.

Functions#

make_cueq_config(→ object | None)

Build the MACE cuEquivariance config when requested.

disable_cueq_in_model_params(→ bool)

Disable MACE cuEquivariance flags in model definition metadata.

transfer_cueq_to_e3nn(→ None)

Transfer MACE cuEquivariance weights to an e3nn MACE model.

make_mace_network(→ torch.nn.Module)

Create a configured MACE subnetwork for DeePMD-GNN.

_set_module_state_tensor(→ None)

link_module_state(→ None)

Point target parameters and buffers at the tensors owned by source.

Module Contents#

deepmd_gnn.mace_network.make_cueq_config(enable_cueq: bool) object | None[source]#

Build the MACE cuEquivariance config when requested.

deepmd_gnn.mace_network.disable_cueq_in_model_params(model_params: collections.abc.MutableMapping[str, Any]) bool[source]#

Disable MACE cuEquivariance flags in model definition metadata.

deepmd_gnn.mace_network.transfer_cueq_to_e3nn(source_model: torch.nn.Module, target_model: torch.nn.Module, *, hidden_irreps: str, correlation: int, num_interactions: int) None[source]#

Transfer MACE cuEquivariance weights to an e3nn MACE model.

deepmd_gnn.mace_network.make_mace_network(*, r_max: float, num_radial_basis: int, num_cutoff_basis: int, max_ell: int, interaction: str, num_interactions: int, num_elements: int, hidden_irreps: str, atomic_numbers: list[int], avg_num_neighbors: float, pair_repulsion: bool, distance_transform: str, correlation: int, gate: str, MLP_irreps: str, std: float, radial_MLP: list[int], radial_type: str, enable_cueq: bool, script_model: bool) torch.nn.Module[source]#

Create a configured MACE subnetwork for DeePMD-GNN.

deepmd_gnn.mace_network._set_module_state_tensor(root: torch.nn.Module, name: str, tensor: torch.Tensor) None[source]#

Point target parameters and buffers at the tensors owned by source.