Source code for dpgen2.utils.dflow_query
import logging
import re
from typing import (
Any,
List,
Optional,
)
import numpy as np
[docs]
def get_subkey(
key: str,
idx: int = -1,
):
return key.split("--")[idx]
[docs]
def get_iteration(
key: str,
):
return get_subkey(key, 0)
[docs]
def matched_step_key(
all_keys: List[str],
step_keys: Optional[List[str]] = None,
):
"""
returns the keys in `all_keys` that matches any of the `step_keys`
"""
if step_keys is None:
return all_keys
ret = []
for kk in all_keys:
for jj in step_keys:
if (
re.match(f"iter-[0-9]*--{jj}-[0-9]*", kk)
or re.match(f"iter-[0-9]*--{jj}", kk)
or re.match(f"init--{jj}", kk)
):
ret.append(kk)
continue
return ret
[docs]
def get_last_scheduler(
wf: Any,
keys: List[str],
):
"""
get the output Scheduler of the last successful iteration
"""
outputs = wf.query_global_outputs()
if (
outputs is not None
and hasattr(outputs, "parameters")
and "exploration_scheduler" in outputs.parameters
and hasattr(outputs.parameters["exploration_scheduler"], "value")
):
return outputs.parameters["exploration_scheduler"].value
logging.warn("Exploration scheduler not found in the global outputs")
scheduler_keys_ = []
for ii in keys:
if get_subkey(ii) == "scheduler":
scheduler_keys_.append(ii)
scheduler_steps = wf.query_step_by_key(scheduler_keys_)
scheduler_keys = []
for step in scheduler_steps:
if step["phase"] == "Succeeded":
scheduler_keys.append(step.key)
if len(scheduler_keys) == 0:
return None
else:
skey = sorted(scheduler_keys)[-1]
step = [step for step in scheduler_steps if step.key == skey][0]
return step.outputs.parameters["exploration_scheduler"].value
[docs]
def get_all_schedulers(
wf: Any,
keys: List[str],
):
"""
get the output Scheduler of the all the iterations
"""
scheduler_keys = sorted(matched_step_key(keys, ["scheduler"]))
if len(scheduler_keys) == 0:
return None
else:
all_schedulers = [
wf.query_step(key=skey)[0].outputs.parameters["exploration_scheduler"].value
for skey in scheduler_keys
]
return all_schedulers
[docs]
def get_last_iteration(
keys: List[str],
):
"""
get the index of the last iteraction from a list of step keys.
"""
return int(sorted([get_subkey(ii, 0) for ii in keys])[-1].split("-")[1])
[docs]
def find_slice_ranges(
keys: List[str],
sliced_subkey: str,
):
"""
find range of sliced OPs that matches the pattern 'iter-[0-9]*--{sliced_subkey}-[0-9]*'
"""
found_range = []
tmp_range = []
status = "not-found"
for idx, ii in enumerate(keys):
if status == "not-found":
if re.match(f"iter-[0-9]*--{sliced_subkey}-[0-9]*", ii):
status = "found"
tmp_range.append(idx)
elif status == "found":
if not re.match(f"iter-[0-9]*--{sliced_subkey}-[0-9]*", ii):
status = "not-found"
tmp_range.append(idx)
found_range.append(tmp_range)
tmp_range = []
else:
raise RuntimeError(f"unknown status {status}, terrible error")
return found_range
def _sort_slice_ops(keys, sliced_subkey):
found_range = find_slice_ranges(keys, sliced_subkey)
for ii in found_range:
keys[ii[0] : ii[1]] = sorted(keys[ii[0] : ii[1]])
return keys
[docs]
def sort_slice_ops(
keys: List[str],
sliced_subkey: List[str],
):
"""
sort the keys of the sliced ops. the keys of the sliced ops contains sliced_subkey
"""
if isinstance(sliced_subkey, str):
sliced_subkey = [sliced_subkey]
for ii in sliced_subkey:
keys = _sort_slice_ops(keys, ii)
return keys