deepmd.train package

Submodules

deepmd.train.run_options module

Module taking care of important package constants.

class deepmd.train.run_options.RunOptions(init_model: str | None = None, init_frz_model: str | None = None, finetune: str | None = None, restart: str | None = None, log_path: str | None = None, log_level: int = 0, mpi_log: str = 'master')[source]

Bases: object

Class with info on how to run training (cluster, MPI and GPU config).

Attributes:
gpus: Optional[List[int]]

list of GPUs if any are present else None

is_chief: bool

in distribured training it is true for tha main MPI process in serail it is always true

world_size: int

total worker count

my_rank: int

index of the MPI task

nodename: str

name of the node

node_list_List[str]

the list of nodes of the current mpirun

my_device: str

deviice type - gpu or cpu

Methods

print_resource_summary()

Print build and current running cluster configuration summary.

gpus: List[int] | None
property is_chief

Whether my rank is 0.

my_device: str
my_rank: int
nodelist: List[int]
nodename: str
print_resource_summary()[source]

Print build and current running cluster configuration summary.

world_size: int

deepmd.train.trainer module

class deepmd.train.trainer.DPTrainer(jdata, run_opt, is_compress=False)[source]

Bases: object

Methods

save_compressed()

Save the compressed graph.

build

eval_single_list

get_evaluation_results

get_feed_dict

get_global_step

print_header

print_on_training

save_checkpoint

train

valid_on_the_fly

build(data=None, stop_batch=0, origin_type_map=None, suffix='')[source]
static eval_single_list(single_batch_list, loss, sess, get_feed_dict_func, prefix='')[source]
get_evaluation_results(batch_list)[source]
get_feed_dict(batch, is_training)[source]
get_global_step()[source]
static print_header(fp, train_results, valid_results, multi_task_mode=False)[source]
static print_on_training(fp, train_results, valid_results, cur_batch, cur_lr, multi_task_mode=False)[source]
save_checkpoint(cur_batch: int)[source]
save_compressed()[source]

Save the compressed graph.

train(train_data=None, valid_data=None)[source]
valid_on_the_fly(fp, train_batches, valid_batches, print_header=False)[source]