Source code for deepmd.pt_expt.model.transform_output

# SPDX-License-Identifier: LGPL-3.0-or-later

import torch

from deepmd.dpmodel import (
    FittingOutputDef,
    OutputVariableDef,
    get_deriv_name,
    get_reduce_name,
)
from deepmd.pt_expt.utils import (
    env,
)


[docs] def atomic_virial_corr( extended_coord: torch.Tensor, atom_energy: torch.Tensor, ) -> torch.Tensor: nloc = atom_energy.shape[1] indices = torch.arange(nloc, dtype=torch.int64, device=extended_coord.device) coord = torch.index_select(extended_coord, 1, indices) # no derivative with respect to the loc coord. coord = coord.detach() ce = coord * atom_energy sumce = torch.sum(ce, dim=1) # [nf, 3] # Explicitly loop over the 3 spatial components instead of vmap, # so that make_fx(symbolic) and torch.export can trace through. results = [] for i in range(3): grad_output = torch.zeros_like(sumce) grad_output[:, i] = 1.0 result = torch.autograd.grad( [sumce], [extended_coord], grad_outputs=[grad_output], create_graph=False, retain_graph=True, )[0] assert result is not None results.append(result) # [3, nf, nall, 3] -> [nf, nall, 3, 3] extended_virial_corr = torch.stack(results, dim=0) return extended_virial_corr.permute(1, 2, 3, 0)
[docs] def task_deriv_one( atom_energy: torch.Tensor, energy: torch.Tensor, extended_coord: torch.Tensor, do_virial: bool = True, do_atomic_virial: bool = False, create_graph: bool = True, ) -> tuple[torch.Tensor, torch.Tensor | None]: faked_grad = torch.ones_like(energy) lst: list[torch.Tensor | None] = [faked_grad] extended_force = torch.autograd.grad( [energy], [extended_coord], grad_outputs=lst, create_graph=create_graph, retain_graph=True, )[0] assert extended_force is not None extended_force = -extended_force if do_virial: extended_virial = torch.einsum( "...ik,...ij->...ikj", extended_force, extended_coord ) # the correction sums to zero, which does not contribute to global virial if do_atomic_virial: extended_virial_corr = atomic_virial_corr(extended_coord, atom_energy) extended_virial = extended_virial + extended_virial_corr # to [...,3,3] -> [...,9] extended_virial = extended_virial.view(list(extended_virial.shape[:-2]) + [9]) # noqa:RUF005 else: extended_virial = None return extended_force, extended_virial
[docs] def get_leading_dims( vv: torch.Tensor, vdef: OutputVariableDef, ) -> list[int]: """Get the dimensions of nf x nloc.""" vshape = vv.shape return list(vshape[: (len(vshape) - len(vdef.shape))])
[docs] def take_deriv( vv: torch.Tensor, svv: torch.Tensor, vdef: OutputVariableDef, coord_ext: torch.Tensor, do_virial: bool = False, do_atomic_virial: bool = False, create_graph: bool = True, ) -> tuple[torch.Tensor, torch.Tensor | None]: size = 1 for ii in vdef.shape: size *= ii vv1 = vv.view(list(get_leading_dims(vv, vdef)) + [size]) # noqa: RUF005 svv1 = svv.view(list(get_leading_dims(svv, vdef)) + [size]) # noqa: RUF005 split_vv1 = torch.split(vv1, [1] * size, dim=-1) split_svv1 = torch.split(svv1, [1] * size, dim=-1) split_ff, split_avir = [], [] for vvi, svvi in zip(split_vv1, split_svv1): # nf x nloc x 3, nf x nloc x 9 ffi, aviri = task_deriv_one( vvi, svvi, coord_ext, do_virial=do_virial, do_atomic_virial=do_atomic_virial, create_graph=create_graph, ) # nf x nloc x 1 x 3, nf x nloc x 1 x 9 ffi = ffi.unsqueeze(-2) split_ff.append(ffi) if do_virial: assert aviri is not None aviri = aviri.unsqueeze(-2) split_avir.append(aviri) # nf x nall x v_dim x 3, nf x nall x v_dim x 9 out_lead_shape = list(coord_ext.shape[:-1]) + vdef.shape ff = torch.concat(split_ff, dim=-2).view(out_lead_shape + [3]) # noqa: RUF005 if do_virial: avir = torch.concat(split_avir, dim=-2).view(out_lead_shape + [9]) # noqa: RUF005 else: avir = None return ff, avir
[docs] def fit_output_to_model_output( fit_ret: dict[str, torch.Tensor], fit_output_def: FittingOutputDef, coord_ext: torch.Tensor, do_atomic_virial: bool = False, create_graph: bool = True, mask: torch.Tensor | None = None, extended_coord_corr: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: """Transform the output of the fitting network to the model output. """ redu_prec = env.GLOBAL_PT_ENER_FLOAT_PRECISION model_ret = dict(fit_ret.items()) for kk, vv in fit_ret.items(): vdef = fit_output_def[kk] shap = vdef.shape atom_axis = -(len(shap) + 1) if vdef.reducible: kk_redu = get_reduce_name(kk) if vdef.intensive: if mask is not None: model_ret[kk_redu] = torch.sum( vv.to(redu_prec), dim=atom_axis ) / torch.sum(mask, dim=-1, keepdim=True) else: model_ret[kk_redu] = torch.mean(vv.to(redu_prec), dim=atom_axis) else: model_ret[kk_redu] = torch.sum(vv.to(redu_prec), dim=atom_axis) if vdef.r_differentiable: kk_derv_r, kk_derv_c = get_deriv_name(kk) dr, dc = take_deriv( vv, model_ret[kk_redu], vdef, coord_ext, do_virial=vdef.c_differentiable, do_atomic_virial=do_atomic_virial, create_graph=create_graph, ) model_ret[kk_derv_r] = dr if vdef.c_differentiable: assert dc is not None if extended_coord_corr is not None: dc_corr = ( dr.squeeze(-2).unsqueeze(-1) @ extended_coord_corr.unsqueeze(-2).to(dr.dtype) ).view(list(dc.shape[:-2]) + [1, 9]) # noqa: RUF005 dc = dc + dc_corr model_ret[kk_derv_c] = dc model_ret[kk_derv_c + "_redu"] = torch.sum( model_ret[kk_derv_c].to(redu_prec), dim=1 ) return model_ret