import random
from abc import (
abstractmethod,
)
from typing import (
List,
Optional,
Tuple,
)
import numpy as np
from dargs import (
Argument,
)
from dflow.python import (
FatalError,
)
from ..deviation import (
DeviManager,
)
from . import (
ExplorationReport,
)
[docs]
class ExplorationReportTrustLevels(ExplorationReport):
def __init__(
self,
level_f_lo,
level_f_hi,
level_v_lo=None,
level_v_hi=None,
conv_accuracy=0.9,
):
self.level_f_lo = level_f_lo
self.level_f_hi = level_f_hi
self.level_v_lo = level_v_lo
self.level_v_hi = level_v_hi
self.conv_accuracy = conv_accuracy
self.clear()
self.v_level = (self.level_v_lo is not None) and (self.level_v_hi is not None)
self.model_devi = None
print_tuple = (
"stage",
"id_stg.",
"iter.",
"accu.",
"cand.",
"fail.",
"lvl_f_lo",
"lvl_f_hi",
)
spaces = [8, 8, 8, 10, 10, 10, 10, 10]
if self.v_level:
print_tuple += (
"v_lo",
"v_hi",
)
spaces += [10, 10]
print_tuple += ("cvged",)
spaces += [8]
self.fmt_str = " ".join([f"%{ii}s" for ii in spaces])
self.fmt_flt = "%.4f"
self.header_str = "#" + self.fmt_str % print_tuple
[docs]
@staticmethod
def args() -> List[Argument]:
doc_level_f_lo = "The lower trust level of force model deviation"
doc_level_f_hi = "The higher trust level of force model deviation"
doc_level_v_lo = "The lower trust level of virial model deviation"
doc_level_v_hi = "The higher trust level of virial model deviation"
doc_conv_accuracy = "If the ratio of accurate frames is larger than this value, the stage is converged"
return [
Argument("level_f_lo", float, optional=False, doc=doc_level_f_lo),
Argument("level_f_hi", float, optional=False, doc=doc_level_f_hi),
Argument(
"level_v_lo", float, optional=True, default=None, doc=doc_level_v_lo
),
Argument(
"level_v_hi", float, optional=True, default=None, doc=doc_level_v_hi
),
Argument(
"conv_accuracy",
float,
optional=True,
default=0.9,
doc=doc_conv_accuracy,
),
]
[docs]
def clear(
self,
):
self.traj_nframes = []
self.traj_cand = []
self.traj_accu = []
self.traj_fail = []
self.traj_cand_picked = []
self.model_devi = None
[docs]
def record(
self,
model_devi: DeviManager,
):
ntraj = model_devi.ntraj
md_f = model_devi.get(DeviManager.MAX_DEVI_F)
md_v = model_devi.get(DeviManager.MAX_DEVI_V)
for ii in range(ntraj):
id_f_cand, id_f_accu, id_f_fail = self._get_indexes(
md_f[ii], self.level_f_lo, self.level_f_hi
)
id_v_cand, id_v_accu, id_v_fail = self._get_indexes(
md_v[ii], self.level_v_lo, self.level_v_hi
)
nframes, set_accu, set_cand, set_fail = self._record_one_traj(
id_f_accu,
id_f_cand,
id_f_fail,
id_v_accu,
id_v_cand,
id_v_fail,
)
# record
self.traj_nframes.append(nframes)
self.traj_cand.append(set_cand)
self.traj_accu.append(set_accu)
self.traj_fail.append(set_fail)
assert len(self.traj_nframes) == ntraj
assert len(self.traj_cand) == ntraj
assert len(self.traj_accu) == ntraj
assert len(self.traj_fail) == ntraj
self.model_devi = model_devi
def _get_indexes(
self,
md,
level_lo,
level_hi,
):
if (md is not None) and (level_hi is not None) and (level_lo is not None):
id_cand = np.where(np.logical_and(md >= level_lo, md < level_hi))[0]
id_accu = np.where(md < level_lo)[0]
id_fail = np.where(md >= level_hi)[0]
else:
id_cand = id_accu = id_fail = None
return id_cand, id_accu, id_fail
def _record_one_traj(
self,
id_f_accu,
id_f_cand,
id_f_fail,
id_v_accu,
id_v_cand,
id_v_fail,
):
"""
Record one trajctory. inputs are the indexes of candidate, accurate and failed frames.
"""
# check consistency
novirial = id_v_cand is None
if novirial:
assert id_v_accu is None
assert id_v_fail is None
nframes = np.size(np.concatenate((id_f_cand, id_f_accu, id_f_fail)))
if (not novirial) and nframes != np.size(
np.concatenate((id_v_cand, id_v_accu, id_v_fail))
):
raise FatalError("number of frames by virial ")
# nframes
# to sets
set_f_accu = set(id_f_accu)
set_f_cand = set(id_f_cand)
set_f_fail = set(id_f_fail)
set_v_accu = set([ii for ii in range(nframes)]) if novirial else set(id_v_accu)
set_v_cand = set([]) if novirial else set(id_v_cand)
set_v_fail = set([]) if novirial else set(id_v_fail)
# accu, cand, fail
set_accu = set_f_accu & set_v_accu
set_cand = (
(set_f_cand & set_v_accu)
| (set_f_cand & set_v_cand)
| (set_f_accu & set_v_cand)
)
set_fail = set_f_fail | set_v_fail
# check size
assert nframes == len(set_accu | set_cand | set_fail)
assert 0 == len(set_accu & set_cand)
assert 0 == len(set_accu & set_fail)
assert 0 == len(set_cand & set_fail)
return nframes, set_accu, set_cand, set_fail
[docs]
@abstractmethod
def converged(
self,
reports: Optional[List[ExplorationReport]] = None,
) -> bool:
pass
[docs]
def failed_ratio(
self,
tag=None,
):
traj_nf = [len(ii) for ii in self.traj_fail]
return float(sum(traj_nf)) / float(sum(self.traj_nframes))
[docs]
def accurate_ratio(
self,
tag=None,
):
traj_nf = [len(ii) for ii in self.traj_accu]
return float(sum(traj_nf)) / float(sum(self.traj_nframes))
[docs]
def candidate_ratio(
self,
tag=None,
):
traj_nf = [len(ii) for ii in self.traj_cand]
return float(sum(traj_nf)) / float(sum(self.traj_nframes))
[docs]
@abstractmethod
def get_candidate_ids(
self,
max_nframes: Optional[int] = None,
) -> List[List[int]]:
pass
[docs]
def print(
self,
stage_idx: int,
idx_in_stage: int,
iter_idx: int,
) -> str:
r"""Print the report"""
fmt_str = self.fmt_str
fmt_flt = self.fmt_flt
print_tuple = (
str(stage_idx),
str(idx_in_stage),
str(iter_idx),
fmt_flt % (self.accurate_ratio()),
fmt_flt % (self.candidate_ratio()),
fmt_flt % (self.failed_ratio()),
fmt_flt % (self.level_f_lo),
fmt_flt % (self.level_f_hi),
)
if self.v_level:
print_tuple += (
fmt_flt % (self.level_v_lo),
fmt_flt % (self.level_v_hi),
)
print_tuple += (str(self.converged()),)
ret = " " + fmt_str % print_tuple
return ret