deepmd.dpmodel.array_api#
Utilities for the array API.
Functions#
| Mark a function as supporting the specific version of the array API. |
| |
| |
| Reduces all values from the src tensor to the indices specified in the index tensor. |
| Adds values to the specified indices of x in place or returns new x (for JAX). |
| Counts the number of occurrences of each value in x. |
Module Contents#
- deepmd.dpmodel.array_api.support_array_api(version: str) callable[source]#
Mark a function as supporting the specific version of the array API.
- Parameters:
- version
str The version of the array API
- version
- Returns:
callable()The decorated function
Examples
>>> @support_array_api(version="2022.12") ... def f(x): ... pass
- deepmd.dpmodel.array_api.xp_scatter_sum(input, dim, index: numpy.ndarray, src: numpy.ndarray) numpy.ndarray[source]#
Reduces all values from the src tensor to the indices specified in the index tensor.