5.5. Multi-task training #
Note
Supported backends: PyTorch
Warning
We have deprecated TensorFlow backend multi-task training, please use the PyTorch one.
5.5.1. Theory#
The multi-task training process can simultaneously handle different datasets with properties that cannot be fitted in one network (e.g. properties from DFT calculations under different exchange-correlation functionals or different basis sets). These datasets are denoted by \(\boldsymbol x^{(1)}, \dots, \boldsymbol x^{(n_t)}\). For each dataset, a training task is defined as
In the PyTorch implementation, during the multi-task training process, all tasks can share any portion of the model parameters. A typical scenario is that each task shares the same descriptor with trainable parameters \(\boldsymbol{\theta}_ {d}\), while each has its own fitting network with trainable parameters \(\boldsymbol{\theta}_ f^{(t)}\), thus \(\boldsymbol{\theta}^{(t)} = \{ \boldsymbol{\theta}_ {d} , \boldsymbol{\theta}_ {f}^{(t)} \}\). At each training step, a task will be randomly selected from \({1, \dots, n_t}\) according to the user-specified probability, and the Adam optimizer is executed to minimize \(L^{(t)}\) for one step to update the parameter \(\boldsymbol \theta^{(t)}\). In the case of multi-GPU parallel training, different GPUs will independently select their tasks. In the DPA-2 model, this multi-task training framework is adopted.[1]
Compared with the previous TensorFlow implementation, the new support in PyTorch is more flexible and efficient. In particular, it makes multi-GPU parallel training and even tasks beyond DFT possible, enabling larger-scale and more general multi-task training to obtain more general pre-trained models.
5.5.2. Perform the multi-task training using PyTorch#
Training on multiple data sets (each data set contains several data systems) can be performed in multi-task mode, typically with one common descriptor and multiple specific fitting nets for each data set. To proceed, one need to change the representation of the model definition in the input script. The core idea is to replace the previous single model definition model with multiple model definitions model/model_dict/model_key, define the shared parameters of the model part shared_dict, and then expand other parts for multi-model settings. Specifically, there are several parts that need to be modified:
model/shared_dict: The parameter definition of the shared part, including various descriptors, type maps (or even fitting nets can be shared). Each module can be defined with a user-defined
part_key
, such asmy_descriptor
. The content needs to align with the corresponding definition in the single-task training model component, such as the definition of the descriptor.model/model_dict: The core definition of the model part and the explanation of sharing rules, starting with user-defined model name keys
model_key
, such asmy_model_1
. Each model part needs to align with the components of the single-task training model, but with the following sharing rules:If you want to share the current model component with other tasks, which should be part of the model/shared_dict, you can directly fill in the corresponding
part_key
, such as"descriptor": "my_descriptor",
to replace the previous detailed parameters. Here, you can also specify the shared_level, such as"descriptor": "my_descriptor:shared_level",
and use the user-defined integershared_level
in the code to share the corresponding module to varying degrees (default is to share all parameters, i.e.,shared_level
=0). The parts that are exclusive to each model can be written following the previous definition.
loss_dict: The loss settings corresponding to each task model, specified by the
model_key
. Each loss_dict/model_key contains the corresponding loss settings, which are the same as the definition in single-task training <loss>.training/data_dict: The data settings corresponding to each task model, specified by the
model_key
. Eachtraining/data_dict/model_key
contains the correspondingtraining_data
andvalidation_data
settings, which are the same as the definition in single-task training training_data and validation_data.(Optional) training/model_prob: The sampling weight settings corresponding to each
model_key
, i.e., the probability weight in the training step. You can specify any positive real number weight for each task. The higher the weight, the higher the probability of being sampled in each training. This setting is optional, and if not set, tasks will be sampled with equal weights.
An example input for multi-task training two models in water system is shown as following:
1{
2 "_comment": "that's all",
3 "model": {
4 "shared_dict": {
5 "type_map_all": [
6 "O",
7 "H"
8 ],
9 "sea_descriptor_1": {
10 "type": "se_e2_a",
11 "sel": [
12 46,
13 92
14 ],
15 "rcut_smth": 0.50,
16 "rcut": 6.00,
17 "neuron": [
18 25,
19 50,
20 100
21 ],
22 "resnet_dt": false,
23 "axis_neuron": 16,
24 "type_one_side": true,
25 "seed": 1,
26 "_comment": " that's all"
27 },
28 "_comment": "that's all"
29 },
30 "model_dict": {
31 "water_1": {
32 "type_map": "type_map_all",
33 "descriptor": "sea_descriptor_1",
34 "fitting_net": {
35 "neuron": [
36 240,
37 240,
38 240
39 ],
40 "resnet_dt": true,
41 "seed": 1,
42 "_comment": " that's all"
43 }
44 },
45 "water_2": {
46 "type_map": "type_map_all",
47 "descriptor": "sea_descriptor_1",
48 "fitting_net": {
49 "neuron": [
50 240,
51 240,
52 240
53 ],
54 "resnet_dt": true,
55 "seed": 1,
56 "_comment": " that's all"
57 }
58 }
59 }
60 },
61 "learning_rate": {
62 "type": "exp",
63 "decay_steps": 5000,
64 "start_lr": 0.0002,
65 "decay_rate": 0.98,
66 "stop_lr": 3.51e-08,
67 "_comment": "that's all"
68 },
69 "loss_dict": {
70 "water_1": {
71 "type": "ener",
72 "start_pref_e": 0.02,
73 "limit_pref_e": 1,
74 "start_pref_f": 1000,
75 "limit_pref_f": 1,
76 "start_pref_v": 0,
77 "limit_pref_v": 0
78 },
79 "water_2": {
80 "type": "ener",
81 "start_pref_e": 0.02,
82 "limit_pref_e": 1,
83 "start_pref_f": 1000,
84 "limit_pref_f": 1,
85 "start_pref_v": 0,
86 "limit_pref_v": 0
87 }
88 },
89 "training": {
90 "model_prob": {
91 "water_1": 0.5,
92 "water_2": 0.5
93 },
94 "data_dict": {
95 "water_1": {
96 "training_data": {
97 "systems": [
98 "../../water/data/data_0/",
99 "../../water/data/data_1/",
100 "../../water/data/data_2/"
101 ],
102 "batch_size": 1,
103 "_comment": "that's all"
104 },
105 "validation_data": {
106 "systems": [
107 "../../water/data/data_3/"
108 ],
109 "batch_size": 1,
110 "_comment": "that's all"
111 }
112 },
113 "water_2": {
114 "training_data": {
115 "systems": [
116 "../../water/data/data_0/",
117 "../../water/data/data_1/",
118 "../../water/data/data_2/"
119 ],
120 "batch_size": 1,
121 "_comment": "that's all"
122 }
123 }
124 },
125 "numb_steps": 100000,
126 "seed": 10,
127 "disp_file": "lcurve.out",
128 "disp_freq": 100,
129 "save_freq": 100,
130 "_comment": "that's all"
131 }
132}
5.5.3. Finetune from the pre-trained multi-task model#
To finetune based on the checkpoint model.pt
after the multi-task pre-training is completed, users can refer to this section.
5.5.4. Multi-task specific parameters#
Note
Details of some parameters that are the same as the regular parameters are not shown below.
- multi-task:#
- type:
dict
argument path:multi-task
Multi-task arguments.
- model:#
- type:
dict
argument path:multi-task/model
- model_dict:#
- type:
dict
argument path:multi-task/model/model_dict
The multiple definition of the model, used in the multi-task mode.
- learning_rate:#
- type:
dict
, optionalargument path:multi-task/learning_rate
The definition of learning rate
- loss_dict:#
- type:
dict
, optionalargument path:multi-task/loss_dict
The multiple definition of the loss, used in the multi-task mode.
- training:#
- type:
dict
argument path:multi-task/training
The training options.
- model_prob:#
- type:
dict
, optional, default:{}
argument path:multi-task/training/model_prob
The visiting probability of each model for each training step in the multi-task mode.
- data_dict:#
- type:
dict
argument path:multi-task/training/data_dict
The multiple definition of the data, used in the multi-task mode.
This argument takes a dict with each key-value pair containing the following:
- training_data:#
- type:
dict
, optionalargument path:multi-task/training/data_dict/training_data
Configurations of training data.
- systems:#
- type:
list[str]
|str
argument path:multi-task/training/data_dict/training_data/systems
The data systems for training. This key can be provided with a list that specifies the systems, or be provided with a string by which the prefix of all systems are given and the list of the systems is automatically generated.
- batch_size:#
- type:
list[int]
|str
|int
, optional, default:auto
argument path:multi-task/training/data_dict/training_data/batch_size
This key can be
list: the length of which is the same as the `systems <training/training_data/systems_>`_. The batch size of each system is given by the elements of the list.
int: all `systems <training/training_data/systems_>`_ use the same batch size.
string “auto”: automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.
string “auto:N”: automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.
string “mixed:N”: the batch data will be sampled from all systems and merged into a mixed system with the batch size N. Only support the se_atten descriptor for TensorFlow backend.
If MPI is used, the value should be considered as the batch size per task.
- auto_prob:#
- type:
str
, optional, default:prob_sys_size
, alias: auto_prob_styleargument path:multi-task/training/data_dict/training_data/auto_prob
Determine the probability of systems automatically. The method is assigned by this key and can be
“prob_uniform” : the probability all the systems are equal, namely 1.0/self.get_nsystems()
“prob_sys_size” : the probability of a system is proportional to the number of batches in the system
“prob_sys_size;stt_idx:end_idx:weight;stt_idx:end_idx:weight;…” : the list of systems is divided into blocks. A block is specified by stt_idx:end_idx:weight, where stt_idx is the starting index of the system, end_idx is then ending (not including) index of the system, the probabilities of the systems in this block sums up to weight, and the relatively probabilities within this block is proportional to the number of batches in the system.
- sys_probs:#
- type:
list[float]
|NoneType
, optional, default:None
, alias: sys_weightsargument path:multi-task/training/data_dict/training_data/sys_probs
A list of float if specified. Should be of the same length as systems, specifying the probability of each system.
- validation_data:#
- type:
NoneType
|dict
, optional, default:None
argument path:multi-task/training/data_dict/validation_data
Configurations of validation data. Similar to that of training data, except that a numb_btch argument may be configured
- systems:#
- type:
list[str]
|str
argument path:multi-task/training/data_dict/validation_data/systems
The data systems for validation. This key can be provided with a list that specifies the systems, or be provided with a string by which the prefix of all systems are given and the list of the systems is automatically generated.
- batch_size:#
- type:
list[int]
|str
|int
, optional, default:auto
argument path:multi-task/training/data_dict/validation_data/batch_size
This key can be
list: the length of which is the same as the `systems <training/validation_data/systems_>`_. The batch size of each system is given by the elements of the list.
int: all `systems <training/validation_data/systems_>`_ use the same batch size.
string “auto”: automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.
string “auto:N”: automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.
- auto_prob:#
- type:
str
, optional, default:prob_sys_size
, alias: auto_prob_styleargument path:multi-task/training/data_dict/validation_data/auto_prob
Determine the probability of systems automatically. The method is assigned by this key and can be
“prob_uniform” : the probability all the systems are equal, namely 1.0/self.get_nsystems()
“prob_sys_size” : the probability of a system is proportional to the number of batches in the system
“prob_sys_size;stt_idx:end_idx:weight;stt_idx:end_idx:weight;…” : the list of systems is divided into blocks. A block is specified by stt_idx:end_idx:weight, where stt_idx is the starting index of the system, end_idx is then ending (not including) index of the system, the probabilities of the systems in this block sums up to weight, and the relatively probabilities within this block is proportional to the number of batches in the system.
- sys_probs:#
- type:
list[float]
|NoneType
, optional, default:None
, alias: sys_weightsargument path:multi-task/training/data_dict/validation_data/sys_probs
A list of float if specified. Should be of the same length as systems, specifying the probability of each system.
- numb_btch:#
- type:
int
, optional, default:1
, alias: numb_batchargument path:multi-task/training/data_dict/validation_data/numb_btch
An integer that specifies the number of batches to be sampled for each validation period.
- stat_file:#
- type:
str
, optionalargument path:multi-task/training/data_dict/stat_file
(Supported Backend: PyTorch) The file path for saving the data statistics results. If set, the results will be saved and directly loaded during the next training session, avoiding the need to recalculate the statistics. If the file extension is .h5 or .hdf5, an HDF5 file is used to store the statistics; otherwise, a directory containing NumPy binary files are used.
- mixed_precision:#
- type:
dict
, optionalargument path:multi-task/training/mixed_precision
Configurations of mixed precision.
- output_prec:#
- type:
str
, optional, default:float32
argument path:multi-task/training/mixed_precision/output_prec
The precision for mixed precision params. “ “The trainable variables precision during the mixed precision training process, “ “supported options are float32 only currently.
- compute_prec:#
- type:
str
argument path:multi-task/training/mixed_precision/compute_prec
The precision for mixed precision compute. “ “The compute precision during the mixed precision training process, “” “supported options are float16 and bfloat16 currently.
- numb_steps:#
- type:
int
, alias: stop_batchargument path:multi-task/training/numb_steps
Number of training batch. Each training uses one batch of data.
- seed:#
- type:
NoneType
|int
, optionalargument path:multi-task/training/seed
The random seed for getting frames from the training data set.
- disp_file:#
- type:
str
, optional, default:lcurve.out
argument path:multi-task/training/disp_file
The file for printing learning curve.
- disp_freq:#
- type:
int
, optional, default:1000
argument path:multi-task/training/disp_freq
The frequency of printing learning curve.
- save_freq:#
- type:
int
, optional, default:1000
argument path:multi-task/training/save_freq
The frequency of saving check point.
- save_ckpt:#
- type:
str
, optional, default:model.ckpt
argument path:multi-task/training/save_ckpt
The path prefix of saving check point files.
- max_ckpt_keep:#
- type:
int
, optional, default:5
argument path:multi-task/training/max_ckpt_keep
The maximum number of checkpoints to keep. The oldest checkpoints will be deleted once the number of checkpoints exceeds max_ckpt_keep. Defaults to 5.
- change_bias_after_training:#
- type:
bool
, optional, default:False
argument path:multi-task/training/change_bias_after_training
Whether to change the output bias after the last training step, by performing predictions using trained model on training data and doing least square on the errors to add the target shift on the bias.
- disp_training:#
- type:
bool
, optional, default:True
argument path:multi-task/training/disp_training
Displaying verbose information during training.
- time_training:#
- type:
bool
, optional, default:True
argument path:multi-task/training/time_training
Timing during training.
- profiling:#
- type:
bool
, optional, default:False
argument path:multi-task/training/profiling
Export the profiling results to the Chrome JSON file for performance analysis, driven by the legacy TensorFlow profiling API or PyTorch Profiler. The output file will be saved to profiling_file.
- profiling_file:#
- type:
str
, optional, default:timeline.json
argument path:multi-task/training/profiling_file
Output file for profiling.
- enable_profiler:#
- type:
bool
, optional, default:False
argument path:multi-task/training/enable_profiler
Export the profiling results to the TensorBoard log for performance analysis, driven by TensorFlow Profiler (available in TensorFlow 2.3) or PyTorch Profiler. The log will be saved to tensorboard_log_dir.
- tensorboard:#
- type:
bool
, optional, default:False
argument path:multi-task/training/tensorboard
Enable tensorboard
- tensorboard_log_dir:#
- type:
str
, optional, default:log
argument path:multi-task/training/tensorboard_log_dir
The log directory of tensorboard outputs
- tensorboard_freq:#
- type:
int
, optional, default:1
argument path:multi-task/training/tensorboard_freq
The frequency of writing tensorboard events.
- warmup_steps:#
- type:
int
, optionalargument path:multi-task/training/warmup_steps
(Supported Backend: PyTorch) The number of steps for learning rate warmup. During warmup, the learning rate begins at zero and progressively increases linearly to start_lr, rather than starting directly from start_lr
- gradient_max_norm:#
- type:
float
, optionalargument path:multi-task/training/gradient_max_norm
(Supported Backend: PyTorch) Clips the gradient norm to a maximum value. If the gradient norm exceeds this value, it will be clipped to this limit. No gradient clipping will occur if set to 0.
Depending on the value of opt_type, different sub args are accepted.
- opt_type:#
- type:
str
(flag key), default:Adam
argument path:multi-task/training/opt_type
possible choices: |code:multi-task/training[Adam]|_, |code:multi-task/training[LKF]|_(Supported Backend: PyTorch) The type of optimizer to use.
When |flag:multi-task/training/opt_type|_ is set to
Adam
:When |flag:multi-task/training/opt_type|_ is set to
LKF
:- kf_blocksize:#
- type:
int
, optionalargument path:multi-task/training[LKF]/kf_blocksize
(Supported Backend: PyTorch) The blocksize for the Kalman filter.
- nvnmd:#
- type:
dict
, optionalargument path:multi-task/nvnmd
The nvnmd options.