Source code for dpgen2.utils.dflow_query

import logging
import re
from typing import (
    Any,
    List,
    Optional,
)

import numpy as np


[docs] def get_subkey( key: str, idx: int = -1, ): return key.split("--")[idx]
[docs] def get_iteration( key: str, ): return get_subkey(key, 0)
[docs] def matched_step_key( all_keys: List[str], step_keys: Optional[List[str]] = None, ): """ returns the keys in `all_keys` that matches any of the `step_keys` """ if step_keys is None: return all_keys ret = [] for kk in all_keys: for jj in step_keys: if ( re.match(f"iter-[0-9]*--{jj}-[0-9]*", kk) or re.match(f"iter-[0-9]*--{jj}", kk) or re.match(f"finetune--{jj}-[0-9]*", kk) or re.match(f"finetune--{jj}", kk) or re.match(f"init--{jj}", kk) ): ret.append(kk) continue return ret
[docs] def get_last_scheduler( wf: Any, keys: List[str], ): """ get the output Scheduler of the last successful iteration """ outputs = wf.query_global_outputs() if ( outputs is not None and hasattr(outputs, "parameters") and "exploration_scheduler" in outputs.parameters and hasattr(outputs.parameters["exploration_scheduler"], "value") ): return outputs.parameters["exploration_scheduler"].value logging.warn("Exploration scheduler not found in the global outputs") scheduler_keys_ = [] for ii in keys: if get_subkey(ii) == "scheduler": scheduler_keys_.append(ii) scheduler_steps = wf.query_step_by_key(scheduler_keys_) scheduler_keys = [] for step in scheduler_steps: if step["phase"] == "Succeeded": scheduler_keys.append(step.key) if len(scheduler_keys) == 0: return None else: skey = sorted(scheduler_keys)[-1] step = [step for step in scheduler_steps if step.key == skey][0] return step.outputs.parameters["exploration_scheduler"].value
[docs] def get_all_schedulers( wf: Any, keys: List[str], ): """ get the output Scheduler of the all the iterations """ scheduler_keys = sorted(matched_step_key(keys, ["scheduler"])) if len(scheduler_keys) == 0: return None else: all_schedulers = [ wf.query_step(key=skey)[0].outputs.parameters["exploration_scheduler"].value for skey in scheduler_keys ] return all_schedulers
[docs] def get_last_iteration( keys: List[str], ): """ get the index of the last iteraction from a list of step keys. """ return int(sorted([get_subkey(ii, 0) for ii in keys])[-1].split("-")[1])
[docs] def find_slice_ranges( keys: List[str], sliced_subkey: str, ): """ find range of sliced OPs that matches the pattern 'iter-[0-9]*--{sliced_subkey}-[0-9]*' """ found_range = [] tmp_range = [] status = "not-found" for idx, ii in enumerate(keys): if status == "not-found": if re.match(f"iter-[0-9]*--{sliced_subkey}-[0-9]*", ii) or re.match( f"finetune--{sliced_subkey}-[0-9]*", ii ): status = "found" tmp_range.append(idx) elif status == "found": if not ( re.match(f"iter-[0-9]*--{sliced_subkey}-[0-9]*", ii) or re.match(f"finetune--{sliced_subkey}-[0-9]*", ii) ): status = "not-found" tmp_range.append(idx) found_range.append(tmp_range) tmp_range = [] else: raise RuntimeError(f"unknown status {status}, terrible error") return found_range
def _sort_slice_ops(keys, sliced_subkey): found_range = find_slice_ranges(keys, sliced_subkey) for ii in found_range: keys[ii[0] : ii[1]] = sorted(keys[ii[0] : ii[1]]) return keys
[docs] def sort_slice_ops( keys: List[str], sliced_subkey: List[str], ): """ sort the keys of the sliced ops. the keys of the sliced ops contains sliced_subkey """ if isinstance(sliced_subkey, str): sliced_subkey = [sliced_subkey] for ii in sliced_subkey: keys = _sort_slice_ops(keys, ii) return keys