Source code for deepmd_gnn.autograd

"""Autograd helpers shared by GNN model wrappers."""

from __future__ import annotations

import torch


@torch.jit.script
[docs] def derive_atomic_virial_from_displacement( atom_energy: torch.Tensor, displacement: torch.Tensor, nloc: int, create_graph: bool, ) -> torch.Tensor: """Derive per-atom virials from atom-energy gradients w.r.t. cell strain.""" nf = atom_energy.shape[0] atomic_virial = torch.zeros( (nf, nloc, 9), dtype=displacement.dtype, device=displacement.device, ) for ii in range(nloc): atom_energy_ii = atom_energy[:, ii] atom_grad_outputs = torch.jit.annotate( list[torch.Tensor | None], [torch.ones_like(atom_energy_ii)], ) atom_virial_ii = torch.autograd.grad( outputs=[atom_energy_ii], inputs=[displacement], grad_outputs=atom_grad_outputs, retain_graph=True, create_graph=create_graph, allow_unused=True, )[0] if atom_virial_ii is None: atom_virial_ii = torch.zeros_like(displacement) atomic_virial[:, ii, :] = (-atom_virial_ii).view(nf, 9) return atomic_virial