deepmd.pt.train.wrapper

Module Contents

Classes

ModelWrapper

Base class for all neural network modules.

Attributes

log

deepmd.pt.train.wrapper.log[source]
class deepmd.pt.train.wrapper.ModelWrapper(model: torch.nn.Module | Dict, loss: torch.nn.Module | Dict = None, model_params=None, shared_links=None)[source]

Bases: torch.nn.Module

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

Variables:

training (bool) – Boolean represents whether this module is in training or evaluation mode.

share_params(shared_links, resume=False)[source]

Share the parameters of classes following rules defined in shared_links during multitask training. If not start from checkpoint (resume is False), some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes.

forward(coord, atype, spin: torch.Tensor | None = None, box: torch.Tensor | None = None, cur_lr: torch.Tensor | None = None, label: torch.Tensor | None = None, task_key: torch.Tensor | None = None, inference_only=False, do_atomic_virial=False, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None)[source]
set_extra_state(state: Dict)[source]

Set extra state contained in the loaded state_dict.

This function is called from load_state_dict() to handle any extra state found within the state_dict. Implement this function and a corresponding get_extra_state() for your module if you need to store extra state within its state_dict.

Parameters:

state (dict) – Extra state from the state_dict

get_extra_state() Dict[source]

Return any extra state to include in the module’s state_dict.

Implement this and a corresponding set_extra_state() for your module if you need to store extra state. This function is called when building the module’s state_dict().

Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.

Returns:

Any extra state to store in the module’s state_dict

Return type:

object