Source code for deepmd.pt.entrypoints.main

# SPDX-License-Identifier: LGPL-3.0-or-later
import argparse
import copy
import json
import logging
import os
from pathlib import (
    Path,
)
from typing import (
    List,
    Optional,
    Union,
)

import h5py
import torch
import torch.distributed as dist
import torch.version
from torch.distributed.elastic.multiprocessing.errors import (
    record,
)

from deepmd import (
    __version__,
)
from deepmd.common import (
    expand_sys_str,
)
from deepmd.env import (
    GLOBAL_CONFIG,
)
from deepmd.loggers.loggers import (
    set_log_handles,
)
from deepmd.main import (
    parse_args,
)
from deepmd.pt.cxx_op import (
    ENABLE_CUSTOMIZED_OP,
)
from deepmd.pt.infer import (
    inference,
)
from deepmd.pt.model.model import (
    BaseModel,
)
from deepmd.pt.train import (
    training,
)
from deepmd.pt.train.wrapper import (
    ModelWrapper,
)
from deepmd.pt.utils import (
    env,
)
from deepmd.pt.utils.dataloader import (
    DpLoaderSet,
)
from deepmd.pt.utils.env import (
    DEVICE,
)
from deepmd.pt.utils.finetune import (
    get_finetune_rules,
)
from deepmd.pt.utils.multi_task import (
    preprocess_shared_params,
)
from deepmd.pt.utils.stat import (
    make_stat_input,
)
from deepmd.pt.utils.utils import (
    to_numpy_array,
)
from deepmd.utils.argcheck import (
    normalize,
)
from deepmd.utils.compat import (
    update_deepmd_input,
)
from deepmd.utils.data_system import (
    get_data,
    process_systems,
)
from deepmd.utils.path import (
    DPPath,
)
from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter

