deepmd.dpmodel.array_api#
Utilities for the array API.
Attributes#
Functions#
|
|
| |
| |
| Take the first n elements along dim. |
| Reduces all values from the src tensor to the indices specified in the index tensor. |
| Adds values to the specified indices of x in place or returns new x (for JAX). |
| Compute the sigmoid function. |
| Set items at boolean mask indices. |
| Counts the number of occurrences of each value in x. |
Module Contents#
- deepmd.dpmodel.array_api.xp_asarray_nodetach(xp: Any, obj: Any, *, dtype: Any = None, device: Any = None) Array[source]#
xp.asarraythat preserves autograd for backend tensors.torch.asarraydetaches its input from the autograd graph, so callingxp.asarrayon a weight attribute that is already a backend tensor (e.g. atorch.nn.Parameterregistered by the pt_expt backend) silently breaks gradient flow to that weight. This helper converts genuine non-backend data (numpy arrays, python scalars/lists) viaxp.asarray; backend tensors are returned as-is, with an optional differentiable dtype cast viaxp.astype.The
deviceargument only applies to the conversion path: backend tensors are assumed to already live on the working device (they are created together with the inputs).
- deepmd.dpmodel.array_api.xp_take_first_n(arr: Array, dim: int, n: int) Array[source]#
Take the first n elements along dim.
For torch tensors, uses
torch.index_selectso thattorch.exportdoes not emit a contiguity guard that would prevent thenall == nloc(no-PBC) case from working. For numpy / jax, uses regular slicing.
- deepmd.dpmodel.array_api.xp_scatter_sum(input: Array, dim: int, index: Array, src: Array) Array[source]#
Reduces all values from the src tensor to the indices specified in the index tensor.
This function is similar to PyTorch’s scatter_add and JAX’s scatter_sum. It adds values from src to input at positions specified by index along the given dimension.
- deepmd.dpmodel.array_api.xp_add_at(x: Array, indices: Array, values: Array) Array[source]#
Adds values to the specified indices of x in place or returns new x (for JAX).
- deepmd.dpmodel.array_api.xp_sigmoid(x: Array) Array[source]#
Compute the sigmoid function.
JAX and PyTorch have optimized sigmoid implementations. See jax-ml/jax#15617