Source code for deepmd.entrypoints.train

# SPDX-License-Identifier: LGPL-3.0-or-later
"""DeePMD training entrypoint script.

Can handle local or distributed training.
"""

import json
import logging
import time
from typing import (
    Any,
    Dict,
    Optional,
)

from deepmd.common import (
    data_requirement,
    expand_sys_str,
    j_loader,
    j_must_have,
)
from deepmd.env import (
    GLOBAL_ENER_FLOAT_PRECISION,
    reset_default_tf_session_config,
    tf,
)
from deepmd.infer.data_modifier import (
    DipoleChargeModifier,
)
from deepmd.model.model import (
    Model,
)
from deepmd.train.run_options import (
    BUILD,
    CITATION,
    WELCOME,
    RunOptions,
)
from deepmd.train.trainer import (
    DPTrainer,
)
from deepmd.utils import random as dp_random
from deepmd.utils.argcheck import (
    normalize,
)
from deepmd.utils.compat import (
    update_deepmd_input,
)
from deepmd.utils.data_system import (
    DeepmdDataSystem,
)
from deepmd.utils.finetune import (
    replace_model_params_with_pretrained_model,
)
from deepmd.utils.multi_init import (
    replace_model_params_with_frz_multi_model,
)
from deepmd.utils.neighbor_stat import (
    NeighborStat,
)
from deepmd.utils.path import (
    DPPath,
)

__all__ = ["train"]

log = logging.getLogger(__name__)


