Source code for deepmd.pd.utils.finetune

# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from copy import (
    deepcopy,
)

import paddle

from deepmd.utils.finetune import (
    FinetuneRuleItem,
)

[docs] log = logging.getLogger(__name__)
[docs] def get_finetune_rule_single( _single_param_target: dict, _model_param_pretrained: dict, from_multitask: bool = False, model_branch: str = "Default", model_branch_from: str = "", change_model_params: bool = False, ) -> dict: single_config = deepcopy(_single_param_target) new_fitting = False model_branch_chosen = "Default" if not from_multitask: single_config_chosen = deepcopy(_model_param_pretrained) if model_branch_from == "RANDOM": # not ["", "RANDOM"], because single-from-single finetune uses pretrained fitting in default new_fitting = True else: model_dict_params = _model_param_pretrained["model_dict"] if model_branch_from in ["", "RANDOM"]: model_branch_chosen = next(iter(model_dict_params.keys())) new_fitting = True log.warning( "The fitting net will be re-init instead of using that in the pretrained model! " "The bias_adjust_mode will be set-by-statistic!" ) else: model_branch_chosen = model_branch_from assert model_branch_chosen in model_dict_params, ( f"No model branch named '{model_branch_chosen}'! " f"Available ones are {list(model_dict_params.keys())}." ) single_config_chosen = deepcopy(model_dict_params[model_branch_chosen]) old_type_map, new_type_map = ( single_config_chosen["type_map"], single_config["type_map"], ) finetune_rule = FinetuneRuleItem( p_type_map=old_type_map, type_map=new_type_map, model_branch=model_branch_chosen, random_fitting=new_fitting, ) if change_model_params: trainable_param = { "descriptor": single_config.get("descriptor", {}).get("trainable", True), "fitting_net": single_config.get("fitting_net", {}).get("trainable", True), } single_config["descriptor"] = single_config_chosen["descriptor"] if not new_fitting: single_config["fitting_net"] = single_config_chosen["fitting_net"] log.info( f"Change the '{model_branch}' model configurations according to the model branch " f"'{model_branch_chosen}' in the pretrained one..." ) for net_type in trainable_param: if net_type in single_config: single_config[net_type]["trainable"] = trainable_param[net_type] else: single_config[net_type] = {"trainable": trainable_param[net_type]} return single_config, finetune_rule
[docs] def get_finetune_rules( finetune_model: str, model_config: dict, model_branch: str = "", change_model_params: bool = True, ) -> tuple[dict, str]: """ Get fine-tuning rules and (optionally) change the model_params according to the pretrained one. This function gets the fine-tuning rules and (optionally) changes input in different modes as follows: 1. Single-task fine-tuning from a single-task pretrained model: - The model will be fine-tuned based on the pretrained model. - (Optional) Updates the model parameters based on the pretrained model. 2. Single-task fine-tuning from a multi-task pretrained model: - The model will be fine-tuned based on the selected branch in the pretrained model. The chosen branch can be defined from the command-line or `finetune_head` input parameter. If not defined, model parameters in the fitting network will be randomly initialized. - (Optional) Updates the model parameters based on the selected branch in the pretrained model. 3. Multi-task fine-tuning from a single-task pretrained model: - The model in each branch will be fine-tuned or resumed based on the single branch ('Default') in the pretrained model. The chosen branches can be defined from the `finetune_head` input parameter of each branch. - If `finetune_head` is defined as 'Default', it will be fine-tuned based on the single branch ('Default') in the pretrained model. - If `finetune_head` is not defined and the model_key is 'Default', it will resume from the single branch ('Default') in the pretrained model without fine-tuning. - If `finetune_head` is not defined and the model_key is not 'Default', it will be fine-tuned based on the single branch ('Default') in the pretrained model, while model parameters in the fitting network of the branch will be randomly initialized. - (Optional) Updates model parameters in each branch based on the single branch ('Default') in the pretrained model. 4. Multi-task fine-tuning from a multi-task pretrained model: - The model in each branch will be fine-tuned or resumed based on the chosen branches in the pretrained model. The chosen branches can be defined from the `finetune_head` input parameter of each branch. - If `finetune_head` is defined as one of the branches in the pretrained model, it will be fine-tuned based on the chosen branch in the pretrained model. - If `finetune_head` is not defined and the model_key is the same as one of those in the pretrained model, it will resume from the model_key branch in the pretrained model without fine-tuning. - If `finetune_head` is not defined and a new model_key is used, it will be fine-tuned based on the chosen branch in the pretrained model, while model parameters in the fitting network of the branch will be randomly initialized. - (Optional) Updates model parameters in each branch based on the chosen branches in the pretrained model. Parameters ---------- finetune_model The pretrained model. model_config The fine-tuning input parameters. model_branch The model branch chosen in command-line mode, only for single-task fine-tuning. change_model_params Whether to change the model parameters according to the pretrained one. Returns ------- model_config: Updated model parameters. finetune_links: Fine-tuning rules in a dict format, with `model_branch`: FinetuneRuleItem pairs. """ multi_task = "model_dict" in model_config state_dict = paddle.load(finetune_model) if "model" in state_dict: state_dict = state_dict["model"] last_model_params = state_dict["_extra_state"]["model_params"] finetune_from_multi_task = "model_dict" in last_model_params finetune_links = {} if not multi_task: # use command-line first if model_branch == "" and "finetune_head" in model_config: model_branch = model_config["finetune_head"] model_config, finetune_rule = get_finetune_rule_single( model_config, last_model_params, from_multitask=finetune_from_multi_task, model_branch="Default", model_branch_from=model_branch, change_model_params=change_model_params, ) finetune_links["Default"] = finetune_rule else: assert model_branch == "", ( "Multi-task fine-tuning does not support command-line branches chosen!" "Please define the 'finetune_head' in each model params!" ) target_keys = model_config["model_dict"].keys() if not finetune_from_multi_task: pretrained_keys = ["Default"] else: pretrained_keys = last_model_params["model_dict"].keys() for model_key in target_keys: resuming = False if ( "finetune_head" in model_config["model_dict"][model_key] and model_config["model_dict"][model_key]["finetune_head"] != "RANDOM" ): pretrained_key = model_config["model_dict"][model_key]["finetune_head"] assert pretrained_key in pretrained_keys, ( f"'{pretrained_key}' head chosen to finetune not exist in the pretrained model!" f"Available heads are: {list(pretrained_keys)}" ) model_branch_from = pretrained_key elif ( "finetune_head" not in model_config["model_dict"][model_key] and model_key in pretrained_keys ): # not do anything if not defined "finetune_head" in heads that exist in the pretrained model # this will just do resuming model_branch_from = model_key resuming = True else: # if not defined "finetune_head" in new heads or "finetune_head" is "RANDOM", the fitting net will bre randomly initialized model_branch_from = "RANDOM" model_config["model_dict"][model_key], finetune_rule = ( get_finetune_rule_single( model_config["model_dict"][model_key], last_model_params, from_multitask=finetune_from_multi_task, model_branch=model_key, model_branch_from=model_branch_from, change_model_params=change_model_params, ) ) finetune_links[model_key] = finetune_rule finetune_links[model_key].resuming = resuming return model_config, finetune_links