Source code for deepmd.jax.utils.network
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
ClassVar,
)
from deepmd.dpmodel.common import (
NativeOP,
)
from deepmd.dpmodel.utils.network import LayerNorm as LayerNormDP
from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP
from deepmd.dpmodel.utils.network import NetworkCollection as NetworkCollectionDP
from deepmd.dpmodel.utils.network import (
make_embedding_network,
make_fitting_network,
make_multilayer_network,
)
from deepmd.jax.common import (
flax_module,
to_jax_array,
)
from deepmd.jax.env import (
nnx,
)
[docs]
class ArrayAPIParam(nnx.Param):
[docs]
def __array__(self, *args, **kwargs):
return self.value.__array__(*args, **kwargs)
[docs]
def __array_namespace__(self, *args, **kwargs):
return self.value.__array_namespace__(*args, **kwargs)
[docs]
def __dlpack__(self, *args, **kwargs):
return self.value.__dlpack__(*args, **kwargs)
[docs]
def __dlpack_device__(self, *args, **kwargs):
return self.value.__dlpack_device__(*args, **kwargs)
@flax_module
[docs]
class NativeLayer(NativeLayerDP):
[docs]
def __setattr__(self, name: str, value: Any) -> None:
if name in {"w", "b", "idt"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIParam(value)
return super().__setattr__(name, value)
@flax_module
[docs]
class NativeNet(make_multilayer_network(NativeLayer, NativeOP)):
pass
[docs]
class EmbeddingNet(make_embedding_network(NativeNet, NativeLayer)):
pass
[docs]
class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)):
pass
@flax_module
[docs]
class NetworkCollection(NetworkCollectionDP):
[docs]
NETWORK_TYPE_MAP: ClassVar[dict[str, type]] = {
"network": NativeNet,
"embedding_network": EmbeddingNet,
"fitting_network": FittingNet,
}
[docs]
class LayerNorm(LayerNormDP, NativeLayer):
pass