deepmd.dpmodel.array_api#

Utilities for the array API.

Functions#

support_array_api(→ callable)

Mark a function as supporting the specific version of the array API.

xp_swapaxes(a, axis1, axis2)

xp_take_along_axis(arr, indices, axis)

xp_scatter_sum(→ numpy.ndarray)

Reduces all values from the src tensor to the indices specified in the index tensor.

xp_add_at(x, indices, values)

Adds values to the specified indices of x in place or returns new x (for JAX).

xp_bincount(x[, weights, minlength])

Counts the number of occurrences of each value in x.

Module Contents#

deepmd.dpmodel.array_api.support_array_api(version: str) callable[source]#

Mark a function as supporting the specific version of the array API.

Parameters:
versionstr

The version of the array API

Returns:
callable()

The decorated function

Examples

>>> @support_array_api(version="2022.12")
... def f(x):
...     pass
deepmd.dpmodel.array_api.xp_swapaxes(a, axis1, axis2)[source]#
deepmd.dpmodel.array_api.xp_take_along_axis(arr, indices, axis)[source]#
deepmd.dpmodel.array_api.xp_scatter_sum(input, dim, index: numpy.ndarray, src: numpy.ndarray) numpy.ndarray[source]#

Reduces all values from the src tensor to the indices specified in the index tensor.

deepmd.dpmodel.array_api.xp_add_at(x, indices, values)[source]#

Adds values to the specified indices of x in place or returns new x (for JAX).

deepmd.dpmodel.array_api.xp_bincount(x, weights=None, minlength=0)[source]#

Counts the number of occurrences of each value in x.