Source code for deepmd_gnn.pt_expt
"""PyTorch exportable backend plugin registration."""
import sys
[docs]
def load() -> None:
"""Entry point placeholder; importing this module registers plugins."""
[docs]
def _is_partially_initialized(module_name: str, attr_name: str) -> bool:
module = sys.modules.get(module_name)
return module is not None and not hasattr(module, attr_name)
[docs]
def _register() -> None:
if _is_partially_initialized(
"deepmd_gnn.mace",
"MaceModel",
):
return
from deepmd.pt_expt.model.model import ( # noqa: PLC0415
BaseModel as ExportableBaseModel,
)
from deepmd_gnn.mace import MaceModel # noqa: PLC0415
# NeQuIP 0.6/e3nn specializes atom and edge counts during torch.export.
# Keep pt_expt registration limited to models with dynamic-shape export.
ExportableBaseModel.register("mace")(MaceModel)
_register()