Source code for deepmd.fit.fitting

# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
    abstractmethod,
)
from typing import (
    Callable,
)

from deepmd.env import (
    tf,
)
from deepmd.loss.loss import (
    Loss,
)
from deepmd.utils import (
    Plugin,
    PluginVariant,
)


[docs]class Fitting(PluginVariant): __plugins = Plugin()
[docs] @staticmethod def register(key: str) -> Callable: """Register a Fitting plugin. Parameters ---------- key : str the key of a Fitting Returns ------- Fitting the registered Fitting Examples -------- >>> @Fitting.register("some_fitting") class SomeFitting(Fitting): pass """ return Fitting.__plugins.register(key)
def __new__(cls, *args, **kwargs): if cls is Fitting: try: fitting_type = kwargs["type"] except KeyError: raise KeyError("the type of fitting should be set by `type`") if fitting_type in Fitting.__plugins.plugins: cls = Fitting.__plugins.plugins[fitting_type] else: raise RuntimeError("Unknown descriptor type: " + fitting_type) return super().__new__(cls) @property def precision(self) -> tf.DType: """Precision of fitting network.""" return self.fitting_precision
[docs] def init_variables( self, graph: tf.Graph, graph_def: tf.GraphDef, suffix: str = "", ) -> None: """Init the fitting net variables with the given dict. Parameters ---------- graph : tf.Graph The input frozen model graph graph_def : tf.GraphDef The input frozen model graph_def suffix : str suffix to name scope Notes ----- This method is called by others when the fitting supported initialization from the given variables. """ raise NotImplementedError( "Fitting %s doesn't support initialization from the given variables!" % type(self).__name__ )
[docs] @abstractmethod def get_loss(self, loss: dict, lr) -> Loss: """Get the loss function. Parameters ---------- loss : dict the loss dict lr : LearningRateExp the learning rate Returns ------- Loss the loss function """