5.6. Multi-task training PyTorch

Note

Supported backends: PyTorch PyTorch

5.6.1. Theory

The multi-task training process can simultaneously handle different datasets with properties that cannot be fitted in one network (e.g. properties from DFT calculations under different exchange-correlation functionals or different basis sets). These datasets are denoted by \(\boldsymbol x^{(1)}, \dots, \boldsymbol x^{(n_t)}\). For each dataset, a training task is defined as

\[ \min_{\boldsymbol \theta} L^{(t)} (\boldsymbol x^{(t)}; \boldsymbol \theta^{(t)}, \tau), \quad t=1, \dots, n_t.\]

In the Pytorch implementation, during the multi-task training process, all tasks can share any portion of the model parameters. A typical scenario is that each task shares the same descriptor with trainable parameters \(\boldsymbol{\theta}_ {d}\), while each has its own fitting network with trainable parameters \(\boldsymbol{\theta}_ f^{(t)}\), thus \(\boldsymbol{\theta}^{(t)} = \{ \boldsymbol{\theta}_ {d} , \boldsymbol{\theta}_ {f}^{(t)} \}\). At each training step, a task will be randomly selected from \({1, \dots, n_t}\) according to the user-specified probability, and the Adam optimizer is executed to minimize \(L^{(t)}\) for one step to update the parameter \(\boldsymbol \theta^{(t)}\). In the case of multi-GPU parallel training, different GPUs will independently select their tasks. In the DPA-2 model, this multi-task training framework is adopted.[^1]

[^1] Duo Zhang, Xinzijian Liu, Xiangyu Zhang, Chengqian Zhang, Chun Cai, Hangrui Bi, Yiming Du, Xuejian Qin, Jiameng Huang, Bowen Li, Yifan Shan, Jinzhe Zeng, Yuzhi Zhang, Siyuan Liu, Yifan Li, Junhan Chang, Xinyan Wang, Shuo Zhou, Jianchuan Liu, Xiaoshan Luo, Zhenyu Wang, Wanrun Jiang, Jing Wu, Yudi Yang, Jiyuan Yang, Manyi Yang, Fu-Qiang Gong, Linshuang Zhang, Mengchao Shi, Fu-Zhi Dai, Darrin M. York, Shi Liu, Tong Zhu, Zhicheng Zhong, Jian Lv, Jun Cheng, Weile Jia, Mohan Chen, Guolin Ke, Weinan E, Linfeng Zhang, Han Wang,arXiv preprint arXiv:2312.15492 (2023) licensed under a Creative Commons Attribution (CC BY) license.

Compared with the previous TensorFlow implementation, the new support in PyTorch is more flexible and efficient. In particular, it makes multi-GPU parallel training and even tasks beyond DFT possible, enabling larger-scale and more general multi-task training to obtain more general pre-trained models.

5.6.2. Perform the multi-task training using PyTorch

Training on multiple data sets (each data set contains several data systems) can be performed in multi-task mode, typically with one common descriptor and multiple specific fitting nets for each data set. To proceed, one need to change the representation of the model definition in the input script. The core idea is to replace the previous single model definition model with multiple model definitions model/model_dict/model_key, define the shared parameters of the model part shared_dict, and then expand other parts for multi-model settings. Specifically, there are several parts that need to be modified:

  • model/shared_dict: The parameter definition of the shared part, including various descriptors, type maps (or even fitting nets can be shared). Each module can be defined with a user-defined part_key, such as my_descriptor. The content needs to align with the corresponding definition in the single-task training model component, such as the definition of the descriptor.

  • model/model_dict: The core definition of the model part and the explanation of sharing rules, starting with user-defined model name keys model_key, such as my_model_1. Each model part needs to align with the components of the single-task training model, but with the following sharing rules:

    • If you want to share the current model component with other tasks, which should be part of the model/shared_dict, you can directly fill in the corresponding part_key, such as "descriptor": "my_descriptor", to replace the previous detailed parameters. Here, you can also specify the shared_level, such as "descriptor": "my_descriptor:shared_level", and use the user-defined integer shared_level in the code to share the corresponding module to varying degrees (default is to share all parameters, i.e., shared_level=0). The parts that are exclusive to each model can be written following the previous definition.

  • loss_dict: The loss settings corresponding to each task model, specified by the model_key. Each loss_dict/model_key contains the corresponding loss settings, which are the same as the definition in single-task training <loss>.

  • training/data_dict: The data settings corresponding to each task model, specified by the model_key. Each training/data_dict/model_key contains the corresponding training_data and validation_data settings, which are the same as the definition in single-task training training_data and validation_data.

  • (Optional) training/model_prob: The sampling weight settings corresponding to each model_key, i.e., the probability weight in the training step. You can specify any positive real number weight for each task. The higher the weight, the higher the probability of being sampled in each training. This setting is optional, and if not set, tasks will be sampled with equal weights.