[docs]def train( *, INPUT: str, init_model: Optional[str], restart: Optional[str], output: str, init_frz_model: str, mpi_log: str, log_level: int, log_path: Optional[str], is_compress: bool = False, skip_neighbor_stat: bool = False, finetune: Optional[str] = None, **kwargs, ): """Run DeePMD model training. Parameters ---------- INPUT : str json/yaml control file init_model : Optional[str] path prefix of checkpoint files or None restart : Optional[str] path prefix of checkpoint files or None output : str path for dump file with arguments init_frz_model : str path to frozen model or None mpi_log : str mpi logging mode log_level : int logging level defined by int 0-3 log_path : Optional[str] logging file path or None if logs are to be output only to stdout is_compress : bool indicates whether in the model compress mode skip_neighbor_stat : bool, default=False skip checking neighbor statistics finetune : Optional[str] path to pretrained model or None **kwargs additional arguments Raises ------ RuntimeError if distributed training job name is wrong """ run_opt = RunOptions( init_model=init_model, restart=restart, init_frz_model=init_frz_model, finetune=finetune, log_path=log_path, log_level=log_level, mpi_log=mpi_log, ) if run_opt.is_distrib and len(run_opt.gpus or []) > 1: # avoid conflict of visible gpus among multipe tf sessions in one process reset_default_tf_session_config(cpu_only=True) # load json database jdata = j_loader(INPUT) origin_type_map = None if run_opt.finetune is not None: jdata, origin_type_map = replace_model_params_with_pretrained_model( jdata, run_opt.finetune ) if "fitting_net_dict" in jdata["model"] and run_opt.init_frz_model is not None: jdata = replace_model_params_with_frz_multi_model(jdata, run_opt.init_frz_model) jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json") jdata = normalize(jdata) if not is_compress and not skip_neighbor_stat: jdata = update_sel(jdata) with open(output, "w") as fp: json.dump(jdata, fp, indent=4) # save the training script into the graph # remove white spaces as it is not compressed tf.constant( json.dumps(jdata, separators=(",", ":")), name="train_attr/training_script", dtype=tf.string, ) for message in WELCOME + CITATION + BUILD: log.info(message) run_opt.print_resource_summary() if origin_type_map is not None: jdata["model"]["origin_type_map"] = origin_type_map _do_work(jdata, run_opt, is_compress)
def _do_work(jdata: Dict[str, Any], run_opt: RunOptions, is_compress: bool = False): """Run serial model training. Parameters ---------- jdata : Dict[str, Any] arguments read form json/yaml control file run_opt : RunOptions object with run configuration is_compress : Bool indicates whether in model compress mode Raises ------ RuntimeError If unsupported modifier type is selected for model """ # make necessary checks assert "training" in jdata # init the model model = DPTrainer(jdata, run_opt=run_opt, is_compress=is_compress) rcut = model.model.get_rcut() type_map = model.model.get_type_map() if len(type_map) == 0: ipt_type_map = None else: ipt_type_map = type_map # init random seed of data systems seed = jdata["training"].get("seed", None) if seed is not None: # avoid the same batch sequence among workers seed += run_opt.my_rank seed = seed % (2**32) dp_random.seed(seed) # setup data modifier modifier = get_modifier(jdata["model"].get("modifier", None)) # check the multi-task mode multi_task_mode = "fitting_net_dict" in jdata["model"] # decouple the training data from the model compress process train_data = None valid_data = None if not is_compress: # init data if not multi_task_mode: train_data = get_data( jdata["training"]["training_data"], rcut, ipt_type_map, modifier ) train_data.print_summary("training") if jdata["training"].get("validation_data", None) is not None: valid_data = get_data( jdata["training"]["validation_data"], rcut, train_data.type_map, modifier, ) valid_data.print_summary("validation") else: train_data = {} valid_data = {} for data_systems in jdata["training"]["data_dict"]: if ( jdata["training"]["fitting_weight"][data_systems] > 0.0 ): # check only the available pair train_data[data_systems] = get_data( jdata["training"]["data_dict"][data_systems]["training_data"], rcut, ipt_type_map, modifier, multi_task_mode, ) train_data[data_systems].print_summary( f"training in {data_systems}" ) if ( jdata["training"]["data_dict"][data_systems].get( "validation_data", None ) is not None ): valid_data[data_systems] = get_data( jdata["training"]["data_dict"][data_systems][ "validation_data" ], rcut, train_data[data_systems].type_map, modifier, multi_task_mode, ) valid_data[data_systems].print_summary( f"validation in {data_systems}" ) else: if modifier is not None: modifier.build_fv_graph() # get training info stop_batch = j_must_have(jdata["training"], "numb_steps") origin_type_map = jdata["model"].get("origin_type_map", None) if ( origin_type_map is not None and not origin_type_map ): # get the type_map from data if not provided origin_type_map = get_data( jdata["training"]["training_data"], rcut, None, modifier ).get_type_map() model.build(train_data, stop_batch, origin_type_map=origin_type_map) if not is_compress: # train the model with the provided systems in a cyclic way start_time = time.time() model.train(train_data, valid_data) end_time = time.time() log.info("finished training") log.info(f"wall time: {(end_time - start_time):.3f} s") else: model.save_compressed() log.info("finished compressing") def get_data(jdata: Dict[str, Any], rcut, type_map, modifier, multi_task_mode=False): systems = j_must_have(jdata, "systems") if isinstance(systems, str): systems = expand_sys_str(systems) elif isinstance(systems, list): systems = systems.copy() help_msg = "Please check your setting for data systems" # check length of systems if len(systems) == 0: msg = "cannot find valid a data system" log.fatal(msg) raise OSError(msg, help_msg) # rougly check all items in systems are valid for ii in systems: ii = DPPath(ii) if not ii.is_dir(): msg = f"dir {ii} is not a valid dir" log.fatal(msg) raise OSError(msg, help_msg) if not (ii / "type.raw").is_file(): msg = f"dir {ii} is not a valid data system dir" log.fatal(msg) raise OSError(msg, help_msg) batch_size = j_must_have(jdata, "batch_size") sys_probs = jdata.get("sys_probs", None) auto_prob = jdata.get("auto_prob", "prob_sys_size") optional_type_map = not multi_task_mode data = DeepmdDataSystem( systems=systems, batch_size=batch_size, test_size=1, # to satisfy the old api shuffle_test=True, # to satisfy the old api rcut=rcut, type_map=type_map, optional_type_map=optional_type_map, modifier=modifier, trn_all_set=True, # sample from all sets sys_probs=sys_probs, auto_prob_style=auto_prob, ) data.add_dict(data_requirement) return data def get_modifier(modi_data=None): modifier: Optional[DipoleChargeModifier] if modi_data is not None: if modi_data["type"] == "dipole_charge": modifier = DipoleChargeModifier( modi_data["model_name"], modi_data["model_charge_map"], modi_data["sys_charge_map"], modi_data["ewald_h"], modi_data["ewald_beta"], ) else: raise RuntimeError("unknown modifier type " + str(modi_data["type"])) else: modifier = None return modifier def get_rcut(jdata): if jdata["model"].get("type") == "pairwise_dprc": return max( jdata["model"]["qm_model"]["descriptor"]["rcut"], jdata["model"]["qmmm_model"]["descriptor"]["rcut"], ) descrpt_data = jdata["model"]["descriptor"] rcut_list = [] if descrpt_data["type"] == "hybrid": for ii in descrpt_data["list"]: rcut_list.append(ii["rcut"]) else: rcut_list.append(descrpt_data["rcut"]) return max(rcut_list) def get_type_map(jdata): return jdata["model"].get("type_map", None) def get_nbor_stat(jdata, rcut, one_type: bool = False): # it seems that DeepmdDataSystem does not need rcut # it's not clear why there is an argument... # max_rcut = get_rcut(jdata) max_rcut = rcut type_map = get_type_map(jdata) if type_map and len(type_map) == 0: type_map = None multi_task_mode = "data_dict" in jdata["training"] if not multi_task_mode: train_data = get_data( jdata["training"]["training_data"], max_rcut, type_map, None ) train_data.get_batch() else: assert ( type_map is not None ), "Data stat in multi-task mode must have available type_map! " train_data = None for systems in jdata["training"]["data_dict"]: tmp_data = get_data( jdata["training"]["data_dict"][systems]["training_data"], max_rcut, type_map, None, ) tmp_data.get_batch() assert tmp_data.get_type_map(), f"In multi-task mode, 'type_map.raw' must be defined in data systems {systems}! " if train_data is None: train_data = tmp_data else: train_data.system_dirs += tmp_data.system_dirs train_data.data_systems += tmp_data.data_systems train_data.natoms += tmp_data.natoms train_data.natoms_vec += tmp_data.natoms_vec train_data.default_mesh += tmp_data.default_mesh data_ntypes = train_data.get_ntypes() if type_map is not None: map_ntypes = len(type_map) else: map_ntypes = data_ntypes ntypes = max([map_ntypes, data_ntypes]) neistat = NeighborStat(ntypes, rcut, one_type=one_type) min_nbor_dist, max_nbor_size = neistat.get_stat(train_data) # moved from traier.py as duplicated # TODO: this is a simple fix but we should have a clear # architecture to call neighbor stat tf.constant( min_nbor_dist, name="train_attr/min_nbor_dist", dtype=GLOBAL_ENER_FLOAT_PRECISION, ) tf.constant(max_nbor_size, name="train_attr/max_nbor_size", dtype=tf.int32) return min_nbor_dist, max_nbor_size def get_sel(jdata, rcut, one_type: bool = False): _, max_nbor_size = get_nbor_stat(jdata, rcut, one_type=one_type) return max_nbor_size def get_min_nbor_dist(jdata, rcut): min_nbor_dist, _ = get_nbor_stat(jdata, rcut) return min_nbor_dist def parse_auto_sel(sel): if not isinstance(sel, str): return False words = sel.split(":") if words[0] == "auto": return True else: return False def parse_auto_sel_ratio(sel): if not parse_auto_sel(sel): raise RuntimeError(f"invalid auto sel format {sel}") else: words = sel.split(":") if len(words) == 1: ratio = 1.1 elif len(words) == 2: ratio = float(words[1]) else: raise RuntimeError(f"invalid auto sel format {sel}") return ratio def wrap_up_4(xx): return 4 * ((int(xx) + 3) // 4) def update_one_sel(jdata, descriptor, one_type: bool = False): rcut = descriptor["rcut"] tmp_sel = get_sel( jdata, rcut, one_type=one_type, ) sel = descriptor["sel"] if isinstance(sel, int): # convert to list and finnally convert back to int sel = [sel] if parse_auto_sel(descriptor["sel"]): ratio = parse_auto_sel_ratio(descriptor["sel"]) descriptor["sel"] = sel = [int(wrap_up_4(ii * ratio)) for ii in tmp_sel] else: # sel is set by user for ii, (tt, dd) in enumerate(zip(tmp_sel, sel)): if dd and tt > dd: # we may skip warning for sel=0, where the user is likely # to exclude such type in the descriptor log.warning( "sel of type %d is not enough! The expected value is " "not less than %d, but you set it to %d. The accuracy" " of your model may get worse." % (ii, tt, dd) ) if one_type: descriptor["sel"] = sel = sum(sel) return descriptor def update_sel(jdata): log.info( "Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)" ) jdata_cpy = jdata.copy() jdata_cpy["model"] = Model.update_sel(jdata, jdata["model"]) return jdata_cpy