Source code for deepmd_gnn.edge
"""Neighbor-list C++ op wrappers used by graph model export paths."""
import torch
import deepmd_gnn.op # noqa: F401
[docs]
def _mm_tensor(mm_types: list[int]) -> torch.Tensor:
return torch.tensor(mm_types, dtype=torch.int64, device="cpu")
if hasattr(torch.ops.deepmd_gnn, "dense_edge_index"):
@torch.library.register_fake("deepmd_gnn::dense_edge_index")
[docs]
def _fake_dense_edge_index(
nlist: torch.Tensor,
_atype: torch.Tensor,
_mm: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if nlist.dim() == 2:
dense_edge_size = nlist.shape[0] * nlist.shape[1]
elif nlist.dim() == 3:
dense_edge_size = nlist.shape[0] * nlist.shape[1] * nlist.shape[2]
else:
msg = "nlist must be 2D or 3D"
raise ValueError(msg)
edge_index = nlist.new_empty((dense_edge_size, 2), dtype=torch.int64)
edge_mask = nlist.new_empty((dense_edge_size,), dtype=torch.bool)
return edge_index, edge_mask
[docs]
def dense_edge_index(
nlist: torch.Tensor,
extended_atype: torch.Tensor,
mm_types: list[int],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Build dense edge indices and validity mask with the C++ op."""
return torch.ops.deepmd_gnn.dense_edge_index(
nlist,
extended_atype,
_mm_tensor(mm_types),
)