# SPDX-License-Identifier: LGPL-3.0-or-later
from packaging.version import (
Version,
)
from deepmd.env import (
TF_VERSION,
tf,
)
from deepmd.utils.errors import (
OutOfMemoryError,
)
from deepmd_utils.utils.batch_size import AutoBatchSize as AutoBatchSizeBase
[docs]class AutoBatchSize(AutoBatchSizeBase):
[docs] def is_gpu_available(self) -> bool:
"""Check if GPU is available.
Returns
-------
bool
True if GPU is available
"""
return (
Version(TF_VERSION) >= Version("1.14")
and tf.config.experimental.get_visible_devices("GPU")
) or tf.test.is_gpu_available()
[docs] def is_oom_error(self, e: Exception) -> bool:
"""Check if the exception is an OOM error.
Parameters
----------
e : Exception
Exception
"""
# TODO: it's very slow to catch OOM error; I don't know what TF is doing here
# but luckily we only need to catch once
return isinstance(e, (tf.errors.ResourceExhaustedError, OutOfMemoryError))