deepmd.pt_expt.train.wrapper
============================

.. py:module:: deepmd.pt_expt.train.wrapper


Attributes
----------

.. autoapisummary::

   deepmd.pt_expt.train.wrapper.log


Classes
-------

.. autoapisummary::

   deepmd.pt_expt.train.wrapper.ModelWrapper


Module Contents
---------------

.. py:data:: log

.. py:class:: ModelWrapper(model: torch.nn.Module, loss: torch.nn.Module | None = None, model_params: dict[str, Any] | None = None)

   Bases: :py:obj:`torch.nn.Module`


   
   Simplified model wrapper that bundles a model and a loss.

   Single-task only for now (no multi-task support).

   :Parameters:

       **model** : :obj:`torch.nn.Module`
           The model to train.

       **loss** : :obj:`torch.nn.Module`
           The loss module.

       **model_params** : :class:`python:dict`, :obj:`optional`
           Model parameters to store as extra state.














   ..
       !! processed by numpydoc !!

   .. py:attribute:: model_params


   .. py:attribute:: train_infos
      :type:  dict[str, Any]


   .. py:attribute:: model


   .. py:attribute:: loss
      :value: None



   .. py:attribute:: inference_only


   .. py:method:: forward(coord: torch.Tensor, atype: torch.Tensor, box: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, cur_lr: float | torch.Tensor | None = None, label: dict[str, torch.Tensor] | None = None, do_atomic_virial: bool = False) -> tuple[dict[str, torch.Tensor], torch.Tensor | None, dict | None]


   .. py:method:: set_extra_state(state: dict) -> None


   .. py:method:: get_extra_state() -> dict


