deepmd.jax.model.hlo#

Attributes#

Classes#

HLO

Base class for final exported model that will be directly used for inference.

Module Contents#

deepmd.jax.model.hlo.OUTPUT_DEFS[source]#
class deepmd.jax.model.hlo.HLO(stablehlo, stablehlo_atomic_virial, stablehlo_no_ghost, stablehlo_atomic_virial_no_ghost, model_def_script, type_map, rcut, dim_fparam, dim_aparam, sel_type, is_aparam_nall, model_output_type, mixed_types, min_nbor_dist, sel)[source]#

Bases: deepmd.jax.model.base_model.BaseModel

Base class for final exported model that will be directly used for inference.

The class defines some abstractmethods that will be directly called by the inference interface. If the final model class inherits some of those methods from other classes, BaseModel should be inherited as the last class to ensure the correct method resolution order.

This class is backend-indepedent.

See also

deepmd.dpmodel.model.base_model.BaseModel

BaseModel class for DPModel backend.

_call_lower[source]#
_call_lower_atomic_virial[source]#
_call_lower_no_ghost[source]#
_call_lower_atomic_virial_no_ghost[source]#
stablehlo[source]#
type_map[source]#
rcut[source]#
dim_fparam[source]#
dim_aparam[source]#
sel_type[source]#
_is_aparam_nall[source]#
_model_output_type[source]#
_mixed_types[source]#
min_nbor_dist[source]#

The minimum distance between two atoms. Used for model compression. None when skipping neighbor statistics.

sel[source]#
model_def_script[source]#

The model definition script.

__call__(coord: deepmd.jax.env.jnp.ndarray, atype: deepmd.jax.env.jnp.ndarray, box: deepmd.jax.env.jnp.ndarray | None = None, fparam: deepmd.jax.env.jnp.ndarray | None = None, aparam: deepmd.jax.env.jnp.ndarray | None = None, do_atomic_virial: bool = False) Any[source]#

Return model prediction.

Parameters:
coord

The coordinates of the atoms. shape: nf x (nloc x 3)

atype

The type of atoms. shape: nf x nloc

box

The simulation box. shape: nf x 9

fparam

frame parameter. nf x ndf

aparam

atomic parameter. nf x nloc x nda

do_atomic_virial

If calculate the atomic virial.

Returns:
ret_dict

The result dict of type dict[str,np.ndarray]. The keys are defined by the ModelOutputDef.

call(coord: deepmd.jax.env.jnp.ndarray, atype: deepmd.jax.env.jnp.ndarray, box: deepmd.jax.env.jnp.ndarray | None = None, fparam: deepmd.jax.env.jnp.ndarray | None = None, aparam: deepmd.jax.env.jnp.ndarray | None = None, do_atomic_virial: bool = False)[source]#

Return model prediction.

Parameters:
coord

The coordinates of the atoms. shape: nf x (nloc x 3)

atype

The type of atoms. shape: nf x nloc

box

The simulation box. shape: nf x 9

fparam

frame parameter. nf x ndf

aparam

atomic parameter. nf x nloc x nda

do_atomic_virial

If calculate the atomic virial.

Returns:
ret_dict

The result dict of type dict[str,np.ndarray]. The keys are defined by the ModelOutputDef.

model_output_def()[source]#
call_lower(extended_coord: deepmd.jax.env.jnp.ndarray, extended_atype: deepmd.jax.env.jnp.ndarray, nlist: deepmd.jax.env.jnp.ndarray, mapping: deepmd.jax.env.jnp.ndarray | None = None, fparam: deepmd.jax.env.jnp.ndarray | None = None, aparam: deepmd.jax.env.jnp.ndarray | None = None, do_atomic_virial: bool = False)[source]#
get_type_map() list[str][source]#

Get the type map.

get_rcut()[source]#

Get the cut-off radius.

get_dim_fparam()[source]#

Get the number (dimension) of frame parameters of this atomic model.

get_dim_aparam()[source]#

Get the number (dimension) of atomic parameters of this atomic model.

get_sel_type() list[int][source]#

Get the selected atom types of this model.

Only atoms with selected atom types have atomic contribution to the result of the model. If returning an empty list, all atom types are selected.

is_aparam_nall() bool[source]#

Check whether the shape of atomic parameters is (nframes, nall, ndim).

If False, the shape is (nframes, nloc, ndim).

model_output_type() list[str][source]#

Get the output type for the model.

abstract serialize() dict[source]#

Serialize the model.

Returns:
dict

The serialized data

classmethod deserialize(data: dict) deepmd.jax.model.base_model.BaseModel[source]#
Abstractmethod:

Deserialize the model.

Parameters:
datadict

The serialized data

Returns:
BaseModel

The deserialized model

get_model_def_script() str[source]#

Get the model definition script.

get_min_nbor_dist() float | None[source]#

Get the minimum distance between two atoms.

get_nnei() int[source]#

Returns the total number of selected neighboring atoms in the cut-off radius.

get_sel() list[int][source]#
get_nsel() int[source]#

Returns the total number of selected neighboring atoms in the cut-off radius.

mixed_types() bool[source]#
classmethod update_sel(train_data: deepmd.utils.data_system.DeepmdDataSystem, type_map: list[str] | None, local_jdata: dict) tuple[dict, float | None][source]#
Abstractmethod:

Update the selection and perform neighbor statistics.

Parameters:
train_dataDeepmdDataSystem

data used to do neighbor statictics

type_maplist[str], optional

The name of each type of atoms

local_jdatadict

The local data refer to the current class

Returns:
dict

The updated local data

float

The minimum distance between two atoms

classmethod get_model(model_params: dict) deepmd.jax.model.base_model.BaseModel[source]#
Abstractmethod:

Get the model by the parameters.

By default, all the parameters are directly passed to the constructor. If not, override this method.

Parameters:
model_paramsdict

The model parameters

Returns:
BaseBaseModel

The model