Source code for tscf_eval.evaluator.metrics.distribution

"""Distribution-based counterfactual evaluation metrics.

This module provides metrics that evaluate counterfactuals based on their
relationship to the data distribution: Plausibility and Diversity.

Classes
-------
Plausibility
    Evaluates whether counterfactuals lie within the training data distribution.
Diversity
    Measures diversity among multiple counterfactuals for the same query.
"""

from __future__ import annotations

from typing import Any, Literal
import warnings

import numpy as np
from sklearn.ensemble import IsolationForest
from sklearn.neighbors import LocalOutlierFactor
from sklearn.svm import OneClassSVM
from sklearn.utils.extmath import randomized_svd

from tscf_eval.counterfactuals.utils._distance import dtw_distance_vec_multich

from ..base import Metric


[docs] class Plausibility(Metric): """Plausibility scored via an outlier detector. Evaluates whether counterfactuals lie within the training data distribution using outlier detection methods. Parameters ---------- method : {'lof', 'if', 'mp_ocsvm', 'dtw_lof'}, default 'dtw_lof' Detector backend: - ``'lof'``: LocalOutlierFactor in novelty mode (Breunig et al., 2000). - ``'if'``: IsolationForest (Liu et al., 2008). - ``'mp_ocsvm'``: Matrix Profile features (Yeh et al., 2016) with OneClassSVM. - ``'dtw_lof'``: LOF with DTW distance (precomputed distance matrix). Uses ``tslearn`` for DTW; falls back to Euclidean if unavailable. More appropriate for time series as it respects temporal alignment. **kwargs Additional arguments passed to the detector. Notes ----- When optional packages (e.g., ``stumpy``) are unavailable, the implementation falls back to safe alternatives. """ direction = "maximize"
[docs] def __init__(self, method: Literal["lof", "if", "mp_ocsvm", "dtw_lof"] = "dtw_lof", **kwargs): """Initialize the Plausibility metric. Parameters ---------- method : {"lof", "if", "mp_ocsvm", "dtw_lof"}, default "dtw_lof" Outlier detection backend to use. **kwargs Additional arguments passed to the underlying detector. """ self.method = method self.kwargs = kwargs # Cache fitted detectors to avoid refitting on the same training data. # Cache key is a tuple of (method, id(train_data), train_data.shape). self._detector_cache: dict[tuple, Any] = {} self._mp_feature_cache: dict[tuple, tuple[np.ndarray, Any]] = {}
[docs] def name(self) -> str: """Return the metric name. Returns ------- str ``'plausibility_{method}'``. """ return f"plausibility_{self.method}"
def _get_or_fit_detector(self, Y: np.ndarray, train_id: int) -> Any: """Get cached detector or fit a new one. Parameters ---------- Y : np.ndarray Flattened training data for fitting the detector. train_id : int Identity of the original training data array (from id()). Returns ------- detector Fitted LOF or IsolationForest detector. """ cache_key = (self.method, train_id, Y.shape) if cache_key in self._detector_cache: return self._detector_cache[cache_key] if self.method == "lof": detector = LocalOutlierFactor(novelty=True, **self.kwargs) elif self.method == "if": detector = IsolationForest(**self.kwargs) else: raise ValueError(f"Cannot cache detector for method: {self.method}") detector.fit(Y) self._detector_cache[cache_key] = detector return detector
[docs] def clear_cache(self) -> None: """Clear cached fitted detectors and matrix profile features to free memory.""" self._detector_cache.clear() self._mp_feature_cache.clear()
[docs] def compute( self, X: np.ndarray, X_cf: np.ndarray, X_train: np.ndarray | None = None, **kwargs, ) -> float: """Compute plausibility score. Parameters ---------- X : np.ndarray Original instances. X_cf : np.ndarray Counterfactual instances. X_train : np.ndarray, optional Training data for fitting the detector. If ``None``, uses ``X``. **kwargs Ignored. Returns ------- float Fraction of counterfactuals classified as inliers, in ``[0, 1]``. """ X_train = np.asarray(X_train) if X_train is not None else None X_cf = np.asarray(X_cf) train_for_detector = X_train if X_train is not None else X Y = np.asarray(train_for_detector).reshape(train_for_detector.shape[0], -1) Z = X_cf.reshape(X_cf.shape[0], -1) # Use id() of the original training data for cache key train_id = id(train_for_detector) if self.method in ("lof", "if"): detector = self._get_or_fit_detector(Y, train_id) pred = detector.predict(Z) inlier = pred == 1 return float(np.mean(inlier)) elif self.method == "mp_ocsvm": return self._compute_mp_ocsvm(train_for_detector, X_cf, Y, Z, train_id) elif self.method == "dtw_lof": return self._compute_dtw_lof(train_for_detector, X_cf, train_id) else: raise ValueError(f"Unknown plausibility method: {self.method}")
def _compute_mp_ocsvm( self, train_for_detector: np.ndarray, X_cf: np.ndarray, Y: np.ndarray, Z: np.ndarray, train_id: int, ) -> float: """Compute plausibility using Matrix Profile features with OneClassSVM. Parameters ---------- train_for_detector : np.ndarray Training data in original shape, used for matrix profile computation. X_cf : np.ndarray Counterfactual instances in original shape. Y : np.ndarray Flattened training data, shape ``(n_train, n_features)``. Z : np.ndarray Flattened counterfactual data, shape ``(n_cf, n_features)``. train_id : int Identity of the training data array for caching. Returns ------- float Fraction of counterfactuals classified as inliers, in ``[0, 1]``. """ try: import stumpy except Exception: stumpy = None if stumpy is None: warnings.warn( "stumpy is not installed. Plausibility(method='mp_ocsvm') is falling " "back to OneClassSVM on flattened features instead of Matrix Profile " "features. Install stumpy for proper matrix profile computation: " "pip install stumpy", UserWarning, stacklevel=2, ) # Fallback: use cached OneClassSVM on flattened features cache_key = ("mp_ocsvm_fallback", train_id, Y.shape) if cache_key in self._detector_cache: oc = self._detector_cache[cache_key] else: oc = OneClassSVM(**self.kwargs) oc.fit(Y) self._detector_cache[cache_key] = oc pred = oc.predict(Z) inlier = pred == 1 return float(np.mean(inlier)) def _mp_feature(series: np.ndarray, train_set: np.ndarray) -> float: """Compute mean matrix profile distance of a series against a training set. Parameters ---------- series : np.ndarray Query time series. train_set : np.ndarray Training set to compare against. Returns ------- float Mean of the minimum matrix profile distances. """ q = np.asarray(series).reshape(-1) mins = [] for t in train_set: tflat = np.asarray(t).reshape(-1) try: if tflat.size < q.size: prof = stumpy.core.mass(tflat, q) else: prof = stumpy.core.mass(q, tflat) mins.append(float(np.min(prof))) except Exception: L = q.size if tflat.size < L: mins.append(float(np.linalg.norm(q - tflat))) else: windows = np.lib.stride_tricks.sliding_window_view(tflat, L) d = np.sqrt(np.sum((windows - q) ** 2, axis=1)) mins.append(float(np.min(d))) return float(np.mean(mins)) # Check if we have cached train features and fitted SVM for this training data cache_key = ("mp_ocsvm", train_id, train_for_detector.shape) if cache_key in self._mp_feature_cache: train_feats, oc = self._mp_feature_cache[cache_key] else: train_feats = np.array( [ _mp_feature(train_for_detector[i], train_for_detector) for i in range(train_for_detector.shape[0]) ] ) train_feats = train_feats.reshape(-1, 1) oc = OneClassSVM(**self.kwargs) oc.fit(train_feats) self._mp_feature_cache[cache_key] = (train_feats, oc) cf_feats = np.array( [_mp_feature(X_cf[j].reshape(-1), train_for_detector) for j in range(X_cf.shape[0])] ) cf_feats = cf_feats.reshape(-1, 1) # Use cached fitted SVM (oc was retrieved from cache or just fitted above) pred = oc.predict(cf_feats) inlier = pred == 1 return float(np.mean(inlier)) def _compute_dtw_lof( self, X_train: np.ndarray, X_cf: np.ndarray, train_id: int, ) -> float: """Compute plausibility using LOF with precomputed DTW distances. Builds a full DTW distance matrix between training data and counterfactuals, then uses LOF in ``metric="precomputed"`` mode. This respects temporal alignment, unlike the flat LOF/IF methods. Parameters ---------- X_train : np.ndarray Training data in original shape, shape ``(n_train, ...)``. X_cf : np.ndarray Counterfactual instances, shape ``(n_cf, ...)``. train_id : int Identity of the training data array for caching. Returns ------- float Fraction of counterfactuals classified as inliers, in ``[0, 1]``. """ N_train = X_train.shape[0] N_cf = X_cf.shape[0] # Check cache for pre-fitted LOF and training distance matrix cache_key = ("dtw_lof", train_id, X_train.shape) if cache_key in self._detector_cache: lof, D_train = self._detector_cache[cache_key] else: # Compute pairwise DTW distance matrix for training data D_train = np.zeros((N_train, N_train), dtype=float) for i in range(N_train): # dtw_distance_vec_multich(x, B) returns distances from x to each row of B D_train[i] = dtw_distance_vec_multich(X_train[i], X_train) # Ensure symmetry (numerical precision) D_train = 0.5 * (D_train + D_train.T) # Filter kwargs: only pass LOF-compatible arguments lof_kwargs = { k: v for k, v in self.kwargs.items() if k in ( "n_neighbors", "algorithm", "leaf_size", "contamination", "n_jobs", ) } lof = LocalOutlierFactor(novelty=True, metric="precomputed", **lof_kwargs) lof.fit(D_train) self._detector_cache[cache_key] = (lof, D_train) # Compute distance matrix from each CF to each training instance D_cf = np.zeros((N_cf, N_train), dtype=float) for i in range(N_cf): D_cf[i] = dtw_distance_vec_multich(X_cf[i], X_train) pred = lof.predict(D_cf) inlier = pred == 1 return float(np.mean(inlier))
[docs] class Diversity(Metric): """Diversity of multiple counterfactuals using DPP-inspired log-determinant. Measures diversity among multiple counterfactuals for the same query. Higher values indicate more diverse counterfactuals. Parameters ---------- distance : {"euclidean", "dtw"}, default "dtw" Distance function used to build the pairwise distance matrix between counterfactuals for each query. - ``"euclidean"``: Euclidean distance on flattened vectors. - ``"dtw"``: Per-channel DTW distance (averaged across channels). Requires ``tslearn``; falls back to Euclidean if unavailable. Notes ----- Expects ``X_cf`` with shape ``(M, K, ...)`` where ``K`` is the number of counterfactuals per query. See Mothilal et al. (2020) and Kulesza & Taskar (2012) for details. """ direction = "maximize"
[docs] def __init__(self, distance: Literal["euclidean", "dtw"] = "dtw"): """Initialize the Diversity metric. Parameters ---------- distance : {"euclidean", "dtw"}, default "dtw" Distance function for building pairwise distance matrices. """ self.distance = distance
[docs] def name(self) -> str: """Return the metric name. Returns ------- str ``'diversity_dpp'`` for Euclidean distance or ``'diversity_dpp_dtw'`` for DTW distance. """ if self.distance == "dtw": return "diversity_dpp_dtw" return "diversity_dpp"
[docs] def compute( self, X: np.ndarray, X_cf: np.ndarray, max_components: int = 50, **kwargs, ) -> float: """Compute diversity score. Parameters ---------- X : np.ndarray Original instances. X_cf : np.ndarray Counterfactual instances of shape ``(M, K, ...)`` where ``K`` is the number of counterfactuals per query. max_components : int, default 50 Maximum number of components for randomized SVD approximation. **kwargs May contain ``_X_cf_all`` with full counterfactuals when the benchmark passes first-CF-only as ``X_cf`` for other metrics. Returns ------- float Diversity score (higher = more diverse). Returns ``np.nan`` if ``X_cf`` has fewer than 3 dimensions (single CF per query). Raises ------ ValueError If ``distance`` is not a supported value. """ # Use _X_cf_all if provided (contains all k counterfactuals per instance) X_cfs = np.asarray(kwargs.get("_X_cf_all", X_cf)) if X_cfs.ndim < 3: # Single counterfactual per query - diversity is not applicable return float("nan") M, K = X_cfs.shape[:2] logdets: list[float] = [] for i in range(M): D = self._pairwise_distances(X_cfs[i], K) Kmat = 1.0 / (1.0 + D + 1e-12) Kmat = 0.5 * (Kmat + Kmat.T) if max_components >= K: try: sign, logdet = np.linalg.slogdet(Kmat) logdets.append(float(logdet) if sign > 0 else float(-np.inf)) except Exception: vals = np.linalg.eigvalsh(Kmat) vals = vals[vals > 0] logdets.append(float(np.sum(np.log(vals))) if vals.size else float(-np.inf)) else: try: r = min(max_components, K - 1) _, Svals, _ = randomized_svd(Kmat, n_components=r) approx_logdet = float(np.sum(np.log(Svals + 1e-12))) logdets.append(approx_logdet) except Exception: vals = np.linalg.eigvalsh(Kmat) vals = vals[vals > 0] logdets.append(float(np.sum(np.log(vals))) if vals.size else float(-np.inf)) valid = [ld for ld in logdets if np.isfinite(ld) and ld > -1e300] if not valid: return float("nan") mean_logdet = float(np.mean(valid)) return float(np.exp(mean_logdet))
def _pairwise_distances(self, cfs: np.ndarray, K: int) -> np.ndarray: """Compute pairwise distance matrix between K counterfactuals. Parameters ---------- cfs : np.ndarray Counterfactuals for a single query, shape ``(K, ...)``. K : int Number of counterfactuals. Returns ------- np.ndarray Pairwise distance matrix of shape ``(K, K)``. """ if self.distance == "euclidean": S = cfs.reshape(K, -1) D: np.ndarray = np.sqrt(((S[:, None, :] - S[None, :, :]) ** 2).sum(axis=2)) return D if self.distance == "dtw": D = np.zeros((K, K), dtype=float) for j in range(K): D[j] = dtw_distance_vec_multich(cfs[j], cfs) D = 0.5 * (D + D.T) return D raise ValueError(f"Unknown distance: {self.distance!r}. Expected 'euclidean' or 'dtw'.")