Source code for deepmd.dpmodel.array_api

# SPDX-License-Identifier: LGPL-3.0-or-later
"""Utilities for the array API."""

import array_api_compat
import numpy as np
from packaging.version import (
    Version,
)


[docs] def support_array_api(version: str) -> callable: """Mark a function as supporting the specific version of the array API. Parameters ---------- version : str The version of the array API Returns ------- callable The decorated function Examples -------- >>> @support_array_api(version="2022.12") ... def f(x): ... pass """ def set_version(func: callable) -> callable: func.array_api_version = version return func return set_version
# array api adds take_along_axis in https://github.com/data-apis/array-api/pull/816 # but it hasn't been released yet # below is a pure Python implementation of take_along_axis # https://github.com/data-apis/array-api/issues/177#issuecomment-2093630595
[docs] def xp_swapaxes(a, axis1, axis2): xp = array_api_compat.array_namespace(a) axes = list(range(a.ndim)) axes[axis1], axes[axis2] = axes[axis2], axes[axis1] a = xp.permute_dims(a, axes) return a
[docs] def xp_take_along_axis(arr, indices, axis): xp = array_api_compat.array_namespace(arr) if Version(xp.__array_api_version__) >= Version("2024.12"): # see: https://github.com/data-apis/array-api-strict/blob/d086c619a58f35c38240592ef994aa19ca7beebc/array_api_strict/_indexing_functions.py#L30-L39 return xp.take_along_axis(arr, indices, axis=axis) arr = xp_swapaxes(arr, axis, -1) indices = xp_swapaxes(indices, axis, -1) m = arr.shape[-1] n = indices.shape[-1] shape = list(arr.shape) shape.pop(-1) shape = [*shape, n] arr = xp.reshape(arr, (-1,)) if n != 0: indices = xp.reshape(indices, (-1, n)) else: indices = xp.reshape(indices, (0, 0)) offset = (xp.arange(indices.shape[0], dtype=indices.dtype) * m)[:, xp.newaxis] indices = xp.reshape(offset + indices, (-1,)) out = xp.take(arr, indices) out = xp.reshape(out, shape) return xp_swapaxes(out, axis, -1)
[docs] def xp_scatter_sum(input, dim, index: np.ndarray, src: np.ndarray) -> np.ndarray: """Reduces all values from the src tensor to the indices specified in the index tensor.""" # jax only if array_api_compat.is_jax_array(input): from deepmd.jax.common import ( scatter_sum, ) return scatter_sum( input, dim, index, src, ) else: raise NotImplementedError("Only JAX arrays are supported.")