Source code for deepmd.jax.utils.network

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    Any,
    ClassVar,
)

from deepmd.dpmodel.common import (
    NativeOP,
)
from deepmd.dpmodel.utils.network import LayerNorm as LayerNormDP
from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP
from deepmd.dpmodel.utils.network import NetworkCollection as NetworkCollectionDP
from deepmd.dpmodel.utils.network import (
    make_embedding_network,
    make_fitting_network,
    make_multilayer_network,
)
from deepmd.jax.common import (
    flax_module,
    to_jax_array,
)
from deepmd.jax.env import (
    nnx,
)


[docs] class ArrayAPIParam(nnx.Param):
[docs] def __array__(self, *args, **kwargs): return self.value.__array__(*args, **kwargs)
[docs] def __array_namespace__(self, *args, **kwargs): return self.value.__array_namespace__(*args, **kwargs)
[docs] def __dlpack__(self, *args, **kwargs): return self.value.__dlpack__(*args, **kwargs)
[docs] def __dlpack_device__(self, *args, **kwargs): return self.value.__dlpack_device__(*args, **kwargs)
@flax_module
[docs] class NativeLayer(NativeLayerDP):
[docs] def __setattr__(self, name: str, value: Any) -> None: if name in {"w", "b", "idt"}: value = to_jax_array(value) if value is not None: value = ArrayAPIParam(value) return super().__setattr__(name, value)
@flax_module
[docs] class NativeNet(make_multilayer_network(NativeLayer, NativeOP)): pass
[docs] class EmbeddingNet(make_embedding_network(NativeNet, NativeLayer)): pass
[docs] class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)): pass
@flax_module
[docs] class NetworkCollection(NetworkCollectionDP):
[docs] NETWORK_TYPE_MAP: ClassVar[dict[str, type]] = { "network": NativeNet, "embedding_network": EmbeddingNet, "fitting_network": FittingNet, }
[docs] class LayerNorm(LayerNormDP, NativeLayer): pass