Source code for dpgen2.utils.dflow_query

import numpy as np
import re
from typing import (
    List, Optional, Any,
)

[docs]def get_subkey( key : str, idx : int = -1, ): return key.split('--')[idx]
[docs]def get_iteration( key : str, ): return get_subkey(key, 0)
[docs]def matched_step_key( all_keys : List[str], step_keys : Optional[List[str]] = None, ): """ returns the keys in `all_keys` that matches any of the `step_keys` """ if step_keys is None: return all_keys ret = [] for kk in all_keys: for jj in step_keys: if re.match(f'iter-[0-9]*--{jj}-[0-9]*', kk) or\ re.match(f'iter-[0-9]*--{jj}', kk) or\ re.match(f'init--{jj}', kk): ret.append(kk) continue return ret
[docs]def get_last_scheduler( wf : Any, keys : List[str], ): """ get the output Scheduler of the last successful iteration """ scheduler_keys = [] for ii in keys: if get_subkey(ii) == 'scheduler': scheduler_keys.append(ii) if len(scheduler_keys) == 0: return None else: skey = sorted(scheduler_keys)[-1] step = wf.query_step(key=skey)[0] return step.outputs.parameters['exploration_scheduler'].value
[docs]def get_all_schedulers( wf : Any, keys : List[str], ): """ get the output Scheduler of the all the iterations """ scheduler_keys = sorted(matched_step_key(keys, ['scheduler'])) if len(scheduler_keys) == 0: return None else: all_schedulers = [ wf.query_step(key=skey)[0].outputs.parameters['exploration_scheduler'].value for skey in scheduler_keys ] return all_schedulers
[docs]def get_last_iteration( keys : List[str], ): """ get the index of the last iteraction from a list of step keys. """ return int(sorted([get_subkey(ii,0) for ii in keys])[-1].split('-')[1])
[docs]def find_slice_ranges( keys : List[str], sliced_subkey : str, ): """ find range of sliced OPs that matches the pattern 'iter-[0-9]*--{sliced_subkey}-[0-9]*' """ found_range = [] tmp_range = [] status = 'not-found' for idx,ii in enumerate(keys): if status == 'not-found': if re.match(f'iter-[0-9]*--{sliced_subkey}-[0-9]*', ii): status = 'found' tmp_range.append(idx) elif status == 'found': if not re.match(f'iter-[0-9]*--{sliced_subkey}-[0-9]*', ii): status = 'not-found' tmp_range.append(idx) found_range.append(tmp_range) tmp_range = [] else : raise RuntimeError(f'unknown status {status}, terrible error') return found_range
def _sort_slice_ops(keys, sliced_subkey): found_range = find_slice_ranges(keys, sliced_subkey) for ii in found_range: keys[ii[0]:ii[1]] = sorted(keys[ii[0]:ii[1]]) return keys
[docs]def sort_slice_ops( keys : List[str], sliced_subkey : List[str], ): """ sort the keys of the sliced ops. the keys of the sliced ops contains sliced_subkey """ if isinstance(sliced_subkey, str) : sliced_subkey = [sliced_subkey] for ii in sliced_subkey: keys = _sort_slice_ops(keys, ii) return keys