"""Result containers for benchmark outputs.
This module provides dataclasses for storing and analyzing benchmark results
across multiple datasets, models, and explainers.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from collections.abc import Iterator
import numpy as np
import pandas as pd
__all__ = [
"BenchmarkResults",
"ExplainerResult",
]
[docs]
@dataclass
class ExplainerResult:
"""Results for a single explainer on a single dataset/model combination.
Parameters
----------
explainer_name : str
Name of the explainer configuration.
dataset_name : str
Name of the dataset.
model_name : str
Name of the model.
X_cf : np.ndarray
Generated counterfactuals, shape ``(n_instances, ...)`` or
``(n_instances, k, ...)`` if k > 1.
y_cf : np.ndarray
Predicted labels for counterfactuals.
success_mask : np.ndarray
Boolean mask indicating successful generations.
metrics : dict[str, Any]
Evaluation metrics computed by the Evaluator.
generation_times : list[float]
Per-instance generation times in seconds.
metadata : list[dict]
Per-instance metadata from the explainer.
"""
explainer_name: str
dataset_name: str
model_name: str
X_cf: np.ndarray
y_cf: np.ndarray
success_mask: np.ndarray
metrics: dict[str, Any]
generation_times: list[float]
metadata: list[dict[str, Any]]
@property
def n_instances(self) -> int:
"""Return the number of test instances.
Returns
-------
int
Length of ``X_cf`` along axis 0.
"""
return len(self.X_cf)
@property
def n_successful(self) -> int:
"""Return the number of successfully generated counterfactuals.
Returns
-------
int
Count of ``True`` values in ``success_mask``.
"""
return int(np.sum(self.success_mask))
@property
def success_rate(self) -> float:
"""Return the fraction of successful generations.
Returns
-------
float
Ratio ``n_successful / n_instances``, in ``[0, 1]``.
"""
return self.n_successful / self.n_instances if self.n_instances > 0 else 0.0
@property
def mean_time(self) -> float:
"""Return the mean generation time per instance in seconds.
Returns
-------
float
Average of ``generation_times``, or ``0.0`` if empty.
"""
if not self.generation_times:
return 0.0
return float(np.mean(self.generation_times))
@property
def total_time(self) -> float:
"""Return the total generation time in seconds.
Returns
-------
float
Sum of ``generation_times``.
"""
return float(np.sum(self.generation_times))
[docs]
def get_metric(self, name: str, default: Any = None) -> Any:
"""Get a metric value by name, with optional default.
Parameters
----------
name : str
Metric key to look up in ``self.metrics``.
default : Any, default None
Value to return if the metric is not found.
Returns
-------
Any
The metric value, or *default* if not present.
"""
return self.metrics.get(name, default)
[docs]
@dataclass
class BenchmarkResults:
"""Container for all benchmark results with analysis methods.
Stores results indexed by (dataset, model, explainer) combinations.
Provides methods for querying, aggregating, and exporting results.
Examples
--------
>>> results = runner.run()
>>>
>>> # Get specific result
>>> result = results.get("GunPoint", "knn", "comte")
>>>
>>> # Get comparison DataFrame
>>> df = results.to_dataframe()
>>>
>>> # Iterate over results
>>> for result in results:
... print(f"{result.dataset_name}/{result.model_name}: {result.metrics}")
"""
_results: dict[tuple[str, str, str], ExplainerResult] = field(default_factory=dict)
[docs]
def add(self, result: ExplainerResult) -> None:
"""Add a result to the collection.
Parameters
----------
result : ExplainerResult
Result to store, keyed by (dataset, model, explainer).
"""
key = (result.dataset_name, result.model_name, result.explainer_name)
self._results[key] = result
[docs]
def get(
self,
dataset: str,
model: str,
explainer: str,
) -> ExplainerResult | None:
"""Get the result for a specific (dataset, model, explainer) combination.
Parameters
----------
dataset : str
Dataset name.
model : str
Model name.
explainer : str
Explainer name.
Returns
-------
ExplainerResult or None
The matching result, or ``None`` if not found.
"""
return self._results.get((dataset, model, explainer))
[docs]
def __iter__(self) -> Iterator[ExplainerResult]:
"""Iterate over all stored results.
Returns
-------
Iterator[ExplainerResult]
Iterator yielding each result.
"""
return iter(self._results.values())
[docs]
def __len__(self) -> int:
"""Return the number of stored results.
Returns
-------
int
Total number of (dataset, model, explainer) entries.
"""
return len(self._results)
@property
def datasets(self) -> list[str]:
"""Return the sorted list of unique dataset names.
Returns
-------
list[str]
Unique dataset names across all stored results.
"""
return sorted({k[0] for k in self._results})
@property
def models(self) -> list[str]:
"""Return the sorted list of unique model names.
Returns
-------
list[str]
Unique model names across all stored results.
"""
return sorted({k[1] for k in self._results})
@property
def explainers(self) -> list[str]:
"""Return the sorted list of unique explainer names.
Returns
-------
list[str]
Unique explainer names across all stored results.
"""
return sorted({k[2] for k in self._results})
[docs]
def filter(
self,
datasets: list[str] | None = None,
models: list[str] | None = None,
explainers: list[str] | None = None,
) -> BenchmarkResults:
"""Create a filtered copy of results.
Parameters
----------
datasets : list[str], optional
Filter to these datasets. None means all.
models : list[str], optional
Filter to these models. None means all.
explainers : list[str], optional
Filter to these explainers. None means all.
Returns
-------
BenchmarkResults
New results containing only matching entries.
"""
filtered = BenchmarkResults()
for (ds, mdl, exp), result in self._results.items():
if datasets is not None and ds not in datasets:
continue
if models is not None and mdl not in models:
continue
if explainers is not None and exp not in explainers:
continue
filtered.add(result)
return filtered
[docs]
def to_dataframe(
self,
metrics: list[str] | None = None,
include_timing: bool = True,
) -> pd.DataFrame:
"""Convert results to a pandas DataFrame.
Parameters
----------
metrics : list[str], optional
Specific metrics to include. None means all available.
include_timing : bool, default True
Include timing columns (mean_time, total_time).
Returns
-------
pd.DataFrame
DataFrame with columns for dataset, model, explainer, and metrics.
"""
rows = []
for result in self:
row = {
"dataset": result.dataset_name,
"model": result.model_name,
"explainer": result.explainer_name,
"n_instances": result.n_instances,
"n_successful": result.n_successful,
"success_rate": result.success_rate,
}
if include_timing:
row["mean_time_s"] = result.mean_time
row["total_time_s"] = result.total_time
# Add metrics
# Note: The evaluator already flattens dict results from metrics like
# Confidence, so we don't need to re-flatten here. We keep the dict
# check only for backwards compatibility with older saved results.
for name, value in result.metrics.items():
if name.startswith("_"):
continue
if metrics is not None and name not in metrics:
continue
# Skip dict values that have already been flattened by the evaluator
# (detected by checking if individual keys already exist)
if isinstance(value, dict):
# Check if this dict was already flattened
first_key = next(iter(value.keys()), None)
if first_key is not None and first_key in result.metrics:
# Already flattened, skip this nested entry
continue
# Legacy: flatten for backwards compatibility
for sub_name, sub_value in value.items():
row[f"{name}_{sub_name}"] = sub_value
else:
row[name] = value
rows.append(row)
return pd.DataFrame(rows)
[docs]
def aggregate(
self,
by: str = "explainer",
metrics: list[str] | None = None,
aggfunc: str | list[str] = "mean",
) -> pd.DataFrame:
"""Aggregate metrics across a dimension.
Parameters
----------
by : str, default "explainer"
Dimension to group by: "explainer", "model", or "dataset".
metrics : list[str], optional
Metrics to aggregate. None means all numeric.
aggfunc : str or list[str], default "mean"
Aggregation function(s): "mean", "median", "std", "min", "max".
When a list is provided (e.g. ``["mean", "std"]``), the returned
DataFrame has a ``MultiIndex`` on columns with levels
``(metric, aggfunc)``.
Returns
-------
pd.DataFrame
Aggregated results.
"""
df = self.to_dataframe(metrics=metrics)
if df.empty:
return df
# Select numeric columns
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
numeric_cols = [c for c in numeric_cols if c not in ["n_instances"]]
return df.groupby(by)[numeric_cols].agg(aggfunc).reset_index()
[docs]
def summary(self) -> pd.DataFrame:
"""Get summary statistics aggregated by explainer.
Returns
-------
pd.DataFrame
Summary with mean metrics per explainer across all datasets/models.
"""
return self.aggregate(by="explainer", aggfunc="mean")
[docs]
def to_dict(self) -> dict[str, Any]:
"""Convert to nested dictionary for serialization."""
return {
"datasets": self.datasets,
"models": self.models,
"explainers": self.explainers,
"results": [
{
"dataset": r.dataset_name,
"model": r.model_name,
"explainer": r.explainer_name,
"n_instances": r.n_instances,
"n_successful": r.n_successful,
"metrics": r.metrics,
"mean_time_s": r.mean_time,
}
for r in self
],
}
[docs]
@classmethod
def from_dict(cls, data: dict[str, Any]) -> BenchmarkResults:
"""Reconstruct from a dictionary produced by :meth:`to_dict`.
Only metrics and metadata are restored; counterfactual arrays
(``X_cf``, ``y_cf``) are not stored by ``to_dict`` and will be
set to empty arrays.
Parameters
----------
data : dict
Dictionary as returned by ``to_dict`` or loaded from JSON.
Returns
-------
BenchmarkResults
"""
results = cls()
for entry in data.get("results", []):
n = entry.get("n_instances", 0)
n_ok = entry.get("n_successful", n)
mean_time = entry.get("mean_time_s", 0.0)
mask = np.zeros(n, dtype=bool)
mask[:n_ok] = True
results.add(
ExplainerResult(
explainer_name=entry["explainer"],
dataset_name=entry["dataset"],
model_name=entry["model"],
X_cf=np.empty((n, 0)),
y_cf=np.empty(n),
success_mask=mask,
metrics=entry.get("metrics", {}),
generation_times=[mean_time] * n if n > 0 else [],
metadata=[{}] * n,
)
)
return results