Source code for deepmd.dpmodel.model.transform_output

# SPDX-License-Identifier: LGPL-3.0-or-later

import array_api_compat
import numpy as np

from deepmd.dpmodel.common import (
    GLOBAL_ENER_FLOAT_PRECISION,
)
from deepmd.dpmodel.output_def import (
    FittingOutputDef,
    ModelOutputDef,
    OutputVariableDef,
    get_deriv_name,
    get_reduce_name,
)


[docs] def fit_output_to_model_output( fit_ret: dict[str, np.ndarray], fit_output_def: FittingOutputDef, coord_ext: np.ndarray, do_atomic_virial: bool = False, ) -> dict[str, np.ndarray]: """Transform the output of the fitting network to the model output. """ xp = array_api_compat.get_namespace(coord_ext) model_ret = dict(fit_ret.items()) for kk, vv in fit_ret.items(): vdef = fit_output_def[kk] shap = vdef.shape atom_axis = -(len(shap) + 1) if vdef.reducible: kk_redu = get_reduce_name(kk) # cast to energy prec before reduction model_ret[kk_redu] = xp.sum( vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis ) if vdef.r_differentiable: kk_derv_r, kk_derv_c = get_deriv_name(kk) # name-holders model_ret[kk_derv_r] = None if vdef.c_differentiable: assert vdef.r_differentiable kk_derv_r, kk_derv_c = get_deriv_name(kk) model_ret[kk_derv_c] = None return model_ret
[docs] def get_leading_dims( vv: np.ndarray, vdef: OutputVariableDef, ): """Get the dimensions of nf x nloc. Parameters ---------- vv : np.ndarray The input array from which to compute the leading dimensions. vdef : OutputVariableDef The output variable definition containing the shape to exclude from `vv`. Returns ------- list A list of leading dimensions of `vv`, excluding the last `len(vdef.shape)` dimensions. """ vshape = vv.shape return list(vshape[: (len(vshape) - len(vdef.shape))])
[docs] def communicate_extended_output( model_ret: dict[str, np.ndarray], model_output_def: ModelOutputDef, mapping: np.ndarray, # nf x nloc do_atomic_virial: bool = False, ) -> dict[str, np.ndarray]: """Transform the output of the model network defined on local and ghost (extended) atoms to local atoms. """ xp = array_api_compat.get_namespace(mapping) new_ret = {} for kk in model_output_def.keys_outp(): vv = model_ret[kk] vdef = model_output_def[kk] new_ret[kk] = vv if vdef.reducible: kk_redu = get_reduce_name(kk) new_ret[kk_redu] = model_ret[kk_redu] kk_derv_r, kk_derv_c = get_deriv_name(kk) mldims = list(mapping.shape) vldims = get_leading_dims(vv, vdef) if vdef.r_differentiable: if model_ret[kk_derv_r] is not None: derv_r_ext_dims = list(vdef.shape) + [3] # noqa:RUF005 mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims))) mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims) force = xp.zeros(vldims + derv_r_ext_dims, dtype=vv.dtype) # jax only if array_api_compat.is_jax_array(force): from deepmd.jax.common import ( scatter_sum, ) force = scatter_sum( force, 1, mapping, model_ret[kk_derv_r], ) else: raise NotImplementedError("Only JAX arrays are supported.") new_ret[kk_derv_r] = force else: # name holders new_ret[kk_derv_r] = None if vdef.c_differentiable: assert vdef.r_differentiable if model_ret[kk_derv_c] is not None: derv_c_ext_dims = list(vdef.shape) + [9] # noqa:RUF005 mapping = xp.tile( mapping, [1] * (len(mldims) + len(vdef.shape)) + [3] ) virial = xp.zeros( vldims + derv_c_ext_dims, dtype=vv.dtype, ) # jax only if array_api_compat.is_jax_array(virial): from deepmd.jax.common import ( scatter_sum, ) virial = scatter_sum( virial, 1, mapping, model_ret[kk_derv_c], ) else: raise NotImplementedError("Only JAX arrays are supported.") new_ret[kk_derv_c] = virial new_ret[kk_derv_c + "_redu"] = xp.sum(new_ret[kk_derv_c], axis=1) else: new_ret[kk_derv_c] = None new_ret[kk_derv_c + "_redu"] = None if not do_atomic_virial: # pop atomic virial, because it is not correctly calculated. new_ret.pop(kk_derv_c) return new_ret