"""Multi-criteria evaluation for counterfactual benchmarks.
This module provides tools for multi-objective analysis of benchmark
results, including Pareto dominance analysis, weighted scalarization,
statistical testing, and LaTeX table generation.
Classes
-------
ParetoAnalyzer
Multi-objective Pareto dominance analysis with plotting support.
WeightedScalarizer
Min-max normalized weighted composite scoring.
Functions
---------
friedman_test
Non-parametric Friedman test across datasets.
format_latex_table
Format a DataFrame as a LaTeX table with best-value highlighting.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Literal, NamedTuple
import warnings
import numpy as np
import pandas as pd
from tscf_eval.evaluator.metrics import (
Composition,
Confidence,
Contiguity,
Controllability,
Diversity,
Efficiency,
Plausibility,
Proximity,
Robustness,
Sparsity,
Validity,
)
if TYPE_CHECKING:
from matplotlib.axes import Axes
from tscf_eval.evaluator.base import Metric
from .results import BenchmarkResults
__all__ = [
"FriedmanResult",
"ParetoAnalyzer",
"WeightedScalarizer",
"format_latex_table",
"friedman_test",
]
def _build_direction_registry() -> dict[str, bool]:
"""Build a metric-name → is_maximize mapping from Metric classes.
Instantiates each built-in metric with default parameters and records
its ``name()`` and ``direction``. Also registers common variants of
parameterized metrics (Proximity with p=1/2/inf, Plausibility with
lof/if) and non-metric columns produced by the benchmark runner.
"""
registry: dict[str, bool] = {}
def _register(metric_instance: Metric) -> None:
"""Register a metric instance in the direction registry.
Parameters
----------
metric_instance : Metric
Metric whose name and direction to record.
"""
is_max = metric_instance.direction == "maximize"
registry[metric_instance.name()] = is_max
# Default instances
for cls in (Validity, Sparsity, Diversity, Contiguity, Controllability, Robustness, Efficiency):
_register(cls())
# Parameterized: Validity (both modes)
_register(Validity(mode="hard"))
_register(Validity(mode="soft"))
# Parameterized: Proximity (common norms + DTW)
for p in (1, 2, float("inf")):
_register(Proximity(p=p, distance="lp"))
_register(Proximity(distance="dtw"))
# Parameterized: Plausibility (common methods)
for method in ("lof", "if", "dtw_lof"):
_register(Plausibility(method=method))
# Dict-returning metrics: register their flattened sub-keys
_register(Composition())
_register(Confidence())
# Composition flattens to composition_mean_n_segments, etc.
registry["composition_mean_n_segments"] = False
registry["composition_mean_avg_segment_len"] = False
# Confidence flattens to mean_conf_orig, mean_conf_cf, mean_conf_delta
registry["mean_conf_orig"] = True
registry["mean_conf_cf"] = True
registry["mean_conf_delta"] = True
# Non-metric benchmark columns
registry["mean_time_s"] = False
registry["total_time_s"] = False
registry["success_rate"] = True
return registry
_DIRECTION_REGISTRY: dict[str, bool] = _build_direction_registry()
def _is_maximize(metric: str) -> bool:
"""Check whether higher values are better for a given metric.
Parameters
----------
metric : str
Metric name to look up.
Returns
-------
bool
``True`` if the metric should be maximized, ``False`` otherwise.
"""
if metric in _DIRECTION_REGISTRY:
return _DIRECTION_REGISTRY[metric]
# Prefix heuristic for unknown parameterized variants
for prefix in ("proximity_", "plausibility_"):
if metric.startswith(prefix):
base_name = next(
(k for k in _DIRECTION_REGISTRY if k.startswith(prefix)),
None,
)
if base_name is not None:
return _DIRECTION_REGISTRY[base_name]
# Conservative default
return False
class FriedmanResult(NamedTuple):
"""Result of a Friedman statistical test.
Attributes
----------
statistic : float
Friedman chi-squared statistic.
p_value : float
p-value of the test.
rankings : pd.DataFrame
Mean ranks per explainer for each metric.
"""
statistic: float
p_value: float
rankings: pd.DataFrame
[docs]
def friedman_test(
results: BenchmarkResults,
metric: str,
aggregate_by: str = "explainer",
group_by: str = "dataset",
) -> FriedmanResult:
"""Run a Friedman test comparing explainers across groups.
The Friedman test is a non-parametric test for detecting differences
in treatments across multiple groups (e.g., explainers across datasets).
Parameters
----------
results : BenchmarkResults
Benchmark results to analyze.
metric : str
Metric name to test.
aggregate_by : str, default "explainer"
Treatments to compare (columns of the rank matrix).
group_by : str, default "dataset"
Blocking factor (rows of the rank matrix).
Returns
-------
FriedmanResult
Named tuple with ``statistic``, ``p_value``, and ``rankings``.
Raises
------
ImportError
If scipy is not installed.
ValueError
If there are fewer than 3 treatments or fewer than 2 groups.
"""
try:
from scipy.stats import friedmanchisquare
except ImportError:
raise ImportError(
"scipy is required for friedman_test. Install it with: pip install tscf-eval[full]"
) from None
df = results.to_dataframe()
if metric not in df.columns:
raise ValueError(
f"Metric '{metric}' not found. "
f"Available: {sorted(df.select_dtypes(include=[np.number]).columns)}"
)
# Pivot: rows = groups, columns = treatments
pivot = df.pivot_table(
values=metric,
index=group_by,
columns=aggregate_by,
aggfunc="mean",
).dropna()
treatments = list(pivot.columns)
if len(treatments) < 3:
raise ValueError(f"Friedman test requires at least 3 treatments, got {len(treatments)}.")
if len(pivot) < 2:
raise ValueError(f"Friedman test requires at least 2 groups, got {len(pivot)}.")
samples = [pivot[t].values for t in treatments]
stat, p_val = friedmanchisquare(*samples)
maximize = _is_maximize(metric)
ranks = pivot.rank(axis=1, ascending=not maximize)
mean_ranks = ranks.mean(axis=0).to_frame(name="mean_rank")
mean_ranks = mean_ranks.sort_values("mean_rank")
return FriedmanResult(
statistic=float(stat),
p_value=float(p_val),
rankings=mean_ranks,
)
[docs]
@dataclass
class ParetoAnalyzer:
"""Multi-objective Pareto dominance analysis.
Analyzes benchmark results to identify Pareto-optimal solutions
and compute dominance rankings. Includes plotting utilities for
Pareto fronts and cross-dataset consistency analysis.
Parameters
----------
metrics : list[str]
Metric names to use for Pareto analysis.
directions : dict[str, Literal["min", "max"]], optional
Override metric directions. Keys are metric names,
values are ``"min"`` or ``"max"``.
Examples
--------
>>> from tscf_eval.benchmark import ParetoAnalyzer
>>>
>>> analyzer = ParetoAnalyzer(["validity", "proximity_l2", "mean_time_s"])
>>> ranking = analyzer.dominance_ranking(results)
>>> pareto_optimal = analyzer.pareto_front(results)
"""
metrics: list[str]
directions: dict[str, Literal["min", "max"]] = field(default_factory=dict)
[docs]
def __post_init__(self) -> None:
"""Validate that at least one metric is provided."""
if not self.metrics:
raise ValueError("At least one metric is required.")
def _get_direction(self, metric: str) -> bool:
"""Get the optimization direction for a metric.
Parameters
----------
metric : str
Metric name to look up.
Returns
-------
bool
``True`` if the metric should be maximized.
"""
if metric in self.directions:
return self.directions[metric] == "max"
return _is_maximize(metric)
def _extract_values(
self,
results: BenchmarkResults,
aggregate_by: str = "explainer",
) -> tuple[list[str], np.ndarray]:
"""Extract metric values as array oriented for minimization.
Parameters
----------
results : BenchmarkResults
Benchmark results to analyze.
aggregate_by : str
Dimension to aggregate by: ``"explainer"``, ``"model"``, or
``"dataset"``.
Returns
-------
names : list[str]
Names of the aggregated entities.
values : np.ndarray
Shape ``(n_entities, n_metrics)``, oriented for minimization.
"""
df = results.to_dataframe()
if df.empty:
return [], np.zeros((0, len(self.metrics)))
if aggregate_by not in df.columns:
raise ValueError(f"Cannot aggregate by '{aggregate_by}'")
available = [m for m in self.metrics if m in df.columns]
if not available:
raise ValueError(
f"None of the specified metrics found. "
f"Available: {list(df.select_dtypes(include=[np.number]).columns)}"
)
grouped = df.groupby(aggregate_by)[available].mean()
names = list(grouped.index)
values = grouped.values.copy()
for i, metric in enumerate(available):
if self._get_direction(metric):
values[:, i] = -values[:, i]
return names, values
def _dominates(self, a: np.ndarray, b: np.ndarray) -> bool:
"""Check if solution *a* Pareto-dominates solution *b* (lower is better).
Parameters
----------
a : np.ndarray
Objective values for solution a, shape ``(n_metrics,)``.
b : np.ndarray
Objective values for solution b, shape ``(n_metrics,)``.
Returns
-------
bool
``True`` if ``a`` is at least as good in all objectives and
strictly better in at least one.
"""
at_least_as_good = np.all(a <= b)
strictly_better = np.any(a < b)
return bool(at_least_as_good and strictly_better)
[docs]
def pareto_front(
self,
results: BenchmarkResults,
aggregate_by: str = "explainer",
) -> list[str]:
"""Find Pareto-optimal (non-dominated) solutions.
Parameters
----------
results : BenchmarkResults
Benchmark results to analyze.
aggregate_by : str, default "explainer"
Dimension to aggregate by.
Returns
-------
list[str]
Names of Pareto-optimal solutions.
"""
names, values = self._extract_values(results, aggregate_by)
n = len(names)
if n == 0:
return []
is_dominated = np.zeros(n, dtype=bool)
for i in range(n):
for j in range(n):
if i != j and self._dominates(values[j], values[i]):
is_dominated[i] = True
break
return [names[i] for i in range(n) if not is_dominated[i]]
[docs]
def dominance_count(
self,
results: BenchmarkResults,
aggregate_by: str = "explainer",
) -> dict[str, int]:
"""Count how many solutions each solution dominates.
Parameters
----------
results : BenchmarkResults
Benchmark results to analyze.
aggregate_by : str, default "explainer"
Dimension to aggregate by.
Returns
-------
dict[str, int]
Mapping from name to number of dominated solutions.
"""
names, values = self._extract_values(results, aggregate_by)
n = len(names)
return {
name: sum(1 for j in range(n) if i != j and self._dominates(values[i], values[j]))
for i, name in enumerate(names)
}
[docs]
def dominated_by_count(
self,
results: BenchmarkResults,
aggregate_by: str = "explainer",
) -> dict[str, int]:
"""Count how many solutions dominate each solution.
Lower is better (0 means Pareto-optimal).
Parameters
----------
results : BenchmarkResults
Benchmark results to analyze.
aggregate_by : str, default "explainer"
Dimension to aggregate by.
Returns
-------
dict[str, int]
Mapping from name to number of dominating solutions.
"""
names, values = self._extract_values(results, aggregate_by)
n = len(names)
return {
name: sum(1 for j in range(n) if i != j and self._dominates(values[j], values[i]))
for i, name in enumerate(names)
}
[docs]
def dominance_ranking(
self,
results: BenchmarkResults,
aggregate_by: str = "explainer",
) -> pd.DataFrame:
"""Compute dominance ranking table.
Parameters
----------
results : BenchmarkResults
Benchmark results to analyze.
aggregate_by : str, default "explainer"
Dimension to aggregate by.
Returns
-------
pd.DataFrame
DataFrame with columns ``name``, ``dominated_by``,
``dominates``, ``pareto``, plus one column per metric.
"""
names, values = self._extract_values(results, aggregate_by)
if len(names) == 0:
return pd.DataFrame(
columns=["name", "dominated_by", "dominates", "pareto", *self.metrics]
)
dominated_by = self.dominated_by_count(results, aggregate_by)
dominates = self.dominance_count(results, aggregate_by)
pareto_set = set(self.pareto_front(results, aggregate_by))
display_values = values.copy()
available_metrics = [m for m in self.metrics if m in results.to_dataframe().columns]
for i, metric in enumerate(available_metrics):
if self._get_direction(metric):
display_values[:, i] = -display_values[:, i]
rows = []
for idx, name in enumerate(names):
row: dict = {
"name": name,
"dominated_by": dominated_by.get(name, 0),
"dominates": dominates.get(name, 0),
"pareto": name in pareto_set,
}
for j, metric in enumerate(available_metrics):
row[metric] = display_values[idx, j]
rows.append(row)
df = pd.DataFrame(rows)
if available_metrics:
first_metric = available_metrics[0]
ascending_first = not self._get_direction(first_metric)
df = df.sort_values(
by=["dominated_by", first_metric],
ascending=[True, ascending_first],
).reset_index(drop=True)
else:
df = df.sort_values(by="dominated_by").reset_index(drop=True)
return df
[docs]
def to_dataframe(
self,
results: BenchmarkResults,
aggregate_by: str = "explainer",
) -> pd.DataFrame:
"""Get DataFrame with metric values and Pareto status.
Alias for :meth:`dominance_ranking`.
Parameters
----------
results : BenchmarkResults
Benchmark results to analyze.
aggregate_by : str, default "explainer"
Dimension to aggregate by.
Returns
-------
pd.DataFrame
DataFrame with columns ``name``, ``dominated_by``,
``dominates``, ``pareto``, plus one column per metric.
"""
return self.dominance_ranking(results, aggregate_by)
[docs]
def plot_front(
self,
results: BenchmarkResults,
x_metric: str,
y_metric: str,
aggregate_by: str = "explainer",
ax: Axes | None = None,
annotate: bool = True,
pareto_color: str = "tab:blue",
other_color: str = "grey",
pareto_marker: str = "o",
other_marker: str = "x",
title: str | None = None,
) -> Axes:
"""Plot a 2-D Pareto front scatter.
Parameters
----------
results : BenchmarkResults
Benchmark results to analyze.
x_metric, y_metric : str
Metrics for the x and y axes.
aggregate_by : str, default "explainer"
Dimension to aggregate by.
ax : matplotlib Axes, optional
Axes to plot on. Created if ``None``.
annotate : bool, default True
Label each point with the entity name.
pareto_color : str, default "tab:blue"
Color for Pareto-optimal points.
other_color : str, default "grey"
Color for dominated points.
pareto_marker : str, default "o"
Marker for Pareto-optimal points.
other_marker : str, default "x"
Marker for dominated points.
title : str, optional
Plot title. Defaults to ``"Pareto Front"``.
Returns
-------
matplotlib.axes.Axes
"""
import matplotlib.pyplot as plt
df = results.to_dataframe()
if aggregate_by not in df.columns:
raise ValueError(f"Cannot aggregate by '{aggregate_by}'")
for m in (x_metric, y_metric):
if m not in df.columns:
raise ValueError(f"Metric '{m}' not found in results.")
grouped = df.groupby(aggregate_by)[[x_metric, y_metric]].mean()
pareto_set = set(self.pareto_front(results, aggregate_by))
if ax is None:
_, ax = plt.subplots(figsize=(8, 6))
pareto_names: list[str] = []
pareto_xs: list[float] = []
pareto_ys: list[float] = []
other_names: list[str] = []
other_xs: list[float] = []
other_ys: list[float] = []
for name in grouped.index:
xv = grouped.loc[name, x_metric]
yv = grouped.loc[name, y_metric]
if name in pareto_set:
pareto_names.append(name)
pareto_xs.append(xv)
pareto_ys.append(yv)
else:
other_names.append(name)
other_xs.append(xv)
other_ys.append(yv)
if other_xs:
ax.scatter(
other_xs,
other_ys,
color=other_color,
marker=other_marker,
alpha=0.6,
label="Dominated",
zorder=2,
)
if pareto_xs:
ax.scatter(
pareto_xs,
pareto_ys,
color=pareto_color,
marker=pareto_marker,
s=80,
label="Pareto-optimal",
zorder=3,
)
if len(pareto_xs) > 1:
order = np.argsort(pareto_xs)
ax.plot(
[pareto_xs[i] for i in order],
[pareto_ys[i] for i in order],
color=pareto_color,
linestyle="--",
alpha=0.5,
zorder=1,
)
if annotate:
all_names = pareto_names + other_names
all_xs = pareto_xs + other_xs
all_ys = pareto_ys + other_ys
try:
from adjustText import adjust_text
texts = [
ax.text(xv, yv, name, fontsize=8)
for name, xv, yv in zip(all_names, all_xs, all_ys, strict=True)
]
adjust_text(texts, ax=ax)
except ImportError:
warnings.warn(
"adjustText is not installed. Label placement in the Pareto "
"plot may overlap. Install adjustText for improved annotation: "
"pip install adjustText",
UserWarning,
stacklevel=2,
)
for name, xv, yv in zip(all_names, all_xs, all_ys, strict=True):
ax.annotate(
name,
(xv, yv),
textcoords="offset points",
xytext=(5, 5),
fontsize=8,
)
x_arrow = "\u2191" if self._get_direction(x_metric) else "\u2193"
y_arrow = "\u2191" if self._get_direction(y_metric) else "\u2193"
ax.set_xlabel(f"{x_metric} ({x_arrow})")
ax.set_ylabel(f"{y_metric} ({y_arrow})")
ax.set_title(title or "Pareto Front")
ax.legend()
return ax
[docs]
def consistency(
self,
results_dict: dict[str, BenchmarkResults],
aggregate_by: str = "explainer",
) -> pd.DataFrame:
"""Compute cross-dataset Pareto consistency matrix.
For each dataset identifies the Pareto-optimal solutions and
returns a boolean matrix (entity x dataset) with a ``count``
column.
Parameters
----------
results_dict : dict[str, BenchmarkResults]
Mapping from dataset/group name to its benchmark results.
aggregate_by : str, default "explainer"
Dimension to aggregate by.
Returns
-------
pd.DataFrame
Boolean DataFrame with entities as rows, datasets as columns,
plus a ``count`` column. Sorted by count descending.
"""
if not results_dict:
return pd.DataFrame()
all_entities: set[str] = set()
per_dataset: dict[str, set[str]] = {}
for ds_name, ds_results in results_dict.items():
front = self.pareto_front(ds_results, aggregate_by)
per_dataset[ds_name] = set(front)
df = ds_results.to_dataframe()
if aggregate_by in df.columns:
all_entities.update(df[aggregate_by].unique())
entities = sorted(all_entities)
datasets = list(results_dict.keys())
data: dict[str, list[bool]] = {
ds: [e in per_dataset.get(ds, set()) for e in entities] for ds in datasets
}
result_df = pd.DataFrame(data, index=entities)
result_df["count"] = result_df.sum(axis=1)
return result_df.sort_values("count", ascending=False)
[docs]
def plot_consistency_heatmap(
self,
consistency_df: pd.DataFrame,
ax: Axes | None = None,
cmap: str = "YlGn",
title: str | None = None,
) -> Axes:
"""Plot Pareto consistency as a heatmap.
Parameters
----------
consistency_df : pd.DataFrame
Output of :meth:`consistency`.
ax : matplotlib Axes, optional
Axes to plot on. Created if ``None``.
cmap : str, default "YlGn"
Matplotlib colormap name.
title : str, optional
Plot title.
Returns
-------
matplotlib.axes.Axes
"""
import matplotlib.pyplot as plt
plot_df = consistency_df.drop(columns=["count"], errors="ignore")
if ax is None:
fig_w = max(8, len(plot_df.columns) * 1.2)
fig_h = max(4, len(plot_df) * 0.5)
_, ax = plt.subplots(figsize=(fig_w, fig_h))
im = ax.imshow(
plot_df.values.astype(float),
aspect="auto",
cmap=cmap,
vmin=0,
vmax=1,
)
ax.set_xticks(range(len(plot_df.columns)))
ax.set_xticklabels(plot_df.columns, rotation=45, ha="right")
ax.set_yticks(range(len(plot_df.index)))
ax.set_yticklabels(plot_df.index)
for i in range(len(plot_df.index)):
for j in range(len(plot_df.columns)):
val = plot_df.iloc[i, j]
ax.text(
j,
i,
"\u2713" if val else "",
ha="center",
va="center",
fontsize=10,
)
ax.set_title(title or "Pareto Consistency Across Datasets")
ax.figure.colorbar(im, ax=ax, label="Pareto-optimal", shrink=0.8)
return ax
[docs]
def to_latex(
self,
results: BenchmarkResults,
aggregate_by: str = "explainer",
precision: int = 3,
caption: str | None = None,
label: str | None = None,
) -> str:
"""Generate a LaTeX table of the dominance ranking.
Parameters
----------
results : BenchmarkResults
Benchmark results to analyze.
aggregate_by : str, default "explainer"
Dimension to aggregate by.
precision : int, default 3
Number of decimal places.
caption, label : str, optional
LaTeX caption and label.
Returns
-------
str
LaTeX table source code.
"""
ranking = self.dominance_ranking(results, aggregate_by)
available = [m for m in self.metrics if m in ranking.columns]
dirs = {m: self._get_direction(m) for m in available}
return format_latex_table(
ranking,
directions=dirs,
precision=precision,
caption=caption,
label=label,
)
[docs]
@dataclass
class WeightedScalarizer:
"""Min-max normalized weighted composite scoring.
Normalizes each metric to ``[0, 1]`` via min-max scaling (respecting
metric directions so that higher normalized values are always better),
then computes a weighted sum.
Parameters
----------
metrics : list[str]
Metric names to include in the composite score.
weights : dict[str, float], optional
Per-metric weights. Automatically normalized to sum to 1.
If ``None``, all metrics are weighted equally.
directions : dict[str, Literal["min", "max"]], optional
Override metric directions.
Examples
--------
>>> scalarizer = WeightedScalarizer(
... ["validity", "proximity_l2", "sparsity"],
... weights={"validity": 2.0, "proximity_l2": 1.0, "sparsity": 1.0},
... )
>>> scores = scalarizer.score(results)
"""
metrics: list[str]
weights: dict[str, float] = field(default_factory=dict)
directions: dict[str, Literal["min", "max"]] = field(default_factory=dict)
[docs]
def __post_init__(self) -> None:
"""Validate metrics and normalize weights to sum to 1."""
if not self.metrics:
raise ValueError("At least one metric is required.")
if not self.weights:
w = 1.0 / len(self.metrics)
self.weights = dict.fromkeys(self.metrics, w)
else:
total = sum(self.weights.get(m, 0.0) for m in self.metrics)
if total <= 0:
raise ValueError("Sum of weights must be positive.")
self.weights = {m: self.weights.get(m, 0.0) / total for m in self.metrics}
def _get_direction(self, metric: str) -> bool:
"""Get the optimization direction for a metric.
Parameters
----------
metric : str
Metric name to look up.
Returns
-------
bool
``True`` if the metric should be maximized.
"""
if metric in self.directions:
return self.directions[metric] == "max"
return _is_maximize(metric)
[docs]
def score(
self,
results: BenchmarkResults,
aggregate_by: str = "explainer",
) -> pd.DataFrame:
"""Compute weighted composite scores.
Parameters
----------
results : BenchmarkResults
Benchmark results to analyze.
aggregate_by : str, default "explainer"
Dimension to aggregate by.
Returns
-------
pd.DataFrame
DataFrame with normalized metric columns plus ``composite``.
Sorted by composite descending.
"""
df = results.to_dataframe()
if df.empty:
return pd.DataFrame(columns=[aggregate_by, *self.metrics, "composite"])
available = [m for m in self.metrics if m in df.columns]
if not available:
raise ValueError(
f"None of the specified metrics found. "
f"Available: {list(df.select_dtypes(include=[np.number]).columns)}"
)
grouped = df.groupby(aggregate_by)[available].mean()
normalized = pd.DataFrame(index=grouped.index)
for col in available:
col_min = grouped[col].min()
col_max = grouped[col].max()
rng = col_max - col_min
if rng == 0:
normalized[col] = 1.0
elif self._get_direction(col):
normalized[col] = (grouped[col] - col_min) / rng
else:
normalized[col] = (col_max - grouped[col]) / rng
composite = np.zeros(len(normalized))
for col in available:
composite += self.weights.get(col, 0.0) * normalized[col].values
available_weight_sum = sum(self.weights.get(m, 0.0) for m in available)
if available_weight_sum > 0 and available_weight_sum != 1.0:
composite /= available_weight_sum
normalized["composite"] = composite
normalized = normalized.sort_values("composite", ascending=False)
return normalized.reset_index()
[docs]
def sensitivity(
self,
results: BenchmarkResults,
vary_metric: str,
n_steps: int = 11,
aggregate_by: str = "explainer",
) -> pd.DataFrame:
"""Sensitivity analysis by sweeping one metric's weight.
Varies the weight of *vary_metric* from 0 to 1, redistributing
the remaining weight proportionally among the other metrics.
Parameters
----------
results : BenchmarkResults
Benchmark results to analyze.
vary_metric : str
The metric whose weight to sweep.
n_steps : int, default 11
Number of weight values (0 to 1 inclusive).
aggregate_by : str, default "explainer"
Dimension to aggregate by.
Returns
-------
pd.DataFrame
Long-format DataFrame with columns
``weight``, ``<aggregate_by>``, ``composite``.
"""
if vary_metric not in self.metrics:
raise ValueError(f"'{vary_metric}' is not in the metrics list.")
other_metrics = [m for m in self.metrics if m != vary_metric]
other_total = sum(self.weights.get(m, 0.0) for m in other_metrics)
rows = []
for w_vary in np.linspace(0, 1, n_steps):
w_rest = 1.0 - w_vary
new_weights: dict[str, float] = {vary_metric: float(w_vary)}
for m in other_metrics:
if other_total > 0:
new_weights[m] = w_rest * (self.weights.get(m, 0.0) / other_total)
else:
new_weights[m] = w_rest / len(other_metrics) if other_metrics else 0.0
temp = WeightedScalarizer(
metrics=self.metrics,
weights=new_weights,
directions=self.directions,
)
scored = temp.score(results, aggregate_by)
for _, row in scored.iterrows():
rows.append(
{
"weight": float(w_vary),
aggregate_by: row[aggregate_by],
"composite": row["composite"],
}
)
return pd.DataFrame(rows)
[docs]
def plot_sensitivity(
self,
sensitivity_df: pd.DataFrame,
aggregate_by: str = "explainer",
ax: Axes | None = None,
title: str | None = None,
) -> Axes:
"""Plot sensitivity analysis results.
Parameters
----------
sensitivity_df : pd.DataFrame
Output of :meth:`sensitivity`.
aggregate_by : str, default "explainer"
Column name for the entity dimension.
ax : matplotlib Axes, optional
Axes to plot on. Created if ``None``.
title : str, optional
Plot title.
Returns
-------
matplotlib.axes.Axes
"""
import matplotlib.pyplot as plt
if ax is None:
_, ax = plt.subplots(figsize=(8, 5))
for entity in sensitivity_df[aggregate_by].unique():
subset = sensitivity_df[sensitivity_df[aggregate_by] == entity].sort_values("weight")
ax.plot(
subset["weight"],
subset["composite"],
marker="o",
label=entity,
markersize=4,
)
ax.set_xlabel("Weight")
ax.set_ylabel("Composite Score")
ax.set_title(title or "Sensitivity Analysis")
ax.legend()
ax.set_xlim(0, 1)
return ax
[docs]
def to_latex(
self,
results: BenchmarkResults,
aggregate_by: str = "explainer",
precision: int = 3,
caption: str | None = None,
label: str | None = None,
) -> str:
"""Generate a LaTeX table of weighted scores.
Parameters
----------
results : BenchmarkResults
Benchmark results to analyze.
aggregate_by : str, default "explainer"
Dimension to aggregate by.
precision : int, default 3
Number of decimal places.
caption, label : str, optional
LaTeX caption and label.
Returns
-------
str
LaTeX table source code.
"""
scored = self.score(results, aggregate_by)
available = [m for m in self.metrics if m in scored.columns]
dirs: dict[str, bool] = dict.fromkeys(available, True)
dirs["composite"] = True
return format_latex_table(
scored,
directions=dirs,
precision=precision,
caption=caption,
label=label,
)