Source code for deepmd_gnn.export

# SPDX-License-Identifier: LGPL-3.0-or-later
"""Helpers used by PyTorch-exportable model paths."""

import torch


[docs] def pad_nlist_for_export(nlist: torch.Tensor) -> torch.Tensor: """Append a sentinel neighbor so symbolic export keeps neighbor axes dynamic.""" pad = -torch.ones( (*nlist.shape[:2], 1), dtype=nlist.dtype, device=nlist.device, ) return torch.cat([nlist, pad], dim=-1)
[docs] def clear_export_guards_once(traced: torch.nn.Module) -> None: """Clear over-specialized guards from the next export of ``traced``. ``make_fx`` traces may specialize symbolic atom counts too aggressively. DeePMD's export path calls ``torch.export.export`` immediately after tracing, so this one-shot wrapper relaxes those constraints only for that export. """ original_export = torch.export.export def strip_deferred_assertions(exported: torch.export.ExportedProgram) -> None: graph = exported.graph_module.graph for node in list(graph.nodes): if ( node.op == "call_function" and node.target is torch.ops.aten._assert_scalar.default # noqa: SLF001 ): node.args = (True, node.args[1]) exported.graph_module.recompile() def relax_range_constraints(exported: torch.export.ExportedProgram) -> None: relaxed = exported.range_constraints.copy() for symbol, value_range in exported.range_constraints.items(): try: should_relax = bool(value_range.lower > 1) except TypeError: should_relax = False if should_relax: relaxed[symbol] = type(value_range)(1, value_range.upper) exported._range_constraints = relaxed # noqa: SLF001 def export_with_guard_cleanup( *export_args: object, **export_kwargs: object, ) -> torch.export.ExportedProgram: try: exported = original_export(*export_args, **export_kwargs) if export_args and export_args[0] is traced: exported._guards_code = [] # noqa: SLF001 strip_deferred_assertions(exported) relax_range_constraints(exported) return exported finally: if torch.export.export is export_with_guard_cleanup: torch.export.export = original_export torch.export.export = export_with_guard_cleanup