Source code for dpti.dags.dp_ti_gdi

import json
import os
import time
from datetime import datetime, timedelta
from typing import ClassVar, Dict

from airflow import DAG
from airflow.api.client.local_client import Client
from airflow.decorators import task
from airflow.exceptions import (
    AirflowFailException,
    DagRunAlreadyExists,
    DagRunNotFound,
)
from airflow.models import DagRun
from airflow.operators.python import get_current_context

# from airflow.api.client.local_client import Client
from airflow.utils.state import State
from dpdispatcher import Machine, Resources, Submission, Task

from dpti.gdi import gdi_main_loop

# default_args = {'owner': 'airflow',
#                 'start_date': datetime(2018, 1, 1)
#                 }

# BASE_DAG_NAME='dag_dpti_gdi_v8'
# MAX_LOOP_NUM = 30


[docs] class GDIDAGFactory: default_args: ClassVar[Dict[str, object]] = { "owner": "airflow", "start_date": datetime(2018, 1, 1), } dagargs: ClassVar[Dict[str, object]] = { "default_args": default_args, "schedule_interval": None, } def __init__(self, gdi_name, dag_work_base): self.gdi_name = gdi_name self.dag_loop_name = self.gdi_name + "_gdi_loop_dag" self.dag_main_name = self.gdi_name + "_gdi_main_dag" self.var_name = self.gdi_name + "_dv_dh" self.dag_work_base = dag_work_base self.main_dag = self.create_main_dag() self.loop_dag = self.create_loop_dag()
[docs] def create_main_dag(self): dag_name = self.dag_loop_name var_name = self.var_name work_base = self.dag_work_base @task() def dpti_gdi_main_prepare(**kwargs): # context = get_current_context() # dag_run = context['params'] # work_base = dag_run['work_base'] # Variable.set(self.var_name, 'run') prepare_return = True return prepare_return # @task() @task(retries=2, retry_delay=timedelta(minutes=1)) def dpti_gdi_main_loop(prepare_return, **kwargs): # context = get_current_context() # dag_run = context['params'] # work_base = dag_run['work_base'] # work_base = work_base print("debug:prepare_return", prepare_return) with open(os.path.join(work_base, "machine.json")) as f: mdata = json.load(f) with open(os.path.join(work_base, "pb.json")) as f: jdata = json.load(f) with open(os.path.join(work_base, "gdidata.json")) as f: gdidata_dict = json.load(f) output_dir = os.path.join(work_base, "new_job/") gdidata_dict["output"] = output_dir # workflow = gdi_workflow = GDIWorkflow(var_name=var_name, dag_name=dag_name) gdi_main_loop( jdata=jdata, mdata=mdata, gdidata_dict=gdidata_dict, gdidata_cli={}, workflow=gdi_workflow, ) # return True # Variable.set(self.var_name, 'run') loop_return = True return loop_return @task() def dpti_gdi_main_end(loop_return, **kwargs): # Variable.set(self.var_name, 'run') end_return = True return end_return main_dag = DAG(self.dag_main_name, **self.__class__.dagargs) with main_dag: prepare_return = dpti_gdi_main_prepare() loop_return = dpti_gdi_main_loop(prepare_return) end_return = dpti_gdi_main_end(loop_return) return main_dag
[docs] def create_loop_dag(self): @task(multiple_outputs=True) def dpti_gdi_loop_prepare(): # Variable.set(self.var_name, 'run') context = get_current_context() dag_run = context["params"] task0_dict = dag_run["task_dict_list"][0] task1_dict = dag_run["task_dict_list"][1] submission_dict = dag_run["submission_dict"] # prepare_return = True # return (task0_dict, task1_dict) return {"task0_dict": task0_dict, "task1_dict": task1_dict} # return (task0_dict, task1_dict) # @task() @task(retries=2, retry_delay=timedelta(minutes=1)) def dpti_gdi_loop_md(task_dict): context = get_current_context() dag_run = context["params"] submission_dict = dag_run["submission_dict"] print("submission_dict", submission_dict) mdata = dag_run["mdata"] print("mdata", mdata) print("debug:task_dict", task_dict) machine = Machine.load_from_dict(mdata["machine"]) resources = Resources.load_from_dict(mdata["resources"]) submission = Submission.deserialize( submission_dict=submission_dict, machine=machine ) submission.resources = resources submission.register_task(task=Task.deserialize(task_dict=task_dict)) submission.run_submission() # md_return = prepare_return return True @task() def dpti_gdi_loop_end(task0_return, task1_return): end_return = True # Variable.set(self.var_name, 'end') return end_return loop_dag = DAG(self.dag_loop_name, **self.__class__.dagargs) with loop_dag: tasks_dict = dpti_gdi_loop_prepare() task0_return = dpti_gdi_loop_md(tasks_dict["task0_dict"]) task1_return = dpti_gdi_loop_md(tasks_dict["task1_dict"]) end_return = dpti_gdi_loop_end(task0_return, task1_return) return loop_dag
[docs] class GDIWorkflow: def __init__(self, dag_name, var_name): self.dag_name = dag_name self.var_name = var_name self.run_id = None
[docs] def get_dag_run_state(self): if self.run_id is None: raise DagRunNotFound(f"dag_id {self.dag_name}; {self.run_id}") dag_runs = DagRun.find(dag_id=self.dag_name, run_id=self.run_id) return dag_runs[0].state if dag_runs else None
[docs] def wait_until_end(self): while True: dag_run_state = self.get_dag_run_state() if dag_run_state == State.SUCCESS: print(f"dag_run_state: {dag_run_state}") break elif dag_run_state == State.RUNNING: print(f"dag_run_state: {dag_run_state}") time.sleep(30) else: raise AirflowFailException( f"subdag dag_run fail dag_id:{self.dag_name}; run_id:{self.run_id};" ) return dag_run_state
[docs] def trigger_loop(self, submission, task_list, mdata): # loop_num = None c = Client(None, None) submission_dict = submission.serialize() task_dict_list = [task.serialize() for task in task_list] submission_hash = submission.submission_hash self.run_id = f"dag_run_{submission_hash}" try: c.trigger_dag( dag_id=self.dag_name, run_id=self.run_id, conf={ "submission_dict": submission_dict, "task_dict_list": task_dict_list, "mdata": mdata, }, ) except DagRunAlreadyExists: dag_run_state = self.get_dag_run_state() if dag_run_state == State.FAILED: raise AirflowFailException( f"subdag dag_run fail dag_id:{self.dag_name}; run_id:{self.run_id};" ) else: print(f"continue from old dag_run {self.run_id}") loop_return = self.wait_until_end() return loop_return