5.5. Multi-task training PyTorch#

Note

Supported backends: PyTorch PyTorch

Warning

We have deprecated TensorFlow backend multi-task training, please use the PyTorch one.

5.5.1. Theory#

The multi-task training process can simultaneously handle different datasets with properties that cannot be fitted in one network (e.g. properties from DFT calculations under different exchange-correlation functionals or different basis sets). These datasets are denoted by \(\boldsymbol x^{(1)}, \dots, \boldsymbol x^{(n_t)}\). For each dataset, a training task is defined as

\[ \min_{\boldsymbol \theta} L^{(t)} (\boldsymbol x^{(t)}; \boldsymbol \theta^{(t)}, \tau), \quad t=1, \dots, n_t.\]

In the PyTorch implementation, during the multi-task training process, all tasks can share any portion of the model parameters. A typical scenario is that each task shares the same descriptor with trainable parameters \(\boldsymbol{\theta}_ {d}\), while each has its own fitting network with trainable parameters \(\boldsymbol{\theta}_ f^{(t)}\), thus \(\boldsymbol{\theta}^{(t)} = \{ \boldsymbol{\theta}_ {d} , \boldsymbol{\theta}_ {f}^{(t)} \}\). At each training step, a task will be randomly selected from \({1, \dots, n_t}\) according to the user-specified probability, and the Adam optimizer is executed to minimize \(L^{(t)}\) for one step to update the parameter \(\boldsymbol \theta^{(t)}\). In the case of multi-GPU parallel training, different GPUs will independently select their tasks. In the DPA-2 model, this multi-task training framework is adopted.[1]

Compared with the previous TensorFlow implementation, the new support in PyTorch is more flexible and efficient. In particular, it makes multi-GPU parallel training and even tasks beyond DFT possible, enabling larger-scale and more general multi-task training to obtain more general pre-trained models.

5.5.2. Perform the multi-task training using PyTorch#

Training on multiple data sets (each data set contains several data systems) can be performed in multi-task mode, typically with one common descriptor and multiple specific fitting nets for each data set. To proceed, one need to change the representation of the model definition in the input script. The core idea is to replace the previous single model definition model with multiple model definitions model/model_dict/model_key, define the shared parameters of the model part shared_dict, and then expand other parts for multi-model settings. Specifically, there are several parts that need to be modified:

  • model/shared_dict: The parameter definition of the shared part, including various descriptors, type maps (or even fitting nets can be shared). Each module can be defined with a user-defined part_key, such as my_descriptor. The content needs to align with the corresponding definition in the single-task training model component, such as the definition of the descriptor.

  • model/model_dict: The core definition of the model part and the explanation of sharing rules, starting with user-defined model name keys model_key, such as my_model_1. Each model part needs to align with the components of the single-task training model, but with the following sharing rules:

    • If you want to share the current model component with other tasks, which should be part of the model/shared_dict, you can directly fill in the corresponding part_key, such as "descriptor": "my_descriptor", to replace the previous detailed parameters. Here, you can also specify the shared_level, such as "descriptor": "my_descriptor:shared_level", and use the user-defined integer shared_level in the code to share the corresponding module to varying degrees (default is to share all parameters, i.e., shared_level=0). The parts that are exclusive to each model can be written following the previous definition.

  • loss_dict: The loss settings corresponding to each task model, specified by the model_key. Each loss_dict/model_key contains the corresponding loss settings, which are the same as the definition in single-task training <loss>.

  • training/data_dict: The data settings corresponding to each task model, specified by the model_key. Each training/data_dict/model_key contains the corresponding training_data and validation_data settings, which are the same as the definition in single-task training training_data and validation_data.

  • (Optional) training/model_prob: The sampling weight settings corresponding to each model_key, i.e., the probability weight in the training step. You can specify any positive real number weight for each task. The higher the weight, the higher the probability of being sampled in each training. This setting is optional, and if not set, tasks will be sampled with equal weights.

