Source code for dpgen2.exploration.deviation.deviation_manager

from abc import (
    ABC,
    abstractmethod,
)
from typing import (
    List,
    Optional,
)

import numpy as np


[docs] class DeviManager(ABC): r"""A class for model deviation management.""" MAX_DEVI_V = "max_devi_v" MIN_DEVI_V = "min_devi_v" AVG_DEVI_V = "avg_devi_v" MAX_DEVI_F = "max_devi_f" MIN_DEVI_F = "min_devi_f" AVG_DEVI_F = "avg_devi_f" def __init__(self) -> None: super().__init__() self.ntraj = 0 def _check_name(self, name: str): assert name in ( DeviManager.MAX_DEVI_V, DeviManager.MIN_DEVI_V, DeviManager.AVG_DEVI_V, DeviManager.MAX_DEVI_F, DeviManager.MIN_DEVI_F, DeviManager.AVG_DEVI_F, ), f"Error: unknown deviation name {name}"
[docs] def add(self, name: str, deviation: np.ndarray) -> None: r"""Add a model deviation into this manager. Parameters ---------- name : str The name of the deviation. The name is restricted to (DeviManager.MAX_DEVI_V, DeviManager.MIN_DEVI_V, DeviManager.AVG_DEVI_V, DeviManager.MAX_DEVI_F, DeviManager.MIN_DEVI_F, DeviManager.AVG_DEVI_F) deviation : np.ndarray The model deviation is a one-dimensional array extracted from a trajectory file. """ self._check_name(name) return self._add(name, deviation)
@abstractmethod def _add(self, name: str, deviation: np.ndarray) -> None: pass
[docs] def get(self, name: str) -> List[Optional[np.ndarray]]: r"""Gat a model deviation from this manager. Parameters ---------- name : str The name of the deviation. The name is restricted to (DeviManager.MAX_DEVI_V, DeviManager.MIN_DEVI_V, DeviManager.AVG_DEVI_V, DeviManager.MAX_DEVI_F, DeviManager.MIN_DEVI_F, DeviManager.AVG_DEVI_F) """ self._check_name(name) self._check_data() return self._get(name)
@abstractmethod def _get(self, name: str) -> List[Optional[np.ndarray]]: pass
[docs] @abstractmethod def clear(self) -> None: r"""Clear all data in this manager.""" pass
@abstractmethod def _check_data(self) -> None: r"""Check if data is valid""" pass