Source code for deepmd.dpmodel.common

# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
    ABC,
    abstractmethod,
)
from functools import (
    wraps,
)
from typing import (
    Any,
    Callable,
    Optional,
    overload,
)

import array_api_compat
import ml_dtypes
import numpy as np

from deepmd.common import (
    VALID_PRECISION,
)
from deepmd.env import (
    GLOBAL_ENER_FLOAT_PRECISION,
    GLOBAL_NP_FLOAT_PRECISION,
)

[docs] PRECISION_DICT = { "float16": np.float16, "float32": np.float32, "float64": np.float64, "half": np.float16, "single": np.float32, "double": np.float64, "int32": np.int32, "int64": np.int64, "bool": np.bool_, "default": GLOBAL_NP_FLOAT_PRECISION, # NumPy doesn't have bfloat16 (and doesn't plan to add) # ml_dtypes is a solution, but it seems not supporting np.save/np.load # hdf5 hasn't supported bfloat16 as well (see https://forum.hdfgroup.org/t/11975) "bfloat16": ml_dtypes.bfloat16, }
assert VALID_PRECISION.issubset(PRECISION_DICT.keys())
[docs] RESERVED_PRECISON_DICT = { np.float16: "float16", np.float32: "float32", np.float64: "float64", np.int32: "int32", np.int64: "int64", ml_dtypes.bfloat16: "bfloat16", np.bool_: "bool", }
assert set(RESERVED_PRECISON_DICT.keys()) == set(PRECISION_DICT.values())
[docs] DEFAULT_PRECISION = "float64"
def get_xp_precision( xp: Any, precision: str, ): """Get the precision from the API compatible namespace.""" if precision == "float16" or precision == "half": return xp.float16 elif precision == "float32" or precision == "single": return xp.float32 elif precision == "float64" or precision == "double": return xp.float64 elif precision == "int32": return xp.int32 elif precision == "int64": return xp.int64 elif precision == "bool": return bool elif precision == "default": return get_xp_precision(xp, RESERVED_PRECISON_DICT[PRECISION_DICT[precision]]) elif precision == "global": return get_xp_precision(xp, RESERVED_PRECISON_DICT[GLOBAL_NP_FLOAT_PRECISION]) elif precision == "bfloat16": return ml_dtypes.bfloat16 else: raise ValueError(f"unsupported precision {precision} for {xp}")
[docs] class NativeOP(ABC): """The unit operation of a native model.""" @abstractmethod
[docs] def call(self, *args, **kwargs): """Forward pass in NumPy implementation.""" pass
[docs] def __call__(self, *args, **kwargs): """Forward pass in NumPy implementation.""" return self.call(*args, **kwargs)
def to_numpy_array(x: Any) -> Optional[np.ndarray]: """Convert an array to a NumPy array. Parameters ---------- x : Any The array to be converted. Returns ------- Optional[np.ndarray] The NumPy array. """ if x is None: return None try: # asarray is not within Array API standard, so may fail return np.asarray(x) except (ValueError, AttributeError): xp = array_api_compat.array_namespace(x) # to fix BufferError: Cannot export readonly array since signalling readonly is unsupported by DLPack. x = xp.asarray(x, copy=True) return np.from_dlpack(x) def cast_precision(func: Callable[..., Any]) -> Callable[..., Any]: """A decorator that casts and casts back the input and output tensor of a method. The decorator should be used on an instance method. The decorator will do the following thing: (1) It casts input arrays from the global precision to precision defined by property `precision`. (2) It casts output arrays from `precision` to the global precision. (3) It checks inputs and outputs and only casts when input or output is an array and its dtype matches the global precision and `precision`, respectively. If it does not match (e.g. it is an integer), the decorator will do nothing on it. The decorator supports the array API. Returns ------- Callable a decorator that casts and casts back the input and output array of a method Examples -------- >>> class A: ... def __init__(self): ... self.precision = "float32" ... ... @cast_precision ... def f(x: Array, y: Array) -> Array: ... return x**2 + y """ @wraps(func) def wrapper(self, *args, **kwargs): # only convert tensors returned_tensor = func( self, *[safe_cast_array(vv, "global", self.precision) for vv in args], **{ kk: safe_cast_array(vv, "global", self.precision) for kk, vv in kwargs.items() }, ) if isinstance(returned_tensor, tuple): return tuple( safe_cast_array(vv, self.precision, "global") for vv in returned_tensor ) elif isinstance(returned_tensor, dict): return { kk: safe_cast_array(vv, self.precision, "global") for kk, vv in returned_tensor.items() } else: return safe_cast_array(returned_tensor, self.precision, "global") return wrapper @overload def safe_cast_array( input: np.ndarray, from_precision: str, to_precision: str ) -> np.ndarray: ... @overload def safe_cast_array(input: None, from_precision: str, to_precision: str) -> None: ... def safe_cast_array( input: Optional[np.ndarray], from_precision: str, to_precision: str ) -> Optional[np.ndarray]: """Convert an array from a precision to another precision. If input is not an array or without the specific precision, the method will not cast it. Array API is supported. Parameters ---------- input : np.ndarray or None Input array from_precision : str Array data type that is casted from to_precision : str Array data type that casts to Returns ------- np.ndarray or None casted array """ if array_api_compat.is_array_api_obj(input): xp = array_api_compat.array_namespace(input) if input.dtype == get_xp_precision(xp, from_precision): return xp.astype(input, get_xp_precision(xp, to_precision)) return input __all__ = [ "GLOBAL_NP_FLOAT_PRECISION", "GLOBAL_ENER_FLOAT_PRECISION", "PRECISION_DICT", "RESERVED_PRECISON_DICT", "DEFAULT_PRECISION", "NativeOP", ]