Source code for tscf_eval.counterfactuals.comte

"""CoMTE counterfactual explainer implementation.

This module provides the ``COMTE`` class, an implementation of the CoMTE
(Counterfactual Explanations for Multivariate Time Series) algorithm for
generating counterfactual explanations for time series classification.

The algorithm was originally developed by Emre Ates, Burak Aksar, Vitus J. Leung,
and Ayse K. Coskun at Boston University's PEAC Lab.

Original implementation: https://github.com/peaclab/CoMTE

Classes
-------
COMTE
    CoMTE counterfactual generator using greedy channel substitution.

Algorithm Overview
------------------
CoMTE generates counterfactuals through a sequential greedy approach:

1. Select distractor candidates from the reference set that are predicted
   as the target class.
2. For each distractor, greedily swap channels from the distractor into
   the query series, selecting the channel that most increases the target
   class probability at each step.
3. Choose the best counterfactual across all distractors using the loss
   function: ``L = max(0, tau - f_c)^2 + lambda_reg * max(0, n_vars - delta)``

Examples
--------
>>> from tscf_eval.counterfactuals import COMTE
>>> import numpy as np
>>>
>>> # Assume clf is a trained classifier
>>> comte = COMTE(
...     model=clf,
...     data=(X_train, y_train),
...     distance="dtw",
...     n_distractors=10,
...     tau=0.95,
... )
>>>
>>> # Generate counterfactual for a test instance
>>> cf, cf_label, meta = comte.explain(x_test)
>>> print(f"Edited channels: {meta['edits_variables']}")
>>> print(f"Target probability: {meta['target_prob']:.3f}")

References
----------
.. [comte1] Ates, E., Aksar, B., Leung, V. J., & Coskun, A. K. (2021).
       Counterfactual Explanations for Multivariate Time Series.
       In Proceedings of the 2021 International Conference on Applied
       Artificial Intelligence (ICAPAI), pp. 1-8.
       DOI: 10.1109/ICAPAI49758.2021.9462056

.. [comte2] Hollig, J., Kulbach, C., & Thoma, S. (2023).
       TSInterpret: A Python Package for the Interpretability of Time Series
       Classification. Journal of Open Source Software, 8(85), 5220.
       https://doi.org/10.21105/joss.05220
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Literal
import warnings

import numpy as np

from .base import Counterfactual
from .utils import (
    dtw_distance_vec_multich,
    ensure_batch_shape,
    euclidean_cdist_flat,
    soft_predict_proba_fn,
    strip_batch,
)

_MIN_GAIN = 1e-9


[docs] @dataclass class COMTE(Counterfactual): """CoMTE (Sequential Greedy) counterfactual generator for time-series. Implementation of the CoMTE algorithm by Ates et al. (2021) [comte1]_. Produces counterfactuals by greedily replacing whole variables (channels) from distractor series drawn from a reference set. Distractors are selected among reference instances predicted as the target class. For each distractor the algorithm performs a sequential greedy search that replaces channels one-by-one, choosing at each step the channel swap that most increases the model probability ``f_c`` of the target class. The best counterfactual across distractors is chosen using the paper's loss:: L = max(0, tau - f_c)^2 + lambda_reg * max(0, n_vars - delta) **Supported distances:** - ``'dtw'`` : multivariate DTW via ``dtw_distance_vec_multich`` - ``'euclidean'`` : Euclidean distance using flattened pairwise distances Parameters ---------- model : object A classifier with a probability estimator (``predict_proba`` or a compatible interface). The helper ``predict_proba_fn`` wraps model inference. data : tuple (``X_ref``, ``y_ref``) Reference dataset used to select distractors. distance : {'euclidean', 'dtw'}, default 'dtw' Distance metric to find nearest distractors. - ``'euclidean'``: Euclidean distance on flattened vectors. Faster but ignores temporal alignment. - ``'dtw'``: Dynamic Time Warping distance (per-channel, averaged). Respects temporal shifts and is recommended for time series. n_distractors : int Maximum number of distractors to try. tau : float Target probability threshold for class ``c``. delta : int Preferred number of variable edits (paper's sweet spot). lambda_reg : float Regularization weight in the paper loss. random_state : Optional[int] Seed for reproducible distractor tie-breaking. References ---------- .. [comte1] Ates, E., Aksar, B., Leung, V. J., & Coskun, A. K. (2021). Counterfactual Explanations for Multivariate Time Series. ICAPAI 2021. https://github.com/peaclab/CoMTE """ model: Any data: tuple[np.ndarray, np.ndarray] # (X_ref, y_ref) distance: Literal["euclidean", "dtw"] = "dtw" n_distractors: int = 10 # try up to n candidates tau: float = 0.95 # target prob threshold delta: int = 3 # min-len sweet spot in L lambda_reg: float = 0.8 # λ in the paper's loss random_state: int | None = 0
[docs] def __post_init__(self): """Initialise probability wrapper, RNG, reference data, and label mapping. Validates all hyperparameters and pre-computes reference-set predictions to avoid redundant calls during distractor selection. Raises ------ ValueError If ``distance`` is not in ``{'euclidean', 'dtw'}``, ``n_distractors < 1``, ``tau`` is outside ``(0, 1]``, ``delta < 1``, or ``lambda_reg < 0``. """ # Validate parameters if self.distance not in ("euclidean", "dtw"): raise ValueError("distance must be one of {'euclidean', 'dtw'}") if self.n_distractors < 1: raise ValueError("n_distractors must be >= 1") if not (0.0 < self.tau <= 1.0): raise ValueError("tau must be in (0, 1]") if self.delta < 1: raise ValueError("delta must be >= 1") if self.lambda_reg < 0: raise ValueError("lambda_reg must be >= 0") self.predict_proba = soft_predict_proba_fn(self.model) self.rng = np.random.default_rng(self.random_state) self.X_ref = np.asarray(self.data[0]) self.y_ref = np.asarray(self.data[1]).ravel() self._init_label_mapping(self.model, self.y_ref) # Pre-compute reference set predictions to avoid redundant calls self._ref_probs = self.predict_proba(self.X_ref) self._ref_yhat = np.argmax(self._ref_probs, axis=1)
[docs] def explain( self, x: np.ndarray, y_pred: int | None = None, *, class_of_interest: int | None = None, ) -> tuple[np.ndarray, int, dict[str, Any]]: """Generate a counterfactual toward a class of interest. Parameters ---------- x : np.ndarray Input time series of shape ``(T,)`` for univariate or ``(C, T)`` for multivariate data. y_pred : int, optional Base predicted class for ``x``. If ``None``, computed via the model. class_of_interest : int, optional Target class for the counterfactual. If ``None``, uses the highest-probability alternative to ``y_pred``. Returns ------- cf : np.ndarray Counterfactual time series with the same shape as ``x``. cf_label : int Predicted class label for the counterfactual. meta : dict Metadata dictionary containing: - ``method``: Algorithm identifier (``'comte_greedy'``). - ``distance``: Distance metric used. - ``class_of_interest``: Target class. - ``tau``, ``delta``, ``lambda_reg``: Algorithm parameters. - ``distractor_index_in_ref``: Index of selected distractor. - ``distractor_distance``: Distance to selected distractor. - ``edits_variables``: List of edited channel indices. - ``target_prob``: Final target class probability. - ``loss``: Final loss value. """ xb, added = ensure_batch_shape(x) x1 = strip_batch(xb, added) base_probs = self.predict_proba(xb)[0] base_idx = int(np.argmax(base_probs)) if y_pred is None else self._label_to_idx(y_pred) target_idx = self._resolve_target_class(base_probs, base_idx, class_of_interest) # Step 1: Find distractor candidates predicted as the target class distractor_meta, distractors = self._find_target_class_distractors(x1, target_idx) # Step 2: Greedy channel swap per distractor, keep the best best = self._select_best_distractor_result(x1, distractors, target_idx) # Step 3: Assemble result if best is None: return self._no_distractor_fallback(x1, base_idx, target_idx) loss, i, cf, edits, fc = best cf_idx = self._predict_class_idx(cf) cf_label = self._idx_to_label(cf_idx) meta = self._build_meta(target_idx, distractor_meta, i, edits, fc, loss) return cf, cf_label, meta
[docs] def explain_k( self, x: np.ndarray, k: int = 5, y_pred: int | None = None, *, class_of_interest: int | None = None, ) -> tuple[np.ndarray, np.ndarray, list[dict[str, Any]]]: """Generate k diverse counterfactuals using different distractors. COMTE naturally supports diverse counterfactual generation by using different distractor instances from the reference set. Each CF is generated using a different distractor, producing structurally diverse explanations. Parameters ---------- x : np.ndarray Input time series. k : int, default 5 Number of counterfactuals to generate. y_pred : int, optional Precomputed predicted label for ``x``. class_of_interest : int, optional Target class for counterfactuals. Returns ------- cfs : np.ndarray Array of k counterfactuals. cf_labels : np.ndarray Array of k predicted labels. metas : list[dict] List of k metadata dictionaries. """ xb, added = ensure_batch_shape(x) x1 = strip_batch(xb, added) base_probs = self.predict_proba(xb)[0] base_idx = int(np.argmax(base_probs)) if y_pred is None else self._label_to_idx(y_pred) target_idx = self._resolve_target_class(base_probs, base_idx, class_of_interest) # Step 1: Get distractors (request more than k to have options) orig_n_distractors = self.n_distractors self.n_distractors = max(k * 2, orig_n_distractors) distractor_meta, distractors = self._find_target_class_distractors(x1, target_idx) self.n_distractors = orig_n_distractors if not distractors: base_label = self._idx_to_label(base_idx) cfs = np.array([x1 for _ in range(k)]) cf_labels = np.array([base_label for _ in range(k)]) metas = [ { "method": "comte_greedy", "k_index": i, "validity": False, "failure_reason": "no_distractors", } for i in range(k) ] return cfs, cf_labels, metas # Step 2: Generate a counterfactual per distractor (up to k) results: list[tuple[np.ndarray, Any, dict[str, Any]]] = [] for i, distractor in enumerate(distractors[:k]): cf, edits, fc = self._swap_channels_greedily(x1, distractor, target_idx) loss = self._compute_loss(fc, len(edits)) cf_idx = self._predict_class_idx(cf) meta = self._build_meta(target_idx, distractor_meta, i, edits, fc, loss, k_index=i) results.append((cf, self._idx_to_label(cf_idx), meta)) # Step 3: Pad with best result if fewer than k distractors while len(results) < k: best_idx = min( range(len(results)), key=lambda j: float(results[j][2]["loss"]), ) cf, label, meta = results[best_idx] new_meta = meta.copy() new_meta["k_index"] = len(results) new_meta["note"] = "duplicated from best result" results.append((cf.copy(), label, new_meta)) cfs = np.array([r[0] for r in results]) cf_labels = np.array([r[1] for r in results]) metas = [r[2] for r in results] return cfs, cf_labels, metas
def _find_target_class_distractors( self, x: np.ndarray, target_idx: int ) -> tuple[dict[str, Any], list[np.ndarray]]: """Find distractor candidates from the reference set. Selects instances predicted as the target class, ranked by distance to the query. Uses a cascading fallback strategy: 1. Correctly-classified instances of the target class. 2. Any instance predicted as the target class. 3. Any instance with ground-truth target class label. Parameters ---------- x : np.ndarray Query time series of shape ``(T,)`` or ``(C, T)``. target_idx : int Target class index. Returns ------- metadata : dict Dictionary with ``'indices'`` (reference set indices) and ``'distances'`` (distances to query). distractors : list of np.ndarray List of distractor time series arrays. Raises ------ ValueError If ``X_ref`` has fewer than 2 or more than 3 dimensions. """ if self.X_ref.ndim not in (2, 3): raise ValueError(f"Unsupported X_ref shape {self.X_ref.shape}") yhat = self._ref_yhat target_label = self._idx_to_label(target_idx) # Primary: correctly-classified instances of target class mask = (self.y_ref == target_label) & (yhat == target_idx) if not np.any(mask): # Fallback A: any instance predicted as target class mask = yhat == target_idx if not np.any(mask): # Fallback B: any instance with ground-truth target class mask = self.y_ref == target_label if not np.any(mask): return {"indices": [], "distances": []}, [] Xc = self.X_ref[mask] if self.distance == "dtw": dvec = dtw_distance_vec_multich(x, Xc) else: xb, _ = ensure_batch_shape(x) dvec = euclidean_cdist_flat(xb, Xc).ravel() order = np.argsort(dvec) k = min(self.n_distractors, len(order)) picks = order[:k] idx_in_ref = np.flatnonzero(mask)[picks] distractors = [Xc[j] for j in picks] return { "indices": idx_in_ref.tolist(), "distances": [float(dvec[j]) for j in picks], }, distractors def _select_best_distractor_result( self, x: np.ndarray, distractors: list[np.ndarray], target_idx: int, ) -> tuple[float, int, np.ndarray, list[int], float] | None: """Run greedy channel swap per distractor, return the best by loss. Parameters ---------- x : np.ndarray Query time series. distractors : list of np.ndarray Distractor candidates. target_idx : int Target class index. Returns ------- tuple or None ``(loss, distractor_index, cf, edits, target_prob)`` for the best distractor, or ``None`` if no distractors were provided. """ best = None for i, distractor in enumerate(distractors): cf, edits, fc = self._swap_channels_greedily(x, distractor, target_idx) loss = self._compute_loss(fc, len(edits)) item = (loss, i, cf, edits, fc) if best is None or loss < best[0]: best = item return best def _swap_channels_greedily( self, x: np.ndarray, distractor: np.ndarray, target_idx: int ) -> tuple[np.ndarray, list[int], float]: """Greedily swap channels from distractor to maximize target class probability. At each step, selects the channel whose substitution most increases the target class probability. Stops when probability reaches tau or no further improvement is possible. Parameters ---------- x : np.ndarray Original time series of shape ``(T,)`` or ``(C, T)``. distractor : np.ndarray Distractor series to copy channels from, same shape as ``x``. target_idx : int Target class index to optimize probability for. Returns ------- cf : np.ndarray Counterfactual series. edited_channels : list of int Channel indices swapped, in order. target_prob : float Final probability of the target class. """ # Univariate short-circuit: single swap-or-not decision if x.ndim == 1: return self._try_univariate_swap(x, distractor, target_idx) C, _ = x.shape edited: list[int] = [] cf = x.copy() current_prob = self._predict_target_prob(cf, target_idx) # Only consider channels where the distractor actually differs remaining = {c for c in range(C) if not np.array_equal(x[c, :], distractor[c, :])} while current_prob < self.tau and remaining: # Step A: Find the channel swap with the highest probability gain best_ch, best_prob, best_gain = self._find_best_channel_swap( cf, distractor, target_idx, remaining, current_prob ) # Step B: If no gain, try fallback (single swap minimizing loss) if best_ch is None or best_gain <= _MIN_GAIN: if not edited: fallback = self._try_fallback_swap( cf, distractor, target_idx, remaining, current_prob ) if fallback is not None: ch, prob = fallback cf[ch, :] = distractor[ch, :] edited.append(ch) remaining.remove(ch) current_prob = prob continue break # Step C: Commit the best channel swap cf[best_ch, :] = distractor[best_ch, :] edited.append(best_ch) remaining.remove(best_ch) current_prob = best_prob return cf, edited, current_prob def _try_univariate_swap( self, x: np.ndarray, distractor: np.ndarray, target_idx: int ) -> tuple[np.ndarray, list[int], float]: """Try swapping the entire univariate series, keep if loss improves. Parameters ---------- x : np.ndarray Original univariate time series of shape ``(T,)``. distractor : np.ndarray Distractor series of shape ``(T,)``. target_idx : int Target class index. Returns ------- cf : np.ndarray Counterfactual series of shape ``(T,)``. edited_channels : list of int ``[0]`` if swapped, ``[]`` if not. target_prob : float Final probability of the target class. """ if np.array_equal(x, distractor): base_prob = self._predict_target_prob(x, target_idx) return x.copy(), [], base_prob base_prob = self._predict_target_prob(x, target_idx) cf_prob = self._predict_target_prob(distractor, target_idx) if self._compute_loss(cf_prob, 1) < self._compute_loss(base_prob, 0): return distractor.copy(), [0], cf_prob return x.copy(), [], base_prob def _find_best_channel_swap( self, cf: np.ndarray, distractor: np.ndarray, target_idx: int, remaining: set[int], current_prob: float, ) -> tuple[int | None, float, float]: """Evaluate all remaining channels, return the best swap. Parameters ---------- cf : np.ndarray Current counterfactual of shape ``(C, T)``. distractor : np.ndarray Distractor series of shape ``(C, T)``. target_idx : int Target class index. remaining : set of int Channel indices still available for swapping. current_prob : float Current target class probability. Returns ------- best_channel : int or None Channel with the highest gain, or ``None`` if set is empty. best_prob : float Probability after swapping ``best_channel``. best_gain : float Probability gain from the swap. """ best_gain = -np.inf best_channel = None best_prob = current_prob for ch in remaining: candidate = cf.copy() candidate[ch, :] = distractor[ch, :] prob = self._predict_target_prob(candidate, target_idx) gain = prob - current_prob if gain > best_gain: best_gain = gain best_channel = ch best_prob = prob return best_channel, best_prob, best_gain def _try_fallback_swap( self, cf: np.ndarray, distractor: np.ndarray, target_idx: int, remaining: set[int], current_prob: float, ) -> tuple[int, float] | None: """When no channel gives positive gain, try the swap that minimizes loss. This fallback is only used when no channels have been edited yet, to ensure at least one swap is attempted. Parameters ---------- cf : np.ndarray Current counterfactual of shape ``(C, T)``. distractor : np.ndarray Distractor series of shape ``(C, T)``. target_idx : int Target class index. remaining : set of int Channel indices still available. current_prob : float Current target class probability. Returns ------- tuple or None ``(channel, prob)`` if a loss-improving swap exists, else ``None``. """ best_loss = self._compute_loss(current_prob, 0) best_channel = None best_prob = current_prob for ch in remaining: candidate = cf.copy() candidate[ch, :] = distractor[ch, :] prob = self._predict_target_prob(candidate, target_idx) loss = self._compute_loss(prob, 1) if loss < best_loss: best_loss = loss best_prob = prob best_channel = ch if best_channel is not None: return best_channel, best_prob return None def _predict_target_prob(self, arr: np.ndarray, target_idx: int) -> float: """Return the model's probability for target_idx. Parameters ---------- arr : np.ndarray Time series of shape ``(T,)`` or ``(C, T)``. target_idx : int Probability column index of the target class. Returns ------- float Predicted probability for the target class. """ return float(self.predict_proba(arr[None, ...])[0][target_idx]) def _predict_class_idx(self, arr: np.ndarray) -> int: """Return the predicted class index for a single instance. Parameters ---------- arr : np.ndarray Time series of shape ``(T,)`` or ``(C, T)``. Returns ------- int Argmax index of the model's probability vector. """ return int(np.argmax(self.predict_proba(arr[None, ...])[0])) def _resolve_target_class( self, base_probs: np.ndarray, base_idx: int, class_of_interest: int | None, ) -> int: """Determine the target class index for counterfactual generation. If ``class_of_interest`` is provided, it is converted to an internal index. Otherwise, the highest-probability class other than ``base_idx`` is selected. Parameters ---------- base_probs : np.ndarray Probability vector for the query instance. base_idx : int Probability column index of the base (original) class. class_of_interest : int or None User-specified target class label, or ``None`` for automatic selection. Returns ------- int Probability column index for the target class. """ if class_of_interest is not None: return self._label_to_idx(class_of_interest) probs_sorted = np.argsort(-base_probs) return int(next(c for c in probs_sorted if c != base_idx)) def _compute_loss(self, target_prob: float, n_edits: int) -> float: """Compute the counterfactual loss balancing validity and sparsity. Loss = max(0, tau - target_prob)^2 + lambda_reg * max(0, n_edits - delta) Parameters ---------- target_prob : float Probability of the target class for the candidate counterfactual. n_edits : int Number of variables (channels) edited in the counterfactual. Returns ------- float Combined loss value (lower is better). """ validity_penalty = max(0.0, self.tau - target_prob) ** 2 sparsity_penalty = float(self.lambda_reg) * max(0, n_edits - int(self.delta)) return float(validity_penalty + sparsity_penalty) def _no_distractor_fallback( self, x: np.ndarray, base_idx: int, target_idx: int ) -> tuple[np.ndarray, int, dict[str, Any]]: """Return the original instance unchanged when no distractors exist. Emits a warning and returns a metadata dict with ``validity=False`` and ``failure_reason='no_distractors'``. Parameters ---------- x : np.ndarray Original time series of shape ``(T,)`` or ``(C, T)``. base_idx : int Probability column index of the base class. target_idx : int Probability column index of the target class. Returns ------- cf : np.ndarray The original ``x`` (unchanged). cf_label : int Base class label. meta : dict Metadata dictionary flagged as invalid. Warns ----- UserWarning When no distractors are found for the target class. """ warnings.warn( f"COMTE: No distractors found for target class " f"{self._idx_to_label(target_idx)}. " f"This typically occurs when the classifier predicts all reference " f"samples as the same class (base class=" f"{self._idx_to_label(base_idx)}). The original " f"instance is returned unchanged. Consider using a different dataset " f"or a classifier with more diverse predictions.", UserWarning, stacklevel=3, ) return ( x, self._idx_to_label(base_idx), { "method": "comte_greedy", "distance": self.distance, "class_of_interest": self._idx_to_label(target_idx), "validity": False, "failure_reason": "no_distractors", "note": "no distractors found; returning original unchanged", "tau": self.tau, "delta": self.delta, "lambda": self.lambda_reg, }, ) def _build_meta( self, target_idx: int, distractor_meta: dict[str, Any], distractor_i: int, edits: list[int], target_prob: float, loss: float, *, k_index: int | None = None, ) -> dict[str, Any]: """Build the metadata dictionary for an explanation result. Parameters ---------- target_idx : int Probability column index of the target class. distractor_meta : dict Metadata from distractor selection, containing ``'indices'`` and ``'distances'``. distractor_i : int Index into ``distractor_meta`` lists identifying the selected distractor. edits : list of int Channel indices swapped, in order. target_prob : float Final probability of the target class. loss : float Computed loss for this counterfactual. k_index : int or None Index of this result within an ``explain_k`` batch. Omitted from the dict when ``None``. Returns ------- dict Metadata dictionary with keys ``method``, ``distance``, ``class_of_interest``, ``tau``, ``delta``, ``lambda``, ``lambda_reg``, ``distractor_index_in_ref``, ``distractor_distance``, ``edits_variables``, ``target_prob``, and ``loss`` (plus ``k_index`` when applicable). """ meta: dict[str, Any] = { "method": "comte_greedy", "distance": self.distance, "class_of_interest": self._idx_to_label(target_idx), "tau": float(self.tau), "delta": int(self.delta), "lambda": float(self.lambda_reg), "lambda_reg": float(self.lambda_reg), "distractor_index_in_ref": ( distractor_meta["indices"][distractor_i] if distractor_meta["indices"] else None ), "distractor_distance": ( distractor_meta["distances"][distractor_i] if distractor_meta["distances"] else None ), "edits_variables": edits, "target_prob": float(target_prob), "loss": float(loss), } if k_index is not None: meta["k_index"] = k_index return meta