Source code for dpgen2.utils.step_config

import os

import dargs
from dargs import (
    Argument,
    Variant,
)
from dflow.config import (
    config,
)
from dflow.plugins.dispatcher import (
    DispatcherExecutor,
)

from dpgen2.constants import (
    default_image,
)


[docs] def dispatcher_args(): """free style dispatcher args""" return []
[docs] def variant_executor(): doc = f"The type of the executor." return Variant( "type", [ Argument("dispatcher", dict, dispatcher_args()), ], doc=doc, )
[docs] def template_conf_args(): doc_image = "The image to run the step." doc_timeout = "The time limit of the OP. Unit is second." doc_retry_on_transient_error = ( "The number of retry times if a TransientError is raised." ) doc_timeout_as_transient_error = "Treat the timeout as TransientError." doc_envs = "The environmental variables." return [ Argument("image", str, optional=True, default=default_image, doc=doc_image), Argument("timeout", int, optional=True, default=None, doc=doc_timeout), Argument( "retry_on_transient_error", int, optional=True, default=None, doc=doc_retry_on_transient_error, ), Argument( "timeout_as_transient_error", bool, optional=True, default=False, doc=doc_timeout_as_transient_error, ), Argument("envs", dict, optional=True, default=None, doc=doc_envs), ]
[docs] def template_slice_conf_args(): doc_group_size = "The number of tasks running on a single node. It is efficient for a large number of short tasks." doc_pool_size = "The number of tasks running at the same time on one node." return [ Argument("group_size", int, optional=True, default=None, doc=doc_group_size), Argument("pool_size", int, optional=True, default=None, doc=doc_pool_size), ]
[docs] def step_conf_args(): doc_template = "The configs passed to the PythonOPTemplate." doc_template_slice = "The configs passed to the Slices." doc_executor = "The executor of the step." doc_continue_on_failed = "If continue the the step is failed (FatalError, TransientError, A certain number of retrial is reached...)." doc_continue_on_num_success = "Only in the sliced OP case. Continue the workflow if a certain number of the sliced jobs are successful." doc_continue_on_success_ratio = "Only in the sliced OP case. Continue the workflow if a certain ratio of the sliced jobs are successful." doc_parallelism = "The parallelism for the step" return [ Argument( "template_config", dict, template_conf_args(), optional=True, default={"image": default_image}, doc=doc_template, ), Argument( "template_slice_config", dict, template_slice_conf_args(), optional=True, doc=doc_template_slice, ), Argument( "continue_on_failed", bool, optional=True, default=False, doc=doc_continue_on_failed, ), Argument( "continue_on_num_success", int, optional=True, default=None, doc=doc_continue_on_num_success, ), Argument( "continue_on_success_ratio", float, optional=True, default=None, doc=doc_continue_on_success_ratio, ), Argument("parallelism", int, optional=True, default=None, doc=doc_parallelism), Argument( "executor", dict, [], [variant_executor()], optional=True, default=None, doc=doc_executor, ), ]
[docs] def normalize(data): sca = step_conf_args() base = Argument("base", dict, sca) data = base.normalize_value(data, trim_pattern="_*") # not possible to strictly check dispatcher arguments, dirty hack! base.check_value(data, strict=False) return data
[docs] def gen_doc(*, make_anchor=True, make_link=True, **kwargs): if make_link: make_anchor = True sca = step_conf_args() base = Argument("step_config", dict, sca) ptr = [] ptr.append(base.gen_doc(make_anchor=make_anchor, make_link=make_link, **kwargs)) key_words = [] for ii in "\n\n".join(ptr).split("\n"): if "argument path" in ii: key_words.append(ii.split(":")[1].replace("`", "").strip()) return "\n\n".join(ptr)
[docs] def init_executor( executor_dict, ): if executor_dict is None: return None etype = executor_dict.pop("type") if etype == "dispatcher": return DispatcherExecutor(**executor_dict) else: raise RuntimeError("unknown executor type", etype)