# SPDX-License-Identifier: LGPL-3.0-or-later
import contextlib
import datetime
import functools
import logging
import time
from copy import (
deepcopy,
)
from pathlib import (
Path,
)
from typing import (
Any,
)
import numpy as np
import paddle
import paddle.distributed as dist
from paddle.distributed import (
fleet,
)
from paddle.distributed.fleet.utils import hybrid_parallel_util as hpu
from paddle.framework import (
core,
)
from paddle.io import (
DataLoader,
)
from deepmd.common import (
symlink_prefix_files,
)
from deepmd.dpmodel.utils import (
compute_total_numb_batch,
resolve_model_prob,
resolve_model_prob_from_epochs,
)
from deepmd.dpmodel.utils.learning_rate import (
BaseLR,
)
from deepmd.loggers.training import (
format_training_message,
format_training_message_per_task,
)
from deepmd.pd.loss import (
EnergyHessianStdLoss,
EnergyStdLoss,
TaskLoss,
)
from deepmd.pd.model.model import (
get_model,
)
from deepmd.pd.train.wrapper import (
ModelWrapper,
)
from deepmd.pd.utils import (
dp_random,
)
from deepmd.pd.utils.dataloader import (
BufferedIterator,
get_sampler_from_params,
)
from deepmd.pd.utils.env import (
CINN,
CINN_ALLOW_DYNAMIC_SHAPE,
DEFAULT_PRECISION,
DEVICE,
JIT,
NUM_WORKERS,
SAMPLER_RECORD,
enable_prim,
)
from deepmd.pd.utils.stat import (
make_stat_input,
)
from deepmd.pd.utils.utils import (
nvprof_context,
to_numpy_array,
)
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.path import (
DPH5Path,
)
[docs]
log = logging.getLogger(__name__)
[docs]
class Trainer:
def __init__(
self,
config: dict[str, Any],
training_data: Any,
stat_file_path: str | Path | None = None,
validation_data: Any | None = None,
init_model: str | None = None,
restart_model: str | None = None,
finetune_model: str | None = None,
force_load: bool = False,
shared_links: dict[str, Any] | None = None,
finetune_links: dict[str, Any] | None = None,
init_frz_model: str | None = None,
) -> None:
"""Construct a DeePMD trainer.
Args:
- config: The Dict-like configuration with training options.
"""
enable_prim(True)
if init_model is not None:
resume_model = init_model
elif restart_model is not None:
resume_model = restart_model
elif finetune_model is not None:
resume_model = finetune_model
else:
resume_model = None
resuming = resume_model is not None
[docs]
self.restart_training = restart_model is not None
model_params = config["model"]
training_params = config["training"]
optimizer_params = config.get("optimizer", {})
[docs]
self.multi_task = "model_dict" in model_params
[docs]
self.finetune_links = finetune_links
[docs]
self.finetune_update_stat = False
[docs]
self.model_keys = (
list(model_params["model_dict"]) if self.multi_task else ["Default"]
)
[docs]
self.rank = (
dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
)
[docs]
self.world_size = (
dist.get_world_size()
if dist.is_available() and dist.is_initialized()
else 1
)
[docs]
self.num_model = len(self.model_keys)
# Iteration config
[docs]
self.num_steps = training_params.get("numb_steps")
[docs]
self.num_epoch = training_params.get("numb_epoch")
[docs]
self.num_epoch_dict = training_params.get("num_epoch_dict")
[docs]
self.acc_freq: int = training_params.get(
"acc_freq", 1
) # gradient accumulation steps
[docs]
self.disp_file = training_params.get("disp_file", "lcurve.out")
[docs]
self.disp_freq = training_params.get("disp_freq", 1000)
[docs]
self.save_ckpt = training_params.get("save_ckpt", "model.ckpt")
[docs]
self.save_freq = training_params.get("save_freq", 1000)
[docs]
self.max_ckpt_keep = training_params.get("max_ckpt_keep", 5)
[docs]
self.display_in_training = training_params.get("disp_training", True)
[docs]
self.timing_in_training = training_params.get("time_training", True)
[docs]
self.change_bias_after_training = training_params.get(
"change_bias_after_training", False
)
def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
"""
Extract optimizer parameters.
Note: Default values are already filled by argcheck.normalize()
before this function is called.
"""
opt_type = params.get("type", "Adam")
if opt_type not in ["Adam", "AdamW"]:
raise ValueError(f"Not supported optimizer type '{opt_type}'")
opt_param = dict(params)
opt_param.pop("type", None)
return opt_type, opt_param
def get_data_loader(
_training_data: Any, _validation_data: Any, _training_params: dict[str, Any]
) -> tuple[Any, Any, Any, Any]:
def get_dataloader_and_buffer(
_data: Any, _params: dict[str, Any]
) -> tuple[Any, Any]:
_sampler = get_sampler_from_params(_data, _params)
if _sampler is None:
log.warning(
"Sampler not specified!"
) # None sampler will lead to a premature stop iteration. Replacement should be True in attribute of the sampler to produce expected number of items in one iteration.
_dataloader = DataLoader(
_data,
batch_sampler=paddle.io.BatchSampler(
sampler=_sampler,
drop_last=False,
),
num_workers=NUM_WORKERS
if dist.is_available()
else 0, # setting to 0 diverges the behavior of its iterator; should be >=1
collate_fn=lambda batch: batch[0], # prevent extra conversion
)
_data_buffered = BufferedIterator(iter(_dataloader))
return _dataloader, _data_buffered
training_dataloader, training_data_buffered = get_dataloader_and_buffer(
_training_data, _training_params["training_data"]
)
if _validation_data is not None:
(
validation_dataloader,
validation_data_buffered,
) = get_dataloader_and_buffer(
_validation_data, _training_params["validation_data"]
)
valid_numb_batch = _training_params["validation_data"].get(
"numb_btch", 1
)
else:
validation_dataloader = None
validation_data_buffered = None
valid_numb_batch = 1
return (
training_dataloader,
training_data_buffered,
validation_dataloader,
validation_data_buffered,
valid_numb_batch,
)
def single_model_stat(
_model: Any,
_data_stat_nbatch: int,
_training_data: Any,
_validation_data: Any | None,
_stat_file_path: str | Path | None,
_data_requirement: list[DataRequirementItem],
finetune_has_new_type: bool = False,
) -> Any:
_data_requirement += get_additional_data_requirement(_model)
_training_data.add_data_requirement(_data_requirement)
if _validation_data is not None:
_validation_data.add_data_requirement(_data_requirement)
@functools.lru_cache
def get_sample() -> dict[str, Any]:
sampled = make_stat_input(
_training_data.systems,
_training_data.dataloaders,
_data_stat_nbatch,
)
return sampled
if (not resuming or finetune_has_new_type) and self.rank == 0:
_model.compute_or_load_stat(
sampled_func=get_sample,
stat_file_path=_stat_file_path,
)
if isinstance(_stat_file_path, DPH5Path):
_stat_file_path.root.close()
return get_sample
def get_lr(lr_params: dict[str, Any]) -> BaseLR:
lr_params["num_steps"] = self.num_steps
lr_schedule = BaseLR(**lr_params)
return lr_schedule
# Optimizer
self.opt_type, self.opt_param = get_opt_param(optimizer_params)
# loss_param_tmp for Hessian activation
loss_param_tmp = None
if not self.multi_task:
loss_param_tmp = config["loss"]
else:
loss_param_tmp = {
model_key: config["loss_dict"][model_key]
for model_key in self.model_keys
}
# Model
[docs]
self.model = get_model_for_wrapper(
model_params,
resuming=resuming,
_loss_params=loss_param_tmp,
)
# Loss
if not self.multi_task:
self.loss = get_loss(
config["loss"],
config["learning_rate"]["start_lr"],
len(model_params["type_map"]),
self.model,
)
else:
self.loss = {}
for model_key in self.model_keys:
loss_param = config["loss_dict"][model_key]
lr_param = config["learning_rate"]["start_lr"]
ntypes = len(model_params["model_dict"][model_key]["type_map"])
self.loss[model_key] = get_loss(
loss_param, lr_param, ntypes, self.model[model_key]
)
# Data
if not self.multi_task:
self.get_sample_func = single_model_stat(
self.model,
model_params.get("data_stat_nbatch", 10),
training_data,
validation_data,
stat_file_path,
self.loss.label_requirement,
finetune_has_new_type=self.finetune_links["Default"].get_has_new_type()
if self.finetune_links is not None
else False,
)
(
self.training_dataloader,
self.training_data,
self.validation_dataloader,
self.validation_data,
self.valid_numb_batch,
) = get_data_loader(training_data, validation_data, training_params)
training_data.print_summary(
"training",
to_numpy_array(self.training_dataloader.batch_sampler.sampler.weights),
)
if validation_data is not None:
validation_data.print_summary(
"validation",
to_numpy_array(
self.validation_dataloader.batch_sampler.sampler.weights
),
)
else:
(
self.training_dataloader,
self.training_data,
self.validation_dataloader,
self.validation_data,
self.valid_numb_batch,
self.get_sample_func,
) = {}, {}, {}, {}, {}, {}
for model_key in self.model_keys:
self.get_sample_func[model_key] = single_model_stat(
self.model[model_key],
model_params["model_dict"][model_key].get("data_stat_nbatch", 10),
training_data[model_key],
validation_data[model_key],
stat_file_path[model_key],
self.loss[model_key].label_requirement,
finetune_has_new_type=self.finetune_links[
model_key
].get_has_new_type()
if self.finetune_links is not None
else False,
)
(
self.training_dataloader[model_key],
self.training_data[model_key],
self.validation_dataloader[model_key],
self.validation_data[model_key],
self.valid_numb_batch[model_key],
) = get_data_loader(
training_data[model_key],
validation_data[model_key],
training_params["data_dict"][model_key],
)
training_data[model_key].print_summary(
f"training in {model_key}",
to_numpy_array(
self.training_dataloader[
model_key
].batch_sampler.sampler.weights
),
)
if (
validation_data is not None
and validation_data[model_key] is not None
):
validation_data[model_key].print_summary(
f"validation in {model_key}",
to_numpy_array(
self.validation_dataloader[
model_key
].batch_sampler.sampler.weights
),
)
per_task_total = []
if not self.multi_task:
if self.num_steps is None:
if self.num_epoch is None:
raise ValueError(
"Either training.numb_steps or training.num_epoch must be set."
)
if self.num_epoch <= 0:
raise ValueError("training.num_epoch must be positive.")
sampler_weights = to_numpy_array(
self.training_dataloader.batch_sampler.sampler.weights
)
total_numb_batch = compute_total_numb_batch(
training_data.index,
sampler_weights,
)
if total_numb_batch <= 0:
raise ValueError(
"Total number of training batches must be positive."
)
self.num_steps = int(np.ceil(self.num_epoch * total_numb_batch))
log.info(
"Computed num_steps=%d from num_epoch=%s and total_numb_batch=%d.",
self.num_steps,
self.num_epoch,
total_numb_batch,
)
else:
if self.num_epoch_dict:
if self.num_steps is not None:
raise ValueError(
"training.numb_steps and training.num_epoch_dict "
"are mutually exclusive."
)
for model_key in self.model_keys:
sampler_weights = to_numpy_array(
self.training_dataloader[
model_key
].batch_sampler.sampler.weights
)
per_task_total.append(
compute_total_numb_batch(
training_data[model_key].index,
sampler_weights,
)
)
(
self.model_prob,
self.num_steps,
per_task_steps,
) = resolve_model_prob_from_epochs(
self.model_keys,
self.num_epoch_dict,
np.asarray(per_task_total, dtype=np.float64),
)
log.info(
"Computed model_prob=%s and num_steps=%d from num_epoch_dict=%s "
"with per-task target steps: %s.",
self.model_prob,
self.num_steps,
self.num_epoch_dict,
{k: int(np.ceil(v)) for k, v in per_task_steps.items()},
)
else:
if self.num_steps is None:
raise ValueError(
"Either training.numb_steps (multi-task only) or "
"training.num_epoch_dict must be set."
)
self.model_prob = resolve_model_prob(
self.model_keys,
training_params.get("model_prob"),
training_data,
rank=self.rank,
)
# Learning rate
[docs]
self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0)
[docs]
self.lr_schedule = get_lr(config["learning_rate"])
# JIT
if JIT:
raise NotImplementedError(
"JIT is not supported yet when training with Paddle"
)
self.model = paddle.jit.to_static(self.model)
# Model Wrapper
[docs]
self.wrapper = ModelWrapper(self.model, self.loss, model_params=model_params)
# resuming and finetune
optimizer_state_dict = None
if resuming:
log.info(f"Resuming from {resume_model}.")
state_dict = paddle.load(resume_model)
if "model" in state_dict:
optimizer_state_dict = (
state_dict["optimizer"] if finetune_model is None else None
)
state_dict = state_dict["model"]
self.start_step = (
state_dict["_extra_state"]["train_infos"]["step"]
if self.restart_training
else 0
)
if self.rank == 0:
if force_load:
input_keys = list(state_dict.keys())
target_keys = list(self.wrapper.state_dict().keys())
missing_keys = [
item for item in target_keys if item not in input_keys
]
if missing_keys:
target_state_dict = self.wrapper.state_dict()
slim_keys = []
for item in missing_keys:
state_dict[item] = target_state_dict[item].clone().detach()
new_key = True
for slim_key in slim_keys:
if slim_key in item:
new_key = False
break
if new_key:
tmp_keys = ".".join(item.split(".")[:3])
slim_keys.append(tmp_keys)
slim_keys = [i + ".*" for i in slim_keys]
log.warning(
f"Force load mode allowed! These keys are not in ckpt and will re-init: {slim_keys}"
)
# update model params in the pretrained model
if finetune_model is not None:
new_state_dict = {}
target_state_dict = self.wrapper.state_dict()
# pretrained_model
pretrained_model = get_model_for_wrapper(
state_dict["_extra_state"]["model_params"]
)
pretrained_model_wrapper = ModelWrapper(pretrained_model)
pretrained_model_wrapper.set_state_dict(state_dict)
# update type related params
for model_key in self.model_keys:
finetune_rule_single = self.finetune_links[model_key]
_model_key_from = finetune_rule_single.get_model_branch()
# skip if updated
if (
finetune_rule_single.get_finetune_tmap()
!= pretrained_model_wrapper.model[
_model_key_from
].get_type_map()
):
model_with_new_type_stat = None
if finetune_rule_single.get_has_new_type():
self.finetune_update_stat = True
model_with_new_type_stat = self.wrapper.model[model_key]
pretrained_model_wrapper.model[
_model_key_from
].change_type_map(
finetune_rule_single.get_finetune_tmap(),
model_with_new_type_stat=model_with_new_type_stat,
)
state_dict = pretrained_model_wrapper.state_dict()
def collect_single_finetune_params(
_model_key: str,
_finetune_rule_single: Any,
_new_state_dict: dict,
_origin_state_dict: dict,
_random_state_dict: dict,
) -> None:
_new_fitting = _finetune_rule_single.get_random_fitting()
_model_key_from = _finetune_rule_single.get_model_branch()
target_keys = [
i
for i in _random_state_dict.keys()
if i != "_extra_state" and f".{_model_key}." in i
]
for item_key in target_keys:
if _new_fitting and (".descriptor." not in item_key):
# print(f'Keep {item_key} in old model!')
_new_state_dict[item_key] = (
_random_state_dict[item_key].clone().detach()
)
else:
new_key = item_key.replace(
f".{_model_key}.", f".{_model_key_from}."
)
# print(f'Replace {item_key} with {new_key} in pretrained_model!')
_new_state_dict[item_key] = (
_origin_state_dict[new_key].clone().detach()
)
# collect model params from the pretrained model
for model_key in self.model_keys:
finetune_rule_single = self.finetune_links[model_key]
collect_single_finetune_params(
model_key,
finetune_rule_single,
new_state_dict,
state_dict,
target_state_dict,
)
state_dict = new_state_dict
state_dict["_extra_state"] = self.wrapper.state_dict()[
"_extra_state"
]
self.wrapper.set_state_dict(state_dict)
# change bias for fine-tuning
if finetune_model is not None:
def single_model_finetune(
_model: Any,
_finetune_rule_single: Any,
_sample_func: Any,
) -> Any:
_model = model_change_out_bias(
_model,
_sample_func,
_bias_adjust_mode="change-by-statistic"
if not _finetune_rule_single.get_random_fitting()
else "set-by-statistic",
)
return _model
if not self.multi_task:
finetune_rule_single = self.finetune_links["Default"]
self.model = single_model_finetune(
self.model, finetune_rule_single, self.get_sample_func
)
else:
for model_key in self.model_keys:
finetune_rule_single = self.finetune_links[model_key]
if not finetune_rule_single.get_resuming():
log.info(
f"Model branch {model_key} will be fine-tuned. This may take a long time..."
)
self.model[model_key] = single_model_finetune(
self.model[model_key],
finetune_rule_single,
self.get_sample_func[model_key],
)
else:
log.info(
f"Model branch {model_key} will resume training."
)
if init_frz_model is not None:
frz_model = paddle.jit.load(init_frz_model)
self.model.set_state_dict(frz_model.state_dict())
# Multi-task share params
if shared_links is not None:
self.wrapper.share_params(
shared_links,
resume=(resuming and not self.finetune_update_stat) or self.rank != 0,
)
# TODO add lr warmups for multitask
# author: iProzd
# TODO add optimizers for multitask
# author: iProzd
if self.opt_type in ["Adam", "AdamW"]:
self.scheduler = paddle.optimizer.lr.LambdaDecay(
learning_rate=self.lr_schedule.start_lr,
lr_lambda=lambda step: (
self.lr_schedule.value(step + self.start_step)
/ self.lr_schedule.start_lr
),
)
opt_cls = (
paddle.optimizer.Adam
if self.opt_type == "Adam"
else paddle.optimizer.AdamW
)
self.optimizer = opt_cls(
learning_rate=self.scheduler,
parameters=self.wrapper.parameters(),
beta1=float(self.opt_param["adam_beta1"]),
beta2=float(self.opt_param["adam_beta2"]),
weight_decay=float(self.opt_param["weight_decay"]),
)
if optimizer_state_dict is not None and self.restart_training:
self.optimizer.set_state_dict(optimizer_state_dict)
self.scheduler.last_epoch -= 1
else:
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")
# NOTE: to_static + compiler should be before distributed wrapper
if CINN:
from paddle import (
jit,
static,
)
backend = "CINN" if CINN else None
if CINN_ALLOW_DYNAMIC_SHAPE:
# Build spec only for keys present in sample data
# NOTE: This is a trick to decide the right input_spec for wrapper.forward
_, label_dict, _ = self.get_data(is_train=True)
# Define specification templates
spec_templates = {
"find_box": np.float32(1.0),
"find_coord": np.float32(1.0),
"find_numb_copy": np.float32(0.0),
"numb_copy": static.InputSpec([1, 1], "int64", name="numb_copy"),
"find_energy": np.float32(1.0),
"energy": static.InputSpec([1, 1], "float64", name="energy"),
"find_force": np.float32(1.0),
"force": static.InputSpec([1, -1, 3], "float64", name="force"),
"find_virial": np.float32(0.0),
"virial": static.InputSpec([1, 9], "float64", name="virial"),
"natoms": static.InputSpec([1, -1], "int32", name="natoms"),
}
label_dict_spec = {
k: spec_templates[k]
for k in label_dict.keys()
if k in spec_templates
}
self.wrapper.forward = jit.to_static(
backend=backend,
input_spec=[
static.InputSpec([1, -1, 3], "float64", name="coord"), # coord
static.InputSpec([1, -1], "int32", name="atype"), # atype
None, # spin
static.InputSpec([1, 9], "float64", name="box"), # box
static.InputSpec([], "float64", name="cur_lr"), # cur_lr
label_dict_spec, # label,
# None, # task_key
# False, # inference_only
# False, # do_atomic_virial
# None, # fparam
# None, # aparam
],
full_graph=True,
)(self.wrapper.forward)
else:
self.wrapper.forward = jit.to_static(full_graph=True, backend=backend)(
self.wrapper.forward
)
log.info(
"[CINN] Enable CINN during training, there may be some additional "
"compilation time in the first training step."
)
if not CINN_ALLOW_DYNAMIC_SHAPE:
log.info(
"[CINN] Dynamic shape is disabled (CINN_ALLOW_DYNAMIC_SHAPE=0). "
"Make sure the input batch shapes are fixed during training. "
"This is recommended for optimal performance, e.g., as in examples/water."
)
log.info(
"[CINN] If batch data from your dataset(s) has varying input shapes, consider setting "
"CINN_ALLOW_DYNAMIC_SHAPE=1 to enable dynamic shape support."
)
if dist.is_available() and dist.is_initialized():
# DDP will guarantee the model parameters are identical across all processes
self.wrapper = fleet.distributed_model(
self.wrapper,
# find_unused_parameters=True,
)
self.optimizer = fleet.distributed_optimizer(self.optimizer)
# Tensorboard
[docs]
self.enable_tensorboard = training_params.get("tensorboard", False)
[docs]
self.tensorboard_log_dir = training_params.get("tensorboard_log_dir", "log")
[docs]
self.tensorboard_freq = training_params.get("tensorboard_freq", 1)
[docs]
self.enable_profiler = training_params.get("enable_profiler", False)
[docs]
self.profiling = training_params.get("profiling", False)
[docs]
self.profiling_file = training_params.get("profiling_file", "timeline.json")
[docs]
def run(self) -> None:
fout = (
open(
self.disp_file,
mode="w" if not self.restart_training else "a",
buffering=1,
)
if self.rank == 0
else None
) # line buffered
if SAMPLER_RECORD:
record_file = f"Sample_rank_{self.rank}.txt"
fout1 = open(record_file, mode="w", buffering=1)
log.info("Start to train %d steps.", self.num_steps)
if dist.is_available() and dist.is_initialized():
log.info(f"Rank: {dist.get_rank()}/{dist.get_world_size()}")
if self.enable_tensorboard:
from tensorboardX import (
SummaryWriter,
)
writer = SummaryWriter(log_dir=self.tensorboard_log_dir)
enable_profiling = self.enable_profiler or self.profiling
if enable_profiling:
core.nvprof_start()
core.nvprof_enable_record_event()
def step(_step_id: int, task_key: str = "Default") -> None:
if self.multi_task:
model_index = dp_random.choice(
np.arange(self.num_model, dtype=np.int_),
p=self.model_prob,
)
task_key = self.model_keys[model_index]
# Paddle Profiler
if enable_profiling:
core.nvprof_nvtx_push(f"Training step {_step_id}")
cur_lr = self.lr_schedule.value(_step_id)
pref_lr = cur_lr
with nvprof_context(enable_profiling, "Fetching data"):
input_dict, label_dict, log_dict = self.get_data(
is_train=True, task_key=task_key
)
if SAMPLER_RECORD:
print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n"
fout1.write(print_str)
fout1.flush()
if self.opt_type in ["Adam", "AdamW"]:
cur_lr = self.scheduler.get_lr()
pref_lr = cur_lr
# disable synchronization in forward-backward manually
# as derivatives exist in model forward
no_sync_context = (
self.wrapper.no_sync
if self.world_size > 1
else contextlib.nullcontext
)
with no_sync_context():
with nvprof_context(enable_profiling, "Forward pass"):
model_pred, loss, more_loss = self.wrapper(
**input_dict,
cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
label=label_dict,
task_key=task_key,
)
with nvprof_context(enable_profiling, "Backward pass"):
loss.backward()
# gradient accumulation
if (_step_id + 1) % self.acc_freq == 0:
# fuse + allreduce manually before optimization if use DDP + no_sync
# details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622
if self.world_size > 1:
hpu.fused_allreduce_gradients(
list(self.wrapper.parameters()), None
)
if self.gradient_max_norm > 0.0:
with nvprof_context(enable_profiling, "Gradient clip"):
paddle.nn.utils.clip_grad_norm_(
self.wrapper.parameters(),
self.gradient_max_norm,
error_if_nonfinite=True,
)
with nvprof_context(enable_profiling, "Optimizer update"):
self.optimizer.step()
self.optimizer.clear_grad(set_to_zero=False)
self.scheduler.step()
else:
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")
# Log and persist
display_step_id = _step_id + 1
if self.display_in_training and (
display_step_id % self.disp_freq == 0 or display_step_id == 1
):
self.wrapper.eval() # Will set to train mode before fininshing validation
def log_loss_train(
_loss: Any, _more_loss: dict, _task_key: str = "Default"
) -> dict:
results = {}
rmse_val = {
item: _more_loss[item]
for item in _more_loss
if "l2_" not in item
}
for item in sorted(rmse_val.keys()):
results[item] = rmse_val[item]
return results
def log_loss_valid(_task_key: str = "Default") -> dict:
single_results = {}
sum_natoms = 0
if not self.multi_task:
valid_numb_batch = self.valid_numb_batch
else:
valid_numb_batch = self.valid_numb_batch[_task_key]
for ii in range(valid_numb_batch):
self.optimizer.clear_grad()
input_dict, label_dict, _ = self.get_data(
is_train=False, task_key=_task_key
)
if input_dict == {}:
# no validation data
return {}
_, loss, more_loss = self.wrapper(
**input_dict,
cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
label=label_dict,
task_key=_task_key,
)
# more_loss.update({"rmse": math.sqrt(loss)})
natoms = int(input_dict["atype"].shape[-1])
sum_natoms += natoms
for k, v in more_loss.items():
if "l2_" not in k:
single_results[k] = (
single_results.get(k, 0.0) + v * natoms
)
results = {k: v / sum_natoms for k, v in single_results.items()}
return results
if not self.multi_task:
train_results = log_loss_train(loss, more_loss)
valid_results = log_loss_valid()
if self.rank == 0:
log.info(
format_training_message_per_task(
batch=display_step_id,
task_name="trn",
rmse=train_results,
learning_rate=cur_lr,
)
)
if valid_results:
log.info(
format_training_message_per_task(
batch=display_step_id,
task_name="val",
rmse=valid_results,
learning_rate=None,
)
)
else:
train_results = {_key: {} for _key in self.model_keys}
valid_results = {_key: {} for _key in self.model_keys}
train_results[task_key] = log_loss_train(
loss, more_loss, _task_key=task_key
)
for _key in self.model_keys:
if _key != task_key:
self.optimizer.clear_grad()
input_dict, label_dict, _ = self.get_data(
is_train=True, task_key=_key
)
_, loss, more_loss = self.wrapper(
**input_dict,
cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
label=label_dict,
task_key=_key,
)
train_results[_key] = log_loss_train(
loss, more_loss, _task_key=_key
)
valid_results[_key] = log_loss_valid(_task_key=_key)
if self.rank == 0:
log.info(
format_training_message_per_task(
batch=display_step_id,
task_name=_key + "_trn",
rmse=train_results[_key],
learning_rate=cur_lr,
)
)
if valid_results[_key]:
log.info(
format_training_message_per_task(
batch=display_step_id,
task_name=_key + "_val",
rmse=valid_results[_key],
learning_rate=None,
)
)
self.wrapper.train()
current_time = time.time()
train_time = current_time - self.t0
self.t0 = current_time
if self.rank == 0 and self.timing_in_training:
eta = int(
(self.num_steps - display_step_id)
/ min(self.disp_freq, display_step_id - self.start_step)
* train_time
)
log.info(
format_training_message(
batch=display_step_id,
wall_time=train_time,
eta=eta,
current_time=datetime.datetime.fromtimestamp(
current_time,
tz=datetime.timezone.utc,
).astimezone(),
)
)
# the first training time is not accurate
if (
(_step_id + 1 - self.start_step) > self.disp_freq
or self.num_steps - self.start_step < 2 * self.disp_freq
):
self.total_train_time += train_time
if fout:
if self.lcurve_should_print_header:
self.print_header(fout, train_results, valid_results)
self.lcurve_should_print_header = False
self.print_on_training(
fout, display_step_id, cur_lr, train_results, valid_results
)
if (
((_step_id + 1) % self.save_freq == 0 and _step_id != self.start_step)
or (_step_id + 1) == self.num_steps
) and (self.rank == 0 or dist.get_rank() == 0):
# Handle the case if rank 0 aborted and re-assigned
self.latest_model = Path(self.save_ckpt + f"-{_step_id + 1}.pd")
self.save_model(self.latest_model, lr=cur_lr, step=_step_id)
log.info(f"Saved model to {self.latest_model}")
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
with open("checkpoint", "w") as f:
f.write(str(self.latest_model))
# tensorboard
if self.enable_tensorboard and (
display_step_id % self.tensorboard_freq == 0 or display_step_id == 1
):
writer.add_scalar(f"{task_key}/lr", cur_lr, display_step_id)
writer.add_scalar(f"{task_key}/loss", loss.item(), display_step_id)
for item in more_loss:
writer.add_scalar(
f"{task_key}/{item}", more_loss[item].item(), display_step_id
)
if enable_profiling:
core.nvprof_nvtx_pop()
self.wrapper.train()
self.t0 = time.time()
self.total_train_time = 0.0
for step_id in range(self.start_step, self.num_steps):
step(step_id)
if JIT:
break
if self.change_bias_after_training and (self.rank == 0 or dist.get_rank() == 0):
if not self.multi_task:
self.model = model_change_out_bias(
self.model,
self.get_sample_func,
_bias_adjust_mode="change-by-statistic",
)
else:
for model_key in self.model_keys:
self.model[model_key] = model_change_out_bias(
self.model[model_key],
self.get_sample_func[model_key],
_bias_adjust_mode="change-by-statistic",
)
self.latest_model = Path(self.save_ckpt + f"-{self.num_steps}.pd")
cur_lr = self.lr_schedule.value(self.num_steps - 1)
self.save_model(self.latest_model, lr=cur_lr, step=self.num_steps - 1)
log.info(f"Saved model to {self.latest_model}")
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
with open("checkpoint", "w") as f:
f.write(str(self.latest_model))
if (
self.rank == 0 or dist.get_rank() == 0
): # Handle the case if rank 0 aborted and re-assigned
if self.num_steps == 0:
# when num_steps is 0, the checkpoint is never not saved
self.latest_model = Path(self.save_ckpt + "-0.pd")
self.save_model(self.latest_model, lr=0, step=0)
log.info(f"Saved model to {self.latest_model}")
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
with open("checkpoint", "w") as f:
f.write(str(self.latest_model))
elapsed_batch = self.num_steps - self.start_step
if self.timing_in_training and elapsed_batch // self.disp_freq > 0:
if self.start_step >= 2 * self.disp_freq:
log.info(
"average training time: %.4f s/batch (exclude first %d batches)",
self.total_train_time
/ (
elapsed_batch // self.disp_freq * self.disp_freq
- self.disp_freq
),
self.disp_freq,
)
else:
log.info(
"average training time: %.4f s/batch",
self.total_train_time
/ (elapsed_batch // self.disp_freq * self.disp_freq),
)
if JIT:
raise NotImplementedError(
"Paddle JIT saving during training is not supported yet."
)
log.info(f"Trained model has been saved to: {self.save_ckpt}")
if fout:
fout.close()
if SAMPLER_RECORD:
fout1.close()
if self.enable_tensorboard:
writer.close()
if enable_profiling:
core.nvprof_stop()
log.info(
"The nsys profiling trace have been saved to *.nsys-rep and *.sqlite "
"files, which can be viewd in NVIDIA Nsight Systems software"
)
[docs]
def save_model(self, save_path: str, lr: float = 0.0, step: int = 0) -> None:
module = (
self.wrapper._layers
if dist.is_available() and dist.is_initialized()
else self.wrapper
)
module.train_infos["lr"] = float(lr)
module.train_infos["step"] = step
paddle.save(
{"model": module.state_dict(), "optimizer": self.optimizer.state_dict()},
str(save_path),
)
checkpoint_dir = save_path.parent
checkpoint_files = [
f
for f in checkpoint_dir.glob("*.pd")
if not f.is_symlink() and f.name.startswith(self.save_ckpt)
]
if len(checkpoint_files) > self.max_ckpt_keep:
checkpoint_files.sort(key=lambda x: x.stat().st_mtime)
checkpoint_files[0].unlink()
[docs]
def get_data(
self, is_train: bool = True, task_key: str = "Default"
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
if not self.multi_task:
if is_train:
try:
batch_data = next(iter(self.training_data))
except StopIteration:
# Refresh the status of the dataloader to start from a new epoch
self.training_data = BufferedIterator(
iter(self.training_dataloader)
)
batch_data = next(iter(self.training_data))
else:
if self.validation_data is None:
return {}, {}, {}
try:
batch_data = next(iter(self.validation_data))
except StopIteration:
self.validation_data = BufferedIterator(
iter(self.validation_dataloader)
)
batch_data = next(iter(self.validation_data))
else:
if is_train:
try:
batch_data = next(iter(self.training_data[task_key]))
except StopIteration:
# Refresh the status of the dataloader to start from a new epoch
self.training_data[task_key] = BufferedIterator(
iter(self.training_dataloader[task_key])
)
batch_data = next(iter(self.training_data[task_key]))
else:
if self.validation_data[task_key] is None:
return {}, {}, {}
try:
batch_data = next(iter(self.validation_data[task_key]))
except StopIteration:
self.validation_data[task_key] = BufferedIterator(
iter(self.validation_dataloader[task_key])
)
batch_data = next(iter(self.validation_data[task_key]))
for key in batch_data.keys():
if key == "sid" or key == "fid" or key == "box" or "find_" in key:
continue
elif not isinstance(batch_data[key], list):
if batch_data[key] is not None:
batch_data[key] = batch_data[key].to(DEVICE, blocking=False)
else:
batch_data[key] = [
item.to(DEVICE, blocking=False) for item in batch_data[key]
]
# we may need a better way to classify which are inputs and which are labels
# now wrapper only supports the following inputs:
input_keys = [
"coord",
"atype",
"spin",
"box",
"fparam",
"aparam",
]
input_dict = dict.fromkeys(input_keys)
label_dict = {}
for item_key in batch_data:
if item_key in input_keys:
input_dict[item_key] = batch_data[item_key]
else:
if item_key not in ["sid", "fid"]:
label_dict[item_key] = batch_data[item_key]
log_dict = {}
if "fid" in batch_data:
log_dict["fid"] = batch_data["fid"]
log_dict["sid"] = batch_data["sid"]
return input_dict, label_dict, log_dict
[docs]
def print_on_training(
self,
fout: Any,
step_id: int,
cur_lr: float,
train_results: dict[str, Any],
valid_results: dict[str, Any],
) -> None:
train_keys = sorted(train_results.keys())
print_str = ""
print_str += f"{step_id:7d}"
if not self.multi_task:
if valid_results:
prop_fmt = " %11.2e %11.2e"
for k in train_keys:
print_str += prop_fmt % (valid_results[k], train_results[k])
else:
prop_fmt = " %11.2e"
for k in train_keys:
print_str += prop_fmt % (train_results[k])
else:
for model_key in self.model_keys:
if valid_results[model_key]:
prop_fmt = " %11.2e %11.2e"
for k in sorted(valid_results[model_key].keys()):
print_str += prop_fmt % (
valid_results[model_key][k],
train_results[model_key][k],
)
else:
prop_fmt = " %11.2e"
for k in sorted(train_results[model_key].keys()):
print_str += prop_fmt % (train_results[model_key][k])
print_str += f" {cur_lr:8.1e}\n"
fout.write(print_str)
fout.flush()
[docs]
def get_additional_data_requirement(_model: Any) -> list[DataRequirementItem]:
additional_data_requirement = []
if _model.get_dim_fparam() > 0:
fparam_requirement_items = [
DataRequirementItem(
"fparam", _model.get_dim_fparam(), atomic=False, must=True
)
]
additional_data_requirement += fparam_requirement_items
if _model.get_dim_aparam() > 0:
aparam_requirement_items = [
DataRequirementItem(
"aparam", _model.get_dim_aparam(), atomic=True, must=True
)
]
additional_data_requirement += aparam_requirement_items
has_spin = getattr(_model, "has_spin", False)
if callable(has_spin):
has_spin = has_spin()
if has_spin:
spin_requirement_items = [
DataRequirementItem("spin", ndof=3, atomic=True, must=True)
]
additional_data_requirement += spin_requirement_items
return additional_data_requirement
[docs]
def whether_hessian(loss_params: dict[str, Any]) -> bool:
loss_type = loss_params.get("type", "ener")
return loss_type == "ener" and loss_params.get("start_pref_h", 0.0) > 0.0
[docs]
def get_loss(
loss_params: dict[str, Any], start_lr: float, _ntypes: int, _model: Any
) -> TaskLoss:
loss_type = loss_params.get("type", "ener")
if whether_hessian(loss_params):
loss_params["starter_learning_rate"] = start_lr
return EnergyHessianStdLoss(**loss_params)
if loss_type == "ener":
loss_params["starter_learning_rate"] = start_lr
return EnergyStdLoss(**loss_params)
else:
loss_params["starter_learning_rate"] = start_lr
return TaskLoss.get_class_by_type(loss_type).get_loss(loss_params)
[docs]
def get_single_model(
_model_params: dict[str, Any],
) -> Any:
model = get_model(deepcopy(_model_params)).to(DEVICE)
return model
[docs]
def get_model_for_wrapper(
_model_params: dict[str, Any],
resuming: bool = False,
_loss_params: dict[str, Any] | None = None,
) -> Any:
if "model_dict" not in _model_params:
if _loss_params is not None and whether_hessian(_loss_params):
_model_params["hessian_mode"] = True
_model = get_single_model(
_model_params,
)
else:
_model = {}
model_keys = list(_model_params["model_dict"])
do_case_embd, case_embd_index = get_case_embd_config(_model_params)
for _model_key in model_keys:
if _loss_params is not None and whether_hessian(_loss_params[_model_key]):
_model_params["model_dict"][_model_key]["hessian_mode"] = True
_model[_model_key] = get_single_model(
_model_params["model_dict"][_model_key],
)
if do_case_embd and not resuming:
# only set case_embd when from scratch multitask training
_model[_model_key].set_case_embd(case_embd_index[_model_key])
return _model
[docs]
def get_case_embd_config(_model_params: dict[str, Any]) -> tuple[bool, dict[str, Any]]:
assert "model_dict" in _model_params, (
"Only support setting case embedding for multi-task model!"
)
model_keys = list(_model_params["model_dict"])
sorted_model_keys = sorted(model_keys)
numb_case_embd_list = [
_model_params["model_dict"][model_key]
.get("fitting_net", {})
.get("dim_case_embd", 0)
for model_key in sorted_model_keys
]
if not all(item == numb_case_embd_list[0] for item in numb_case_embd_list):
raise ValueError(
f"All models must have the same dimension of case embedding, while the settings are: {numb_case_embd_list}"
)
if numb_case_embd_list[0] == 0:
return False, {}
case_embd_index = {
model_key: idx for idx, model_key in enumerate(sorted_model_keys)
}
return True, case_embd_index
[docs]
def model_change_out_bias(
_model: Any,
_sample_func: Any,
_bias_adjust_mode: str = "change-by-statistic",
) -> None:
old_bias = deepcopy(_model.get_out_bias())
_model.change_out_bias(
_sample_func,
bias_adjust_mode=_bias_adjust_mode,
)
new_bias = deepcopy(_model.get_out_bias())
from deepmd.pd.model.model.dp_model import (
DPModelCommon,
)
if isinstance(_model, DPModelCommon) and _bias_adjust_mode == "set-by-statistic":
_model.get_fitting_net().compute_input_stats(_sample_func)
model_type_map = _model.get_type_map()
log.info(
f"Change output bias of {model_type_map!s} "
f"from {to_numpy_array(old_bias).reshape(-1)!s} "
f"to {to_numpy_array(new_bias).reshape(-1)!s}."
)
return _model