Source code for deepmd.utils.multi_init

# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import logging
from typing import (
    Any,
    Dict,
)

from deepmd.utils.errors import (
    GraphWithoutTensorError,
)
from deepmd.utils.graph import (
    get_tensor_by_name,
)

log = logging.getLogger(__name__)


[docs]def replace_model_params_with_frz_multi_model( jdata: Dict[str, Any], pretrained_model: str ): """Replace the model params in input script according to pretrained frozen multi-task united model. Parameters ---------- jdata : Dict[str, Any] input script pretrained_model : str filename of the pretrained frozen multi-task united model """ # Get the input script from the pretrained model try: t_jdata = get_tensor_by_name(pretrained_model, "train_attr/training_script") except GraphWithoutTensorError as e: raise RuntimeError( "The input frozen pretrained model: %s has no training script, " "which is not supported to perform multi-task training. " "Please use the model pretrained with v2.1.5 or higher version of DeePMD-kit." % input ) from e pretrained_jdata = json.loads(t_jdata) # Check the model type assert "fitting_net_dict" in pretrained_jdata["model"], ( "The multi-task init process only supports models trained in multi-task mode and frozen into united model!" "Please use '--united-model' argument in 'dp freeze' command." ) # Check the type map pretrained_type_map = pretrained_jdata["model"]["type_map"] cur_type_map = jdata["model"].get("type_map", []) out_line_type = [] for i in cur_type_map: if i not in pretrained_type_map: out_line_type.append(i) assert not out_line_type, ( f"{out_line_type!s} type(s) not contained in the pretrained model! " "Please choose another suitable one." ) if cur_type_map != pretrained_type_map: log.info( "Change the type_map from {} to {}.".format( str(cur_type_map), str(pretrained_type_map) ) ) jdata["model"]["type_map"] = pretrained_type_map # Change model configurations pretrained_fitting_keys = sorted( pretrained_jdata["model"]["fitting_net_dict"].keys() ) cur_fitting_keys = sorted(jdata["model"]["fitting_net_dict"].keys()) newly_added_fittings = set(cur_fitting_keys) - set(pretrained_fitting_keys) reused_fittings = set(cur_fitting_keys) - newly_added_fittings log.info("Change the model configurations according to the pretrained one...") for config_key in ["type_embedding", "descriptor", "fitting_net_dict"]: if ( config_key not in jdata["model"].keys() and config_key in pretrained_jdata["model"].keys() ): log.info( "Add the '{}' from pretrained model: {}.".format( config_key, str(pretrained_jdata["model"][config_key]) ) ) jdata["model"][config_key] = pretrained_jdata["model"][config_key] elif ( config_key == "type_embedding" and config_key in jdata["model"].keys() and config_key not in pretrained_jdata["model"].keys() ): # 'type_embedding' can be omitted using 'se_atten' descriptor, and the activation_function will be None. cur_para = jdata["model"].pop(config_key) if "trainable" in cur_para and not cur_para["trainable"]: jdata["model"][config_key] = { "trainable": False, "activation_function": "None", } log.info("The type_embeddings from pretrained model will be frozen.") elif config_key == "fitting_net_dict": if reused_fittings: log.info( f"These fitting nets will use the configurations from pretrained frozen model : {reused_fittings}." ) for fitting_key in reused_fittings: _change_sub_config( jdata["model"][config_key], pretrained_jdata["model"][config_key], fitting_key, ) if newly_added_fittings: log.info( f"These fitting nets will be initialized from scratch : {newly_added_fittings}." ) elif ( config_key in jdata["model"].keys() and config_key in pretrained_jdata["model"].keys() and jdata["model"][config_key] != pretrained_jdata["model"][config_key] ): _change_sub_config(jdata["model"], pretrained_jdata["model"], config_key) # Change other multi-task configurations log.info("Change the training configurations according to the pretrained one...") for config_key in ["loss_dict", "training/data_dict"]: cur_jdata = jdata target_jdata = pretrained_jdata for sub_key in config_key.split("/"): cur_jdata = cur_jdata[sub_key] target_jdata = target_jdata[sub_key] for fitting_key in reused_fittings: if fitting_key not in cur_jdata: target_para = target_jdata[fitting_key] cur_jdata[fitting_key] = target_para log.info( f"Add '{config_key}/{fitting_key}' configurations from the pretrained frozen model." ) # learning rate dict keep backward compatibility config_key = "learning_rate_dict" single_config_key = "learning_rate" cur_jdata = jdata target_jdata = pretrained_jdata if (single_config_key not in cur_jdata) and (config_key in cur_jdata): cur_jdata = cur_jdata[config_key] if config_key in target_jdata: target_jdata = target_jdata[config_key] for fitting_key in reused_fittings: if fitting_key not in cur_jdata: target_para = target_jdata[fitting_key] cur_jdata[fitting_key] = target_para log.info( f"Add '{config_key}/{fitting_key}' configurations from the pretrained frozen model." ) else: for fitting_key in reused_fittings: if fitting_key not in cur_jdata: cur_jdata[fitting_key] = {} log.info( f"Add '{config_key}/{fitting_key}' configurations as default." ) return jdata
def _change_sub_config(jdata: Dict[str, Any], src_jdata: Dict[str, Any], sub_key: str): target_para = src_jdata[sub_key] cur_para = jdata[sub_key] # keep some params that are irrelevant to model structures (need to discuss) TODO if "trainable" in cur_para.keys(): target_para["trainable"] = cur_para["trainable"] log.info(f"Change the '{sub_key}' from {cur_para!s} to {target_para!s}.") jdata[sub_key] = target_para