Source code for deepmd.utils.batch_size

# SPDX-License-Identifier: LGPL-3.0-or-later
from packaging.version import (
    Version,
)

from deepmd.env import (
    TF_VERSION,
    tf,
)
from deepmd.utils.errors import (
    OutOfMemoryError,
)
from deepmd_utils.utils.batch_size import AutoBatchSize as AutoBatchSizeBase


[docs]class AutoBatchSize(AutoBatchSizeBase):
[docs] def is_gpu_available(self) -> bool: """Check if GPU is available. Returns ------- bool True if GPU is available """ return ( Version(TF_VERSION) >= Version("1.14") and tf.config.experimental.get_visible_devices("GPU") ) or tf.test.is_gpu_available()
[docs] def is_oom_error(self, e: Exception) -> bool: """Check if the exception is an OOM error. Parameters ---------- e : Exception Exception """ # TODO: it's very slow to catch OOM error; I don't know what TF is doing here # but luckily we only need to catch once return isinstance(e, (tf.errors.ResourceExhaustedError, OutOfMemoryError))