# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
ABC,
abstractmethod,
)
from functools import (
wraps,
)
from typing import (
Any,
Callable,
Optional,
overload,
)
import array_api_compat
import ml_dtypes
import numpy as np
from deepmd.common import (
VALID_PRECISION,
)
from deepmd.env import (
GLOBAL_ENER_FLOAT_PRECISION,
GLOBAL_NP_FLOAT_PRECISION,
)
[docs]
PRECISION_DICT = {
"float16": np.float16,
"float32": np.float32,
"float64": np.float64,
"half": np.float16,
"single": np.float32,
"double": np.float64,
"int32": np.int32,
"int64": np.int64,
"bool": np.bool_,
"default": GLOBAL_NP_FLOAT_PRECISION,
# NumPy doesn't have bfloat16 (and doesn't plan to add)
# ml_dtypes is a solution, but it seems not supporting np.save/np.load
# hdf5 hasn't supported bfloat16 as well (see https://forum.hdfgroup.org/t/11975)
"bfloat16": ml_dtypes.bfloat16,
}
assert VALID_PRECISION.issubset(PRECISION_DICT.keys())
[docs]
RESERVED_PRECISON_DICT = {
np.float16: "float16",
np.float32: "float32",
np.float64: "float64",
np.int32: "int32",
np.int64: "int64",
ml_dtypes.bfloat16: "bfloat16",
np.bool_: "bool",
}
assert set(RESERVED_PRECISON_DICT.keys()) == set(PRECISION_DICT.values())
[docs]
DEFAULT_PRECISION = "float64"
def get_xp_precision(
xp: Any,
precision: str,
):
"""Get the precision from the API compatible namespace."""
if precision == "float16" or precision == "half":
return xp.float16
elif precision == "float32" or precision == "single":
return xp.float32
elif precision == "float64" or precision == "double":
return xp.float64
elif precision == "int32":
return xp.int32
elif precision == "int64":
return xp.int64
elif precision == "bool":
return bool
elif precision == "default":
return get_xp_precision(xp, RESERVED_PRECISON_DICT[PRECISION_DICT[precision]])
elif precision == "global":
return get_xp_precision(xp, RESERVED_PRECISON_DICT[GLOBAL_NP_FLOAT_PRECISION])
elif precision == "bfloat16":
return ml_dtypes.bfloat16
else:
raise ValueError(f"unsupported precision {precision} for {xp}")
[docs]
class NativeOP(ABC):
"""The unit operation of a native model."""
@abstractmethod
[docs]
def call(self, *args, **kwargs):
"""Forward pass in NumPy implementation."""
pass
[docs]
def __call__(self, *args, **kwargs):
"""Forward pass in NumPy implementation."""
return self.call(*args, **kwargs)
def to_numpy_array(x: Any) -> Optional[np.ndarray]:
"""Convert an array to a NumPy array.
Parameters
----------
x : Any
The array to be converted.
Returns
-------
Optional[np.ndarray]
The NumPy array.
"""
if x is None:
return None
try:
# asarray is not within Array API standard, so may fail
return np.asarray(x)
except (ValueError, AttributeError):
xp = array_api_compat.array_namespace(x)
# to fix BufferError: Cannot export readonly array since signalling readonly is unsupported by DLPack.
x = xp.asarray(x, copy=True)
return np.from_dlpack(x)
def cast_precision(func: Callable[..., Any]) -> Callable[..., Any]:
"""A decorator that casts and casts back the input
and output tensor of a method.
The decorator should be used on an instance method.
The decorator will do the following thing:
(1) It casts input arrays from the global precision
to precision defined by property `precision`.
(2) It casts output arrays from `precision` to
the global precision.
(3) It checks inputs and outputs and only casts when
input or output is an array and its dtype matches
the global precision and `precision`, respectively.
If it does not match (e.g. it is an integer), the decorator
will do nothing on it.
The decorator supports the array API.
Returns
-------
Callable
a decorator that casts and casts back the input and
output array of a method
Examples
--------
>>> class A:
... def __init__(self):
... self.precision = "float32"
...
... @cast_precision
... def f(x: Array, y: Array) -> Array:
... return x**2 + y
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
# only convert tensors
returned_tensor = func(
self,
*[safe_cast_array(vv, "global", self.precision) for vv in args],
**{
kk: safe_cast_array(vv, "global", self.precision)
for kk, vv in kwargs.items()
},
)
if isinstance(returned_tensor, tuple):
return tuple(
safe_cast_array(vv, self.precision, "global") for vv in returned_tensor
)
elif isinstance(returned_tensor, dict):
return {
kk: safe_cast_array(vv, self.precision, "global")
for kk, vv in returned_tensor.items()
}
else:
return safe_cast_array(returned_tensor, self.precision, "global")
return wrapper
@overload
def safe_cast_array(
input: np.ndarray, from_precision: str, to_precision: str
) -> np.ndarray: ...
@overload
def safe_cast_array(input: None, from_precision: str, to_precision: str) -> None: ...
def safe_cast_array(
input: Optional[np.ndarray], from_precision: str, to_precision: str
) -> Optional[np.ndarray]:
"""Convert an array from a precision to another precision.
If input is not an array or without the specific precision, the method will not
cast it.
Array API is supported.
Parameters
----------
input : np.ndarray or None
Input array
from_precision : str
Array data type that is casted from
to_precision : str
Array data type that casts to
Returns
-------
np.ndarray or None
casted array
"""
if array_api_compat.is_array_api_obj(input):
xp = array_api_compat.array_namespace(input)
if input.dtype == get_xp_precision(xp, from_precision):
return xp.astype(input, get_xp_precision(xp, to_precision))
return input
__all__ = [
"GLOBAL_NP_FLOAT_PRECISION",
"GLOBAL_ENER_FLOAT_PRECISION",
"PRECISION_DICT",
"RESERVED_PRECISON_DICT",
"DEFAULT_PRECISION",
"NativeOP",
]