from pathlib import (
Path,
)
from typing import (
List,
Optional,
Tuple,
)
import numpy as np
from dflow.python import (
FatalError,
)
from dpgen2.exploration.report import (
ExplorationReport,
)
from dpgen2.exploration.selector import (
ConfSelector,
)
from dpgen2.exploration.task import (
ExplorationStage,
ExplorationTaskGroup,
)
from .stage_scheduler import (
StageScheduler,
)
[docs]
class ExplorationScheduler:
"""
The exploration scheduler.
"""
def __init__(
self,
):
self.stage_schedulers = []
self.cur_stage = 0
self.complete_ = False
[docs]
def add_stage_scheduler(
self,
stage_scheduler: StageScheduler,
):
"""
Add stage scheduler.
All added schedulers can be treated as a `list` (order matters). Only one stage is converged, the iteration goes to the next iteration.
Parameters
----------
stage_scheduler : StageScheduler
The added stage scheduler
"""
self.stage_schedulers.append(stage_scheduler)
self.complete_ = False
return self
[docs]
def get_stage(self):
"""
Get the index of current stage.
Stage index increases when the previous stage converges. Usually called after `self.plan_next_iteration`.
"""
return self.cur_stage
[docs]
def get_iteration(self):
"""
Get the index of the current iteration.
Iteration index increase when `self.plan_next_iteration` returns valid `expl_task_grp` and `conf_selector` for the next iteration.
"""
tot_iter = -1
for idx, ii in enumerate(self.stage_schedulers):
if ii.complete():
# the last plan is not used because the stage
# is found converged
tot_iter += ii.next_iteration() - 1
else:
tot_iter += ii.next_iteration()
return tot_iter
[docs]
def complete(self):
"""
Tell if all stages are converged.
"""
return self.complete_
[docs]
def force_stage_complete(self):
"""
Force complete the current stage
"""
self.stage_schedulers[self.cur_stage].force_complete()
self.cur_stage += 1
if self.cur_stage < len(self.stage_schedulers):
# goes to next stage
self.plan_next_iteration()
else:
# all stages complete
self.complete_ = True
[docs]
def plan_next_iteration(
self,
report: Optional[ExplorationReport] = None,
trajs: Optional[List[Path]] = None,
) -> Tuple[bool, Optional[ExplorationTaskGroup], Optional[ConfSelector]]:
"""
Make the plan for the next DPGEN iteration.
Parameters
----------
report : ExplorationReport
The exploration report of this iteration.
trajs : List[Path]
A list of configurations generated during the exploration. May be used to generate new configurations for the next iteration.
Returns
-------
complete: bool
If all the DPGEN stages complete.
task: ExplorationTaskGroup
A `ExplorationTaskGroup` defining the exploration of the next iteration. Should be `None` if converged.
conf_selector: ConfSelector
The configuration selector for the next iteration. Should be `None` if converged.
"""
try:
stg_complete, expl_task_grp, conf_selector = self.stage_schedulers[
self.cur_stage
].plan_next_iteration(
report,
trajs,
)
except FatalError as e:
raise FatalError(f"stage {self.cur_stage}: " + str(e))
if stg_complete:
self.cur_stage += 1
if self.cur_stage < len(self.stage_schedulers):
# goes to next stage
return self.plan_next_iteration()
else:
# all stages complete
self.complete_ = True
return (
True,
None,
None,
)
else:
return stg_complete, expl_task_grp, conf_selector
[docs]
def get_stage_of_iterations(self):
"""
Get the stage index and the index in the stage of iterations.
"""
stages = self.stage_schedulers
n_stage_iters = []
for ii in range(self.get_stage() + 1):
if ii < len(stages) and len(stages[ii].reports) > 0:
n_stage_iters.append(len(stages[ii].reports))
cumsum_stage_iters = np.cumsum(n_stage_iters)
max_iter = self.get_iteration()
if self.complete() or max_iter == -1:
max_iter += 1
stage_idx = []
idx_in_stage = []
iter_idx = []
for ii in range(max_iter):
idx = np.searchsorted(cumsum_stage_iters, ii + 1)
stage_idx.append(idx)
if idx > 0:
idx_in_stage.append(ii - cumsum_stage_iters[idx - 1])
else:
idx_in_stage.append(ii)
iter_idx.append(ii)
assert len(stage_idx) == max_iter
assert len(idx_in_stage) == max_iter
assert len(iter_idx) == max_iter
return stage_idx, idx_in_stage, iter_idx
[docs]
def get_convergence_ratio(self):
"""
Get the accurate, candidate and failed ratios of the iterations
Returns
-------
accu np.ndarray
The accurate ratio. length of array the same as # iterations.
cand np.ndarray
The candidate ratio. length of array the same as # iterations.
fail np.ndarray
The failed ration. length of array the same as # iterations.
"""
stages = self.stage_schedulers
stag_idx, idx_in_stag, iter_idx = self.get_stage_of_iterations()
accu = []
cand = []
fail = []
for ii in range(np.size(iter_idx)):
accu.append(stages[stag_idx[ii]].reports[idx_in_stag[ii]].accurate_ratio())
cand.append(stages[stag_idx[ii]].reports[idx_in_stag[ii]].candidate_ratio())
fail.append(stages[stag_idx[ii]].reports[idx_in_stag[ii]].failed_ratio())
return np.array(accu), np.array(cand), np.array(fail)
def _print_prev_summary(self, prev_stg_idx):
if prev_stg_idx >= 0:
yes = "YES" if self.stage_schedulers[prev_stg_idx].converged() else "NO "
rmx = (
"YES"
if self.stage_schedulers[prev_stg_idx].reached_max_iteration()
else "NO "
)
return f"# Stage {prev_stg_idx:4d} converged {yes} reached max numb iterations {rmx}"
else:
return None
[docs]
def print_last_iteration(self, print_header=False):
stages = self.stage_schedulers
stage_idx, idx_in_stage, iter_idx = self.get_stage_of_iterations()
if np.size(iter_idx) == 0:
return "No finished iteration found\n"
iidx = np.size(iter_idx) - 1
ret = []
if print_header:
ret.append(
stages[stage_idx[iidx]].reports[idx_in_stage[iidx]].print_header()
)
ret.append(
stages[stage_idx[iidx]]
.reports[idx_in_stage[iidx]]
.print(stage_idx[iidx], idx_in_stage[iidx], iidx)
)
if self.complete():
ret.append(f"# All stages converged")
return "\n".join(ret + [""])
[docs]
def print_convergence(self):
ret = []
stages = self.stage_schedulers
stage_idx, idx_in_stage, iter_idx = self.get_stage_of_iterations()
if np.size(iter_idx) == 0:
return "No finished iteration found\n"
prev_stg_idx = -1
for iidx in range(np.size(iter_idx)):
if len(ret) == 0:
ret.append(
stages[stage_idx[iidx]].reports[idx_in_stage[iidx]].print_header()
)
if stage_idx[iidx] != prev_stg_idx:
if prev_stg_idx >= 0:
_summary = self._print_prev_summary(prev_stg_idx)
assert _summary is not None
ret.append(_summary)
ret.append(f"# Stage {stage_idx[iidx]:4d} " + "-" * 20)
prev_stg_idx = stage_idx[iidx]
ret.append(
stages[stage_idx[iidx]]
.reports[idx_in_stage[iidx]]
.print(stage_idx[iidx], idx_in_stage[iidx], iidx)
)
if self.complete():
if prev_stg_idx >= 0:
_summary = self._print_prev_summary(prev_stg_idx)
assert _summary is not None
ret.append(_summary)
ret.append(f"# All stages converged")
return "\n".join(ret + [""])