deepmd.jax.model.hlo#
Attributes#
Classes#
Base class for final exported model that will be directly used for inference. |
Module Contents#
- class deepmd.jax.model.hlo.HLO(stablehlo, stablehlo_atomic_virial, stablehlo_no_ghost, stablehlo_atomic_virial_no_ghost, model_def_script, type_map, rcut, dim_fparam, dim_aparam, sel_type, is_aparam_nall, model_output_type, mixed_types, min_nbor_dist, sel)[source]#
Bases:
deepmd.jax.model.base_model.BaseModel
Base class for final exported model that will be directly used for inference.
The class defines some abstractmethods that will be directly called by the inference interface. If the final model class inherits some of those methods from other classes, BaseModel should be inherited as the last class to ensure the correct method resolution order.
This class is backend-indepedent.
See also
deepmd.dpmodel.model.base_model.BaseModel
BaseModel class for DPModel backend.
- min_nbor_dist[source]#
The minimum distance between two atoms. Used for model compression. None when skipping neighbor statistics.
- __call__(coord: deepmd.jax.env.jnp.ndarray, atype: deepmd.jax.env.jnp.ndarray, box: deepmd.jax.env.jnp.ndarray | None = None, fparam: deepmd.jax.env.jnp.ndarray | None = None, aparam: deepmd.jax.env.jnp.ndarray | None = None, do_atomic_virial: bool = False) Any [source]#
Return model prediction.
- Parameters:
- coord
The coordinates of the atoms. shape: nf x (nloc x 3)
- atype
The type of atoms. shape: nf x nloc
- box
The simulation box. shape: nf x 9
- fparam
frame parameter. nf x ndf
- aparam
atomic parameter. nf x nloc x nda
- do_atomic_virial
If calculate the atomic virial.
- Returns:
ret_dict
The result dict of type dict[str,np.ndarray]. The keys are defined by the ModelOutputDef.
- call(coord: deepmd.jax.env.jnp.ndarray, atype: deepmd.jax.env.jnp.ndarray, box: deepmd.jax.env.jnp.ndarray | None = None, fparam: deepmd.jax.env.jnp.ndarray | None = None, aparam: deepmd.jax.env.jnp.ndarray | None = None, do_atomic_virial: bool = False)[source]#
Return model prediction.
- Parameters:
- coord
The coordinates of the atoms. shape: nf x (nloc x 3)
- atype
The type of atoms. shape: nf x nloc
- box
The simulation box. shape: nf x 9
- fparam
frame parameter. nf x ndf
- aparam
atomic parameter. nf x nloc x nda
- do_atomic_virial
If calculate the atomic virial.
- Returns:
ret_dict
The result dict of type dict[str,np.ndarray]. The keys are defined by the ModelOutputDef.
- call_lower(extended_coord: deepmd.jax.env.jnp.ndarray, extended_atype: deepmd.jax.env.jnp.ndarray, nlist: deepmd.jax.env.jnp.ndarray, mapping: deepmd.jax.env.jnp.ndarray | None = None, fparam: deepmd.jax.env.jnp.ndarray | None = None, aparam: deepmd.jax.env.jnp.ndarray | None = None, do_atomic_virial: bool = False)[source]#
- get_sel_type() list[int] [source]#
Get the selected atom types of this model.
Only atoms with selected atom types have atomic contribution to the result of the model. If returning an empty list, all atom types are selected.
- is_aparam_nall() bool [source]#
Check whether the shape of atomic parameters is (nframes, nall, ndim).
If False, the shape is (nframes, nloc, ndim).
- classmethod deserialize(data: dict) deepmd.jax.model.base_model.BaseModel [source]#
- Abstractmethod:
Deserialize the model.
- Parameters:
- data
dict
The serialized data
- data
- Returns:
BaseModel
The deserialized model
- get_nnei() int [source]#
Returns the total number of selected neighboring atoms in the cut-off radius.
- get_nsel() int [source]#
Returns the total number of selected neighboring atoms in the cut-off radius.
- classmethod update_sel(train_data: deepmd.utils.data_system.DeepmdDataSystem, type_map: list[str] | None, local_jdata: dict) tuple[dict, float | None] [source]#
- Abstractmethod:
Update the selection and perform neighbor statistics.
- classmethod get_model(model_params: dict) deepmd.jax.model.base_model.BaseModel [source]#
- Abstractmethod:
Get the model by the parameters.
By default, all the parameters are directly passed to the constructor. If not, override this method.
- Parameters:
- model_params
dict
The model parameters
- model_params
- Returns:
BaseBaseModel
The model