deepmd.dpmodel.array_api#

Utilities for the array API.

Functions#

support_array_api(→ callable)

Mark a function as supporting the specific version of the array API.

xp_swapaxes(a, axis1, axis2)

xp_take_along_axis(arr, indices, axis)

xp_scatter_sum(→ numpy.ndarray)

Reduces all values from the src tensor to the indices specified in the index tensor.

Module Contents#

deepmd.dpmodel.array_api.support_array_api(version: str) callable[source]#

Mark a function as supporting the specific version of the array API.

Parameters:
versionstr

The version of the array API

Returns:
callable()

The decorated function

Examples

>>> @support_array_api(version="2022.12")
... def f(x):
...     pass
deepmd.dpmodel.array_api.xp_swapaxes(a, axis1, axis2)[source]#
deepmd.dpmodel.array_api.xp_take_along_axis(arr, indices, axis)[source]#
deepmd.dpmodel.array_api.xp_scatter_sum(input, dim, index: numpy.ndarray, src: numpy.ndarray) numpy.ndarray[source]#

Reduces all values from the src tensor to the indices specified in the index tensor.