Source code for dpgen2.entrypoint.submit

import copy
import glob
import json
import logging
import os
import pickle
import re
from copy import (
    deepcopy,
)
from pathlib import (
    Path,
)
from typing import (
    Dict,
    List,
    Optional,
    Tuple,
    Type,
    Union,
)

import dpdata
from dflow import (
    ArgoStep,
    InputArtifact,
    InputParameter,
    Inputs,
    OutputArtifact,
    OutputParameter,
    Outputs,
    S3Artifact,
    Step,
    Steps,
    Workflow,
    argo_range,
    download_artifact,
    upload_artifact,
)
from dflow.python import (
    OP,
    OPIO,
    Artifact,
    FatalError,
    OPIOSign,
    PythonOPTemplate,
    TransientError,
    upload_packages,
)

from dpgen2.conf import (
    conf_styles,
)
from dpgen2.constants import (
    default_host,
    default_image,
)
from dpgen2.entrypoint.args import normalize as normalize_args
from dpgen2.entrypoint.common import (
    expand_idx,
    expand_sys_str,
    global_config_workflow,
)
from dpgen2.exploration.render import (
    TrajRenderLammps,
)
from dpgen2.exploration.report import (
    ExplorationReportTrustLevelsRandom,
    conv_styles,
)
from dpgen2.exploration.scheduler import (
    ConvergenceCheckStageScheduler,
    ExplorationScheduler,
)
from dpgen2.exploration.selector import (
    ConfFilters,
    ConfSelectorFrames,
    conf_filter_styles,
)
from dpgen2.exploration.task import (
    CustomizedLmpTemplateTaskGroup,
    ExplorationStage,
    ExplorationTask,
    LmpTemplateTaskGroup,
    NPTTaskGroup,
    caly_normalize,
    diffcsp_normalize,
    make_calypso_task_group_from_config,
    make_diffcsp_task_group_from_config,
    make_lmp_task_group_from_config,
    normalize_lmp_task_group_config,
)
from dpgen2.flow import (
    ConcurrentLearning,
)
from dpgen2.fp import (
    fp_styles,
)
from dpgen2.op import (
    CollectData,
    CollRunCaly,
    DiffCSPGen,
    PrepCalyDPOptim,
    PrepCalyInput,
    PrepCalyModelDevi,
    PrepDPTrain,
    PrepLmp,
    PrepRelax,
    RunCalyDPOptim,
    RunCalyModelDevi,
    RunDPTrain,
    RunLmp,
    RunLmpHDF5,
    RunRelax,
    RunRelaxHDF5,
    SelectConfs,
)
from dpgen2.op.caly_evo_step_merge import (
    CalyEvoStepMerge,
)
from dpgen2.superop import (
    ConcurrentLearningBlock,
    PrepRunCaly,
    PrepRunDiffCSP,
    PrepRunDPTrain,
    PrepRunFp,
    PrepRunLmp,
)
from dpgen2.superop.caly_evo_step import (
    CalyEvoStep,
)
from dpgen2.utils import (
    BinaryFileInput,
    bohrium_config_from_dict,
    dump_object_to_file,
    get_artifact_from_uri,
    get_subkey,
    load_object_from_file,
    matched_step_key,
    print_keys_in_nice_format,
    sort_slice_ops,
    upload_artifact_and_print_uri,
    workflow_config_from_dict,
)
from dpgen2.utils.step_config import normalize as normalize_step_dict

default_config = normalize_step_dict(
    {
        "template_config": {
            "image": default_image,
        }
    }
)


