Source code for dpdata.driver
"""Driver plugin system."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Callable
from .plugin import Plugin
if TYPE_CHECKING:
import ase.calculators.calculator
[docs]
class Driver(ABC):
"""The base class for a driver plugin. A driver can
label a pure System to generate the LabeledSystem.
See Also
--------
dpdata.plugins.deepmd.DPDriver : an example of Driver
"""
__DriverPlugin = Plugin()
[docs]
@staticmethod
def register(key: str) -> Callable:
"""Register a driver plugin. Used as decorators.
Parameters
----------
key : str
key of the plugin.
Returns
-------
Callable
decorator of a class
Examples
--------
>>> @Driver.register("some_driver")
... class SomeDriver(Driver):
... pass
"""
return Driver.__DriverPlugin.register(key)
[docs]
@staticmethod
def get_driver(key: str) -> type[Driver]:
"""Get a driver plugin.
Parameters
----------
key : str
key of the plugin.
Returns
-------
Driver
the specific driver class
Raises
------
RuntimeError
if the requested driver is not implemented
"""
try:
return Driver.__DriverPlugin.plugins[key]
except KeyError as e:
raise RuntimeError("Unknown driver: " + key) from e
[docs]
@staticmethod
def get_drivers() -> dict:
"""Get all driver plugins.
Returns
-------
dict
dict for all driver plugisn
"""
return Driver.__DriverPlugin.plugins
def __init__(self, *args, **kwargs) -> None:
"""Setup the driver."""
[docs]
@abstractmethod
def label(self, data: dict) -> dict:
"""Label a system data. Returns new data with energy, forces, and virials.
Parameters
----------
data : dict
data with coordinates and atom types
Returns
-------
dict
labeled data with energies and forces
"""
return NotImplemented
@property
def ase_calculator(self) -> ase.calculators.calculator.Calculator:
"""Returns an ase calculator based on this driver."""
from .ase_calculator import DPDataCalculator
return DPDataCalculator(self)
[docs]
@Driver.register("hybrid")
class HybridDriver(Driver):
"""Hybrid driver, with mixed drivers.
Parameters
----------
drivers : list[dict, Driver]
list of drivers or drivers dict. For a dict, it should
contain `type` as the name of the driver, and others
are arguments of the driver.
Raises
------
TypeError
The value of `drivers` is not a dict or `Driver`.
Examples
--------
>>> driver = HybridDriver([
... {"type": "sqm", "qm_theory": "DFTB3"},
... {"type": "dp", "dp": "frozen_model.pb"},
... ])
This driver is the hybrid of SQM and DP.
"""
def __init__(self, drivers: list[dict | Driver]) -> None:
self.drivers = []
for driver in drivers:
if isinstance(driver, Driver):
self.drivers.append(driver)
elif isinstance(driver, dict):
type = driver["type"]
del driver["type"]
self.drivers.append(Driver.get_driver(type)(**driver))
else:
raise TypeError("driver should be Driver or dict")
[docs]
def label(self, data: dict) -> dict:
"""Label a system data.
Energies and forces are the sum of those of each driver.
Parameters
----------
data : dict
data with coordinates and atom types
Returns
-------
dict
labeled data with energies and forces
"""
labeled_data = {}
for ii, driver in enumerate(self.drivers):
lb_data = driver.label(data.copy())
if ii == 0:
labeled_data = lb_data.copy()
else:
labeled_data["energies"] += lb_data["energies"]
if "forces" in labeled_data and "forces" in lb_data:
labeled_data["forces"] += lb_data["forces"]
if "virials" in labeled_data and "virials" in lb_data:
labeled_data["virials"] += lb_data["virials"]
return labeled_data
[docs]
class Minimizer(ABC):
"""The base class for a minimizer plugin. A minimizer can
minimize geometry.
"""
__MinimizerPlugin = Plugin()
[docs]
@staticmethod
def register(key: str) -> Callable:
"""Register a minimizer plugin. Used as decorators.
Parameters
----------
key : str
key of the plugin.
Returns
-------
Callable
decorator of a class
Examples
--------
>>> @Minimizer.register("some_minimizer")
... class SomeMinimizer(Minimizer):
... pass
"""
return Minimizer.__MinimizerPlugin.register(key)
[docs]
@staticmethod
def get_minimizer(key: str) -> type[Minimizer]:
"""Get a minimizer plugin.
Parameters
----------
key : str
key of the plugin.
Returns
-------
Minimizer
the specific minimizer class
Raises
------
RuntimeError
if the requested minimizer is not implemented
"""
try:
return Minimizer.__MinimizerPlugin.plugins[key]
except KeyError as e:
raise RuntimeError("Unknown minimizer: " + key) from e
[docs]
@staticmethod
def get_minimizers() -> dict:
"""Get all minimizer plugins.
Returns
-------
dict
dict for all minimizer plugisn
"""
return Minimizer.__MinimizerPlugin.plugins
def __init__(self, *args, **kwargs) -> None:
"""Setup the minimizer."""
[docs]
@abstractmethod
def minimize(self, data: dict) -> dict:
"""Minimize the geometry.
Parameters
----------
data : dict
data with coordinates and atom types
Returns
-------
dict
labeled data with minimized coordinates, energies, and forces
"""