deepmd.pt_expt.train.wrapper#

Attributes#

Classes#

ModelWrapper

Simplified model wrapper that bundles a model and a loss.

Module Contents#

deepmd.pt_expt.train.wrapper.log[source]#
class deepmd.pt_expt.train.wrapper.ModelWrapper(model: torch.nn.Module, loss: torch.nn.Module | None = None, model_params: dict[str, Any] | None = None)[source]#

Bases: torch.nn.Module

Simplified model wrapper that bundles a model and a loss.

Single-task only for now (no multi-task support).

Parameters:
modeltorch.nn.Module

The model to train.

losstorch.nn.Module

The loss module.

model_paramsdict, optional

Model parameters to store as extra state.

model_params[source]#
train_infos: dict[str, Any][source]#
model[source]#
loss = None[source]#
inference_only[source]#
forward(coord: torch.Tensor, atype: torch.Tensor, box: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, cur_lr: float | torch.Tensor | None = None, label: dict[str, torch.Tensor] | None = None, do_atomic_virial: bool = False) tuple[dict[str, torch.Tensor], torch.Tensor | None, dict | None][source]#
set_extra_state(state: dict) None[source]#
get_extra_state() dict[source]#