Source code for dpgen2.exploration.task.conf_sampling_task_group

import itertools, random
from typing import (
    List,
    Optional,
)
from . import (
    ExplorationTask,
    ExplorationTaskGroup,
)
from dpgen2.constants import (
    lmp_conf_name, 
    lmp_input_name,
    model_name_pattern,
)

[docs]class ConfSamplingTaskGroup(ExplorationTaskGroup): def __init__( self, ): super().__init__() self.conf_set = False
[docs] def set_conf( self, conf_list : List[str], n_sample : Optional[int] = None, random_sample : bool = False, ): """ Set the configurations of exploration Parameters ---------- conf_list str A list of file contents n_sample int Number of samples drawn from the conf list each time `make_task` is called. If set to `None`, `n_sample` is set to length of the conf_list. random_sample bool If true the confs are randomly sampled, otherwise are consecutively sampled from the conf_list """ self.conf_list = conf_list if n_sample is None: self.n_sample = len(self.conf_list) else: self.n_sample = n_sample self.random_sample = random_sample self.conf_queue = [] self.conf_set = True
def _sample_confs( self, ): confs = [] for ii in range(self.n_sample): if len(self.conf_queue) == 0: add_list = self.conf_list.copy() if self.random_sample: random.shuffle(add_list) self.conf_queue += add_list confs.append(self.conf_queue.pop(0)) return confs