Source code for dpgen2.entrypoint.watch

import logging
import time
from typing import (
    Dict,
    List,
    Optional,
    Union,
)

from dflow import (
    Workflow,
)

from dpgen2.entrypoint.args import normalize as normalize_args
from dpgen2.entrypoint.common import (
    global_config_workflow,
)
from dpgen2.utils.dflow_query import (
    matched_step_key,
)
from dpgen2.utils.download_dpgen2_artifacts import (
    download_dpgen2_artifacts,
)

default_watching_keys = [
    "prep-run-train",
    "prep-run-explore",
    "prep-run-fp",
    "collect-data",
]


[docs] def update_finished_steps( wf, finished_keys: Optional[List[str]] = None, download: Optional[bool] = False, watching_keys: Optional[List[str]] = None, prefix: Optional[str] = None, chk_pnt: bool = False, ): wf_keys = wf.query_keys_of_steps() wf_keys = matched_step_key(wf_keys, watching_keys) if finished_keys is not None: diff_keys = [] for kk in wf_keys: if not (kk in finished_keys): diff_keys.append(kk) else: diff_keys = wf_keys for kk in diff_keys: logging.info(f'steps {kk.ljust(50,"-")} finished') if download: download_dpgen2_artifacts(wf, kk, prefix=prefix, chk_pnt=chk_pnt) logging.info(f'steps {kk.ljust(50,"-")} downloaded') finished_keys = wf_keys return finished_keys
[docs] def watch( workflow_id, wf_config: Optional[Dict] = {}, watching_keys: Optional[List] = default_watching_keys, frequency: float = 600.0, download: bool = False, prefix: Optional[str] = None, chk_pnt: bool = False, ): wf_config = normalize_args(wf_config) global_config_workflow(wf_config) wf = Workflow(id=workflow_id) finished_keys = None while wf.query_status() in ["Pending", "Running", "Failed", "Error"]: finished_keys = update_finished_steps( wf, finished_keys, download=download, watching_keys=watching_keys, prefix=prefix, chk_pnt=chk_pnt, ) if wf.query_status() in ["Failed", "Error"]: break time.sleep(frequency) status = wf.query_status() if status == "Succeeded": finished_keys = update_finished_steps( wf, finished_keys, download=download, watching_keys=watching_keys, prefix=prefix, chk_pnt=chk_pnt, ) logging.info("well done") elif status in ["Failed", "Error"]: logging.error("failed or error workflow")