deepmd.pt.model.descriptor.repflow_layer

Contents

deepmd.pt.model.descriptor.repflow_layer#

Classes#

Module Contents#

class deepmd.pt.model.descriptor.repflow_layer.RepFlowLayer(e_rcut: float, e_rcut_smth: float, e_sel: int, a_rcut: float, a_rcut_smth: float, a_sel: int, ntypes: int, n_dim: int = 128, e_dim: int = 16, a_dim: int = 64, a_compress_rate: int = 0, a_compress_use_split: bool = False, a_compress_e_rate: int = 1, n_multi_edge_message: int = 1, axis_neuron: int = 4, update_angle: bool = True, optim_update: bool = True, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, smooth_edge_update: bool = False, activation_function: str = 'silu', update_style: str = 'res_residual', update_residual: float = 0.1, update_residual_init: str = 'const', precision: str = 'float64', seed: int | list[int] | None = None)[source]#

Bases: torch.nn.Module

epsilon = 0.0001[source]#
e_rcut[source]#
e_rcut_smth[source]#
ntypes[source]#
nnei[source]#
e_sel[source]#
sec[source]#
a_rcut[source]#
a_rcut_smth[source]#
a_sel[source]#
n_dim = 128[source]#
e_dim = 16[source]#
a_dim = 64[source]#
a_compress_rate = 0[source]#
n_multi_edge_message = 1[source]#
axis_neuron = 4[source]#
update_angle = True[source]#
activation_function = 'silu'[source]#
act[source]#
update_style = 'res_residual'[source]#
update_residual = 0.1[source]#
update_residual_init = 'const'[source]#
a_compress_e_rate = 1[source]#
a_compress_use_split = False[source]#
precision = 'float64'[source]#
seed = None[source]#
prec[source]#
optim_update = True[source]#
smooth_edge_update = False[source]#
use_dynamic_sel = False[source]#
sel_reduce_factor = 10.0[source]#
dynamic_e_sel[source]#
dynamic_a_sel[source]#
n_residual = [][source]#
e_residual = [][source]#
a_residual = [][source]#
edge_info_dim = 272[source]#
node_self_mlp[source]#
n_sym_dim = 576[source]#
node_sym_linear[source]#
node_edge_linear[source]#
edge_self_linear[source]#
static _cal_hg(edge_ebd: torch.Tensor, h2: torch.Tensor, nlist_mask: torch.Tensor, sw: torch.Tensor) torch.Tensor[source]#

Calculate the transposed rotation matrix.

Parameters:
edge_ebd

Neighbor-wise/Pair-wise edge embeddings, with shape nb x nloc x nnei x e_dim.

h2

Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3.

nlist_mask

Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei.

sw

The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, and remains 0 beyond rcut, with shape nb x nloc x nnei.

Returns:
hg

The transposed rotation matrix, with shape nb x nloc x 3 x e_dim.

static _cal_hg_dynamic(flat_edge_ebd: torch.Tensor, flat_h2: torch.Tensor, flat_sw: torch.Tensor, owner: torch.Tensor, num_owner: int, nb: int, nloc: int, scale_factor: float) torch.Tensor[source]#

Calculate the transposed rotation matrix.

Parameters:
flat_edge_ebd

Flatted neighbor-wise/pair-wise invariant rep tensors, with shape n_edge x e_dim.

flat_h2

Flatted neighbor-wise/pair-wise equivariant rep tensors, with shape n_edge x 3.

flat_sw

Flatted switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, and remains 0 beyond rcut, with shape n_edge.

owner

The owner index of the neighbor to reduce on.

num_ownerint

The total number of the owner.

nbint

The number of batches.

nlocint

The number of local atoms.

scale_factorfloat

The scale factor to apply after reduce.

Returns:
hg

The transposed rotation matrix, with shape nf x nloc x 3 x e_dim.

static _cal_grrg(h2g2: torch.Tensor, axis_neuron: int) torch.Tensor[source]#

Calculate the atomic invariant rep.

Parameters:
h2g2

The transposed rotation matrix, with shape nb x nloc x 3 x e_dim.

axis_neuron

Size of the submatrix.

Returns:
grrg

Atomic invariant rep, with shape nb x nloc x (axis_neuron x e_dim)

symmetrization_op(edge_ebd: torch.Tensor, h2: torch.Tensor, nlist_mask: torch.Tensor, sw: torch.Tensor, axis_neuron: int) torch.Tensor[source]#

Symmetrization operator to obtain atomic invariant rep.

Parameters:
edge_ebd

Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x e_dim.

h2

Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3.

nlist_mask

Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei.

sw

The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, and remains 0 beyond rcut, with shape nb x nloc x nnei.

axis_neuron

Size of the submatrix.

Returns:
grrg

Atomic invariant rep, with shape nb x nloc x (axis_neuron x e_dim)