An example input for multi-task training two models in water system is shown as following:

  1{
  2  "_comment": "that's all",
  3  "model": {
  4    "shared_dict": {
  5      "type_map_all": [
  6        "O",
  7        "H"
  8      ],
  9      "sea_descriptor_1": {
 10        "type": "se_e2_a",
 11        "sel": [
 12          46,
 13          92
 14        ],
 15        "rcut_smth": 0.50,
 16        "rcut": 6.00,
 17        "neuron": [
 18          25,
 19          50,
 20          100
 21        ],
 22        "resnet_dt": false,
 23        "axis_neuron": 16,
 24        "type_one_side": true,
 25        "seed": 1,
 26        "_comment": " that's all"
 27      },
 28      "_comment": "that's all"
 29    },
 30    "model_dict": {
 31      "water_1": {
 32        "type_map": "type_map_all",
 33        "descriptor": "sea_descriptor_1",
 34        "fitting_net": {
 35          "neuron": [
 36            240,
 37            240,
 38            240
 39          ],
 40          "resnet_dt": true,
 41          "seed": 1,
 42          "_comment": " that's all"
 43        }
 44      },
 45      "water_2": {
 46        "type_map": "type_map_all",
 47        "descriptor": "sea_descriptor_1",
 48        "fitting_net": {
 49          "neuron": [
 50            240,
 51            240,
 52            240
 53          ],
 54          "resnet_dt": true,
 55          "seed": 1,
 56          "_comment": " that's all"
 57        }
 58      }
 59    }
 60  },
 61  "learning_rate": {
 62    "type": "exp",
 63    "decay_steps": 5000,
 64    "start_lr": 0.0002,
 65    "decay_rate": 0.98,
 66    "stop_lr": 3.51e-08,
 67    "_comment": "that's all"
 68  },
 69  "loss_dict": {
 70    "water_1": {
 71      "type": "ener",
 72      "start_pref_e": 0.02,
 73      "limit_pref_e": 1,
 74      "start_pref_f": 1000,
 75      "limit_pref_f": 1,
 76      "start_pref_v": 0,
 77      "limit_pref_v": 0
 78    },
 79    "water_2": {
 80      "type": "ener",
 81      "start_pref_e": 0.02,
 82      "limit_pref_e": 1,
 83      "start_pref_f": 1000,
 84      "limit_pref_f": 1,
 85      "start_pref_v": 0,
 86      "limit_pref_v": 0
 87    }
 88  },
 89  "training": {
 90    "model_prob": {
 91      "water_1": 0.5,
 92      "water_2": 0.5
 93    },
 94    "data_dict": {
 95      "water_1": {
 96        "training_data": {
 97          "systems": [
 98            "../../water/data/data_0/",
 99            "../../water/data/data_1/",
100            "../../water/data/data_2/"
101          ],
102          "batch_size": 1,
103          "_comment": "that's all"
104        },
105        "validation_data": {
106          "systems": [
107            "../../water/data/data_3/"
108          ],
109          "batch_size": 1,
110          "_comment": "that's all"
111        }
112      },
113      "water_2": {
114        "training_data": {
115          "systems": [
116            "../../water/data/data_0/",
117            "../../water/data/data_1/",
118            "../../water/data/data_2/"
119          ],
120          "batch_size": 1,
121          "_comment": "that's all"
122        }
123      }
124    },
125    "numb_steps": 100000,
126    "seed": 10,
127    "disp_file": "lcurve.out",
128    "disp_freq": 100,
129    "save_freq": 100,
130    "_comment": "that's all"
131  }
132}

5.5.3. Finetune from the pre-trained multi-task model#

To finetune based on the checkpoint model.pt after the multi-task pre-training is completed, users can refer to this section.

5.5.4. Multi-task specific parameters#

Note

Details of some parameters that are the same as the regular parameters are not shown below.

multi-task:#
type: dict
argument path: multi-task

Multi-task arguments.

model:#
type: dict
argument path: multi-task/model
model_dict:#
type: dict
argument path: multi-task/model/model_dict

The multiple definition of the model, used in the multi-task mode.

shared_dict:#
type: dict, optional, default: {}
argument path: multi-task/model/shared_dict

The definition of the shared parameters used in the model_dict within multi-task mode.

learning_rate:#
type: dict, optional
argument path: multi-task/learning_rate

The definition of learning rate

loss_dict:#
type: dict, optional
argument path: multi-task/loss_dict

The multiple definition of the loss, used in the multi-task mode.

training:#
type: dict
argument path: multi-task/training

The training options.

model_prob:#
type: dict, optional, default: {}
argument path: multi-task/training/model_prob

The visiting probability of each model for each training step in the multi-task mode.

data_dict:#
type: dict
argument path: multi-task/training/data_dict

The multiple definition of the data, used in the multi-task mode.

This argument takes a dict with each key-value pair containing the following:

