Source code for dpgen2.exploration.deviation.deviation_std
from collections import (
defaultdict,
)
from typing import (
Dict,
List,
Optional,
)
import numpy as np
from .deviation_manager import (
DeviManager,
)
[docs]
class DeviManagerStd(DeviManager):
r"""The class which is responsible for model deviation management.
This is the standard implementation of DeviManager. Each deviation
(e.g. max_devi_f, max_devi_v in file `model_devi.out`) is stored
as a List[Optional[np.ndarray]], where np.array is a one-dimensional
array.
A List[np.ndarray][ii][jj] is the force model deviation of the jj-th
frame of the ii-th trajectory.
The model deviation can be List[None], where len(List[None]) is
the number of trajectory files.
"""
def __init__(self):
super().__init__()
self._data = defaultdict(list)
def _add(self, name: str, deviation: np.ndarray) -> None:
assert isinstance(
deviation, np.ndarray
), f"Error: deviation(type: {type(deviation)}) is not a np.ndarray"
assert len(deviation.shape) == 1, (
f"Error: deviation(shape: {deviation.shape}) is not a "
+ f"one-dimensional array"
)
self._data[name].append(deviation)
self.ntraj = max(self.ntraj, len(self._data[name]))
def _get(self, name: str) -> List[Optional[np.ndarray]]:
if self.ntraj == 0:
return []
elif len(self._data[name]) == 0:
return [None for _ in range(self.ntraj)]
else:
return self._data[name]
[docs]
def clear(self) -> None:
self.__init__()
return None
def _check_data(self) -> None:
r"""Check if data is valid"""
model_devi_names = (
DeviManager.MAX_DEVI_V,
DeviManager.MIN_DEVI_V,
DeviManager.AVG_DEVI_V,
DeviManager.MAX_DEVI_F,
DeviManager.MIN_DEVI_F,
DeviManager.AVG_DEVI_F,
)
# check the length of model deviations
frames = {}
for name in model_devi_names:
if len(self._data[name]) > 0:
assert len(self._data[name]) == self.ntraj, (
f"Error: the number of model deviation {name} "
+ f"({len(self._data[name])}) and trajectory files ({self.ntraj}) "
+ f"are not equal."
)
for idx, ndarray in enumerate(self._data[name]):
assert isinstance(ndarray, np.ndarray), (
f"Error: model deviation in {name} is not ndarray, "
+ f"index: {idx}, type: {type(ndarray)}"
)
frames[name] = [arr.shape[0] for arr in self._data[name]]
if len(frames[name]) == 0:
frames.pop(name)
# check if "max_devi_f" exists
assert (
len(self._data[DeviManager.MAX_DEVI_F]) == self.ntraj
), f"Error: cannot find model deviation {DeviManager.MAX_DEVI_F}"
# check if the length of the arrays corresponding to the same
# trajectory has the same number of frames
non_empty_deviations = list(frames.keys())
for name in non_empty_deviations[1:]:
assert frames[name] == frames[non_empty_deviations[0]], (
f"Error: the number of frames in {name} is different "
+ f"with that in {non_empty_deviations[0]}.\n"
+ f"{name}: {frames[name]}\n"
+ f"{non_empty_deviations[0]}: {frames[non_empty_deviations[0]]}\n"
)