Source code for dpgen2.exploration.selector.conf_filter
from __future__ import (
annotations,
)
from abc import (
ABC,
abstractmethod,
)
from typing import (
List,
)
import dpdata
import numpy as np
[docs]
class ConfFilter(ABC):
[docs]
@abstractmethod
def check(
self,
frame: dpdata.System,
) -> bool:
"""Check if the configuration is valid.
Parameters
----------
frame : dpdata.System
A dpdata.System containing a single frame
Returns
-------
valid : bool
`True` if the configuration is a valid configuration, else `False`.
"""
pass
[docs]
def batched_check(
self,
frames: List[dpdata.System],
) -> List[bool]:
"""Check if a list of configurations are valid.
Parameters
----------
frames : List[dpdata.System]
A list of dpdata.System each containing a single frame
Returns
-------
valid : List[bool]
`True` if the configuration is a valid configuration, else `False`.
"""
return list(map(self.check, frames))
[docs]
class ConfFilters:
def __init__(
self,
):
self._filters = []
[docs]
def add(
self,
conf_filter: ConfFilter,
) -> ConfFilters:
self._filters.append(conf_filter)
return self
[docs]
def check(
self,
ms: dpdata.MultiSystems,
) -> dpdata.MultiSystems:
selected_idx = []
for i in range(len(ms)):
for j in range(ms[i].get_nframes()):
selected_idx.append((i, j))
for ff in self._filters:
res = ff.batched_check([ms[i][j] for i, j in selected_idx])
selected_idx = [idx for i, idx in enumerate(selected_idx) if res[i]]
selected_idx_list = [[] for _ in range(len(ms))]
for i, j in selected_idx:
selected_idx_list[i].append(j)
ms2 = dpdata.MultiSystems(type_map=ms.atom_names)
for i in range(len(ms)):
if len(selected_idx_list[i]) > 0:
ms2.append(ms[i].sub_system(selected_idx_list[i]))
return ms2