Source code for deepmd.jax.model.dp_zbl_model

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    Any,
)

from deepmd.dpmodel.model.dp_zbl_model import DPZBLModel as DPZBLModelDP
from deepmd.jax.atomic_model.linear_atomic_model import (
    DPZBLLinearEnergyAtomicModel,
)
from deepmd.jax.common import (
    flax_module,
)
from deepmd.jax.env import (
    jax,
    jnp,
)
from deepmd.jax.model.base_model import (
    BaseModel,
    forward_common_atomic,
)


@BaseModel.register("zbl")
@flax_module
[docs] class DPZBLModel(DPZBLModelDP):
[docs] def __setattr__(self, name: str, value: Any) -> None: if name == "atomic_model": value = DPZBLLinearEnergyAtomicModel.deserialize(value.serialize()) return super().__setattr__(name, value)
[docs] def forward_common_atomic( self, extended_coord: jnp.ndarray, extended_atype: jnp.ndarray, nlist: jnp.ndarray, mapping: jnp.ndarray | None = None, fparam: jnp.ndarray | None = None, aparam: jnp.ndarray | None = None, do_atomic_virial: bool = False, extended_coord_corr: jnp.ndarray | None = None, comm_dict: dict | None = None, charge_spin: jnp.ndarray | None = None, ) -> dict[str, jnp.ndarray]: del comm_dict # JAX path has no MPI ghost exchange return forward_common_atomic( self, extended_coord, extended_atype, nlist, mapping=mapping, fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, extended_coord_corr=extended_coord_corr, charge_spin=charge_spin, )
[docs] def format_nlist( self, extended_coord: jnp.ndarray, extended_atype: jnp.ndarray, nlist: jnp.ndarray, extra_nlist_sort: bool = False, ) -> jnp.ndarray: return DPZBLModelDP.format_nlist( self, jax.lax.stop_gradient(extended_coord), extended_atype, nlist, extra_nlist_sort=extra_nlist_sort, )