Source code for deepmd.pd.utils.decomp

# SPDX-License-Identifier: LGPL-3.0-or-later

# This file is used to implement some paddle functions with composite API,
# so as to support high-order differentation when double-backward is needed.
# For example: [norm] --decomposition--> [multiply, power, sum]
# This file will be removed when implmented functions are decomposed into primitive
# function in Paddle framework in the future.

from __future__ import (
    annotations,
)

import numpy as np
import paddle

__all__ = [
    "masked_add_",
    "numel",
    "scatter_reduce",
    "sec",
]


def scatter_reduce_decomp(
    input: paddle.Tensor,
    axis: int,
    index: paddle.Tensor,
    src: paddle.Tensor,
    reduce: str,
) -> paddle.Tensor:
    """Forward decompsition function of scatter_reduce.

    Parameters
    ----------
    input : paddle.Tensor
        Input tensor.
    axis : int
        The axis along which to index.
    index : paddle.Tensor
        The indices of elements to scatter and reduce.
    src : paddle.Tensor
        The source elements to scatter and reduce.
    reduce : str
        The reduction operation to apply for non-unique indices.
        Supported modes: ("sum", "prod", "mean", "amax", "amin").

    Returns
    -------
    paddle.Tensor
        Computed output.
    """
    # reduce: "sum", "prod", "mean", "amax", "amin"
    if reduce == "sum":
        output = input.put_along_axis(
            indices=index, values=src, axis=axis, reduce="add"
        )
    elif reduce == "mean":
        output = input.put_along_axis(
            indices=index, values=src, axis=axis, reduce="mean"
        )
    elif reduce == "prod":
        output = input.put_along_axis(
            indices=index, values=src, axis=axis, reduce="mul"
        )
    else:
        raise NotImplementedError("only support mode in ['sum', 'prod', 'mean']!")
    return output


[docs] def sec(length: int, size: int) -> list[int]: """Auxiliary function for decomposed functions. If length is not divisible by size, the last chunk will be smaller. Parameters ---------- length : int Length to be chunked. size : int Chunk size. Returns ------- list[int] Chunked output list. """ assert length > 0 assert size > 0 if length % size == 0: return [size] * (length // size) return [size] * (length // size) + [length % size]
def masked_add__decomp( x: paddle.Tensor, mask: paddle.Tensor, v: paddle.Tensor ) -> paddle.Tensor: """Forward decompsition function of masked_add_(inplace operator). Parameters ---------- x : paddle.Tensor Input tensor. mask : paddle.Tensor Mask tensor. v : paddle.Tensor Value to add. Returns ------- paddle.Tensor Computed output. """ assert mask.dtype == paddle.bool, f"mask must be bool type, but got {mask.dtype}" # indices is bool mask mask_coord = paddle.nonzero(mask, as_tuple=False) # [nz, dim] if not paddle.is_tensor(v): v = paddle.full([mask_coord.shape[0]], v, dtype=x.dtype) t = paddle.scatter_nd_add( x, mask_coord, v, ) paddle.assign(t, x) # inplace update return x
[docs] def numel(x: paddle.Tensor) -> int: if paddle.in_dynamic_mode(): return np.prod(x.shape) return paddle.numel(x)
# alias for decomposed functions for convinience
[docs] masked_add_ = masked_add__decomp
[docs] scatter_reduce = scatter_reduce_decomp