Source code for deepmd_utils.model_format.network

# SPDX-License-Identifier: LGPL-3.0-or-later
"""Native DP model format for multiple backends.

See issue #2982 for more information.
"""
import copy
import itertools
import json
from typing import (
    ClassVar,
    Dict,
    List,
    Optional,
    Union,
)

import h5py
import numpy as np

try:
    from deepmd_utils._version import version as __version__
except ImportError:
    __version__ = "unknown"

from .common import (
    DEFAULT_PRECISION,
    PRECISION_DICT,
    NativeOP,
)


[docs]def traverse_model_dict(model_obj, callback: callable, is_variable: bool = False): """Traverse a model dict and call callback on each variable. Parameters ---------- model_obj : object The model object to traverse. callback : callable The callback function to call on each variable. is_variable : bool, optional Whether the current node is a variable. Returns ------- object The model object after traversing. """ if isinstance(model_obj, dict): for kk, vv in model_obj.items(): model_obj[kk] = traverse_model_dict( vv, callback, is_variable=is_variable or kk == "@variables" ) elif isinstance(model_obj, list): for ii, vv in enumerate(model_obj): model_obj[ii] = traverse_model_dict(vv, callback, is_variable=is_variable) elif is_variable: model_obj = callback(model_obj) return model_obj
[docs]class Counter: """A callable counter. Examples -------- >>> counter = Counter() >>> counter() 0 >>> counter() 1 """ def __init__(self): self.count = -1 def __call__(self): self.count += 1 return self.count
[docs]def save_dp_model(filename: str, model_dict: dict, extra_info: Optional[dict] = None): """Save a DP model to a file in the native format. Parameters ---------- filename : str The filename to save to. model_dict : dict The model dict to save. extra_info : dict, optional Extra meta information to save. """ model_dict = model_dict.copy() variable_counter = Counter() if extra_info is not None: extra_info = extra_info.copy() else: extra_info = {} with h5py.File(filename, "w") as f: model_dict = traverse_model_dict( model_dict, lambda x: f.create_dataset( f"variable_{variable_counter():04d}", data=x ).name, ) save_dict = { "model": model_dict, "software": "deepmd-kit", "version": __version__, **extra_info, } f.attrs["json"] = json.dumps(save_dict, separators=(",", ":"))
[docs]def load_dp_model(filename: str) -> dict: """Load a DP model from a file in the native format. Parameters ---------- filename : str The filename to load from. Returns ------- dict The loaded model dict, including meta information. """ with h5py.File(filename, "r") as f: model_dict = json.loads(f.attrs["json"]) model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy()) return model_dict
[docs]class NativeLayer(NativeOP): """Native representation of a layer. Parameters ---------- w : np.ndarray, optional The weights of the layer. b : np.ndarray, optional The biases of the layer. idt : np.ndarray, optional The identity matrix of the layer. activation_function : str, optional The activation function of the layer. resnet : bool, optional Whether the layer is a residual layer. """ def __init__( self, num_in, num_out, bias: bool = True, use_timestep: bool = False, activation_function: Optional[str] = None, resnet: bool = False, precision: str = DEFAULT_PRECISION, ) -> None: prec = PRECISION_DICT[precision.lower()] self.precision = precision rng = np.random.default_rng() self.w = rng.normal(size=(num_in, num_out)).astype(prec) self.b = rng.normal(size=(num_out,)).astype(prec) if bias else None self.idt = rng.normal(size=(num_out,)).astype(prec) if use_timestep else None self.activation_function = ( activation_function if activation_function is not None else "none" ) self.resnet = resnet self.check_type_consistency() self.check_shape_consistency()
[docs] def serialize(self) -> dict: """Serialize the layer to a dict. Returns ------- dict The serialized layer. """ data = { "w": self.w, "b": self.b, "idt": self.idt, } return { "bias": self.b is not None, "use_timestep": self.idt is not None, "activation_function": self.activation_function, "resnet": self.resnet, "precision": self.precision, "@variables": data, }
[docs] @classmethod def deserialize(cls, data: dict) -> "NativeLayer": """Deserialize the layer from a dict. Parameters ---------- data : dict The dict to deserialize from. """ data = copy.deepcopy(data) variables = data.pop("@variables") assert variables["w"] is not None and len(variables["w"].shape) == 2 num_in, num_out = variables["w"].shape obj = cls( num_in, num_out, **data, ) obj.w, obj.b, obj.idt = ( variables["w"], variables.get("b", None), variables.get("idt", None), ) obj.check_shape_consistency() return obj
[docs] def check_shape_consistency(self): if self.b is not None and self.w.shape[1] != self.b.shape[0]: raise ValueError( f"dim 1 of w {self.w.shape[1]} is not equal to shape " f"of b {self.b.shape[0]}", ) if self.idt is not None and self.w.shape[1] != self.idt.shape[0]: raise ValueError( f"dim 1 of w {self.w.shape[1]} is not equal to shape " f"of idt {self.idt.shape[0]}", )
[docs] def check_type_consistency(self): precision = self.precision def check_var(var): if var is not None: # assertion "float64" == "double" would fail assert PRECISION_DICT[var.dtype.name] is PRECISION_DICT[precision] check_var(self.w) check_var(self.b) check_var(self.idt)
def __setitem__(self, key, value): if key in ("w", "matrix"): self.w = value elif key in ("b", "bias"): self.b = value elif key == "idt": self.idt = value elif key == "activation_function": self.activation_function = value elif key == "resnet": self.resnet = value elif key == "precision": self.precision = value else: raise KeyError(key) def __getitem__(self, key): if key in ("w", "matrix"): return self.w elif key in ("b", "bias"): return self.b elif key == "idt": return self.idt elif key == "activation_function": return self.activation_function elif key == "resnet": return self.resnet elif key == "precision": return self.precision else: raise KeyError(key)
[docs] def dim_in(self) -> int: return self.w.shape[0]
[docs] def dim_out(self) -> int: return self.w.shape[1]
[docs] def call(self, x: np.ndarray) -> np.ndarray: """Forward pass. Parameters ---------- x : np.ndarray The input. Returns ------- np.ndarray The output. """ if self.w is None or self.activation_function is None: raise ValueError("w, b, and activation_function must be set") if self.activation_function == "tanh": fn = np.tanh elif self.activation_function.lower() == "none": def fn(x): return x else: raise NotImplementedError(self.activation_function) y = ( np.matmul(x, self.w) + self.b if self.b is not None else np.matmul(x, self.w) ) y = fn(y) if self.idt is not None: y *= self.idt if self.resnet and self.w.shape[1] == self.w.shape[0]: y += x elif self.resnet and self.w.shape[1] == 2 * self.w.shape[0]: y += np.concatenate([x, x], axis=-1) return y
[docs]def make_multilayer_network(T_NetworkLayer, ModuleBase): class NN(ModuleBase): """Native representation of a neural network. Parameters ---------- layers : list[NativeLayer], optional The layers of the network. """ def __init__(self, layers: Optional[List[dict]] = None) -> None: super().__init__() if layers is None: layers = [] self.layers = [T_NetworkLayer.deserialize(layer) for layer in layers] self.check_shape_consistency() def serialize(self) -> dict: """Serialize the network to a dict. Returns ------- dict The serialized network. """ return {"layers": [layer.serialize() for layer in self.layers]} @classmethod def deserialize(cls, data: dict) -> "NN": """Deserialize the network from a dict. Parameters ---------- data : dict The dict to deserialize from. """ return cls(data["layers"]) def __getitem__(self, key): assert isinstance(key, int) return self.layers[key] def __setitem__(self, key, value): assert isinstance(key, int) self.layers[key] = value def check_shape_consistency(self): for ii in range(len(self.layers) - 1): if self.layers[ii].dim_out() != self.layers[ii + 1].dim_in(): raise ValueError( f"the dim of layer {ii} output {self.layers[ii].dim_out} ", f"does not match the dim of layer {ii+1} ", f"output {self.layers[ii].dim_out}", ) def call(self, x): """Forward pass. Parameters ---------- x : np.ndarray The input. Returns ------- np.ndarray The output. """ for layer in self.layers: x = layer(x) return x return NN
NativeNet = make_multilayer_network(NativeLayer, NativeOP)
[docs]def make_embedding_network(T_Network, T_NetworkLayer): class EN(T_Network): """The embedding network. Parameters ---------- in_dim Input dimension. neuron The number of neurons in each layer. The output dimension is the same as the dimension of the last layer. activation_function The activation function. resnet_dt Use time step at the resnet architecture. precision Floating point precision for the model paramters. """ def __init__( self, in_dim, neuron: List[int] = [24, 48, 96], activation_function: str = "tanh", resnet_dt: bool = False, precision: str = DEFAULT_PRECISION, ): layers = [] i_in = in_dim for idx, ii in enumerate(neuron): i_ot = ii layers.append( T_NetworkLayer( i_in, i_ot, bias=True, use_timestep=resnet_dt, activation_function=activation_function, resnet=True, precision=precision, ).serialize() ) i_in = i_ot super().__init__(layers) self.in_dim = in_dim self.neuron = neuron self.activation_function = activation_function self.resnet_dt = resnet_dt self.precision = precision def serialize(self) -> dict: """Serialize the network to a dict. Returns ------- dict The serialized network. """ return { "in_dim": self.in_dim, "neuron": self.neuron.copy(), "activation_function": self.activation_function, "resnet_dt": self.resnet_dt, "precision": self.precision, "layers": [layer.serialize() for layer in self.layers], } @classmethod def deserialize(cls, data: dict) -> "EmbeddingNet": """Deserialize the network from a dict. Parameters ---------- data : dict The dict to deserialize from. """ data = copy.deepcopy(data) layers = data.pop("layers") obj = cls(**data) super(EN, obj).__init__(layers) return obj return EN
EmbeddingNet = make_embedding_network(NativeNet, NativeLayer)
[docs]def make_fitting_network(T_EmbeddingNet, T_Network, T_NetworkLayer): class FN(T_EmbeddingNet): """The fitting network. It may be implemented as an embedding net connected with a linear output layer. Parameters ---------- in_dim Input dimension. out_dim Output dimension neuron The number of neurons in each hidden layer. activation_function The activation function. resnet_dt Use time step at the resnet architecture. precision Floating point precision for the model paramters. bias_out The last linear layer has bias. """ def __init__( self, in_dim, out_dim, neuron: List[int] = [24, 48, 96], activation_function: str = "tanh", resnet_dt: bool = False, precision: str = DEFAULT_PRECISION, bias_out: bool = True, ): super().__init__( in_dim, neuron=neuron, activation_function=activation_function, resnet_dt=resnet_dt, precision=precision, ) i_in, i_ot = neuron[-1], out_dim self.layers.append( T_NetworkLayer( i_in, i_ot, bias=bias_out, use_timestep=False, activation_function=None, resnet=False, precision=precision, ) ) self.out_dim = out_dim self.bias_out = bias_out def serialize(self) -> dict: """Serialize the network to a dict. Returns ------- dict The serialized network. """ return { "in_dim": self.in_dim, "out_dim": self.out_dim, "neuron": self.neuron.copy(), "activation_function": self.activation_function, "resnet_dt": self.resnet_dt, "precision": self.precision, "bias_out": self.bias_out, "layers": [layer.serialize() for layer in self.layers], } @classmethod def deserialize(cls, data: dict) -> "FittingNet": """Deserialize the network from a dict. Parameters ---------- data : dict The dict to deserialize from. """ data = copy.deepcopy(data) layers = data.pop("layers") obj = cls(**data) T_Network.__init__(obj, layers) return obj return FN
FittingNet = make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)
[docs]class NetworkCollection: """A collection of networks for multiple elements. The number of dimesions for types might be 0, 1, or 2. - 0: embedding or fitting with type embedding, in () - 1: embedding with type_one_side, or fitting, in (type_i) - 2: embedding without type_one_side, in (type_i, type_j) Parameters ---------- ndim : int The number of dimensions. network_type : str, optional The type of the network. networks : dict, optional The networks to initialize with. """ # subclass may override this NETWORK_TYPE_MAP: ClassVar[Dict[str, type]] = { "network": NativeNet, "embedding_network": EmbeddingNet, "fitting_network": FittingNet, } def __init__( self, ndim: int, ntypes: int, network_type: str = "network", networks: List[Union[NativeNet, dict]] = [], ): self.ndim = ndim self.ntypes = ntypes self.network_type = self.NETWORK_TYPE_MAP[network_type] self._networks = [None for ii in range(ntypes**ndim)] for ii, network in enumerate(networks): self[ii] = network if len(networks): self.check_completeness()
[docs] def check_completeness(self): """Check whether the collection is complete. Raises ------ RuntimeError If the collection is incomplete. """ for tt in itertools.product(range(self.ntypes), repeat=self.ndim): if self[tuple(tt)] is None: raise RuntimeError(f"network for {tt} not found")
def _convert_key(self, key): if isinstance(key, int): idx = key else: if isinstance(key, tuple): pass elif isinstance(key, str): key = tuple([int(tt) for tt in key.split("_")[1:]]) else: raise TypeError(key) assert isinstance(key, tuple) assert len(key) == self.ndim idx = sum([tt * self.ntypes**ii for ii, tt in enumerate(key)]) return idx def __getitem__(self, key): return self._networks[self._convert_key(key)] def __setitem__(self, key, value): if isinstance(value, self.network_type): pass elif isinstance(value, dict): value = self.network_type.deserialize(value) else: raise TypeError(value) self._networks[self._convert_key(key)] = value
[docs] def serialize(self) -> dict: """Serialize the networks to a dict. Returns ------- dict The serialized networks. """ network_type_map_inv = {v: k for k, v in self.NETWORK_TYPE_MAP.items()} network_type_name = network_type_map_inv[self.network_type] return { "ndim": self.ndim, "ntypes": self.ntypes, "network_type": network_type_name, "networks": [nn.serialize() for nn in self._networks], }
[docs] @classmethod def deserialize(cls, data: dict) -> "NetworkCollection": """Deserialize the networks from a dict. Parameters ---------- data : dict The dict to deserialize from. """ return cls(**data)