symmetrization_op_dynamic(flat_edge_ebd: torch.Tensor, flat_h2: torch.Tensor, flat_sw: torch.Tensor, owner: torch.Tensor, num_owner: int, nb: int, nloc: int, scale_factor: float, axis_neuron: int) torch.Tensor[source]#

Symmetrization operator to obtain atomic invariant rep.

Parameters:
flat_edge_ebd

Flatted neighbor-wise/pair-wise invariant rep tensors, with shape n_edge x e_dim.

flat_h2

Flatted neighbor-wise/pair-wise equivariant rep tensors, with shape n_edge x 3.

flat_sw

Flatted switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, and remains 0 beyond rcut, with shape n_edge.

owner

The owner index of the neighbor to reduce on.

num_ownerint

The total number of the owner.

nbint

The number of batches.

nlocint

The number of local atoms.

scale_factorfloat

The scale factor to apply after reduce.

axis_neuron

Size of the submatrix.

Returns:
grrg

Atomic invariant rep, with shape nb x nloc x (axis_neuron x e_dim)

optim_angle_update(angle_ebd: torch.Tensor, node_ebd: torch.Tensor, edge_ebd: torch.Tensor, feat: str = 'edge') torch.Tensor[source]#
optim_angle_update_dynamic(flat_angle_ebd: torch.Tensor, node_ebd: torch.Tensor, flat_edge_ebd: torch.Tensor, n2a_index: torch.Tensor, eij2a_index: torch.Tensor, eik2a_index: torch.Tensor, feat: str = 'edge') torch.Tensor[source]#
optim_edge_update(node_ebd: torch.Tensor, node_ebd_ext: torch.Tensor, edge_ebd: torch.Tensor, nlist: torch.Tensor, feat: str = 'node') torch.Tensor[source]#
optim_edge_update_dynamic(node_ebd: torch.Tensor, node_ebd_ext: torch.Tensor, flat_edge_ebd: torch.Tensor, n2e_index: torch.Tensor, n_ext2e_index: torch.Tensor, feat: str = 'node') torch.Tensor[source]#
forward(node_ebd_ext: torch.Tensor, edge_ebd: torch.Tensor, h2: torch.Tensor, angle_ebd: torch.Tensor, nlist: torch.Tensor, nlist_mask: torch.Tensor, sw: torch.Tensor, a_nlist: torch.Tensor, a_nlist_mask: torch.Tensor, a_sw: torch.Tensor, edge_index: torch.Tensor, angle_index: torch.Tensor)[source]#
Parameters:
node_ebd_extnf x nall x n_dim

Extended node embedding.

edge_ebdnf x nloc x nnei x e_dim

Edge embedding.

h2nf x nloc x nnei x 3

Pair-atom channel, equivariant.

angle_ebdnf x nloc x a_nnei x a_nnei x a_dim

Angle embedding.

nlistnf x nloc x nnei

Neighbor list. (padded neis are set to 0)

nlist_masknf x nloc x nnei

Masks of the neighbor list. real nei 1 otherwise 0

swnf x nloc x nnei

Switch function.

a_nlistnf x nloc x a_nnei

Neighbor list for angle. (padded neis are set to 0)

a_nlist_masknf x nloc x a_nnei

Masks of the neighbor list for angle. real nei 1 otherwise 0

a_swnf x nloc x a_nnei

Switch function for angle.

edge_indexOptional for dynamic sel, n_edge x 2
n2e_indexn_edge

Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i).

n_ext2e_indexn_edge

Broadcast indices from extended node(j) to edge(ij).

angle_indexOptional for dynamic sel, n_angle x 3
n2a_indexn_angle

Broadcast indices from extended node(j) to angle(ijk).

eij2a_indexn_angle

Broadcast indices from extended edge(ij) to angle(ijk), or reduction indices from angle(ijk) to edge(ij).

eik2a_indexn_angle

Broadcast indices from extended edge(ik) to angle(ijk).

Returns:
n_updated: nf x nloc x n_dim

Updated node embedding.

e_updated: nf x nloc x nnei x e_dim

Updated edge embedding.

a_updatednf x nloc x a_nnei x a_nnei x a_dim

Updated angle embedding.

list_update_res_avg(update_list: list[torch.Tensor]) torch.Tensor[source]#
list_update_res_incr(update_list: list[torch.Tensor]) torch.Tensor[source]#
list_update_res_residual(update_list: list[torch.Tensor], update_name: str = 'node') torch.Tensor[source]#
list_update(update_list: list[torch.Tensor], update_name: str = 'node') torch.Tensor[source]#
serialize() dict[source]#

Serialize the networks to a dict.

Returns:
dict

The serialized networks.

classmethod deserialize(data: dict) RepFlowLayer[source]#

Deserialize the networks from a dict.

Parameters:
datadict

The dict to deserialize from.