Source code for deepmd.dpmodel.model.transform_output
# SPDX-License-Identifier: LGPL-3.0-or-later
import array_api_compat
import numpy as np
from deepmd.dpmodel.common import (
GLOBAL_ENER_FLOAT_PRECISION,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
ModelOutputDef,
OutputVariableDef,
get_deriv_name,
get_reduce_name,
)
[docs]
def fit_output_to_model_output(
fit_ret: dict[str, np.ndarray],
fit_output_def: FittingOutputDef,
coord_ext: np.ndarray,
do_atomic_virial: bool = False,
) -> dict[str, np.ndarray]:
"""Transform the output of the fitting network to
the model output.
"""
xp = array_api_compat.get_namespace(coord_ext)
model_ret = dict(fit_ret.items())
for kk, vv in fit_ret.items():
vdef = fit_output_def[kk]
shap = vdef.shape
atom_axis = -(len(shap) + 1)
if vdef.reducible:
kk_redu = get_reduce_name(kk)
# cast to energy prec before reduction
model_ret[kk_redu] = xp.sum(
vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis
)
if vdef.r_differentiable:
kk_derv_r, kk_derv_c = get_deriv_name(kk)
# name-holders
model_ret[kk_derv_r] = None
if vdef.c_differentiable:
assert vdef.r_differentiable
kk_derv_r, kk_derv_c = get_deriv_name(kk)
model_ret[kk_derv_c] = None
return model_ret
[docs]
def get_leading_dims(
vv: np.ndarray,
vdef: OutputVariableDef,
):
"""Get the dimensions of nf x nloc.
Parameters
----------
vv : np.ndarray
The input array from which to compute the leading dimensions.
vdef : OutputVariableDef
The output variable definition containing the shape to exclude from `vv`.
Returns
-------
list
A list of leading dimensions of `vv`, excluding the last `len(vdef.shape)` dimensions.
"""
vshape = vv.shape
return list(vshape[: (len(vshape) - len(vdef.shape))])
[docs]
def communicate_extended_output(
model_ret: dict[str, np.ndarray],
model_output_def: ModelOutputDef,
mapping: np.ndarray, # nf x nloc
do_atomic_virial: bool = False,
) -> dict[str, np.ndarray]:
"""Transform the output of the model network defined on
local and ghost (extended) atoms to local atoms.
"""
xp = array_api_compat.get_namespace(mapping)
new_ret = {}
for kk in model_output_def.keys_outp():
vv = model_ret[kk]
vdef = model_output_def[kk]
new_ret[kk] = vv
if vdef.reducible:
kk_redu = get_reduce_name(kk)
new_ret[kk_redu] = model_ret[kk_redu]
kk_derv_r, kk_derv_c = get_deriv_name(kk)
mldims = list(mapping.shape)
vldims = get_leading_dims(vv, vdef)
if vdef.r_differentiable:
if model_ret[kk_derv_r] is not None:
derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005
mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims)))
mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims)
force = xp.zeros(vldims + derv_r_ext_dims, dtype=vv.dtype)
# jax only
if array_api_compat.is_jax_array(force):
from deepmd.jax.common import (
scatter_sum,
)
force = scatter_sum(
force,
1,
mapping,
model_ret[kk_derv_r],
)
else:
raise NotImplementedError("Only JAX arrays are supported.")
new_ret[kk_derv_r] = force
else:
# name holders
new_ret[kk_derv_r] = None
if vdef.c_differentiable:
assert vdef.r_differentiable
if model_ret[kk_derv_c] is not None:
derv_c_ext_dims = list(vdef.shape) + [9] # noqa:RUF005
mapping = xp.tile(
mapping, [1] * (len(mldims) + len(vdef.shape)) + [3]
)
virial = xp.zeros(
vldims + derv_c_ext_dims,
dtype=vv.dtype,
)
# jax only
if array_api_compat.is_jax_array(virial):
from deepmd.jax.common import (
scatter_sum,
)
virial = scatter_sum(
virial,
1,
mapping,
model_ret[kk_derv_c],
)
else:
raise NotImplementedError("Only JAX arrays are supported.")
new_ret[kk_derv_c] = virial
new_ret[kk_derv_c + "_redu"] = xp.sum(new_ret[kk_derv_c], axis=1)
else:
new_ret[kk_derv_c] = None
new_ret[kk_derv_c + "_redu"] = None
if not do_atomic_virial:
# pop atomic virial, because it is not correctly calculated.
new_ret.pop(kk_derv_c)
return new_ret