[docs] log = logging.getLogger(__name__)
[docs] def get_trainer( config, init_model=None, restart_model=None, finetune_model=None, force_load=False, init_frz_model=None, shared_links=None, finetune_links=None, ): multi_task = "model_dict" in config.get("model", {}) # Initialize DDP local_rank = os.environ.get("LOCAL_RANK") if local_rank is not None: local_rank = int(local_rank) assert dist.is_nccl_available() dist.init_process_group(backend="nccl") def prepare_trainer_input_single( model_params_single, data_dict_single, rank=0, seed=None ): training_dataset_params = data_dict_single["training_data"] validation_dataset_params = data_dict_single.get("validation_data", None) validation_systems = ( validation_dataset_params["systems"] if validation_dataset_params else None ) training_systems = training_dataset_params["systems"] training_systems = process_systems(training_systems) if validation_systems is not None: validation_systems = process_systems(validation_systems) # stat files stat_file_path_single = data_dict_single.get("stat_file", None) if rank != 0: stat_file_path_single = None elif stat_file_path_single is not None: if not Path(stat_file_path_single).exists(): if stat_file_path_single.endswith((".h5", ".hdf5")): with h5py.File(stat_file_path_single, "w") as f: pass else: Path(stat_file_path_single).mkdir() stat_file_path_single = DPPath(stat_file_path_single, "a") # validation and training data # avoid the same batch sequence among devices rank_seed = (seed + rank) % (2**32) if seed is not None else None validation_data_single = ( DpLoaderSet( validation_systems, validation_dataset_params["batch_size"], model_params_single["type_map"], seed=rank_seed, ) if validation_systems else None ) train_data_single = DpLoaderSet( training_systems, training_dataset_params["batch_size"], model_params_single["type_map"], seed=rank_seed, ) return ( train_data_single, validation_data_single, stat_file_path_single, ) rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 data_seed = config["training"].get("seed", None) if not multi_task: ( train_data, validation_data, stat_file_path, ) = prepare_trainer_input_single( config["model"], config["training"], rank=rank, seed=data_seed, ) else: train_data, validation_data, stat_file_path = {}, {}, {} for model_key in config["model"]["model_dict"]: ( train_data[model_key], validation_data[model_key], stat_file_path[model_key], ) = prepare_trainer_input_single( config["model"]["model_dict"][model_key], config["training"]["data_dict"][model_key], rank=rank, seed=data_seed, ) trainer = training.Trainer( config, train_data, stat_file_path=stat_file_path, validation_data=validation_data, init_model=init_model, restart_model=restart_model, finetune_model=finetune_model, force_load=force_load, shared_links=shared_links, finetune_links=finetune_links, init_frz_model=init_frz_model, ) return trainer
[docs] class SummaryPrinter(BaseSummaryPrinter): """Summary printer for PyTorch."""
[docs] def is_built_with_cuda(self) -> bool: """Check if the backend is built with CUDA.""" return torch.version.cuda is not None
[docs] def is_built_with_rocm(self) -> bool: """Check if the backend is built with ROCm.""" return torch.version.hip is not None
[docs] def get_compute_device(self) -> str: """Get Compute device.""" return str(DEVICE)
[docs] def get_ngpus(self) -> int: """Get the number of GPUs.""" return torch.cuda.device_count()
[docs] def get_backend_info(self) -> dict: """Get backend information.""" if ENABLE_CUSTOMIZED_OP: op_info = { "build with PT ver": GLOBAL_CONFIG["pt_version"], "build with PT inc": GLOBAL_CONFIG["pt_include_dir"].replace(";", "\n"), "build with PT lib": GLOBAL_CONFIG["pt_libs"].replace(";", "\n"), } else: op_info = {} return { "Backend": "PyTorch", "PT ver": f"v{torch.__version__}-g{torch.version.git_version[:11]}", "Enable custom OP": ENABLE_CUSTOMIZED_OP, **op_info, }
[docs] def train(FLAGS): log.info("Configuration path: %s", FLAGS.INPUT) SummaryPrinter()() with open(FLAGS.INPUT) as fin: config = json.load(fin) # ensure suffix, as in the command line help, we say "path prefix of checkpoint files" if FLAGS.init_model is not None and not FLAGS.init_model.endswith(".pt"): FLAGS.init_model += ".pt" if FLAGS.restart is not None and not FLAGS.restart.endswith(".pt"): FLAGS.restart += ".pt" # update multitask config multi_task = "model_dict" in config["model"] shared_links = None if multi_task: config["model"], shared_links = preprocess_shared_params(config["model"]) # handle the special key assert ( "RANDOM" not in config["model"]["model_dict"] ), "Model name can not be 'RANDOM' in multi-task mode!" # update fine-tuning config finetune_links = None if FLAGS.finetune is not None: config["model"], finetune_links = get_finetune_rules( FLAGS.finetune, config["model"], model_branch=FLAGS.model_branch, change_model_params=FLAGS.use_pretrain_script, ) # update init_model or init_frz_model config if necessary if ( FLAGS.init_model is not None or FLAGS.init_frz_model is not None ) and FLAGS.use_pretrain_script: if FLAGS.init_model is not None: init_state_dict = torch.load(FLAGS.init_model, map_location=DEVICE) if "model" in init_state_dict: init_state_dict = init_state_dict["model"] config["model"] = init_state_dict["_extra_state"]["model_params"] else: config["model"] = json.loads( torch.jit.load( FLAGS.init_frz_model, map_location=DEVICE ).get_model_def_script() ) # argcheck config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json") config = normalize(config, multi_task=multi_task) # do neighbor stat min_nbor_dist = None if not FLAGS.skip_neighbor_stat: log.info( "Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)" ) type_map = config["model"].get("type_map") if not multi_task: train_data = get_data( config["training"]["training_data"], 0, type_map, None ) config["model"], min_nbor_dist = BaseModel.update_sel( train_data, type_map, config["model"] ) else: min_nbor_dist = {} for model_item in config["model"]["model_dict"]: train_data = get_data( config["training"]["data_dict"][model_item]["training_data"], 0, type_map, None, ) config["model"]["model_dict"][model_item], min_nbor_dist[model_item] = ( BaseModel.update_sel( train_data, type_map, config["model"]["model_dict"][model_item] ) ) with open(FLAGS.output, "w") as fp: json.dump(config, fp, indent=4) trainer = get_trainer( config, FLAGS.init_model, FLAGS.restart, FLAGS.finetune, FLAGS.force_load, FLAGS.init_frz_model, shared_links=shared_links, finetune_links=finetune_links, ) # save min_nbor_dist if min_nbor_dist is not None: if not multi_task: trainer.model.min_nbor_dist = min_nbor_dist else: for model_item in min_nbor_dist: trainer.model[model_item].min_nbor_dist = min_nbor_dist[model_item] trainer.run()
[docs] def freeze(FLAGS): model = inference.Tester(FLAGS.model, head=FLAGS.head).model model.eval() model = torch.jit.script(model) extra_files = {} torch.jit.save( model, FLAGS.output, extra_files, )
[docs] def show(FLAGS): if FLAGS.INPUT.split(".")[-1] == "pt": state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE) if "model" in state_dict: state_dict = state_dict["model"] model_params = state_dict["_extra_state"]["model_params"] elif FLAGS.INPUT.split(".")[-1] == "pth": model_params_string = torch.jit.load( FLAGS.INPUT, map_location=env.DEVICE ).model_def_script model_params = json.loads(model_params_string) else: raise RuntimeError( "The model provided must be a checkpoint file with a .pt extension " "or a frozen model with a .pth extension" ) model_is_multi_task = "model_dict" in model_params log.info("This is a multitask model") if model_is_multi_task else log.info( "This is a singletask model" ) if "model-branch" in FLAGS.ATTRIBUTES: # The model must be multitask mode if not model_is_multi_task: raise RuntimeError( "The 'model-branch' option requires a multitask model." " The provided model does not meet this criterion." ) model_branches = list(model_params["model_dict"].keys()) model_branches += ["RANDOM"] log.info( f"Available model branches are {model_branches}, " f"where 'RANDOM' means using a randomly initialized fitting net." ) if "type-map" in FLAGS.ATTRIBUTES: if model_is_multi_task: model_branches = list(model_params["model_dict"].keys()) for branch in model_branches: type_map = model_params["model_dict"][branch]["type_map"] log.info(f"The type_map of branch {branch} is {type_map}") else: type_map = model_params["type_map"] log.info(f"The type_map is {type_map}") if "descriptor" in FLAGS.ATTRIBUTES: if model_is_multi_task: model_branches = list(model_params["model_dict"].keys()) for branch in model_branches: descriptor = model_params["model_dict"][branch]["descriptor"] log.info(f"The descriptor parameter of branch {branch} is {descriptor}") else: descriptor = model_params["descriptor"] log.info(f"The descriptor parameter is {descriptor}") if "fitting-net" in FLAGS.ATTRIBUTES: if model_is_multi_task: model_branches = list(model_params["model_dict"].keys()) for branch in model_branches: fitting_net = model_params["model_dict"][branch]["fitting_net"] log.info( f"The fitting_net parameter of branch {branch} is {fitting_net}" ) else: fitting_net = model_params["fitting_net"] log.info(f"The fitting_net parameter is {fitting_net}")
[docs] def change_bias(FLAGS): if FLAGS.INPUT.endswith(".pt"): old_state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE) model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict)) model_params = model_state_dict["_extra_state"]["model_params"] elif FLAGS.INPUT.endswith(".pth"): old_model = torch.jit.load(FLAGS.INPUT, map_location=env.DEVICE) model_params_string = old_model.get_model_def_script() model_params = json.loads(model_params_string) old_state_dict = old_model.state_dict() model_state_dict = old_state_dict else: raise RuntimeError( "The model provided must be a checkpoint file with a .pt extension " "or a frozen model with a .pth extension" ) multi_task = "model_dict" in model_params model_branch = FLAGS.model_branch bias_adjust_mode = ( "change-by-statistic" if FLAGS.mode == "change" else "set-by-statistic" ) if multi_task: assert ( model_branch is not None ), "For multitask model, the model branch must be set!" assert model_branch in model_params["model_dict"], ( f"For multitask model, the model branch must be in the 'model_dict'! " f"Available options are : {list(model_params['model_dict'].keys())}." ) log.info(f"Changing out bias for model {model_branch}.") model = training.get_model_for_wrapper(model_params) type_map = ( model_params["type_map"] if not multi_task else model_params["model_dict"][model_branch]["type_map"] ) model_to_change = model if not multi_task else model[model_branch] if FLAGS.INPUT.endswith(".pt"): wrapper = ModelWrapper(model) wrapper.load_state_dict(old_state_dict["model"]) else: # for .pth model.load_state_dict(old_state_dict) if FLAGS.bias_value is not None: # use user-defined bias assert model_to_change.model_type in [ "ener" ], "User-defined bias is only available for energy model!" assert ( len(FLAGS.bias_value) == len(type_map) ), f"The number of elements in the bias should be the same as that in the type_map: {type_map}." old_bias = model_to_change.get_out_bias() bias_to_set = torch.tensor( FLAGS.bias_value, dtype=old_bias.dtype, device=old_bias.device ).view(old_bias.shape) model_to_change.set_out_bias(bias_to_set) log.info( f"Change output bias of {type_map!s} " f"from {to_numpy_array(old_bias).reshape(-1)!s} " f"to {to_numpy_array(bias_to_set).reshape(-1)!s}." ) updated_model = model_to_change else: # calculate bias on given systems if FLAGS.datafile is not None: with open(FLAGS.datafile) as datalist: all_sys = datalist.read().splitlines() else: all_sys = expand_sys_str(FLAGS.system) data_systems = process_systems(all_sys) data_single = DpLoaderSet( data_systems, 1, type_map, ) mock_loss = training.get_loss( {"inference": True}, 1.0, len(type_map), model_to_change ) data_requirement = mock_loss.label_requirement data_requirement += training.get_additional_data_requirement(model_to_change) data_single.add_data_requirement(data_requirement) nbatches = FLAGS.numb_batch if FLAGS.numb_batch != 0 else float("inf") sampled_data = make_stat_input( data_single.systems, data_single.dataloaders, nbatches, ) updated_model = training.model_change_out_bias( model_to_change, sampled_data, _bias_adjust_mode=bias_adjust_mode ) if not multi_task: model = updated_model else: model[model_branch] = updated_model if FLAGS.INPUT.endswith(".pt"): output_path = ( FLAGS.output if FLAGS.output is not None else FLAGS.INPUT.replace(".pt", "_updated.pt") ) wrapper = ModelWrapper(model) if "model" in old_state_dict: old_state_dict["model"] = wrapper.state_dict() old_state_dict["model"]["_extra_state"] = model_state_dict["_extra_state"] else: old_state_dict = wrapper.state_dict() old_state_dict["_extra_state"] = model_state_dict["_extra_state"] torch.save(old_state_dict, output_path) else: # for .pth output_path = ( FLAGS.output if FLAGS.output is not None else FLAGS.INPUT.replace(".pth", "_updated.pth") ) model = torch.jit.script(model) torch.jit.save( model, output_path, {}, ) log.info(f"Saved model to {output_path}")
@record
[docs] def main(args: Optional[Union[List[str], argparse.Namespace]] = None): if not isinstance(args, argparse.Namespace): FLAGS = parse_args(args=args) else: FLAGS = args set_log_handles(FLAGS.log_level, FLAGS.log_path, mpi_log=None) log.debug("Log handles were successfully set") log.info("DeePMD version: %s", __version__) if FLAGS.command == "train": train(FLAGS) elif FLAGS.command == "freeze": if Path(FLAGS.checkpoint_folder).is_dir(): checkpoint_path = Path(FLAGS.checkpoint_folder) latest_ckpt_file = (checkpoint_path / "checkpoint").read_text() FLAGS.model = str(checkpoint_path.joinpath(latest_ckpt_file)) else: FLAGS.model = FLAGS.checkpoint_folder FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth")) freeze(FLAGS) elif FLAGS.command == "show": show(FLAGS) elif FLAGS.command == "change-bias": change_bias(FLAGS) else: raise RuntimeError(f"Invalid command {FLAGS.command}!")
if __name__ == "__main__": main()