import itertools, random
from typing import (
List,
Optional,
)
from pathlib import Path
from . import (
ExplorationTask,
ExplorationTaskGroup,
ConfSamplingTaskGroup,
)
from .lmp import make_lmp_input
from dpgen2.constants import (
lmp_conf_name,
lmp_traj_name,
lmp_input_name,
plm_input_name,
plm_output_name,
model_name_pattern,
)
[docs]class LmpTemplateTaskGroup(ConfSamplingTaskGroup):
def __init__(
self,
):
super().__init__()
self.lmp_set = False
self.plm_set = False
[docs] def set_lmp(
self,
numb_models : int,
lmp_template_fname : str,
plm_template_fname : Optional[str] = None,
revisions : dict = {},
traj_freq : int = 10,
)->None:
self.lmp_template = Path(lmp_template_fname).read_text().split('\n')
self.revisions = revisions
self.traj_freq = traj_freq
self.lmp_set = True
self.model_list = sorted([model_name_pattern%ii for ii in range(numb_models)])
self.lmp_template = revise_lmp_input_model(
self.lmp_template, self.model_list, self.traj_freq)
self.lmp_template = revise_lmp_input_dump(self.lmp_template, self.traj_freq)
if plm_template_fname is not None:
self.plm_template = Path(plm_template_fname).read_text().split('\n')
self.plm_set = True
[docs] def make_task(
self,
)->ExplorationTaskGroup:
if not self.conf_set:
raise RuntimeError('confs are not set')
if not self.lmp_set:
raise RuntimeError('Lammps template and revisions are not set')
if self.plm_set:
lmp_template = revise_lmp_input_plm(
self.lmp_template, plm_input_name, out_plm=plm_output_name,)
else:
lmp_template = self.lmp_template
# clear all existing tasks
self.clear()
confs = self._sample_confs()
templates = [lmp_template]
if self.plm_set:
templates.append(self.plm_template)
conts = self.make_cont(templates, self.revisions)
nconts = len(conts[0])
for cc, ii in itertools.product(confs, range(nconts)):
if not self.plm_set:
self.add_task(self._make_lmp_task(cc, conts[0][ii]))
else:
self.add_task(self._make_lmp_task(cc, conts[0][ii], conts[1][ii]))
return self
[docs] def make_cont(
self,
templates : list,
revisions : dict,
):
keys = revisions.keys()
prod_vv = [ revisions[kk] for kk in keys ]
ntemplate = len(templates)
ret = [[] for ii in range(ntemplate)]
for vv in itertools.product(*prod_vv):
for ii in range(ntemplate):
tt = templates[ii].copy()
ret[ii].append('\n'.join(revise_by_keys(tt, keys, vv)))
return ret
def _make_lmp_task(
self,
conf: str,
lmp_cont : str,
plm_cont : Optional[str] = None,
)->ExplorationTask:
task = ExplorationTask()
task\
.add_file(
lmp_conf_name,
conf,
)\
.add_file(
lmp_input_name,
lmp_cont,
)
if plm_cont is not None:
task.add_file(
plm_input_name,
plm_cont,
)
return task
[docs]def find_only_one_key(lmp_lines, key):
found = []
for idx in range(len(lmp_lines)):
words = lmp_lines[idx].split()
nkey = len(key)
if len(words) >= nkey and words[:nkey] == key :
found.append(idx)
if len(found) > 1:
raise RuntimeError('found %d keywords %s' % (len(found), key))
if len(found) == 0:
raise RuntimeError('failed to find keyword %s' % (key))
return found[0]
[docs]def revise_by_keys(lmp_lines, keys, values):
for kk,vv in zip(keys, values):
for ii in range(len(lmp_lines)):
lmp_lines[ii] = lmp_lines[ii].replace(kk, str(vv))
return lmp_lines