Source code for deepmd.pt_expt.fitting.dpa4_ener
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
ClassVar,
)
import torch
from deepmd.dpmodel.fitting.dpa4_ener import GLUFittingNet as GLUFittingNetDP
from deepmd.dpmodel.fitting.dpa4_ener import (
SeZMEnergyFittingNet as SeZMEnergyFittingNetDP,
)
from deepmd.dpmodel.fitting.dpa4_ener import (
SeZMNetworkCollection as SeZMNetworkCollectionDP,
)
from deepmd.pt_expt.common import (
register_dpmodel_mapping,
torch_module,
)
from .base_fitting import (
BaseFitting,
)
@torch_module
[docs]
class GLUFittingNet(GLUFittingNetDP):
[docs]
def forward(self, *args: Any, **kwargs: Any) -> Any:
return self.call(*args, **kwargs)
register_dpmodel_mapping(
GLUFittingNetDP,
lambda v: GLUFittingNet.deserialize(v.serialize()),
)
@torch_module
[docs]
class SeZMNetworkCollection(SeZMNetworkCollectionDP):
[docs]
NETWORK_TYPE_MAP: ClassVar[dict[str, type]] = {
"sezm_fitting_network": GLUFittingNet,
}
def __init__(self, *args: Any, **kwargs: Any) -> None:
[docs]
self._module_networks = torch.nn.ModuleDict()
super().__init__(*args, **kwargs)
[docs]
def __setitem__(self, key: int | tuple | str, value: Any) -> None:
super().__setitem__(key, value)
idx = self._convert_key(key)
net = self._networks[idx]
key_str = str(idx)
if isinstance(net, torch.nn.Module):
self._module_networks[key_str] = net
elif key_str in self._module_networks:
del self._module_networks[key_str]
register_dpmodel_mapping(
SeZMNetworkCollectionDP,
lambda v: SeZMNetworkCollection.deserialize(v.serialize()),
)
@BaseFitting.register("dpa4_ener")
@BaseFitting.register("sezm_ener")
@torch_module
[docs]
class SeZMEnergyFittingNet(SeZMEnergyFittingNetDP):
[docs]
def forward(self, *args: Any, **kwargs: Any) -> Any:
return self.call(*args, **kwargs)