"""Base interface for counterfactual explainers.
This module defines the abstract :class:`Counterfactual` class which other
explainers should subclass. The single required method is ``explain`` which
produces a counterfactual for a single query instance.
Classes
-------
Counterfactual
Abstract base class for counterfactual explainers. Defines the standard
interface: ``explain(x, y_pred) -> (counterfactual, label, metadata)``.
Notes
-----
The ``explain`` method operates on single instances, not batches. This design
allows explainers to maintain instance-specific state and metadata during
generation. For batch processing, wrap calls in a loop or use parallel
execution.
The ``explain_k`` method generates multiple diverse counterfactuals for a
single query. The default implementation calls ``explain`` multiple times
with random restarts, but subclasses can override for more sophisticated
diversity mechanisms.
Examples
--------
Creating a custom explainer:
>>> from tscf_eval.counterfactuals.base import Counterfactual
>>> import numpy as np
>>>
>>> class SimpleExplainer(Counterfactual):
... def __init__(self, model):
... self.model = model
...
... def explain(self, x, y_pred=None):
... # Simple example: perturb the series slightly
... cf = x + np.random.randn(*x.shape) * 0.1
... cf_label = int(self.model.predict(cf[None, ...])[0])
... meta = {"method": "random_perturbation"}
... return cf, cf_label, meta
See Also
--------
tscf_eval.counterfactuals.COMTE : CoMTE algorithm implementation.
tscf_eval.counterfactuals.NativeGuide : NativeGuide algorithm implementation.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
import numpy as np
[docs]
class Counterfactual(ABC):
"""Minimal base interface for counterfactual explainers.
Subclasses must implement ``explain``. The method operates on a single
instance (not batches) and returns the generated counterfactual, its
predicted label, and an optional metadata dictionary describing how the
counterfactual was produced.
Label Mapping
-------------
Subclasses that accept a ``model`` and ``data`` should call
:meth:`_init_label_mapping` during initialisation (e.g. in
``__post_init__``). This populates ``_classes`` and enables
:meth:`_idx_to_label` / :meth:`_label_to_idx` for converting between
probability-column indices and actual class labels.
"""
# Populated by _init_label_mapping; declared here so type checkers
# and IDE autocompletion see them on the base class.
_classes: np.ndarray
_label_to_idx_map: dict[Any, int]
def _init_label_mapping(self, model: Any, y_ref: np.ndarray) -> None:
"""Build the label-to-index and index-to-label mapping.
Must be called once during subclass initialisation.
Parameters
----------
model
Trained classifier. If it exposes ``classes_`` the mapping
is derived from that attribute (guaranteeing agreement with
the column order of ``predict_proba``). Otherwise
``np.unique(y_ref)`` is used as a fallback.
y_ref : np.ndarray
Ground-truth labels from the reference dataset.
"""
if hasattr(model, "classes_"):
self._classes = np.asarray(model.classes_)
else:
self._classes = np.unique(y_ref)
self._label_to_idx_map = {lbl: i for i, lbl in enumerate(self._classes)}
def _idx_to_label(self, class_idx: int) -> Any:
"""Convert a probability column index to the actual class label.
Parameters
----------
class_idx : int
Zero-based index into the probability output columns.
Returns
-------
Any
Corresponding class label from ``_classes``.
"""
return self._classes[class_idx]
def _label_to_idx(self, label: Any) -> int:
"""Convert an actual class label to a probability column index.
Parameters
----------
label : Any
Class label to look up. String/int coercion is attempted if
the exact value is not found.
Returns
-------
int
Zero-based index into the probability output columns.
"""
if label in self._label_to_idx_map:
return self._label_to_idx_map[label]
# Label might be int while classes are str, or vice versa
str_label = str(label)
if str_label in self._label_to_idx_map:
return self._label_to_idx_map[str_label]
raise ValueError(f"Label {label!r} not found in classes {self._classes}")
[docs]
@abstractmethod
def explain(
self, x: np.ndarray, y_pred: int | None = None
) -> tuple[np.ndarray, int, dict[str, Any]]:
"""Return a counterfactual for a single instance `x`.
Parameters
----------
x
A single time-series instance. Supported shapes include ``(T,)``,
``(1, T)`` or ``(1, 1, T)`` for compatibility with callers that may
add a leading batch or channel dimension.
y_pred
Optional precomputed predicted label for ``x``. If ``None``, the
explainer implementation may compute the prediction from its
internally-held model.
Returns
-------
cf_x
Counterfactual series (shape ``(T,)`` or matching input format).
cf_label
Predicted label for the counterfactual.
meta
Metadata dictionary with information about the generation
process (e.g., neighbor indices, distances, edits, timings).
"""
raise NotImplementedError()
[docs]
def explain_k(
self,
x: np.ndarray,
k: int = 5,
y_pred: int | None = None,
) -> tuple[np.ndarray, np.ndarray, list[dict[str, Any]]]:
"""Generate k diverse counterfactuals for a single instance.
The default implementation calls ``explain`` k times. Subclasses may
override this method to implement more sophisticated diversity
mechanisms (e.g., different random seeds, target classes, or
optimization restarts).
Parameters
----------
x : np.ndarray
A single time-series instance.
k : int, default 5
Number of counterfactuals to generate.
y_pred : int, optional
Optional precomputed predicted label for ``x``.
Returns
-------
cfs : np.ndarray
Array of counterfactuals with shape ``(k, ...)``, where ``...``
matches the shape of the input ``x``.
cf_labels : np.ndarray
Array of predicted labels for each counterfactual, shape ``(k,)``.
metas : list[dict]
List of k metadata dictionaries.
Examples
--------
>>> cfs, labels, metas = explainer.explain_k(x, k=5)
>>> cfs.shape # (5, T) for univariate or (5, C, T) for multivariate
"""
cfs = []
cf_labels = []
metas = []
for i in range(k):
cf, label, meta = self.explain(x, y_pred=y_pred)
meta["k_index"] = i
cfs.append(cf)
cf_labels.append(label)
metas.append(meta)
return np.array(cfs), np.array(cf_labels), metas