Source code for deepmd.utils.eval_metrics

# SPDX-License-Identifier: LGPL-3.0-or-later

from __future__ import (
    annotations,
)

from dataclasses import (
    dataclass,
)

import numpy as np

[docs] FULL_VALIDATION_METRIC_KEY_MAP = { "e:mae": "mae_e_per_atom", "e:rmse": "rmse_e_per_atom", "f:mae": "mae_f", "f:rmse": "rmse_f", "v:mae": "mae_v_per_atom", "v:rmse": "rmse_v_per_atom", }
[docs] FULL_VALIDATION_WEIGHTED_METRIC_KEYS = { "energy_per_atom": ("mae_e_per_atom", "rmse_e_per_atom"), "force": ("mae_f", "rmse_f"), "virial_per_atom": ("mae_v_per_atom", "rmse_v_per_atom"), }
[docs] FULL_VALIDATION_METRIC_FAMILY_BY_KEY = { "mae_e_per_atom": "e", "rmse_e_per_atom": "e", "mae_f": "f", "rmse_f": "f", "mae_v_per_atom": "v", "rmse_v_per_atom": "v", }
[docs] DP_TEST_WEIGHTED_METRIC_KEYS = { "energy": ("mae_e", "rmse_e"), "energy_per_atom": ("mae_ea", "rmse_ea"), "force": ("mae_f", "rmse_f"), "virial": ("mae_v", "rmse_v"), "virial_per_atom": ("mae_va", "rmse_va"), }
[docs] DP_TEST_SPIN_WEIGHTED_METRIC_KEYS = { "force_real": ("mae_fr", "rmse_fr"), "force_magnetic": ("mae_fm", "rmse_fm"), }
[docs] DP_TEST_WEIGHTED_FORCE_METRIC_KEYS = ("mae_fw", "rmse_fw")
[docs] DP_TEST_HESSIAN_METRIC_KEYS = ("mae_h", "rmse_h")
[docs] def mae(diff: np.ndarray) -> float: """Calculate mean absolute error.""" return float(np.mean(np.abs(diff)))
[docs] def rmse(diff: np.ndarray) -> float: """Calculate root mean square error.""" return float(np.sqrt(np.mean(diff * diff)))
@dataclass(frozen=True)
[docs] class ErrorStat: """One weighted MAE/RMSE pair."""
[docs] mae: float
[docs] rmse: float
[docs] weight: float
[docs] def as_weighted_average_errors( self, mae_key: str, rmse_key: str, ) -> dict[str, tuple[float, float]]: """Convert one metric pair into `weighted_average` inputs.""" return { mae_key: (self.mae, self.weight), rmse_key: (self.rmse, self.weight), }
@dataclass(frozen=True)
[docs] class EnergyTypeEvalMetrics: """Shared energy-type metrics for one evaluation batch or system."""
[docs] energy: ErrorStat | None = None
[docs] energy_per_atom: ErrorStat | None = None
[docs] force: ErrorStat | None = None
[docs] virial: ErrorStat | None = None
[docs] virial_per_atom: ErrorStat | None = None
[docs] def as_weighted_average_errors( self, metric_keys: dict[str, tuple[str, str]], ) -> dict[str, tuple[float, float]]: """Project shared metrics into caller-specific error dict keys.""" errors: dict[str, tuple[float, float]] = {} for metric_name, (mae_key, rmse_key) in metric_keys.items(): metric = getattr(self, metric_name) if metric is not None: errors.update(metric.as_weighted_average_errors(mae_key, rmse_key)) return errors
@dataclass(frozen=True)
[docs] class SpinForceEvalMetrics: """Shared spin-force metrics for one evaluation batch or system."""
[docs] force_real: ErrorStat | None = None
[docs] force_magnetic: ErrorStat | None = None
[docs] def as_weighted_average_errors( self, metric_keys: dict[str, tuple[str, str]], ) -> dict[str, tuple[float, float]]: """Project shared spin metrics into caller-specific error dict keys.""" errors: dict[str, tuple[float, float]] = {} for metric_name, (mae_key, rmse_key) in metric_keys.items(): metric = getattr(self, metric_name) if metric is not None: errors.update(metric.as_weighted_average_errors(mae_key, rmse_key)) return errors
[docs] def compute_error_stat( prediction: np.ndarray, reference: np.ndarray, *, scale: float = 1.0, ) -> ErrorStat: """Compute one MAE/RMSE pair from aligned prediction and reference arrays.""" diff = prediction - reference return ErrorStat( mae=mae(diff) * scale, rmse=rmse(diff) * scale, weight=float(diff.size), )
[docs] def compute_weighted_error_stat( prediction: np.ndarray, reference: np.ndarray, weight: np.ndarray, ) -> ErrorStat: """Compute weighted MAE/RMSE from aligned prediction and reference arrays.""" diff = prediction - reference weight_sum = float(np.sum(weight)) if weight_sum <= 0.0: return ErrorStat(mae=0.0, rmse=0.0, weight=weight_sum) return ErrorStat( mae=float(np.sum(np.abs(diff) * weight) / weight_sum), rmse=float(np.sqrt(np.sum(diff * diff * weight) / weight_sum)), weight=weight_sum, )
[docs] def compute_energy_type_metrics( prediction: dict[str, np.ndarray], test_data: dict[str, np.ndarray], natoms: int, has_pbc: bool, ) -> EnergyTypeEvalMetrics: """Compute shared energy-type metrics for one evaluation dataset.""" energy = None energy_per_atom = None force = None virial = None virial_per_atom = None if bool(test_data.get("find_energy", 0.0)): energy = compute_error_stat( prediction["energy"].reshape(-1, 1), test_data["energy"].reshape(-1, 1), ) energy_per_atom = compute_error_stat( prediction["energy"].reshape(-1, 1), test_data["energy"].reshape(-1, 1), scale=1.0 / natoms, ) if bool(test_data.get("find_force", 0.0)): force = compute_error_stat( prediction["force"].reshape(-1), test_data["force"].reshape(-1), ) if has_pbc and bool(test_data.get("find_virial", 0.0)): virial = compute_error_stat( prediction["virial"].reshape(-1, 9), test_data["virial"].reshape(-1, 9), ) virial_per_atom = compute_error_stat( prediction["virial"].reshape(-1, 9), test_data["virial"].reshape(-1, 9), scale=1.0 / natoms, ) return EnergyTypeEvalMetrics( energy=energy, energy_per_atom=energy_per_atom, force=force, virial=virial, virial_per_atom=virial_per_atom, )
[docs] def compute_spin_force_metrics( force_real_prediction: np.ndarray, force_real_reference: np.ndarray, *, force_magnetic_prediction: np.ndarray | None = None, force_magnetic_reference: np.ndarray | None = None, ) -> SpinForceEvalMetrics: """Compute spin-aware force metrics from aligned real and magnetic forces.""" force_real = compute_error_stat(force_real_prediction, force_real_reference) force_magnetic = None if force_magnetic_prediction is not None or force_magnetic_reference is not None: if force_magnetic_prediction is None or force_magnetic_reference is None: raise ValueError( "Spin magnetic force metrics require both prediction and reference." ) force_magnetic = compute_error_stat( force_magnetic_prediction, force_magnetic_reference, ) return SpinForceEvalMetrics( force_real=force_real, force_magnetic=force_magnetic, )