Source code for deepmd.infer

"""Submodule containing all the implemented potentials."""

from pathlib import Path
from typing import Union

from .data_modifier import DipoleChargeModifier
from .deep_dipole import DeepDipole
from .deep_eval import DeepEval
from .deep_polar import DeepGlobalPolar, DeepPolar
from .deep_pot import DeepPot
from .deep_wfc import DeepWFC
from .ewald_recp import EwaldRecp
from .model_devi import calc_model_devi

__all__ = [
    "DeepPotential",
    "DeepDipole",
    "DeepEval",
    "DeepGlobalPolar",
    "DeepPolar",
    "DeepPot",
    "DeepWFC",
    "DipoleChargeModifier",
    "EwaldRecp",
    "calc_model_devi"
]


[docs]def DeepPotential( model_file: Union[str, Path], load_prefix: str = "load", default_tf_graph: bool = False, ) -> Union[DeepDipole, DeepGlobalPolar, DeepPolar, DeepPot, DeepWFC]: """Factory function that will inialize appropriate potential read from `model_file`. Parameters ---------- model_file: str The name of the frozen model file. load_prefix: str The prefix in the load computational graph default_tf_graph : bool If uses the default tf graph, otherwise build a new tf graph for evaluation Returns ------- Union[DeepDipole, DeepGlobalPolar, DeepPolar, DeepPot, DeepWFC] one of the available potentials Raises ------ RuntimeError if model file does not correspond to any implementd potential """ mf = Path(model_file) model_type = DeepEval( mf, load_prefix=load_prefix, default_tf_graph=default_tf_graph ).model_type if model_type == "ener": dp = DeepPot(mf, load_prefix=load_prefix, default_tf_graph=default_tf_graph) elif model_type == "dipole": dp = DeepDipole(mf, load_prefix=load_prefix, default_tf_graph=default_tf_graph) elif model_type == "polar": dp = DeepPolar(mf, load_prefix=load_prefix, default_tf_graph=default_tf_graph) elif model_type == "global_polar": dp = DeepGlobalPolar( mf, load_prefix=load_prefix, default_tf_graph=default_tf_graph ) elif model_type == "wfc": dp = DeepWFC(mf, load_prefix=load_prefix, default_tf_graph=default_tf_graph) else: raise RuntimeError(f"unknow model type {model_type}") return dp