Source code for deepmd.pd.infer.inference

# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from copy import (
    deepcopy,
)

import paddle

from deepmd.pd.model.model import (
    get_model,
)
from deepmd.pd.train.wrapper import (
    ModelWrapper,
)
from deepmd.pd.utils.env import (
    DEVICE,
    JIT,
)

[docs] log = logging.getLogger(__name__)
[docs] class Tester: def __init__( self, model_ckpt, head=None, ): """Construct a DeePMD tester. Args: - config: The Dict-like configuration with training options. """ # Model state_dict = paddle.load(model_ckpt) if "model" in state_dict: state_dict = state_dict["model"] model_params = state_dict["_extra_state"]["model_params"]
[docs] self.multi_task = "model_dict" in model_params
if self.multi_task: assert head is not None, "Head must be specified in multitask mode!" self.head = head assert head in model_params["model_dict"], ( f"Specified head {head} not found in model {model_ckpt}! " f"Available ones are {list(model_params['model_dict'].keys())}." ) model_params = model_params["model_dict"][head] state_dict_head = {"_extra_state": state_dict["_extra_state"]} for item in state_dict: if f"model.{head}." in item: state_dict_head[ item.replace(f"model.{head}.", "model.Default.") ] = state_dict[item].clone() state_dict = state_dict_head
[docs] self.model_params = deepcopy(model_params)
[docs] self.model = get_model(model_params).to(DEVICE)
# Model Wrapper
[docs] self.wrapper = ModelWrapper(self.model) # inference only
if JIT: raise NotImplementedError # self.wrapper = paddle.jit.to_static(self.wrapper) self.wrapper.set_state_dict(state_dict)