[docs] def make_concurrent_learning_op( train_style: str = "dp", explore_style: str = "lmp", fp_style: str = "vasp", prep_train_config: dict = default_config, run_train_config: dict = default_config, prep_explore_config: dict = default_config, run_explore_config: dict = default_config, prep_fp_config: dict = default_config, run_fp_config: dict = default_config, select_confs_config: dict = default_config, collect_data_config: dict = default_config, cl_step_config: dict = default_config, upload_python_packages: Optional[List[os.PathLike]] = None, valid_data: Optional[S3Artifact] = None, train_optional_files: Optional[List[str]] = None, explore_config: Optional[dict] = None, ): if train_style in ("dp", "dp-dist"): prep_run_train_op = PrepRunDPTrain( "prep-run-dp-train", PrepDPTrain, RunDPTrain, prep_config=prep_train_config, run_config=run_train_config, upload_python_packages=upload_python_packages, valid_data=valid_data, optional_files=train_optional_files, ) else: raise RuntimeError(f"unknown train_style {train_style}") if explore_style == "lmp": prep_run_explore_op = PrepRunLmp( "prep-run-lmp", PrepLmp, RunLmpHDF5 if explore_config["use_hdf5"] else RunLmp, # type: ignore prep_config=prep_explore_config, run_config=run_explore_config, upload_python_packages=upload_python_packages, ) elif "calypso" in explore_style: expl_mode = explore_style.split(":")[-1] if ":" in explore_style else "default" if expl_mode == "merge": caly_evo_step_op = CalyEvoStepMerge( name="caly-evo-step", collect_run_caly=CollRunCaly, prep_dp_optim=PrepCalyDPOptim, run_dp_optim=RunCalyDPOptim, expl_mode=expl_mode, prep_config=prep_explore_config, run_config=run_explore_config, upload_python_packages=None, ) elif expl_mode == "default": caly_evo_step_op = CalyEvoStep( name="caly-evo-step", collect_run_caly=CollRunCaly, prep_dp_optim=PrepCalyDPOptim, run_dp_optim=RunCalyDPOptim, expl_mode=expl_mode, prep_config=prep_explore_config, run_config=run_explore_config, upload_python_packages=upload_python_packages, ) else: raise KeyError( f"Unknown key: {explore_style}, support `calypso:default` and `calypso:merge`." ) prep_run_explore_op = PrepRunCaly( "prep-run-calypso", prep_caly_input_op=PrepCalyInput, caly_evo_step_op=caly_evo_step_op, prep_caly_model_devi_op=PrepCalyModelDevi, run_caly_model_devi_op=RunCalyModelDevi, expl_mode=expl_mode, prep_config=prep_explore_config, run_config=run_explore_config, upload_python_packages=upload_python_packages, ) elif explore_style == "diffcsp": prep_run_explore_op = PrepRunDiffCSP( "prep-run-diffcsp", DiffCSPGen, PrepRelax, RunRelaxHDF5 if explore_config["use_hdf5"] else RunRelax, # type: ignore prep_config=prep_explore_config, run_config=run_explore_config, upload_python_packages=upload_python_packages, ) else: raise RuntimeError(f"unknown explore_style {explore_style}") if fp_style in fp_styles.keys(): prep_run_fp_op = PrepRunFp( "prep-run-fp", fp_styles[fp_style]["prep"], fp_styles[fp_style]["run"], prep_config=prep_fp_config, run_config=run_fp_config, upload_python_packages=upload_python_packages, ) else: raise RuntimeError(f"unknown fp_style {fp_style}") # ConcurrentLearningBlock block_cl_op = ConcurrentLearningBlock( "concurrent-learning-block", prep_run_train_op, prep_run_explore_op, SelectConfs, prep_run_fp_op, CollectData, select_confs_config=select_confs_config, collect_data_config=collect_data_config, upload_python_packages=upload_python_packages, ) # dpgen dpgen_op = ConcurrentLearning( "concurrent-learning", block_cl_op, upload_python_packages=upload_python_packages, step_config=cl_step_config, ) return dpgen_op
[docs] def make_naive_exploration_scheduler( config, ): # use npt task group explore_style = config["explore"]["type"] if explore_style == "lmp": return make_lmp_naive_exploration_scheduler(config) elif "calypso" in explore_style or explore_style == "diffcsp": return make_naive_exploration_scheduler_without_conf(config, explore_style) else: raise KeyError(f"Unknown explore_style `{explore_style}`")
[docs] def get_conf_filters(config): conf_filters = None if len(config) > 0: conf_filters = ConfFilters() for c in config: c = deepcopy(c) conf_filter = conf_filter_styles[c.pop("type")](**c) conf_filters.add(conf_filter) return conf_filters
[docs] def make_naive_exploration_scheduler_without_conf(config, explore_style): model_devi_jobs = config["explore"]["stages"] fp_task_max = config["fp"]["task_max"] max_numb_iter = config["explore"]["max_numb_iter"] fatal_at_max = config["explore"]["fatal_at_max"] convergence = config["explore"]["convergence"] output_nopbc = config["explore"]["output_nopbc"] conf_filters = get_conf_filters(config["explore"]["filters"]) scheduler = ExplorationScheduler() # report conv_style = convergence.pop("type") report = conv_styles[conv_style](**convergence) # trajectory render, the format of the output trajs are assumed to be lammps/dump render = TrajRenderLammps(nopbc=output_nopbc) # selector selector = ConfSelectorFrames( render, report, fp_task_max, conf_filters, ) for job_ in model_devi_jobs: if not isinstance(job_, list): job = [job_] else: job = job_ # stage stage = ExplorationStage() for jj in job: if "calypso" in explore_style: jconf = caly_normalize(jj) # make task group tgroup = make_calypso_task_group_from_config(jconf) elif explore_style == "diffcsp": jconf = diffcsp_normalize(jj) # make task group tgroup = make_diffcsp_task_group_from_config(jconf) else: raise KeyError(f"Unknown explore_style `{explore_style}`") # add the list to task group tasks = tgroup.make_task() stage.add_task_group(tasks) # stage_scheduler stage_scheduler = ConvergenceCheckStageScheduler( stage, selector, max_numb_iter=max_numb_iter, fatal_at_max=fatal_at_max, ) # scheduler scheduler.add_stage_scheduler(stage_scheduler) return scheduler
[docs] def make_lmp_naive_exploration_scheduler(config): model_devi_jobs = config["explore"]["stages"] sys_configs = config["explore"]["configurations"] mass_map = config["inputs"]["mass_map"] type_map = config["inputs"]["type_map"] numb_models = config["train"]["numb_models"] fp_task_max = config["fp"]["task_max"] max_numb_iter = config["explore"]["max_numb_iter"] fatal_at_max = config["explore"]["fatal_at_max"] convergence = config["explore"]["convergence"] output_nopbc = config["explore"]["output_nopbc"] conf_filters = get_conf_filters(config["explore"]["filters"]) use_ele_temp = config["inputs"]["use_ele_temp"] scheduler = ExplorationScheduler() # report conv_style = convergence.pop("type") report = conv_styles[conv_style](**convergence) render = TrajRenderLammps(nopbc=output_nopbc, use_ele_temp=use_ele_temp) # selector selector = ConfSelectorFrames( render, report, fp_task_max, conf_filters, ) sys_configs_lmp = [] for sys_config in sys_configs: conf_style = sys_config.pop("type") generator = conf_styles[conf_style](**sys_config) sys_configs_lmp.append(generator.get_file_content(type_map)) for job_ in model_devi_jobs: if not isinstance(job_, list): job = [job_] else: job = job_ # stage stage = ExplorationStage() for jj in job: jconf = normalize_lmp_task_group_config(jj) n_sample = jconf.pop("n_sample") ## ignore the expansion of sys_idx # get all file names of md initial configurations sys_idx = jconf.pop("conf_idx") conf_list = [] for ii in sys_idx: conf_list += sys_configs_lmp[ii] # make task group tgroup = make_lmp_task_group_from_config(numb_models, mass_map, jconf) # add the list to task group tgroup.set_conf( conf_list, n_sample=n_sample, random_sample=True, ) tasks = tgroup.make_task() stage.add_task_group(tasks) # stage_scheduler stage_scheduler = ConvergenceCheckStageScheduler( stage, selector, max_numb_iter=max_numb_iter, fatal_at_max=fatal_at_max, ) # scheduler scheduler.add_stage_scheduler(stage_scheduler) return scheduler
[docs] def get_kspacing_kgamma_from_incar( fname, ): with open(fname) as fp: lines = fp.readlines() ks = None kg = None for ii in lines: if "KSPACING" in ii: ks = float(ii.split("=")[1]) if "KGAMMA" in ii: if "T" in ii.split("=")[1]: kg = True elif "F" in ii.split("=")[1]: kg = False else: raise RuntimeError(f"invalid kgamma value {ii.split('=')[1]}") assert ks is not None and kg is not None return ks, kg
[docs] def make_optional_parameter( mixed_type=False, finetune_mode="no", ): return {"data_mixed_type": mixed_type, "finetune_mode": finetune_mode}
[docs] def get_systems_from_data(data, data_prefix=None): data = [data] if isinstance(data, str) else data assert isinstance(data, list) if data_prefix is not None: data = [os.path.join(data_prefix, ii) for ii in data] data = sum([expand_sys_str(ii) for ii in data], []) return data
[docs] def workflow_concurrent_learning( config: Dict, ) -> Step: default_config = config["default_step_config"] train_config = config["train"]["config"] explore_config = config["explore"]["config"] train_style = config["train"]["type"] explore_style = config["explore"]["type"] fp_style = config["fp"]["type"] prep_train_config = config["step_configs"]["prep_train_config"] run_train_config = config["step_configs"]["run_train_config"] prep_explore_config = config["step_configs"]["prep_explore_config"] run_explore_config = config["step_configs"]["run_explore_config"] prep_fp_config = config["step_configs"]["prep_fp_config"] run_fp_config = config["step_configs"]["run_fp_config"] select_confs_config = config["step_configs"]["select_confs_config"] collect_data_config = config["step_configs"]["collect_data_config"] cl_step_config = config["step_configs"]["cl_step_config"] upload_python_packages = config.get("upload_python_packages", None) train_optional_files = config["train"].get("optional_files", None) if train_style == "dp": init_models_paths = config["train"].get("init_models_paths", None) numb_models = config["train"]["numb_models"] if init_models_paths is not None and len(init_models_paths) != numb_models: raise RuntimeError( f"{len(init_models_paths)} init models provided, which does " "not match numb_models={numb_models}" ) elif train_style == "dp-dist": init_models_paths = ( [config["train"]["student_model_path"]] if "student_model_path" in config["train"] else None ) config["train"]["numb_models"] = 1 else: raise RuntimeError(f"unknown params, train_style: {train_style}") if upload_python_packages is not None and isinstance(upload_python_packages, str): upload_python_packages = [upload_python_packages] if upload_python_packages is not None: _upload_python_packages: List[os.PathLike] = [ Path(ii) for ii in upload_python_packages ] upload_python_packages = _upload_python_packages multitask = config["inputs"]["multitask"] valid_data = None if multitask: if config["inputs"]["multi_valid_data_uri"] is not None: valid_data = get_artifact_from_uri(config["inputs"]["multi_valid_data_uri"]) elif config["inputs"]["multi_valid_data"] is not None: multi_valid_data = config["inputs"]["multi_valid_data"] valid_data = {} for k, v in multi_valid_data.items(): sys = v["sys"] sys = get_systems_from_data(sys, v.get("prefix", None)) valid_data[k] = sys valid_data = upload_artifact_and_print_uri(valid_data, "multi_valid_data") else: if config["inputs"]["valid_data_uri"] is not None: valid_data = get_artifact_from_uri(config["inputs"]["valid_data_uri"]) elif config["inputs"]["valid_data_prefix"] is not None: valid_data_prefix = config["inputs"]["valid_data_prefix"] valid_data = config["inputs"]["valid_data_sys"] valid_data = get_systems_from_data(valid_data, valid_data_prefix) valid_data = upload_artifact_and_print_uri(valid_data, "valid_data") concurrent_learning_op = make_concurrent_learning_op( train_style, explore_style, fp_style, prep_train_config=prep_train_config, run_train_config=run_train_config, prep_explore_config=prep_explore_config, run_explore_config=run_explore_config, prep_fp_config=prep_fp_config, run_fp_config=run_fp_config, select_confs_config=select_confs_config, collect_data_config=collect_data_config, cl_step_config=cl_step_config, upload_python_packages=upload_python_packages, valid_data=valid_data, train_optional_files=train_optional_files, explore_config=explore_config, ) scheduler = make_naive_exploration_scheduler(config) type_map = config["inputs"]["type_map"] numb_models = config["train"]["numb_models"] template_script_ = config["train"]["template_script"] if isinstance(template_script_, list): template_script = [json.loads(Path(ii).read_text()) for ii in template_script_] else: template_script = json.loads(Path(template_script_).read_text()) if ( "teacher_model_path" in explore_config and explore_config["teacher_model_path"] is not None ): assert os.path.exists( explore_config["teacher_model_path"] ), f"No such file: {explore_config['teacher_model_path']}" explore_config["teacher_model_path"] = BinaryFileInput( explore_config["teacher_model_path"] ) fp_config = {} fp_inputs_config = config["fp"]["inputs_config"] fp_inputs = fp_styles[fp_style]["inputs"](**fp_inputs_config) fp_config["inputs"] = fp_inputs fp_config["run"] = config["fp"]["run_config"] fp_config["extra_output_files"] = config["fp"]["extra_output_files"] if fp_style == "deepmd": assert ( "teacher_model_path" in fp_config["run"] ), f"Cannot find 'teacher_model_path' in config['fp']['run_config'] when fp_style == 'deepmd'" assert os.path.exists( fp_config["run"]["teacher_model_path"] ), f"No such file: {fp_config['run']['teacher_model_path']}" fp_config["run"]["teacher_model_path"] = BinaryFileInput( fp_config["run"]["teacher_model_path"] ) multitask = config["inputs"]["multitask"] if multitask: head = config["inputs"]["head"] if config["inputs"]["multi_init_data_uri"] is not None: init_data = get_artifact_from_uri(config["inputs"]["multi_init_data_uri"]) else: multi_init_data = config["inputs"]["multi_init_data"] init_data = {} for k, v in multi_init_data.items(): sys = v["sys"] sys = get_systems_from_data(sys, v.get("prefix", None)) init_data[k] = sys init_data = upload_artifact_and_print_uri(init_data, "multi_init_data") train_config["multitask"] = True train_config["head"] = head explore_config["model_frozen_head"] = head else: if config["inputs"]["init_data_uri"] is not None: init_data = get_artifact_from_uri(config["inputs"]["init_data_uri"]) else: init_data_prefix = config["inputs"]["init_data_prefix"] init_data = config["inputs"]["init_data_sys"] init_data = get_systems_from_data(init_data, init_data_prefix) init_data = upload_artifact_and_print_uri(init_data, "init_data") iter_data = upload_artifact([]) if train_style == "dp" and config["train"]["init_models_uri"] is not None: init_models = get_artifact_from_uri(config["train"]["init_models_uri"]) elif train_style == "dp-dist" and config["train"]["student_model_uri"] is not None: init_models = get_artifact_from_uri(config["train"]["student_model_uri"]) elif init_models_paths is not None: init_models = upload_artifact_and_print_uri(init_models_paths, "init_models") else: init_models = None if config["inputs"]["use_ele_temp"]: explore_config["use_ele_temp"] = config["inputs"]["use_ele_temp"] optional_parameter = make_optional_parameter( config["inputs"]["mixed_type"], ) if config["inputs"].get("do_finetune", False): if train_config["init_model_policy"] != "yes": logging.warning("In finetune mode, init_model_policy is forced to be 'yes'") train_config["init_model_policy"] = "yes" optional_parameter = make_optional_parameter( config["inputs"]["mixed_type"], finetune_mode="finetune", ) # here the scheduler is passed as input parameter to the concurrent_learning_op dpgen_step = Step( "dpgen-step", template=concurrent_learning_op, parameters={ "type_map": type_map, "numb_models": numb_models, "template_script": template_script, "train_config": train_config, "explore_config": explore_config, "fp_config": fp_config, "exploration_scheduler": scheduler, "optional_parameter": optional_parameter, }, artifacts={ "init_models": init_models, "init_data": init_data, "iter_data": iter_data, }, ) return dpgen_step
[docs] def get_scheduler_ids( reuse_step, ): scheduler_ids = [] for idx, ii in enumerate(reuse_step): if get_subkey(ii.key, 1) == "scheduler": scheduler_ids.append(idx) scheduler_keys = [reuse_step[ii].key for ii in scheduler_ids] assert ( sorted(scheduler_keys) == scheduler_keys ), "The scheduler keys are not properly sorted" if len(scheduler_ids) == 0: logging.warning( "No scheduler found in the workflow, " "does not do any replacement." ) return scheduler_ids
[docs] def update_reuse_step_scheduler( reuse_step, scheduler_new, ): scheduler_ids = get_scheduler_ids(reuse_step) if len(scheduler_ids) == 0: return reuse_step # do replacement reuse_step[scheduler_ids[-1]].modify_output_parameter( "exploration_scheduler", scheduler_new ) return reuse_step
[docs] def copy_scheduler_plans( scheduler_new, scheduler_old, ): if len(scheduler_old.stage_schedulers) == 0: return scheduler_new if len(scheduler_new.stage_schedulers) < len(scheduler_old.stage_schedulers): raise RuntimeError( "The new scheduler has less stages than the old scheduler, " "scheduler copy is not supported." ) # the scheduler_old is planned. minic the init call of the scheduler if scheduler_old.get_iteration() > -1: scheduler_new.plan_next_iteration() for ii in range(len(scheduler_old.stage_schedulers)): old_stage = scheduler_old.stage_schedulers[ii] old_reports = old_stage.get_reports() if old_stage.next_iteration() > 0: if ii != scheduler_new.get_stage(): raise RuntimeError( f"The stage {scheduler_new.get_stage()} of the new " f"scheduler does not match" f"the stage {ii} of the old scheduler. " f"scheduler, which should not happen" ) for report in old_reports: scheduler_new.plan_next_iteration(report) if old_stage.complete() and ( not scheduler_new.stage_schedulers[ii].complete() ): scheduler_new.force_stage_complete() else: break return scheduler_new
[docs] def submit_concurrent_learning( wf_config, reuse_step: Optional[List[ArgoStep]] = None, replace_scheduler: bool = False, no_submission: bool = False, ): # normalize args wf_config = normalize_args(wf_config) global_config_workflow(wf_config) dpgen_step = workflow_concurrent_learning(wf_config) if reuse_step is not None and replace_scheduler: scheduler_new = copy.deepcopy( dpgen_step.inputs.parameters["exploration_scheduler"].value ) idx_old = get_scheduler_ids(reuse_step)[-1] scheduler_old = ( reuse_step[idx_old].inputs.parameters["exploration_scheduler"].value ) scheduler_new = copy_scheduler_plans(scheduler_new, scheduler_old) exploration_report = ( reuse_step[idx_old].inputs.parameters["exploration_report"].value ) # plan next # hack! trajs is set to None... conv, expl_task_grp, selector = scheduler_new.plan_next_iteration( exploration_report, trajs=None ) # update output of the scheduler step reuse_step[idx_old].modify_output_parameter( "converged", conv, ) reuse_step[idx_old].modify_output_parameter( "exploration_scheduler", scheduler_new, ) reuse_step[idx_old].modify_output_parameter( "expl_task_grp", expl_task_grp, ) reuse_step[idx_old].modify_output_parameter( "conf_selector", selector, ) wf = Workflow(name=wf_config["name"], parallelism=wf_config["parallelism"]) wf.add(dpgen_step) # for debug purpose, we may not really submit the wf if not no_submission: wf.submit(reuse_step=reuse_step) return wf
[docs] def get_resubmit_keys( wf, ): wf_info = wf.query() all_steps = [ step for step in wf_info.get_step(sort_by_generation=True) if step.key is not None ] super_keys = ["prep-run-train", "prep-run-explore", "prep-run-fp"] other_keys = [ "select-confs", "collect-data", "scheduler", "id", ] folded_keys = {} for step in all_steps: if len(matched_step_key([step.key], super_keys)) > 0: sub_steps = wf_info.get_step(parent_id=step.id, sort_by_generation=True) sub_keys = [ step.key for step in sub_steps if step.key is not None and step.phase == "Succeeded" ] sub_keys = sort_slice_ops( sub_keys, ["run-train", "run-lmp", "run-fp", "diffcsp-gen", "run-relax"], ) if step.phase == "Succeeded": folded_keys[step.key] = sub_keys else: for key in sub_keys: folded_keys[key] = [key] elif len(matched_step_key([step.key], other_keys)) > 0: folded_keys[step.key] = [step.key] return folded_keys
[docs] def resubmit_concurrent_learning( wf_config, wfid, list_steps=False, reuse=None, replace_scheduler=False, fold=False, ): wf_config = normalize_args(wf_config) global_config_workflow(wf_config) old_wf = Workflow(id=wfid) folded_keys = get_resubmit_keys(old_wf) all_step_keys = [] super_keys = {} for super_key, keys in folded_keys.items(): all_step_keys += keys for key in keys: super_keys[key] = super_key if list_steps: prt_str = print_keys_in_nice_format( all_step_keys, ["run-train", "run-lmp", "run-fp", "diffcsp-gen", "run-relax"], ) print(prt_str) if reuse is None: return None reuse_idx = expand_idx(reuse) reused_keys = [all_step_keys[ii] for ii in reuse_idx] if fold: reused_folded_keys = {} for key in reused_keys: super_key = super_keys[key] if super_key not in reused_folded_keys: reused_folded_keys[super_key] = [] reused_folded_keys[super_key].append(key) for k, v in reused_folded_keys.items(): # reuse the super OP iif all steps within it are reused if set(v) == set(folded_keys[k]): reused_folded_keys[k] = [k] reused_keys = sum(reused_folded_keys.values(), []) reuse_step = old_wf.query_step(key=reused_keys, sort_by_generation=True) wf = submit_concurrent_learning( wf_config, reuse_step=reuse_step, replace_scheduler=replace_scheduler, ) return wf