Source code for dpgen2.exploration.scheduler.convergence_check_stage_scheduler

from pathlib import (
    Path,
)
from typing import (
    List,
    Optional,
    Tuple,
)

from dflow.python import (
    FatalError,
)

from dpgen2.exploration.report import (
    ExplorationReport,
)
from dpgen2.exploration.selector import (
    ConfSelector,
)
from dpgen2.exploration.task import (
    BaseExplorationTaskGroup,
    ExplorationStage,
    ExplorationTaskGroup,
)

from .stage_scheduler import (
    StageScheduler,
)


[docs] class ConvergenceCheckStageScheduler(StageScheduler): def __init__( self, stage: ExplorationStage, selector: ConfSelector, max_numb_iter: Optional[int] = None, fatal_at_max: bool = True, ): self.stage = stage self.selector = selector self.max_numb_iter = max_numb_iter self.fatal_at_max = fatal_at_max self.nxt_iter = 0 self.conv = False self.reached_max_iter = False self.complete_ = False self.reports = []
[docs] def get_reports(self): return self.reports
[docs] def complete(self): return self.complete_
[docs] def force_complete(self): self.complete_ = True
[docs] def next_iteration(self): return self.nxt_iter
[docs] def converged(self): return self.conv
[docs] def reached_max_iteration(self): return self.reached_max_iter
[docs] def plan_next_iteration( self, report: Optional[ExplorationReport] = None, trajs: Optional[List[Path]] = None, ) -> Tuple[bool, Optional[BaseExplorationTaskGroup], Optional[ConfSelector]]: if self.complete(): raise FatalError("Cannot plan because the stage has completed.") if report is None: stg_complete = False self.conv = stg_complete lmp_task_grp = self.stage.make_task() ret_selector = self.selector else: stg_complete = report.converged(self.reports) self.conv = stg_complete if not stg_complete: # check if we have any candidate to improve the quality of the model if report.no_candidate(): raise FatalError( "The iteration is not converted, but we find that " "it does not selected any candidate configuration. " "This means the quality of the model would not be " "improved and the iteraction would not end. " "Please try to increase the higher trust levels. " ) # if not stg_complete, check max iter if ( self.max_numb_iter is not None and self.nxt_iter == self.max_numb_iter ): self.reached_max_iter = True if self.fatal_at_max: raise FatalError("reached maximal number of iterations") else: stg_complete = True # make lmp tasks if stg_complete: # if stg_complete, no more lmp task lmp_task_grp = None ret_selector = None else: lmp_task_grp = self.stage.make_task() ret_selector = self.selector self.reports.append(report) self.nxt_iter += 1 self.complete_ = stg_complete return stg_complete, lmp_task_grp, ret_selector