import numpy as np
import re
from typing import (
List, Optional, Any,
)
[docs]def get_subkey(
key : str,
idx : int = -1,
):
return key.split('--')[idx]
[docs]def get_iteration(
key : str,
):
return get_subkey(key, 0)
[docs]def matched_step_key(
all_keys : List[str],
step_keys : Optional[List[str]] = None,
):
"""
returns the keys in `all_keys` that matches any of the `step_keys`
"""
if step_keys is None:
return all_keys
ret = []
for kk in all_keys:
for jj in step_keys:
if re.match(f'iter-[0-9]*--{jj}-[0-9]*', kk) or\
re.match(f'iter-[0-9]*--{jj}', kk) or\
re.match(f'init--{jj}', kk):
ret.append(kk)
continue
return ret
[docs]def get_last_scheduler(
wf : Any,
keys : List[str],
):
"""
get the output Scheduler of the last successful iteration
"""
scheduler_keys = []
for ii in keys:
if get_subkey(ii) == 'scheduler':
scheduler_keys.append(ii)
if len(scheduler_keys) == 0:
return None
else:
skey = sorted(scheduler_keys)[-1]
step = wf.query_step(key=skey)[0]
return step.outputs.parameters['exploration_scheduler'].value
[docs]def get_all_schedulers(
wf : Any,
keys : List[str],
):
"""
get the output Scheduler of the all the iterations
"""
scheduler_keys = sorted(matched_step_key(keys, ['scheduler']))
if len(scheduler_keys) == 0:
return None
else:
all_schedulers = [
wf.query_step(key=skey)[0].outputs.parameters['exploration_scheduler'].value
for skey in scheduler_keys
]
return all_schedulers
[docs]def get_last_iteration(
keys : List[str],
):
"""
get the index of the last iteraction from a list of step keys.
"""
return int(sorted([get_subkey(ii,0) for ii in keys])[-1].split('-')[1])
[docs]def find_slice_ranges(
keys : List[str],
sliced_subkey : str,
):
"""
find range of sliced OPs that matches the pattern 'iter-[0-9]*--{sliced_subkey}-[0-9]*'
"""
found_range = []
tmp_range = []
status = 'not-found'
for idx,ii in enumerate(keys):
if status == 'not-found':
if re.match(f'iter-[0-9]*--{sliced_subkey}-[0-9]*', ii):
status = 'found'
tmp_range.append(idx)
elif status == 'found':
if not re.match(f'iter-[0-9]*--{sliced_subkey}-[0-9]*', ii):
status = 'not-found'
tmp_range.append(idx)
found_range.append(tmp_range)
tmp_range = []
else :
raise RuntimeError(f'unknown status {status}, terrible error')
return found_range
def _sort_slice_ops(keys, sliced_subkey):
found_range = find_slice_ranges(keys, sliced_subkey)
for ii in found_range:
keys[ii[0]:ii[1]] = sorted(keys[ii[0]:ii[1]])
return keys
[docs]def sort_slice_ops(
keys : List[str],
sliced_subkey : List[str],
):
"""
sort the keys of the sliced ops. the keys of the sliced ops contains sliced_subkey
"""
if isinstance(sliced_subkey, str) :
sliced_subkey = [sliced_subkey]
for ii in sliced_subkey:
keys = _sort_slice_ops(keys, ii)
return keys