# SPDX-License-Identifier: LGPL-3.0-or-later
"""Lightweight helpers for official MACE-OFF checkpoint selection and download."""
from __future__ import annotations
import concurrent.futures
import hashlib
import logging
import os
import shutil
import tempfile
import time
import urllib.error
import urllib.parse
import urllib.request
from pathlib import Path
logger = logging.getLogger(__name__)
DOWNLOAD_TIMEOUT_SECONDS = 120
SOURCE_PROBE_TIMEOUT_SECONDS = 8
_MACE_OFF_BASE_URL = "https://raw.githubusercontent.com/ACEsuit/mace-off/v0.2"
_GHFAST_PREFIX = "https://ghfast.top/"
[docs]
MACE_OFF_MODELS = {
"off23_small": f"{_MACE_OFF_BASE_URL}/mace_off23/MACE-OFF23_small.model",
"off23_medium": f"{_MACE_OFF_BASE_URL}/mace_off23/MACE-OFF23_medium.model",
"off23_large": f"{_MACE_OFF_BASE_URL}/mace_off23/MACE-OFF23_large.model",
"off24_medium": f"{_MACE_OFF_BASE_URL}/mace_off24/MACE-OFF24_medium.model",
}
[docs]
MACE_OFF_MODEL_URLS = {
model_name: [
url,
f"{_GHFAST_PREFIX}{url}",
]
for model_name, url in MACE_OFF_MODELS.items()
}
[docs]
MACE_OFF_MODEL_SHA256 = {
"off23_small": "165cce4cfec5a34b9c64d4ebf95de15d71106bb584b7291c8470f0749977c46f",
"off23_medium": "4842c52ad210d6e1f84d6cf1ffa70fae25a7e0d755ed55cf223f43913f587db7",
"off23_large": "a29e397dbf3e7a24ac50a9b0dfc919bd5a62efa346f5895a6237b0950c1d76f4",
"off24_medium": "e5ccf5837f685899811a68754e7c994393bfd1a81720393b03c643b46c70bc69",
}
_MACE_OFF_MODEL_ALIASES = {
"small": "off23_small",
"medium": "off23_medium",
"large": "off23_large",
}
[docs]
MACE_OFF_MODEL_CHOICES = tuple(
sorted({*MACE_OFF_MODELS.keys(), *_MACE_OFF_MODEL_ALIASES.keys()}),
)
def _canonical_model_name(model_name: str) -> str:
return _MACE_OFF_MODEL_ALIASES.get(model_name, model_name)
def _validate_download_url(url: str) -> None:
"""Validate that download URL uses HTTPS scheme."""
parsed = urllib.parse.urlparse(url)
if parsed.scheme != "https":
msg = f"Unsupported URL scheme for download: {parsed.scheme}"
raise ValueError(msg)
def _model_download_urls(model_name: str) -> list[str]:
"""Return candidate download URLs for a canonical model name."""
urls = MACE_OFF_MODEL_URLS[model_name]
seen: set[str] = set()
unique_urls: list[str] = []
for url in urls:
if url not in seen:
seen.add(url)
unique_urls.append(url)
return unique_urls
def _sha256sum(path: Path) -> str:
hasher = hashlib.sha256()
with path.open("rb") as fh:
for chunk in iter(lambda: fh.read(1024 * 1024), b""):
hasher.update(chunk)
return hasher.hexdigest()
def _validate_model_file(model_path: Path, expected_sha256: str) -> bool:
return model_path.exists() and _sha256sum(model_path) == expected_sha256
def _probe_download_url(url: str) -> float | None:
"""Probe one URL and return latency seconds if reachable; else None."""
_validate_download_url(url)
request = urllib.request.Request( # noqa: S310
url,
headers={"Range": "bytes=0-0"},
method="GET",
)
start = time.monotonic()
try:
with urllib.request.urlopen( # noqa: S310
request,
timeout=SOURCE_PROBE_TIMEOUT_SECONDS,
):
pass
except (urllib.error.URLError, OSError, ValueError):
return None
return time.monotonic() - start
def _rank_download_urls(urls: list[str]) -> list[str]:
"""Rank candidate URLs by probe latency (fastest first)."""
if len(urls) <= 1:
return urls
results: dict[str, float] = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=min(4, len(urls))) as exe:
future_to_url = {exe.submit(_probe_download_url, url): url for url in urls}
for future in concurrent.futures.as_completed(future_to_url):
url = future_to_url[future]
latency = future.result()
if latency is not None:
results[url] = latency
ranked_ok = sorted(results, key=lambda url: results[url])
ranked_fail = [url for url in urls if url not in results]
return ranked_ok + ranked_fail
def _download_file(url: str, destination: Path) -> None:
"""Download URL content to destination atomically."""
_validate_download_url(url)
destination.parent.mkdir(parents=True, exist_ok=True)
tmp_path: Path | None = None
try:
with tempfile.NamedTemporaryFile(
"wb",
dir=destination.parent,
prefix=f".{destination.name}.",
suffix=".part",
delete=False,
) as out_file:
tmp_path = Path(out_file.name)
with urllib.request.urlopen( # noqa: S310
url,
timeout=DOWNLOAD_TIMEOUT_SECONDS,
) as response:
shutil.copyfileobj(response, out_file)
except Exception:
if tmp_path is not None:
tmp_path.unlink(missing_ok=True)
raise
tmp_path.replace(destination)
[docs]
def get_mace_off_cache_dir() -> Path:
"""Get the cache directory for MACE-OFF models."""
if "XDG_CACHE_HOME" in os.environ:
cache_dir = Path(os.environ["XDG_CACHE_HOME"]) / "deepmd-gnn" / "mace-off"
else:
cache_dir = Path.home() / ".cache" / "deepmd-gnn" / "mace-off"
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir
[docs]
def download_mace_off_model(
model_name: str,
cache_dir: Path | None = None,
) -> Path:
"""Download a selected official MACE-OFF model file.
Parameters
----------
model_name
Canonical model name (for example ``off23_small``) or one of the
compatibility aliases ``small`` / ``medium`` / ``large``.
cache_dir
Cache directory. Defaults to :func:`get_mace_off_cache_dir`.
"""
canonical_name = _canonical_model_name(model_name)
if canonical_name not in MACE_OFF_MODELS:
msg = (
f"Unknown MACE-OFF model: {model_name}. Available models: "
f"{sorted(MACE_OFF_MODELS)}"
)
raise ValueError(msg)
if cache_dir is None:
cache_dir = get_mace_off_cache_dir()
else:
cache_dir = Path(cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
urls = _model_download_urls(canonical_name)
expected_sha256 = MACE_OFF_MODEL_SHA256[canonical_name]
filename = Path(MACE_OFF_MODELS[canonical_name]).name
model_path = cache_dir / filename
if _validate_model_file(model_path, expected_sha256):
logger.info("Using cached model: %s", model_path)
return model_path
if model_path.exists():
logger.warning(
"Cached model SHA256 mismatch for %s; re-downloading %s",
model_path,
canonical_name,
)
model_path.unlink(missing_ok=True)
else:
logger.info("Downloading MACE-OFF model %s", canonical_name)
ranked_urls = _rank_download_urls(urls)
if len(ranked_urls) > 1:
logger.info(
"Selecting fastest source among %d candidates...",
len(ranked_urls),
)
last_error: Exception | None = None
for idx, url in enumerate(ranked_urls, start=1):
logger.info(
"Downloading MACE-OFF model %s (source %d/%d): %s",
canonical_name,
idx,
len(ranked_urls),
url,
)
try:
_download_file(url, model_path)
except (urllib.error.URLError, OSError, ValueError) as exc:
last_error = exc
logger.warning("Download attempt failed from %s: %s", url, exc)
continue
actual_sha256 = _sha256sum(model_path)
if actual_sha256 != expected_sha256:
model_path.unlink(missing_ok=True)
msg = (
f"SHA256 mismatch for downloaded model {canonical_name}: "
f"expected {expected_sha256}, got {actual_sha256}"
)
last_error = ValueError(msg)
logger.warning("SHA256 verification failed from source: %s", url)
logger.warning("Expected: %s", expected_sha256)
logger.warning("Actual: %s", actual_sha256)
continue
return model_path
if isinstance(last_error, ValueError):
raise last_error
msg = f"Failed to download model '{canonical_name}' from all sources"
raise RuntimeError(msg) from last_error
__all__ = [
"MACE_OFF_MODELS",
"MACE_OFF_MODEL_CHOICES",
"MACE_OFF_MODEL_SHA256",
"MACE_OFF_MODEL_URLS",
"download_mace_off_model",
"get_mace_off_cache_dir",
]