training_data:#
type: dict, optional
argument path: multi-task/training/data_dict/training_data

Configurations of training data.

systems:#
type: list[str] | str
argument path: multi-task/training/data_dict/training_data/systems

The data systems for training. This key can be provided with a list that specifies the systems, or be provided with a string by which the prefix of all systems are given and the list of the systems is automatically generated.

batch_size:#
type: list[int] | str | int, optional, default: auto
argument path: multi-task/training/data_dict/training_data/batch_size

This key can be

  • list: the length of which is the same as the `systems <training/training_data/systems_>`_. The batch size of each system is given by the elements of the list.

  • int: all `systems <training/training_data/systems_>`_ use the same batch size.

  • string “auto”: automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.

  • string “auto:N”: automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.

  • string “mixed:N”: the batch data will be sampled from all systems and merged into a mixed system with the batch size N. Only support the se_atten descriptor for TensorFlow backend.

If MPI is used, the value should be considered as the batch size per task.

auto_prob:#
type: str, optional, default: prob_sys_size, alias: auto_prob_style
argument path: multi-task/training/data_dict/training_data/auto_prob

Determine the probability of systems automatically. The method is assigned by this key and can be

  • “prob_uniform” : the probability all the systems are equal, namely 1.0/self.get_nsystems()

  • “prob_sys_size” : the probability of a system is proportional to the number of batches in the system

  • “prob_sys_size;stt_idx:end_idx:weight;stt_idx:end_idx:weight;…” : the list of systems is divided into blocks. A block is specified by stt_idx:end_idx:weight, where stt_idx is the starting index of the system, end_idx is then ending (not including) index of the system, the probabilities of the systems in this block sums up to weight, and the relatively probabilities within this block is proportional to the number of batches in the system.

sys_probs:#
type: list[float] | NoneType, optional, default: None, alias: sys_weights
argument path: multi-task/training/data_dict/training_data/sys_probs

A list of float if specified. Should be of the same length as systems, specifying the probability of each system.

validation_data:#
type: NoneType | dict, optional, default: None
argument path: multi-task/training/data_dict/validation_data

Configurations of validation data. Similar to that of training data, except that a numb_btch argument may be configured

systems:#
type: list[str] | str
argument path: multi-task/training/data_dict/validation_data/systems

The data systems for validation. This key can be provided with a list that specifies the systems, or be provided with a string by which the prefix of all systems are given and the list of the systems is automatically generated.

batch_size:#
type: list[int] | str | int, optional, default: auto
argument path: multi-task/training/data_dict/validation_data/batch_size

This key can be

  • list: the length of which is the same as the `systems <training/validation_data/systems_>`_. The batch size of each system is given by the elements of the list.

  • int: all `systems <training/validation_data/systems_>`_ use the same batch size.

  • string “auto”: automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.

  • string “auto:N”: automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.

auto_prob:#
type: str, optional, default: prob_sys_size, alias: auto_prob_style
argument path: multi-task/training/data_dict/validation_data/auto_prob

Determine the probability of systems automatically. The method is assigned by this key and can be

  • “prob_uniform” : the probability all the systems are equal, namely 1.0/self.get_nsystems()

  • “prob_sys_size” : the probability of a system is proportional to the number of batches in the system

  • “prob_sys_size;stt_idx:end_idx:weight;stt_idx:end_idx:weight;…” : the list of systems is divided into blocks. A block is specified by stt_idx:end_idx:weight, where stt_idx is the starting index of the system, end_idx is then ending (not including) index of the system, the probabilities of the systems in this block sums up to weight, and the relatively probabilities within this block is proportional to the number of batches in the system.

sys_probs:#
type: list[float] | NoneType, optional, default: None, alias: sys_weights
argument path: multi-task/training/data_dict/validation_data/sys_probs

A list of float if specified. Should be of the same length as systems, specifying the probability of each system.

numb_btch:#
type: int, optional, default: 1, alias: numb_batch
argument path: multi-task/training/data_dict/validation_data/numb_btch

An integer that specifies the number of batches to be sampled for each validation period.

stat_file:#
type: str, optional
argument path: multi-task/training/data_dict/stat_file

(Supported Backend: PyTorch) The file path for saving the data statistics results. If set, the results will be saved and directly loaded during the next training session, avoiding the need to recalculate the statistics. If the file extension is .h5 or .hdf5, an HDF5 file is used to store the statistics; otherwise, a directory containing NumPy binary files are used.

