deepmd.jax.jax2tf.make_model

deepmd.jax.jax2tf.make_model#

Functions#

model_call_from_call_lower(*, call_lower, rcut, sel, ...)

Return model prediction from lower interface.

Module Contents#

deepmd.jax.jax2tf.make_model.model_call_from_call_lower(*, call_lower: Callable[[tensorflow.experimental.numpy.ndarray, tensorflow.experimental.numpy.ndarray, tensorflow.experimental.numpy.ndarray, tensorflow.experimental.numpy.ndarray, tensorflow.experimental.numpy.ndarray, bool], dict[str, tensorflow.experimental.numpy.ndarray]], rcut: float, sel: list[int], mixed_types: bool, model_output_def: deepmd.dpmodel.output_def.ModelOutputDef, coord: tensorflow.experimental.numpy.ndarray, atype: tensorflow.experimental.numpy.ndarray, box: tensorflow.experimental.numpy.ndarray, fparam: tensorflow.experimental.numpy.ndarray, aparam: tensorflow.experimental.numpy.ndarray, do_atomic_virial: bool = False)[source]#

Return model prediction from lower interface.

Parameters:
coord

The coordinates of the atoms. shape: nf x (nloc x 3)

atype

The type of atoms. shape: nf x nloc

box

The simulation box. shape: nf x 9

fparam

frame parameter. nf x ndf

aparam

atomic parameter. nf x nloc x nda

do_atomic_virial

If calculate the atomic virial.

Returns:
ret_dict

The result dict of type dict[str,tnp.ndarray]. The keys are defined by the ModelOutputDef.