Source code for tscf_eval.evaluator.evaluator

"""Evaluator orchestration class.

This module provides the :class:`Evaluator` class that orchestrates the
computation of multiple metrics over pairs of original instances and their
counterfactuals.

Classes
-------
Evaluator
    Orchestration class that runs a collection of metrics and returns
    results as a dictionary mapping metric names to computed values.

Example
-------
>>> from tscf_eval.evaluator import Evaluator, Validity, Proximity, Sparsity
>>> import numpy as np
>>>
>>> # Create evaluator with multiple metrics
>>> evaluator = Evaluator([Validity(), Proximity(p=2), Sparsity()])
>>>
>>> # Evaluate counterfactuals
>>> X = np.random.randn(100, 50)
>>> X_cf = X + np.random.randn(100, 50) * 0.1
>>> results = evaluator.evaluate(X, X_cf, y=np.zeros(100), y_cf=np.ones(100))

See Also
--------
tscf_eval.evaluator.base : Metric abstract base class.
tscf_eval.evaluator.metrics : Built-in metric implementations.
"""

from __future__ import annotations

import contextlib
import time
from typing import TYPE_CHECKING, Any
import warnings

import numpy as np

if TYPE_CHECKING:
    from collections.abc import Iterable

    from .base import Metric


[docs] class Evaluator: """Run a collection of :class:`Metric` instances over example pairs. The Evaluator orchestrates the computation of multiple metrics over pairs of original instances and their counterfactuals. It handles progress reporting, error handling, and result aggregation. Parameters ---------- metrics : iterable of Metric Collection of metric instances to compute during evaluation. Attributes ---------- metrics : list of Metric The configured metric instances. Examples -------- >>> from tscf_eval.evaluator import Evaluator, Validity, Proximity, Sparsity >>> import numpy as np >>> >>> # Create evaluator with multiple metrics >>> evaluator = Evaluator([Validity(), Proximity(p=2), Sparsity()]) >>> >>> # Evaluate counterfactuals >>> X = np.random.randn(100, 50) # 100 instances, 50 time points >>> X_cf = X + np.random.randn(100, 50) * 0.1 >>> results = evaluator.evaluate(X, X_cf, y=np.zeros(100), y_cf=np.ones(100)) >>> >>> print(results['validity'], results['proximity_l2'], results['sparsity']) """
[docs] def __init__(self, metrics: Iterable[Metric]): """Initialize the Evaluator with a collection of metrics. Parameters ---------- metrics : iterable of Metric Metric instances to compute. Each metric must implement ``name()`` and ``compute(X, X_cf, **kwargs)``. """ self.metrics: list[Metric] = list(metrics) # Cached names for quick validation self._metric_names = [m.name() for m in self.metrics]
[docs] def evaluate(self, X: np.ndarray, X_cf: np.ndarray, **kwargs) -> dict[str, Any]: """Compute all configured metrics and return a mapping name -> result. The evaluator forwards all provided ``kwargs`` to each metric's ``compute`` method. To avoid silent behavior, if the caller provides ``time_per_instance`` then an ``Efficiency``-style metric must be present (i.e., a metric whose ``name()`` returns ``"efficiency_time_s"``) that will consume that argument and report a canonical value. This avoids the evaluator guessing at how to aggregate timings. Parameters ---------- X : np.ndarray Original instances, shape ``(M, ...)``. X_cf : np.ndarray Counterfactual instances, shape matching ``X``. **kwargs Forwarded to each metric. Common kwargs include: - ``model``: Classifier for metrics like Validity, Controllability. - ``X_train``: Training data for Plausibility, Robustness. - ``y``, ``y_cf``: Labels for Validity when model not provided. - ``time_per_instance``: Timings for Efficiency metric. Returns ------- dict Mapping from metric name to computed result. Also includes ``'_evaluator_time_s'`` with total evaluation time. Raises ------ ValueError If ``X`` and ``X_cf`` have different numbers of instances, or if ``time_per_instance`` is provided without an Efficiency metric. TypeError If a metric raises TypeError due to unexpected kwargs. """ X, X_cf = self._validate_inputs(X, X_cf, **kwargs) self._cache_model_predictions(X, X_cf, kwargs) results: dict[str, Any] = {} tqdm = self._get_progress_bar() total_metrics = len(self.metrics) pbar = tqdm(total=total_metrics, desc="Evaluating metrics", unit="metric", leave=False) start_time = time.time() try: for metric in self.metrics: try: result = metric.compute(X, X_cf, **kwargs) except TypeError as exc: with contextlib.suppress(Exception): pbar.close() raise TypeError( f"Metric '{metric.name()}' raised TypeError when called " f"with evaluator kwargs: {exc}" ) from exc if isinstance(result, dict): for key, value in result.items(): results[key] = value else: results[metric.name()] = result with contextlib.suppress(Exception): pbar.update(1) finally: with contextlib.suppress(Exception): pbar.close() elapsed = time.time() - start_time results["_evaluator_time_s"] = float(elapsed) return results
def _validate_inputs( self, X: np.ndarray, X_cf: np.ndarray, **kwargs ) -> tuple[np.ndarray, np.ndarray]: """Validate and coerce evaluation inputs. Parameters ---------- X : np.ndarray Original instances. X_cf : np.ndarray Counterfactual instances. **kwargs Evaluator kwargs (checked for ``time_per_instance``). Returns ------- tuple of np.ndarray Validated ``(X, X_cf)`` as numpy arrays. Raises ------ ValueError If shapes mismatch or ``time_per_instance`` is provided without an ``Efficiency`` metric. """ X = np.asarray(X) X_cf = np.asarray(X_cf) if X.shape[0] != X_cf.shape[0]: raise ValueError( "X and X_cf must have the same number of instances; " f"got {X.shape[0]} vs {X_cf.shape[0]}" ) if "time_per_instance" in kwargs and "efficiency_time_s" not in self._metric_names: raise ValueError( "Caller provided 'time_per_instance' but evaluator does not contain " "an Efficiency metric. Please include an Efficiency() metric to report " "timings." ) return X, X_cf @staticmethod def _cache_model_predictions(X: np.ndarray, X_cf: np.ndarray, kwargs: dict[str, Any]) -> None: """Pre-compute and cache model predictions in *kwargs* (mutated in-place). Caches ``_cached_y_pred``, ``_cached_y_cf_pred``, ``_cached_proba_X``, and ``_cached_proba_X_cf`` so that multiple metrics can reuse them without redundant inference calls. Parameters ---------- X : np.ndarray Original instances. X_cf : np.ndarray Counterfactual instances. kwargs : dict Evaluator kwargs dict (modified in-place with cached predictions). """ model = kwargs.get("model") if model is None: return if hasattr(model, "predict"): kwargs["_cached_y_pred"] = np.asarray(model.predict(X)) kwargs["_cached_y_cf_pred"] = np.asarray(model.predict(X_cf)) if hasattr(model, "predict_proba"): from tscf_eval.counterfactuals.utils import soft_predict_proba_fn soft_proba = soft_predict_proba_fn(model) kwargs["_cached_proba_X"] = np.asarray(soft_proba(X)) kwargs["_cached_proba_X_cf"] = np.asarray(soft_proba(X_cf)) @staticmethod def _get_progress_bar(): """Return tqdm progress bar constructor, or a no-op fallback. Returns ------- callable A tqdm-compatible constructor. """ try: from tqdm.auto import tqdm # type: ignore return tqdm except Exception: warnings.warn( "tqdm is not installed. Progress bars are disabled. " "Install tqdm for progress reporting: pip install tqdm", UserWarning, stacklevel=2, ) class _DummyTqdm: """No-op progress bar used when tqdm is not installed.""" def __init__( self, total: int = 0, desc: str = "", unit: str = "", leave: bool = True ): """Initialize the dummy progress bar. Parameters ---------- total : int, default 0 Total number of expected iterations. desc : str, default "" Description prefix for the progress bar. unit : str, default "" Unit name for each iteration. leave : bool, default True Whether to leave the bar on screen after completion. """ self.total = total self.desc = desc self.n = 0 self.leave = leave def set_description(self, desc: str): """Update the progress bar description (no-op). Parameters ---------- desc : str New description string. """ self.desc = desc def update(self, n: int = 1): """Advance the progress counter (no-op). Parameters ---------- n : int, default 1 Number of iterations to advance. """ self.n += n def close(self): """Close the progress bar (no-op).""" return None def _tqdm_fallback( total: int = 0, desc: str = "", unit: str = "", leave: bool = True ) -> _DummyTqdm: """Create a dummy progress bar as a tqdm replacement. Parameters ---------- total : int, default 0 Total number of expected iterations. desc : str, default "" Description prefix for the progress bar. unit : str, default "" Unit name for each iteration. leave : bool, default True Whether to leave the bar on screen after completion. Returns ------- _DummyTqdm A no-op progress bar instance. """ return _DummyTqdm(total=total, desc=desc, unit=unit, leave=leave) return _tqdm_fallback