Source code for deepmd.jax.model.dp_model

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    Any,
    Optional,
)

from deepmd.dpmodel.model import (
    DPModelCommon,
)
from deepmd.jax.atomic_model.dp_atomic_model import (
    DPAtomicModel,
)
from deepmd.jax.common import (
    flax_module,
)
from deepmd.jax.env import (
    jax,
    jnp,
)
from deepmd.jax.model.base_model import (
    forward_common_atomic,
)


[docs] def make_jax_dp_model_from_dpmodel( dpmodel_model: type[DPModelCommon], jax_atomicmodel: type[DPAtomicModel] ) -> type[DPModelCommon]: """Make a JAX backend DP model from a DPModel backend DP model. Parameters ---------- dpmodel_model : type[DPModelCommon] The DPModel backend DP model. jax_atomicmodel : type[DPAtomicModel] The JAX backend DP atomic model. Returns ------- type[DPModelCommon] The JAX backend DP model. """ @flax_module class jax_model(dpmodel_model): def __setattr__(self, name: str, value: Any) -> None: if name == "atomic_model": value = jax_atomicmodel.deserialize(value.serialize()) return super().__setattr__(name, value) def forward_common_atomic( self, extended_coord: jnp.ndarray, extended_atype: jnp.ndarray, nlist: jnp.ndarray, mapping: Optional[jnp.ndarray] = None, fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, ): return forward_common_atomic( self, extended_coord, extended_atype, nlist, mapping=mapping, fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, ) def format_nlist( self, extended_coord: jnp.ndarray, extended_atype: jnp.ndarray, nlist: jnp.ndarray, extra_nlist_sort: bool = False, ): return dpmodel_model.format_nlist( self, jax.lax.stop_gradient(extended_coord), extended_atype, nlist, extra_nlist_sort=extra_nlist_sort, ) return jax_model