import itertools
import random
from pathlib import (
Path,
)
from typing import (
List,
Optional,
)
from dpgen2.constants import (
lmp_conf_name,
lmp_input_name,
lmp_model_devi_name,
lmp_pimd_model_devi_name,
lmp_pimd_traj_name,
lmp_traj_name,
model_name_pattern,
plm_input_name,
plm_output_name,
)
from .conf_sampling_task_group import (
ConfSamplingTaskGroup,
)
from .lmp import (
make_lmp_input,
)
from .task import (
ExplorationTask,
)
[docs]
class LmpTemplateTaskGroup(ConfSamplingTaskGroup):
def __init__(
self,
):
super().__init__()
self.lmp_set = False
self.plm_set = False
[docs]
def set_lmp(
self,
numb_models: int,
lmp_template_fname: str,
plm_template_fname: Optional[str] = None,
revisions: dict = {},
traj_freq: int = 10,
extra_pair_style_args: str = "",
pimd_bead: Optional[str] = None,
) -> None:
self.lmp_template = Path(lmp_template_fname).read_text().split("\n")
self.revisions = revisions
self.traj_freq = traj_freq
self.extra_pair_style_args = extra_pair_style_args
self.pimd_bead = pimd_bead
self.lmp_set = True
self.model_list = sorted([model_name_pattern % ii for ii in range(numb_models)])
self.lmp_template = revise_lmp_input_model(
self.lmp_template,
self.model_list,
self.traj_freq,
self.extra_pair_style_args,
self.pimd_bead,
)
self.lmp_template = revise_lmp_input_dump(
self.lmp_template, self.traj_freq, self.pimd_bead
)
if plm_template_fname is not None:
self.plm_template = Path(plm_template_fname).read_text().split("\n")
self.plm_set = True
[docs]
def make_task(
self,
) -> "LmpTemplateTaskGroup":
if not self.conf_set:
raise RuntimeError("confs are not set")
if not self.lmp_set:
raise RuntimeError("Lammps template and revisions are not set")
if self.plm_set:
lmp_template = revise_lmp_input_plm(
self.lmp_template,
plm_input_name,
out_plm=plm_output_name,
)
else:
lmp_template = self.lmp_template
# clear all existing tasks
self.clear()
confs = self._sample_confs()
templates = [lmp_template]
if self.plm_set:
templates.append(self.plm_template)
conts = self.make_cont(templates, self.revisions)
nconts = len(conts[0])
for cc, ii in itertools.product(confs, range(nconts)): # type: ignore
if not self.plm_set:
self.add_task(self._make_lmp_task(cc, conts[0][ii]))
else:
self.add_task(self._make_lmp_task(cc, conts[0][ii], conts[1][ii]))
return self
[docs]
def make_cont(
self,
templates: list,
revisions: dict,
):
keys = revisions.keys()
prod_vv = [revisions[kk] for kk in keys]
ntemplate = len(templates)
ret = [[] for ii in range(ntemplate)]
for vv in itertools.product(*prod_vv):
for ii in range(ntemplate):
tt = templates[ii].copy()
ret[ii].append("\n".join(revise_by_keys(tt, keys, vv)))
return ret
def _make_lmp_task(
self,
conf: str,
lmp_cont: str,
plm_cont: Optional[str] = None,
) -> ExplorationTask:
task = ExplorationTask()
task.add_file(
lmp_conf_name,
conf,
).add_file(
lmp_input_name,
lmp_cont,
)
if plm_cont is not None:
task.add_file(
plm_input_name,
plm_cont,
)
return task
[docs]
def find_only_one_key(lmp_lines, key):
found = []
for idx in range(len(lmp_lines)):
words = lmp_lines[idx].split()
nkey = len(key)
if len(words) >= nkey and words[:nkey] == key:
found.append(idx)
if len(found) > 1:
raise RuntimeError("found %d keywords %s" % (len(found), key))
if len(found) == 0:
raise RuntimeError("failed to find keyword %s" % (key))
return found[0]
[docs]
def revise_by_keys(lmp_lines, keys, values):
for kk, vv in zip(keys, values): # type: ignore
for ii in range(len(lmp_lines)):
lmp_lines[ii] = lmp_lines[ii].replace(kk, str(vv))
return lmp_lines