Source code for deepmd.infer

"""Submodule containing all the implemented potentials."""

from pathlib import (
    Path,
)
from typing import (
    Union,
)

from .data_modifier import (
    DipoleChargeModifier,
)
from .deep_dipole import (
    DeepDipole,
)
from .deep_eval import (
    DeepEval,
)
from .deep_polar import (
    DeepGlobalPolar,
    DeepPolar,
)
from .deep_pot import (
    DeepPot,
)
from .deep_wfc import (
    DeepWFC,
)
from .ewald_recp import (
    EwaldRecp,
)
from .model_devi import (
    calc_model_devi,
)

__all__ = [
    "DeepPotential",
    "DeepDipole",
    "DeepEval",
    "DeepGlobalPolar",
    "DeepPolar",
    "DeepPot",
    "DeepWFC",
    "DipoleChargeModifier",
    "EwaldRecp",
    "calc_model_devi",
]


[docs]def DeepPotential( model_file: Union[str, Path], load_prefix: str = "load", default_tf_graph: bool = False, ) -> Union[DeepDipole, DeepGlobalPolar, DeepPolar, DeepPot, DeepWFC]: """Factory function that will inialize appropriate potential read from `model_file`. Parameters ---------- model_file : str The name of the frozen model file. load_prefix : str The prefix in the load computational graph default_tf_graph : bool If uses the default tf graph, otherwise build a new tf graph for evaluation Returns ------- Union[DeepDipole, DeepGlobalPolar, DeepPolar, DeepPot, DeepWFC] one of the available potentials Raises ------ RuntimeError if model file does not correspond to any implementd potential """ mf = Path(model_file) model_type = DeepEval( mf, load_prefix=load_prefix, default_tf_graph=default_tf_graph ).model_type if model_type == "ener": dp = DeepPot(mf, load_prefix=load_prefix, default_tf_graph=default_tf_graph) elif model_type == "dipole": dp = DeepDipole(mf, load_prefix=load_prefix, default_tf_graph=default_tf_graph) elif model_type == "polar": dp = DeepPolar(mf, load_prefix=load_prefix, default_tf_graph=default_tf_graph) elif model_type == "global_polar": dp = DeepGlobalPolar( mf, load_prefix=load_prefix, default_tf_graph=default_tf_graph ) elif model_type == "wfc": dp = DeepWFC(mf, load_prefix=load_prefix, default_tf_graph=default_tf_graph) else: raise RuntimeError(f"unknown model type {model_type}") return dp