Source code for deepmd_gnn.stat_compat
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Compatibility accessors for DeePMD observed-type statistics helpers."""
import importlib
from collections.abc import Callable
from typing import cast
[docs]
ObservedTypeRestore = Callable[[object], list[str] | None]
[docs]
ObservedTypeSave = Callable[[object, list[str]], None]
[docs]
ObservedTypeCollect = Callable[[object, list[str]], list[str]]
[docs]
def load_observed_type_stat_compat() -> tuple[
ObservedTypeRestore,
ObservedTypeSave,
ObservedTypeCollect,
]:
"""Return observed-type statistic helpers across supported DeePMD versions."""
try:
stat_mod = importlib.import_module("deepmd.dpmodel.utils.stat")
except ImportError:
def collect_observed_types(sampled: object, type_map: list[str]) -> list[str]:
"""Fallback for older deepmd-kit without observed-type helpers."""
_ = sampled, type_map
return []
def restore_observed_type_from_file(
stat_file_path: object,
) -> list[str] | None:
"""Fallback for older deepmd-kit without observed-type helpers."""
_ = stat_file_path
return None
def save_observed_type_to_file(
stat_file_path: object,
observed_type: list[str],
) -> None:
"""Fallback for older deepmd-kit without observed-type helpers."""
_ = stat_file_path, observed_type
return (
restore_observed_type_from_file,
save_observed_type_to_file,
collect_observed_types,
)
else:
restore = cast(
"ObservedTypeRestore",
stat_mod._restore_observed_type_from_file, # noqa: SLF001
)
save = cast(
"ObservedTypeSave",
stat_mod._save_observed_type_to_file, # noqa: SLF001
)
collect = cast("ObservedTypeCollect", stat_mod.collect_observed_types)
return (restore, save, collect)