deepmd.pt_expt.train.wrapper#
Attributes#
Classes#
Simplified model wrapper that bundles a model and a loss. |
Module Contents#
- class deepmd.pt_expt.train.wrapper.ModelWrapper(model: torch.nn.Module, loss: torch.nn.Module | None = None, model_params: dict[str, Any] | None = None)[source]#
Bases:
torch.nn.ModuleSimplified model wrapper that bundles a model and a loss.
Single-task only for now (no multi-task support).
- Parameters:
- model
torch.nn.Module The model to train.
- loss
torch.nn.Module The loss module.
- model_params
dict,optional Model parameters to store as extra state.
- model
- forward(coord: torch.Tensor, atype: torch.Tensor, box: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, cur_lr: float | torch.Tensor | None = None, label: dict[str, torch.Tensor] | None = None, do_atomic_virial: bool = False) tuple[dict[str, torch.Tensor], torch.Tensor | None, dict | None][source]#