An example input for multi-task training two models in water system is shown as following:

  1{
  2  "_comment": "that's all",
  3  "model": {
  4    "shared_dict": {
  5      "type_map_all": [
  6        "O",
  7        "H"
  8      ],
  9      "sea_descriptor_1": {
 10        "type": "se_e2_a",
 11        "sel": [
 12          46,
 13          92
 14        ],
 15        "rcut_smth": 0.50,
 16        "rcut": 6.00,
 17        "neuron": [
 18          25,
 19          50,
 20          100
 21        ],
 22        "resnet_dt": false,
 23        "axis_neuron": 16,
 24        "type_one_side": true,
 25        "seed": 1,
 26        "_comment": " that's all"
 27      },
 28      "_comment": "that's all"
 29    },
 30    "model_dict": {
 31      "water_1": {
 32        "type_map": "type_map_all",
 33        "descriptor": "sea_descriptor_1",
 34        "fitting_net": {
 35          "neuron": [
 36            240,
 37            240,
 38            240
 39          ],
 40          "resnet_dt": true,
 41          "seed": 1,
 42          "_comment": " that's all"
 43        }
 44      },
 45      "water_2": {
 46        "type_map": "type_map_all",
 47        "descriptor": "sea_descriptor_1",
 48        "fitting_net": {
 49          "neuron": [
 50            240,
 51            240,
 52            240
 53          ],
 54          "resnet_dt": true,
 55          "seed": 1,
 56          "_comment": " that's all"
 57        }
 58      }
 59    }
 60  },
 61  "learning_rate": {
 62    "type": "exp",
 63    "decay_steps": 5000,
 64    "start_lr": 0.0002,
 65    "decay_rate": 0.98,
 66    "stop_lr": 3.51e-08,
 67    "_comment": "that's all"
 68  },
 69  "loss_dict": {
 70    "_comment": " that's all",
 71    "water_1": {
 72      "type": "ener",
 73      "start_pref_e": 0.02,
 74      "limit_pref_e": 1,
 75      "start_pref_f": 1000,
 76      "limit_pref_f": 1,
 77      "start_pref_v": 0,
 78      "limit_pref_v": 0
 79    },
 80    "water_2": {
 81      "type": "ener",
 82      "start_pref_e": 0.02,
 83      "limit_pref_e": 1,
 84      "start_pref_f": 1000,
 85      "limit_pref_f": 1,
 86      "start_pref_v": 0,
 87      "limit_pref_v": 0
 88    }
 89  },
 90  "training": {
 91    "model_prob": {
 92      "water_1": 0.5,
 93      "water_2": 0.5
 94    },
 95    "data_dict": {
 96      "water_1": {
 97        "training_data": {
 98          "systems": [
 99            "../../water/data/data_0/",
100            "../../water/data/data_1/",
101            "../../water/data/data_2/"
102          ],
103          "batch_size": 1,
104          "_comment": "that's all"
105        },
106        "validation_data": {
107          "systems": [
108            "../../water/data/data_3/"
109          ],
110          "batch_size": 1,
111          "_comment": "that's all"
112        }
113      },
114      "water_2": {
115        "training_data": {
116          "systems": [
117            "../../water/data/data_0/",
118            "../../water/data/data_1/",
119            "../../water/data/data_2/"
120          ],
121          "batch_size": 1,
122          "_comment": "that's all"
123        }
124      }
125    },
126    "numb_steps": 100000,
127    "seed": 10,
128    "disp_file": "lcurve.out",
129    "disp_freq": 100,
130    "save_freq": 100,
131    "_comment": "that's all"
132  }
133}

5.6.3. Finetune from the pretrained multi-task model

To finetune based on the checkpoint model.pt after the multi-task pre-training is completed, users only need to prepare the normal input for single-task training input_single.json, and then select one of the trained model’s task names model_key. Run the following command:

$ dp --pt train input_single.json --finetune model.pt --model-branch model_key