Source code for deepmd_gnn.cli

"""Command-line interface for DeePMD-GNN."""

from __future__ import annotations

import argparse
import sys
from importlib import import_module
from pathlib import Path
from typing import TYPE_CHECKING, NoReturn

from deepmd_gnn.mace_off_cli import (
    MACE_OFF_MODEL_CHOICES,
    download_mace_off_model,
    get_mace_off_cache_dir,
)

if TYPE_CHECKING:
    from collections.abc import Sequence


[docs] def build_parser() -> argparse.ArgumentParser: """Build the top-level command-line parser.""" parser = argparse.ArgumentParser( prog="deepmd-gnn", description="Command-line tools for DeePMD-GNN.", ) subparsers = parser.add_subparsers(dest="command", required=True) mace_off_parser = subparsers.add_parser( "mace-off", help="Utilities for selected official MACE-OFF checkpoints.", ) mace_off_subparsers = mace_off_parser.add_subparsers( dest="mace_off_command", required=True, ) convert_parser = mace_off_subparsers.add_parser( "convert", help="Convert a supported MACE-OFF checkpoint to a DeePMD-GNN TorchScript model.", ) convert_parser.add_argument( "output_file", help="Output path for the exported DeePMD-GNN TorchScript model.", ) convert_parser.add_argument( "--sel", type=int, required=True, help="Required DeePMD runtime neighbor cap.", ) convert_source_group = convert_parser.add_mutually_exclusive_group(required=True) convert_source_group.add_argument( "--model", choices=MACE_OFF_MODEL_CHOICES, help="Official MACE-OFF model name or alias to download.", ) convert_source_group.add_argument( "--model-path", type=Path, help="Use a trusted local checkpoint path instead of downloading an official model.", ) convert_parser.add_argument( "--cache-dir", type=Path, help="Cache directory for downloaded official checkpoints.", ) download_parser = mace_off_subparsers.add_parser( "download", help="Download a supported official MACE-OFF checkpoint.", ) download_parser.add_argument( "--model", required=True, choices=MACE_OFF_MODEL_CHOICES, help="Official MACE-OFF model name or alias to download.", ) download_parser.add_argument( "--cache-dir", type=Path, help="Cache directory for downloaded official checkpoints.", ) mace_off_subparsers.add_parser( "cache-dir", help="Print the cache directory used for MACE-OFF checkpoints.", ) return parser
[docs] def _parser_error(parser: argparse.ArgumentParser, message: str) -> NoReturn: parser.error(message) msg = "unreachable" raise AssertionError(msg)
[docs] def main(argv: Sequence[str] | None = None) -> int: """Run the command-line interface.""" parser = build_parser() args = parser.parse_args(list(argv) if argv is not None else None) if args.command == "mace-off": if args.mace_off_command == "convert": convert_mace_off_to_deepmd = import_module( "deepmd_gnn.mace_off", ).convert_mace_off_to_deepmd output_path = convert_mace_off_to_deepmd( output_file=args.output_file, sel=args.sel, model_name=args.model, model_path=args.model_path, cache_dir=args.cache_dir, ) sys.stdout.write(f"{output_path}\n") return 0 if args.mace_off_command == "download": model_path = download_mace_off_model( model_name=args.model, cache_dir=args.cache_dir, ) sys.stdout.write(f"{model_path}\n") return 0 if args.mace_off_command == "cache-dir": sys.stdout.write(f"{get_mace_off_cache_dir()}\n") return 0 _parser_error(parser, "Unhandled command")