mixed_precision:#
type: dict, optional
argument path: multi-task/training/mixed_precision

Configurations of mixed precision.

output_prec:#
type: str, optional, default: float32
argument path: multi-task/training/mixed_precision/output_prec

The precision for mixed precision params. “ “The trainable variables precision during the mixed precision training process, “ “supported options are float32 only currently.

compute_prec:#
type: str
argument path: multi-task/training/mixed_precision/compute_prec

The precision for mixed precision compute. “ “The compute precision during the mixed precision training process, “” “supported options are float16 and bfloat16 currently.

numb_steps:#
type: int, alias: stop_batch
argument path: multi-task/training/numb_steps

Number of training batch. Each training uses one batch of data.

seed:#
type: NoneType | int, optional
argument path: multi-task/training/seed

The random seed for getting frames from the training data set.

disp_file:#
type: str, optional, default: lcurve.out
argument path: multi-task/training/disp_file

The file for printing learning curve.

disp_freq:#
type: int, optional, default: 1000
argument path: multi-task/training/disp_freq

The frequency of printing learning curve.

save_freq:#
type: int, optional, default: 1000
argument path: multi-task/training/save_freq

The frequency of saving check point.

save_ckpt:#
type: str, optional, default: model.ckpt
argument path: multi-task/training/save_ckpt

The path prefix of saving check point files.

max_ckpt_keep:#
type: int, optional, default: 5
argument path: multi-task/training/max_ckpt_keep

The maximum number of checkpoints to keep. The oldest checkpoints will be deleted once the number of checkpoints exceeds max_ckpt_keep. Defaults to 5.

change_bias_after_training:#
type: bool, optional, default: False
argument path: multi-task/training/change_bias_after_training

Whether to change the output bias after the last training step, by performing predictions using trained model on training data and doing least square on the errors to add the target shift on the bias.

disp_training:#
type: bool, optional, default: True
argument path: multi-task/training/disp_training

Displaying verbose information during training.

time_training:#
type: bool, optional, default: True
argument path: multi-task/training/time_training

Timing during training.

profiling:#
type: bool, optional, default: False
argument path: multi-task/training/profiling

Export the profiling results to the Chrome JSON file for performance analysis, driven by the legacy TensorFlow profiling API or PyTorch Profiler. The output file will be saved to profiling_file.

profiling_file:#
type: str, optional, default: timeline.json
argument path: multi-task/training/profiling_file

Output file for profiling.

enable_profiler:#
type: bool, optional, default: False
argument path: multi-task/training/enable_profiler

Export the profiling results to the TensorBoard log for performance analysis, driven by TensorFlow Profiler (available in TensorFlow 2.3) or PyTorch Profiler. The log will be saved to tensorboard_log_dir.

tensorboard:#
type: bool, optional, default: False
argument path: multi-task/training/tensorboard

Enable tensorboard

tensorboard_log_dir:#
type: str, optional, default: log
argument path: multi-task/training/tensorboard_log_dir

The log directory of tensorboard outputs

tensorboard_freq:#
type: int, optional, default: 1
argument path: multi-task/training/tensorboard_freq

The frequency of writing tensorboard events.

warmup_steps:#
type: int, optional
argument path: multi-task/training/warmup_steps

(Supported Backend: PyTorch) The number of steps for learning rate warmup. During warmup, the learning rate begins at zero and progressively increases linearly to start_lr, rather than starting directly from start_lr

gradient_max_norm:#
type: float, optional
argument path: multi-task/training/gradient_max_norm

(Supported Backend: PyTorch) Clips the gradient norm to a maximum value. If the gradient norm exceeds this value, it will be clipped to this limit. No gradient clipping will occur if set to 0.

Depending on the value of opt_type, different sub args are accepted.

opt_type:#
type: str (flag key), default: Adam
argument path: multi-task/training/opt_type

(Supported Backend: PyTorch) The type of optimizer to use.

When |flag:multi-task/training/opt_type|_ is set to Adam:

When |flag:multi-task/training/opt_type|_ is set to LKF:

kf_blocksize:#
type: int, optional
argument path: multi-task/training[LKF]/kf_blocksize

(Supported Backend: PyTorch) The blocksize for the Kalman filter.

nvnmd:#
type: dict, optional
argument path: multi-task/nvnmd

The nvnmd options.