Source code for deepmd.pt.model.descriptor.sezm_nn.projection

# SPDX-License-Identifier: LGPL-3.0-or-later
"""
Grid projection helpers for SeZM function-space nonlinearities.

The projectors in this module only handle basis transforms.  They do not apply
channel mixing or nonlinearities.  A projector maps coefficient tensors to a
fixed quadrature grid, and maps grid fields back to coefficients with the
matching quadrature rule.
"""

from __future__ import (
    annotations,
)

import math
from typing import (
    Any,
)

import torch
import torch.nn as nn
from e3nn.o3 import (
    FromS2Grid,
    ToS2Grid,
    spherical_harmonics,
)

from deepmd.pt.utils.env import (
    PRECISION_DICT,
    RESERVED_PRECISION_DICT,
)
from deepmd.utils.version import (
    check_version_compatibility,
)

from .indexing import (
    build_l_major_index,
    build_m_major_index,
    so3_packed_index,
)
from .lebedev import (
    LEBEDEV_PRECISION_TO_NPOINTS,
    load_lebedev_rule,
)
from .wignerd import (
    WignerDCalculator,
    build_edge_quaternion,
    quaternion_multiply,
    quaternion_z_rotation,
)


[docs] class BaseGridProjector(nn.Module): """ Base class for fixed coefficient-to-grid projection matrices. Subclasses build ``to_grid_mat`` with shape ``(G, J)`` and ``from_grid_mat`` with shape ``(J, G)``, where ``G`` is the number of grid samples and ``J`` is the flattened coefficient axis consumed by the grid net. For ordinary S2 projections, ``J`` is the SO(3) feature coefficient axis: ``D = (lmax + 1)^2`` in packed layout, or the retained ``D_m`` axis in m-major layout. For SO(3) frame projections, ``J = D * n_frames`` with frame index packed inside each coefficient row. """ def __init__( self, *, lmax: int, mmax: int | None, dtype: torch.dtype, n_frames: int, coefficient_layout: str, ) -> None: super().__init__()
[docs] self.lmax = int(lmax)
[docs] self.mmax = int(self.lmax if mmax is None else mmax)
if self.mmax < 0: raise ValueError("`mmax` must be non-negative") if self.mmax > self.lmax: raise ValueError("`mmax` must be <= `lmax`")
[docs] self.coefficient_layout = str(coefficient_layout).lower()
if self.coefficient_layout not in {"packed", "m_major"}: raise ValueError( "`coefficient_layout` must be either 'packed' or 'm_major'" )
[docs] self.dtype = dtype
[docs] self.device = torch.device("cpu")
[docs] self.precision = RESERVED_PRECISION_DICT[dtype]
[docs] self.n_frames = int(n_frames)
[docs] self.packed_dim = int((self.lmax + 1) ** 2)
coeff_index = self._build_coefficient_index(device=torch.device("cpu")) to_grid_mat, from_grid_mat = self._build_projection_mats(coeff_index)
[docs] self.coeff_dim = int(to_grid_mat.shape[1])
[docs] self.grid_size = int(to_grid_mat.shape[0])
if self.coeff_dim != int(from_grid_mat.shape[0]): raise ValueError("Projection matrix coefficient axes `J` do not match") if self.grid_size != int(from_grid_mat.shape[1]): raise ValueError("Projection matrix grid axes `G` do not match") self.register_buffer( "to_grid_mat", to_grid_mat.to(device=self.device, dtype=self.dtype), persistent=False, ) self.register_buffer( "from_grid_mat", from_grid_mat.to(device=self.device, dtype=self.dtype), persistent=False, )
[docs] def to_grid(self, embedding: torch.Tensor) -> torch.Tensor: """Project flattened coefficients ``(N, J, C)`` to grid fields ``(N, G, C)``.""" to_grid_mat = self.to_grid_mat.to( device=embedding.device, dtype=embedding.dtype, ) return torch.einsum("gj,njc->ngc", to_grid_mat, embedding)
[docs] def from_grid(self, grid: torch.Tensor) -> torch.Tensor: """Project grid fields ``(N, G, C)`` back to flattened coefficients ``(N, J, C)``.""" from_grid_mat = self.from_grid_mat.to( device=grid.device, dtype=grid.dtype, ) return torch.einsum("jg,ngc->njc", from_grid_mat, grid)
[docs] def _build_coefficient_index(self, device: torch.device) -> torch.Tensor: """Build the coefficient subset consumed by the projector matrices.""" if self.coefficient_layout == "m_major": return build_m_major_index(self.lmax, self.mmax, device=device) if self.mmax == self.lmax: return torch.arange((self.lmax + 1) ** 2, device=device, dtype=torch.long) return build_l_major_index(self.lmax, self.mmax, device=device)
[docs] def _build_projection_mats( self, coeff_index: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Build ``to_grid_mat (G, J)`` and ``from_grid_mat (J, G)``.""" raise NotImplementedError
[docs] class S2GridProjector(BaseGridProjector): """ Project SO(3) coefficients to/from a flattened S2 grid. Parameters ---------- lmax Maximum spherical harmonic degree. mmax Maximum order kept in the coefficient layout. If None, use ``lmax``. dtype Buffer dtype used by the projection matrices. grid_resolution_list Two-element resolution list. For ``grid_method='e3nn'`` it is ``[R_phi, R_theta]`` and is converted to the ``e3nn`` ``(lat, long) = (R_theta, R_phi)`` ordering. For ``grid_method='lebedev'`` it is ``[precision, n_points]``. coefficient_layout Coefficient ordering expected by the caller: - ``"packed"``: packed ``(l, m)`` order, optionally truncated by ``mmax``. - ``"m_major"``: reduced m-major order used inside ``SO2Convolution``. grid_method S2 quadrature backend. Must be ``"e3nn"`` or ``"lebedev"``. """ def __init__( self, *, lmax: int, mmax: int | None = None, dtype: torch.dtype, grid_resolution_list: list[int] | None = None, coefficient_layout: str = "packed", grid_method: str = "e3nn", ) -> None: lmax_i = int(lmax) mmax_i = int(lmax_i if mmax is None else mmax)
[docs] self.grid_method = str(grid_method).lower()
if self.grid_method not in {"e3nn", "lebedev"}: raise ValueError("`grid_method` must be either 'e3nn' or 'lebedev'")
[docs] self.grid_resolution_list = _normalize_s2_grid_resolution( lmax_i, mmax_i, grid_resolution_list, method=self.grid_method, )
if self.grid_method == "e3nn": self.phi_resolution, self.theta_resolution = self.grid_resolution_list self.lebedev_precision = 0 self.lebedev_npoints = 0 else: self.phi_resolution = 0 self.theta_resolution = 0 self.lebedev_precision, self.lebedev_npoints = self.grid_resolution_list super().__init__( lmax=lmax_i, mmax=mmax_i, dtype=dtype, n_frames=1, coefficient_layout=coefficient_layout, )
[docs] def _rescale_truncated_orders(self, mat: torch.Tensor) -> None: if self.lmax == self.mmax: return for degree in range(self.lmax + 1): if degree <= self.mmax: continue start_idx = degree * degree length = 2 * degree + 1 rescale = math.sqrt(length / float(2 * self.mmax + 1)) mat[:, :, start_idx : start_idx + length].mul_(rescale)
[docs] def _rescale_truncated_matrix(self, mat: torch.Tensor) -> None: if self.lmax == self.mmax: return for degree in range(self.lmax + 1): if degree <= self.mmax: continue start_idx = degree * degree length = 2 * degree + 1 rescale = math.sqrt(length / float(2 * self.mmax + 1)) mat[:, start_idx : start_idx + length].mul_(rescale)
[docs] def _build_projection_mats( self, coeff_index: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: if self.grid_method == "lebedev": return self._build_lebedev_projection_mats(coeff_index) return self._build_e3nn_projection_mats(coeff_index)
[docs] def _build_e3nn_projection_mats( self, coeff_index: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: with torch.device("cpu"): to_grid = ToS2Grid( self.lmax, (self.theta_resolution, self.phi_resolution), normalization="component", device="cpu", ) to_grid_mat = torch.einsum("mbi,am->bai", to_grid.shb, to_grid.sha).detach() self._rescale_truncated_orders(to_grid_mat) from_grid = FromS2Grid( (self.theta_resolution, self.phi_resolution), self.lmax, normalization="component", device="cpu", ) from_grid_mat = torch.einsum( "am,mbi->bai", from_grid.sha, from_grid.shb ).detach() self._rescale_truncated_orders(from_grid_mat) to_grid_mat = to_grid_mat.flatten(0, 1).index_select(1, coeff_index) from_grid_mat = ( from_grid_mat.flatten(0, 1).permute(1, 0).index_select(0, coeff_index) ) return to_grid_mat, from_grid_mat
[docs] def _build_lebedev_projection_mats( self, coeff_index: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: with torch.device("cpu"): points, weights = load_lebedev_rule( self.lebedev_precision, dtype=torch.float64, device=torch.device("cpu"), ) harmonics = spherical_harmonics( list(range(self.lmax + 1)), points, normalize=True, normalization="norm", ) # Match the component-normalized product-grid convention used by # e3nn's ToS2Grid/FromS2Grid pair so both S2 backends are drop-in # replacements for the same grid net. scale = math.sqrt(float(self.lmax + 1)) degree_factors = harmonics.new_tensor( [ float(2 * degree + 1) for degree in range(self.lmax + 1) for _ in range(2 * degree + 1) ] ) to_grid_mat = harmonics / scale from_grid_mat = harmonics * ( weights[:, None] * scale * degree_factors[None, :] ) self._rescale_truncated_matrix(to_grid_mat) self._rescale_truncated_matrix(from_grid_mat) to_grid_mat = to_grid_mat.index_select(1, coeff_index) from_grid_mat = from_grid_mat.index_select(1, coeff_index).transpose(0, 1) return to_grid_mat, from_grid_mat
[docs] def serialize(self) -> dict[str, Any]: return { "@class": "S2GridProjector", "@version": 1, "config": { "lmax": self.lmax, "mmax": self.mmax, "precision": RESERVED_PRECISION_DICT[self.dtype], "grid_resolution_list": self.grid_resolution_list, "coefficient_layout": self.coefficient_layout, "grid_method": self.grid_method, }, "@variables": {}, }
@classmethod
[docs] def deserialize(cls, data: dict[str, Any]) -> S2GridProjector: data = data.copy() data_cls = data.pop("@class") if data_cls != "S2GridProjector": raise ValueError(f"Invalid class for S2GridProjector: {data_cls}") version = int(data.pop("@version")) check_version_compatibility(version, 1, 1) config = data.pop("config") data.pop("@variables", None) precision = config.pop("precision") config["dtype"] = PRECISION_DICT[precision] return cls(**config)
[docs] class SO3GridProjector(BaseGridProjector): """ Project SO(3) coefficients to/from a Wigner-D grid with frame indices. The coefficient axis is packed as ``(l, m, k)`` with ordinary SeZM ``(l, m)`` order outside and the configured frame set inside each row. A frame index outside ``[-l, l]`` is kept as a zero column/row. This keeps the tensor layout regular while preserving the exact per-degree frame support. """ def __init__( self, *, lmax: int, mmax: int | None = None, kmax: int = 1, dtype: torch.dtype, lebedev_precision: int | None = None, coefficient_layout: str = "packed", ) -> None: lmax_i = int(lmax) mmax_i = int(lmax_i if mmax is None else mmax)
[docs] self.kmax = int(kmax)
if self.kmax < 0: raise ValueError("`kmax` must be non-negative")
[docs] self.frame_set = _build_so3_frame_set(self.kmax)
[docs] self.frame_zero_index = self.frame_set.index(0)
self.lebedev_precision, self.lebedev_npoints, self.n_gamma = resolve_so3_grid( lmax_i, kmax=self.kmax, lebedev_precision=lebedev_precision, ) super().__init__( lmax=lmax_i, mmax=mmax_i, dtype=dtype, n_frames=len(self.frame_set), coefficient_layout=coefficient_layout, ) self.register_buffer( "frame_values", torch.tensor(self.frame_set, dtype=torch.long, device=self.device), persistent=False, )
[docs] def _build_projection_mats( self, coeff_index: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: with torch.device("cpu"): points, weights = load_lebedev_rule( self.lebedev_precision, dtype=torch.float64, device=torch.device("cpu"), ) gamma = torch.arange( self.n_gamma, dtype=torch.float64, device=points.device ) * (2.0 * math.pi / float(self.n_gamma)) edge_quaternion = build_edge_quaternion(points, eps=1e-14) edge_quaternion = edge_quaternion.repeat_interleave(self.n_gamma, dim=0) gamma_quaternion = quaternion_z_rotation(gamma).repeat(points.shape[0], 1) grid_quaternion = quaternion_multiply(gamma_quaternion, edge_quaternion) wigner_grid, _ = WignerDCalculator(self.lmax, dtype=torch.float64).to( torch.device("cpu") )(grid_quaternion) # ``build_edge_quaternion`` follows SeZM's global-to-local convention. # The transpose below stores the local m=0 column in the same layout # as ``WignerDCalculator.forward_zonal`` and extends it to k != 0. wigner_grid = wigner_grid.transpose(-1, -2).contiguous() haar_weight = weights.repeat_interleave(self.n_gamma) / float(self.n_gamma) grid_size = int(grid_quaternion.shape[0]) coeff_dim = int(coeff_index.numel() * len(self.frame_set)) to_grid_mat = torch.zeros( grid_size, coeff_dim, dtype=torch.float64, device=points.device, ) from_grid_mat = torch.zeros( coeff_dim, grid_size, dtype=torch.float64, device=points.device, ) for degree in range(self.lmax + 1): degree_factor = float(2 * degree + 1) for m_order in range(-degree, degree + 1): packed_idx = so3_packed_index(degree, m_order) coeff_positions = (coeff_index == packed_idx).nonzero( as_tuple=False ) if coeff_positions.numel() == 0: continue coeff_pos = int(coeff_positions[0, 0]) for frame_pos, frame_order in enumerate(self.frame_set): flat_idx = coeff_pos * len(self.frame_set) + frame_pos if abs(frame_order) > degree: continue row = so3_packed_index(degree, m_order) col = so3_packed_index(degree, frame_order) values = wigner_grid[:, row, col] to_grid_mat[:, flat_idx] = values from_grid_mat[flat_idx, :] = ( degree_factor * haar_weight * values ) return to_grid_mat, from_grid_mat
[docs] def serialize(self) -> dict[str, Any]: return { "@class": "SO3GridProjector", "@version": 1, "config": { "lmax": self.lmax, "mmax": self.mmax, "kmax": self.kmax, "precision": RESERVED_PRECISION_DICT[self.dtype], "lebedev_precision": self.lebedev_precision, "coefficient_layout": self.coefficient_layout, }, "@variables": {}, }
@classmethod
[docs] def deserialize(cls, data: dict[str, Any]) -> SO3GridProjector: data = data.copy() data_cls = data.pop("@class") if data_cls != "SO3GridProjector": raise ValueError(f"Invalid class for SO3GridProjector: {data_cls}") version = int(data.pop("@version")) check_version_compatibility(version, 1, 1) config = data.pop("config") data.pop("@variables", None) precision = config.pop("precision") config["dtype"] = PRECISION_DICT[precision] return cls(**config)
[docs] def resolve_s2_grid_resolution( lmax: int, mmax: int, *, method: str = "e3nn", ) -> list[int]: """ Resolve the default S2 grid resolution. For ``method='e3nn'``, the automatic default uses even azimuthal sampling ``R_phi = 2 * mmax + 4`` and even polar sampling ``R_theta = ceil_even(3 * lmax + 2)``. For ``method='lebedev'``, the automatic default picks the smallest packaged Lebedev rule whose algebraic precision is at least ``3 * lmax`` and returns ``[precision, n_points]``. """ method = str(method).lower() if method not in {"e3nn", "lebedev"}: raise ValueError("`method` must be either 'e3nn' or 'lebedev'") if method == "lebedev": required_precision = 3 * int(lmax) for precision, n_points in LEBEDEV_PRECISION_TO_NPOINTS.items(): if precision >= required_precision: return [precision, n_points] raise ValueError( f"No packaged Lebedev rule has precision >= {required_precision}" ) phi_resolution = 2 * int(mmax) + 4 theta_resolution = 3 * int(lmax) + 2 theta_resolution += theta_resolution % 2 return [phi_resolution, theta_resolution]
[docs] def resolve_so3_grid( lmax: int, *, kmax: int = 1, lebedev_precision: int | None = None, ) -> tuple[int, int, int]: """ Resolve the default SO(3) quadrature as Lebedev sphere times gamma samples. The Lebedev precision follows the same conservative ``3*lmax`` rule used by the S2 grid path. The gamma grid is chosen for the quadratic grid products used by the SO(3) grid nets, whose third-angle frequency can reach ``k1 + k2 - kout``. """ lmax_i = int(lmax) kmax_i = int(kmax) if kmax_i < 0: raise ValueError("`kmax` must be non-negative") if lebedev_precision is None: required_precision = 3 * lmax_i for precision, n_points in LEBEDEV_PRECISION_TO_NPOINTS.items(): if precision >= required_precision: lebedev_precision = precision lebedev_npoints = n_points break else: raise ValueError( f"No packaged Lebedev rule has precision >= {required_precision}" ) else: lebedev_precision = int(lebedev_precision) lebedev_npoints = LEBEDEV_PRECISION_TO_NPOINTS.get(lebedev_precision) if lebedev_npoints is None: raise ValueError( f"Lebedev rule with precision {lebedev_precision} is not packaged" ) # A quadratic product followed by analysis can contain gamma frequencies up # to ``3*kmax``. A uniform grid with more samples than that frequency # resolves the integer Fourier modes exactly. n_gamma = 1 if kmax_i == 0 else 3 * kmax_i + 1 return int(lebedev_precision), int(lebedev_npoints), int(n_gamma)
[docs] def _normalize_s2_grid_resolution( lmax: int, mmax: int, grid_resolution_list: list[int] | None, *, method: str, ) -> list[int]: """Resolve default grids or validate already-resolved low-level grids.""" method = str(method).lower() if grid_resolution_list is None: return resolve_s2_grid_resolution(lmax, mmax, method=method) if method == "lebedev": if len(grid_resolution_list) != 2: raise ValueError( "Lebedev `grid_resolution_list` must be [precision, n_points]" ) precision = int(grid_resolution_list[0]) n_points = int(grid_resolution_list[1]) expected_n_points = LEBEDEV_PRECISION_TO_NPOINTS.get(precision) if expected_n_points != n_points: raise ValueError( "Lebedev `grid_resolution_list` must match a packaged " f"[precision, n_points] pair; got [{precision}, {n_points}]" ) return [precision, n_points] if len(grid_resolution_list) != 2: raise ValueError("`grid_resolution_list` must contain two integers") resolution = [int(grid_resolution_list[0]), int(grid_resolution_list[1])] if resolution[0] < 1 or resolution[1] < 1: raise ValueError("grid resolutions must be positive") return resolution
[docs] def _build_so3_frame_set(kmax: int) -> list[int]: """Build the symmetric frame-index set with zero first.""" kmax_i = int(kmax) if kmax_i < 0: raise ValueError("`kmax` must be non-negative") return [0, *[frame for kk in range(1, kmax_i + 1) for frame in (-kk, kk)]]