deepmd.jax.jax2tf.transform_output#

Functions#

get_leading_dims(→ tensorflow.experimental.numpy.ndarray)

Get the dimensions of nf x nloc.

communicate_extended_output(→ dict[str, ...)

Transform the output of the model network defined on

Module Contents#

deepmd.jax.jax2tf.transform_output.get_leading_dims(vv: tensorflow.experimental.numpy.ndarray, vdef: deepmd.dpmodel.output_def.OutputVariableDef) tensorflow.experimental.numpy.ndarray[source]#

Get the dimensions of nf x nloc.

Parameters:
vvnp.ndarray

The input array from which to compute the leading dimensions.

vdefOutputVariableDef

The output variable definition containing the shape to exclude from vv.

Returns:
list

A list of leading dimensions of vv, excluding the last len(vdef.shape) dimensions.

deepmd.jax.jax2tf.transform_output.communicate_extended_output(model_ret: dict[str, tensorflow.experimental.numpy.ndarray], model_output_def: deepmd.dpmodel.output_def.ModelOutputDef, mapping: tensorflow.experimental.numpy.ndarray, do_atomic_virial: bool = False) dict[str, tensorflow.experimental.numpy.ndarray][source]#

Transform the output of the model network defined on local and ghost (extended) atoms to local atoms.