deepmd.pt.train.wrapper#

Attributes#

Classes#

Module Contents#

deepmd.pt.train.wrapper.log[source]#
class deepmd.pt.train.wrapper.ModelWrapper(model: torch.nn.Module | dict, loss: torch.nn.Module | dict = None, model_params=None, shared_links=None)[source]#

Bases: torch.nn.Module

model_params[source]#
train_infos[source]#
multi_task = False[source]#
model[source]#
loss = None[source]#
inference_only[source]#
share_params(shared_links, resume=False) None[source]#

Share the parameters of classes following rules defined in shared_links during multitask training. If not start from checkpoint (resume is False), some separated parameters (e.g. mean and stddev) will be re-calculated across different classes.

forward(coord, atype, spin: torch.Tensor | None = None, box: torch.Tensor | None = None, cur_lr: torch.Tensor | None = None, label: torch.Tensor | None = None, task_key: torch.Tensor | None = None, inference_only=False, do_atomic_virial=False, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None)[source]#
set_extra_state(state: dict) None[source]#
get_extra_state() dict[source]#