Source code for deepmd.utils.batch_size

# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import os
from abc import (
    ABC,
    abstractmethod,
)
from collections.abc import (
    Callable,
)
from typing import (
    Any,
)

import array_api_compat
import numpy as np

from deepmd.utils.errors import (
    OutOfMemoryError,
)

[docs] log = logging.getLogger(__name__)
[docs] class RetrySignal(Exception): """Signal to retry execution after OOM error."""
[docs] class AutoBatchSize(ABC): """This class allows DeePMD-kit to automatically decide the maximum batch size that will not cause an OOM error. Notes ----- In some CPU environments, the program may be directly killed when OOM. In this case, by default the batch size will not be increased for CPUs. The environment variable `DP_INFER_BATCH_SIZE` can be set as the batch size. In other cases, we assume all OOM error will raise :class:`OutOfMemoryError`. Parameters ---------- initial_batch_size : int, default: 1024 initial batch size (number of total atoms) when DP_INFER_BATCH_SIZE is not set factor : float, default: 2. increased factor silent : bool, default: False whether to suppress auto batch size informational logs Attributes ---------- current_batch_size : int current batch size (number of total atoms) maximum_working_batch_size : int maximum working batch size minimal_not_working_batch_size : int minimal not working batch size """ def __init__( self, initial_batch_size: int = 1024, factor: float = 2.0, *, silent: bool = False, ) -> None: # See also PyTorchLightning/pytorch-lightning#1638
[docs] self.silent = silent
[docs] self.current_batch_size = initial_batch_size
DP_INFER_BATCH_SIZE = int(os.environ.get("DP_INFER_BATCH_SIZE", 0)) if DP_INFER_BATCH_SIZE > 0: self.current_batch_size = DP_INFER_BATCH_SIZE self.maximum_working_batch_size = DP_INFER_BATCH_SIZE self.minimal_not_working_batch_size = self.maximum_working_batch_size + 1 else: self.maximum_working_batch_size = initial_batch_size if self.is_gpu_available(): self.minimal_not_working_batch_size = 2**31 else: self.minimal_not_working_batch_size = ( self.maximum_working_batch_size + 1 ) if not self.silent: log.warning( "You can use the environment variable DP_INFER_BATCH_SIZE to" "control the inference batch size (nframes * natoms). " f"The default value is {initial_batch_size}." )
[docs] self.factor = factor
[docs] self.oom_retry_mode = False
[docs] def execute( self, callable: Callable, start_index: int, natoms: int ) -> tuple[int, tuple]: """Excuate a method with given batch size. Parameters ---------- callable : Callable The method should accept the batch size and start_index as parameters, and returns executed batch size and data. start_index : int start index natoms : int natoms Returns ------- int executed batch size * number of atoms tuple result from callable, None if failing to execute Raises ------ OutOfMemoryError OOM when batch size is 1 """ if natoms > 0: batch_nframes = self.current_batch_size // natoms else: batch_nframes = self.current_batch_size try: n_batch, result = callable(max(batch_nframes, 1), start_index) except Exception as e: if not self.is_oom_error(e): raise e self.minimal_not_working_batch_size = min( self.minimal_not_working_batch_size, self.current_batch_size ) if self.maximum_working_batch_size >= self.minimal_not_working_batch_size: self.maximum_working_batch_size = int( self.minimal_not_working_batch_size / self.factor ) if self.minimal_not_working_batch_size <= natoms: raise OutOfMemoryError( "The callable still throws an out-of-memory (OOM) error even when batch size is 1!" ) from e # adjust the next batch size self._adjust_batch_size(1.0 / self.factor) if self.oom_retry_mode: raise RetrySignal from e return 0, None else: n_tot = n_batch * natoms self.maximum_working_batch_size = max( self.maximum_working_batch_size, n_tot ) # adjust the next batch size if ( n_tot + natoms > self.current_batch_size and self.current_batch_size * self.factor < self.minimal_not_working_batch_size ): self._adjust_batch_size(self.factor) return n_batch, result
[docs] def _adjust_batch_size(self, factor: float) -> None: old_batch_size = self.current_batch_size self.current_batch_size = int(self.current_batch_size * factor) if not self.silent: log.info( f"Adjust batch size from {old_batch_size} to {self.current_batch_size}" )
[docs] def execute_all( self, callable: Callable, total_size: int, natoms: int, *args: Any, **kwargs: Any, ) -> tuple[np.ndarray]: """Excuate a method with all given data. This method is compatible with Array API. Parameters ---------- callable : Callable The method should accept *args and **kwargs as input and return the similar array. total_size : int Total size natoms : int The number of atoms *args Variable length argument list. **kwargs If 2D np.ndarray, assume the first axis is batch; otherwise do nothing. """ def execute_with_batch_size( batch_size: int, start_index: int ) -> tuple[int, tuple[np.ndarray]]: end_index = start_index + batch_size end_index = min(end_index, total_size) return (end_index - start_index), callable( *[ ( vv[start_index:end_index, ...] if ( (array_api_compat.is_array_api_obj(vv) and vv.ndim > 1) or str(vv.__class__) == "<class 'paddle.Tensor'>" ) else vv ) for vv in args ], **{ kk: ( vv[start_index:end_index, ...] if ( (array_api_compat.is_array_api_obj(vv) and vv.ndim > 1) or str(vv.__class__) == "<class 'paddle.Tensor'>" ) else vv ) for kk, vv in kwargs.items() }, ) index = 0 results = None returned_dict = None while index < total_size: n_batch, result = self.execute(execute_with_batch_size, index, natoms) if n_batch == 0: continue returned_dict = ( isinstance(result, dict) if returned_dict is None else returned_dict ) if not returned_dict: result = (result,) if not isinstance(result, tuple) else result index += n_batch def append_to_list(res_list: list[Any], res: Any) -> list[Any]: if n_batch: res_list.append(res) return res_list if not returned_dict: results = [] if results is None else results results = append_to_list(results, result) else: results = {kk: [] for kk in result} if results is None else results results = {kk: append_to_list(results[kk], result[kk]) for kk in result} assert results is not None assert returned_dict is not None def concate_result(r: list[Any]) -> Any: if array_api_compat.is_array_api_obj(r[0]): xp = array_api_compat.array_namespace(r[0]) ret = xp.concat(r, axis=0) elif str(r[0].__class__) == "<class 'paddle.Tensor'>": try: import paddle except ModuleNotFoundError as e: raise ModuleNotFoundError( "The 'paddlepaddle' is required but not installed." ) from e ret = paddle.concat(r, axis=0) else: raise RuntimeError(f"Unexpected result type {type(r[0])}") return ret if not returned_dict: r_list = [concate_result(r) for r in zip(*results, strict=True)] r = tuple(r_list) if len(r) == 1: # avoid returning tuple if callable doesn't return tuple r = r[0] else: r = {kk: concate_result(vv) for kk, vv in results.items()} return r
@abstractmethod
[docs] def is_gpu_available(self) -> bool: """Check if GPU is available. Returns ------- bool True if GPU is available """
@abstractmethod
[docs] def is_oom_error(self, e: Exception) -> bool: """Check if the exception is an OOM error. Parameters ---------- e : Exception Exception Returns ------- bool True if the exception is an OOM error """
[docs] def set_oom_retry_mode(self, enable: bool) -> None: """Set OOM retry mode. In OOM retry mode, an OOM during execution may reduce the current batch size and raise :class:`RetrySignal` to indicate that execution should be retried. Callers that want all data to be re-executed must catch :class:`RetrySignal` and restart the full evaluation themselves. Parameters ---------- enable : bool True to enable OOM retry mode """ self.oom_retry_mode = enable