Source code for dpgen2.exploration.selector.distance_conf_filter

import logging
from concurrent.futures import (
from copy import (
from typing import (

import dargs
import dpdata
import numpy as np
from dargs import (

from . import (

safe_dist_dict = {
    "H": 1.2255,
    "He": 0.936,
    "Li": 1.8,
    "Be": 1.56,
    "B": 1.32,
    "C": 1.32,
    "N": 1.32,
    "O": 1.32,
    "F": 1.26,
    "Ne": 1.92,
    "Na": 1.595,
    "Mg": 1.87,
    "Al": 1.87,
    "Si": 1.76,
    "P": 1.65,
    "S": 1.65,
    "Cl": 1.65,
    "Ar": 2.09,
    "K": 2.3,
    "Ca": 2.3,
    "Sc": 2.0,
    "Ti": 2.0,
    "V": 2.0,
    "Cr": 1.9,
    "Mn": 1.95,
    "Fe": 1.9,
    "Co": 1.9,
    "Ni": 1.9,
    "Cu": 1.9,
    "Zn": 1.9,
    "Ga": 2.0,
    "Ge": 2.0,
    "As": 2.0,
    "Se": 2.1,
    "Br": 2.1,
    "Kr": 2.3,
    "Rb": 2.5,
    "Sr": 2.5,
    "Y": 2.1,
    "Zr": 2.1,
    "Nb": 2.1,
    "Mo": 2.1,
    "Tc": 2.1,
    "Ru": 2.1,
    "Rh": 2.1,
    "Pd": 2.1,
    "Ag": 2.1,
    "Cd": 2.1,
    "In": 2.0,
    "Sn": 2.0,
    "Sb": 2.0,
    "Te": 2.0,
    "I": 2.0,
    "Xe": 2.0,
    "Cs": 2.5,
    "Ba": 2.8,
    "La": 2.5,
    "Ce": 2.55,
    "Pr": 2.7,
    "Nd": 2.8,
    "Pm": 2.8,
    "Sm": 2.8,
    "Eu": 2.8,
    "Gd": 2.8,
    "Tb": 2.8,
    "Dy": 2.8,
    "Ho": 2.8,
    "Er": 2.6,
    "Tm": 2.8,
    "Yb": 2.8,
    "Lu": 2.8,
    "Hf": 2.4,
    "Ta": 2.5,
    "W": 2.3,
    "Re": 2.3,
    "Os": 2.3,
    "Ir": 2.3,
    "Pt": 2.3,
    "Au": 2.3,
    "Hg": 2.3,
    "Tl": 2.3,
    "Pb": 2.3,
    "Bi": 2.3,
    "Po": 2.3,
    "At": 2.3,
    "Rn": 2.3,
    "Fr": 2.9,
    "Ra": 2.9,
    "Ac": 2.9,
    "Th": 2.8,
    "Pa": 2.8,
    "U": 2.8,
    "Np": 2.8,
    "Pu": 2.8,
    "Am": 2.8,
    "Cm": 2.8,
    "Cf": 2.3,

[docs] def check_multiples(a, b, c, multiple): values = [a, b, c] for i in range(len(values)): for j in range(len(values)): if i != j: if values[i] > multiple * values[j]: logging.warning( f"Value {values[i]} is {multiple} times greater than {values[j]}" ) return True return False
[docs] class DistanceConfFilter(ConfFilter): def __init__(self, max_workers=None, custom_safe_dist=None, safe_dist_ratio=1.0): self.max_workers = max_workers self.custom_safe_dist = custom_safe_dist if custom_safe_dist is not None else {} self.safe_dist_ratio = safe_dist_ratio
[docs] def check( self, frame: dpdata.System, ): from ase import ( Atoms, ) from import ( make_supercell, ) safe_dist = deepcopy(safe_dist_dict) safe_dist.update(self.custom_safe_dist) for k in safe_dist: # bohr -> ang and multiply by a relaxation ratio safe_dist[k] *= 0.529 / 1.2 * self.safe_dist_ratio atom_names = list(safe_dist) structure = Atoms( positions=frame["coords"][0], numbers=[ atom_names.index(frame["atom_names"][t]) + 1 for t in frame["atom_types"] ], cell=frame["cells"][0], pbc=(not frame.nopbc), ) P = [[2, 0, 0], [0, 2, 0], [0, 0, 2]] extended_structure = make_supercell(structure, P) coords = extended_structure.positions symbols = extended_structure.get_chemical_symbols() num_atoms = len(coords) for i in range(num_atoms): for j in range(i + 1, num_atoms): dist = extended_structure.get_distance(i, j, mic=True) type_i = symbols[i] type_j = symbols[j] dr = safe_dist[type_i] + safe_dist[type_j] if dist < dr: logging.warning( f"Dangerous close for {type_i} - {type_j}, {dist:.5f} less than {dr:.5f}" ) return False return True
[docs] def batched_check( self, frames: List[dpdata.System], ): if self.max_workers == 1: return list(map(self.check, frames)) else: with ProcessPoolExecutor(self.max_workers) as executor: return list(, frames))
[docs] @staticmethod def args() -> List[dargs.Argument]: r"""The argument definition of the `ConfFilter`. Returns ------- arguments: List[dargs.Argument] List of dargs.Argument defines the arguments of the `ConfFilter`. """ doc_max_workers = ( "The maximum number of processes used to filter configurations, " + "None represents as many as the processors of the machine, and 1 for serial" ) doc_custom_safe_dist = "Custom safe distance (in unit of bohr) for each element" doc_safe_dist_ratio = "The ratio multiplied to the safe distance" return [ Argument( "max_workers", int, optional=True, default=None, doc=doc_max_workers, ), Argument( "custom_safe_dist", dict, optional=True, default={}, doc=doc_custom_safe_dist, ), Argument( "safe_dist_ratio", float, optional=True, default=1.0, doc=doc_safe_dist_ratio, ), ]
[docs] class BoxSkewnessConfFilter(ConfFilter): def __init__(self, max_workers=None, theta=60.0): self.max_workers = max_workers self.theta = theta
[docs] def check( self, frame: dpdata.System, ): from ase import ( Atoms, ) atom_names = list(safe_dist_dict) structure = Atoms( positions=frame["coords"][0], numbers=[ atom_names.index(frame["atom_names"][t]) + 1 for t in frame["atom_types"] ], cell=frame["cells"][0], pbc=(not frame.nopbc), ) cell, _ = structure.get_cell().standard_form() if ( cell[1][0] > np.tan(self.theta / 180.0 * np.pi) * cell[1][1] # type: ignore or cell[2][0] > np.tan(self.theta / 180.0 * np.pi) * cell[2][2] # type: ignore or cell[2][1] > np.tan(self.theta / 180.0 * np.pi) * cell[2][2] # type: ignore ): logging.warning("Inclined box") return False return True
[docs] def batched_check( self, frames: List[dpdata.System], ): if self.max_workers == 1: return list(map(self.check, frames)) else: with ProcessPoolExecutor(self.max_workers) as executor: return list(, frames))
[docs] @staticmethod def args() -> List[dargs.Argument]: r"""The argument definition of the `ConfFilter`. Returns ------- arguments: List[dargs.Argument] List of dargs.Argument defines the arguments of the `ConfFilter`. """ doc_max_workers = ( "The maximum number of processes used to filter configurations, " + "None represents as many as the processors of the machine, and 1 for serial" ) doc_theta = "The threshold for angles between the edges of the cell. If all angles are larger than this value the check is passed" return [ Argument( "max_workers", int, optional=True, default=None, doc=doc_max_workers, ), Argument( "theta", float, optional=True, default=60.0, doc=doc_theta, ), ]
[docs] class BoxLengthFilter(ConfFilter): def __init__(self, max_workers=None, length_ratio=5.0): self.max_workers = max_workers self.length_ratio = length_ratio
[docs] def check( self, frame: dpdata.System, ): from ase import ( Atoms, ) atom_names = list(safe_dist_dict) structure = Atoms( positions=frame["coords"][0], numbers=[ atom_names.index(frame["atom_names"][t]) + 1 for t in frame["atom_types"] ], cell=frame["cells"][0], pbc=(not frame.nopbc), ) cell, _ = structure.get_cell().standard_form() a = cell[0][0] # type: ignore b = cell[1][1] # type: ignore c = cell[2][2] # type: ignore if check_multiples(a, b, c, self.length_ratio): logging.warning("One side is %s larger than another" % self.length_ratio) return False return True
[docs] def batched_check( self, frames: List[dpdata.System], ): if self.max_workers == 1: return list(map(self.check, frames)) else: with ProcessPoolExecutor(self.max_workers) as executor: return list(, frames))
[docs] @staticmethod def args() -> List[dargs.Argument]: r"""The argument definition of the `ConfFilter`. Returns ------- arguments: List[dargs.Argument] List of dargs.Argument defines the arguments of the `ConfFilter`. """ doc_max_workers = ( "The maximum number of processes used to filter configurations, " + "None represents as many as the processors of the machine, and 1 for serial" ) doc_length_ratio = "The threshold for the length ratio between the edges of the cell. If all length ratios are smaller than this value the check is passed" return [ Argument( "max_workers", int, optional=True, default=None, doc=doc_max_workers, ), Argument( "length_ratio", float, optional=True, default=5.0, doc=doc_length_ratio, ), ]