"""SETS counterfactual explainer implementation.
This module provides the ``SETS`` class, an implementation of the SETS
(Shapelet Explainer for Time Series) algorithm for generating counterfactual
explanations for time series classification.
The algorithm was originally developed by Omar Bahri, Soukaina Filali
Boubrahimi, and Shah Muhammad Hamdi at Utah State University and New Mexico
State University.
Original implementation: https://github.com/omarbahri/SETS
Classes
-------
SETS
SETS counterfactual generator using class-specific shapelet manipulation.
Algorithm Overview
------------------
SETS generates counterfactuals through class-specific shapelet manipulation:
1. Extract discriminative shapelets using the Random Shapelet Transform.
2. Compute an occlusion threshold from the bottom percentile of scaled
shapelet-to-series distances.
3. Assign each shapelet to its exclusive class (discard multi-class ones).
4. Build occurrence heat maps describing typical shapelet positions.
5. For a test instance, per dimension (ordered by information gain):
a. **Phase A - Remove original-class shapelets**: replace detected
original-class shapelet regions with min-max rescaled segments from the
nearest unlike neighbor (NUN).
b. **Phase B - Introduce target-class shapelets**: insert target-class
shapelets at heat-map-guided positions, min-max rescaled.
6. Check the classifier prediction after each edit; stop early if the target
class is achieved.
7. If single dimensions fail, try combinations of perturbed dimensions.
Examples
--------
>>> from tscf_eval.counterfactuals import SETS
>>> import numpy as np
>>>
>>> # Assume clf is a trained classifier with predict_proba
>>> sets = SETS(
... model=clf,
... data=(X_train, y_train),
... n_shapelet_samples=5000,
... max_shapelets=200,
... )
>>>
>>> # Generate counterfactual for a test instance
>>> cf, cf_label, meta = sets.explain(x_test)
>>> print(f"Valid: {meta['validity']}")
>>> print(f"Dimensions modified: {meta['dimensions_modified']}")
References
----------
.. [sets1] Bahri, O., Filali Boubrahimi, S., & Hamdi, S. M. (2022).
Shapelet-Based Counterfactual Explanations for Multivariate Time Series.
In Proceedings of the ACM SIGKDD Workshop on Mining and Learning from
Time Series (KDD-MiLeTS 2022).
DOI: 10.48550/arXiv.2208.10462
See Also
--------
tscf_eval.counterfactuals.COMTE : CoMTE algorithm implementation.
tscf_eval.counterfactuals.NativeGuide : NativeGuide algorithm implementation.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from itertools import combinations
from typing import Any, NamedTuple
import warnings
import numpy as np
from .base import Counterfactual
from .utils import (
ensure_batch_shape,
predict_proba_fn,
strip_batch,
)
from .utils._nun import find_nearest_unlike_neighbor
# Optional: aeon for shapelet transform
try:
from aeon.transformations.collection.shapelet_based import (
RandomShapeletTransform,
)
_AEON_SHAPELET_AVAILABLE = True
except ImportError: # pragma: no cover
_AEON_SHAPELET_AVAILABLE = False
class ShapeletInfo(NamedTuple):
"""Parsed shapelet metadata with both normalized and raw arrays."""
idx: int # Index in the transformer's shapelet list
info_gain: float
length: int
start_pos: int
channel: int
series_id: int
class_value: Any # Matches y_ref dtype
z_norm_array: np.ndarray # Z-normalized (from aeon)
raw_array: np.ndarray # Raw values (from X_ref)
[docs]
@dataclass
class SETS(Counterfactual):
"""SETS counterfactual generator using class-specific shapelets.
Implementation of the SETS algorithm by Bahri et al. (2022) [sets1]_.
SETS leverages the inherent interpretability of shapelets to produce
counterfactual explanations with contiguous, visually meaningful
perturbations. The preprocessing phase discovers class-exclusive
shapelets and their typical occurrence positions; the generation phase
removes original-class shapelets and introduces target-class shapelets
to flip the classifier prediction.
Parameters
----------
model : object
A classifier with ``predict_proba`` (or compatible interface).
data : tuple (``X_ref``, ``y_ref``)
Reference dataset for shapelet extraction and NUN lookup.
n_shapelet_samples : int, default 10000
Number of candidate shapelets to evaluate during extraction.
max_shapelets : int or None, default None
Maximum shapelets to retain. ``None`` uses aeon's default
(``min(10 * n_cases, 1000)``).
min_shapelet_length : int, default 3
Minimum shapelet length.
max_shapelet_length : int or None, default None
Maximum shapelet length. ``None`` uses the full series length.
time_limit_in_minutes : float, default 0.0
Time budget for shapelet extraction (0 = use ``n_shapelet_samples``).
threshold_percentile : float, default 10.0
Bottom percentile of per-shapelet scaled distances used as the
occlusion threshold. Lower values are stricter.
max_combination_dims : int, default 3
Maximum number of dimensions to combine when single-dimension
edits fail. Caps the combinatorial search at C(D, k) for
k ≤ ``max_combination_dims``.
random_state : int or None, default 0
PRNG seed for reproducibility.
n_jobs : int, default 1
Number of parallel jobs for shapelet extraction.
Attributes
----------
predict_proba : callable
Wrapped probability prediction function.
rng : numpy.random.Generator
Random number generator.
X_ref : np.ndarray
Reference dataset features.
y_ref : np.ndarray
Reference dataset labels.
References
----------
.. [sets1] Bahri, O., Filali Boubrahimi, S., & Hamdi, S. M. (2022).
Shapelet-Based Counterfactual Explanations for Multivariate Time
Series. In Proceedings of the ACM SIGKDD Workshop on Mining and
Learning from Time Series (KDD-MiLeTS 2022).
https://github.com/omarbahri/SETS
"""
model: Any
data: tuple[np.ndarray, np.ndarray]
# Shapelet transform parameters
n_shapelet_samples: int = 10000
max_shapelets: int | None = None
min_shapelet_length: int = 3
max_shapelet_length: int | None = None
time_limit_in_minutes: float = 0.0
# SETS algorithm parameters
threshold_percentile: float = 10.0
max_combination_dims: int = 3
random_state: int | None = 0
n_jobs: int = 1
# Internal state (not user-facing)
_shapelets: list[ShapeletInfo] = field(default_factory=list, init=False, repr=False)
_class_shapelets: dict[Any, list[int]] = field(default_factory=dict, init=False, repr=False)
_thresholds: dict[int, float] = field(default_factory=dict, init=False, repr=False)
_heat_maps: dict[int, np.ndarray] = field(default_factory=dict, init=False, repr=False)
_dim_ig: dict[int, float] = field(default_factory=dict, init=False, repr=False)
_n_channels: int = field(default=1, init=False, repr=False)
_series_length: int = field(default=0, init=False, repr=False)
[docs]
def __post_init__(self):
"""Initialise prediction wrapper, reference data, and shapelet pipeline.
Validates parameters, fits the shapelet transform, computes the
occlusion threshold, assigns class-exclusive shapelets, builds
heat maps, and computes per-channel information gain.
"""
if not _AEON_SHAPELET_AVAILABLE:
raise ImportError(
"SETS requires aeon's RandomShapeletTransform. Install aeon: pip install aeon"
)
self._validate_params()
self._preprocess_data()
# Run full preprocessing pipeline
self._fit_shapelets()
if self._shapelets:
self._compute_thresholds()
self._assign_classes()
self._build_heat_maps()
self._compute_dim_ig()
def _validate_params(self) -> None:
"""Validate user-facing parameters.
Raises
------
ValueError
If any parameter is out of its valid range.
"""
if self.n_shapelet_samples < 1:
raise ValueError("n_shapelet_samples must be >= 1")
if self.min_shapelet_length < 1:
raise ValueError("min_shapelet_length must be >= 1")
if (
self.max_shapelet_length is not None
and self.max_shapelet_length < self.min_shapelet_length
):
raise ValueError("max_shapelet_length must be >= min_shapelet_length")
if not (0.0 < self.threshold_percentile <= 100.0):
raise ValueError("threshold_percentile must be in (0, 100]")
if self.max_combination_dims < 1:
raise ValueError("max_combination_dims must be >= 1")
def _preprocess_data(self) -> None:
"""Initialise internal data structures from constructor arguments.
Wraps the model's predict function, normalises ``X_ref`` to 3-D,
and pre-computes reference predictions and label mapping.
Raises
------
ValueError
If ``X_ref`` has an unsupported number of dimensions.
"""
self.predict_proba = predict_proba_fn(self.model)
self.rng = np.random.default_rng(self.random_state)
self.X_ref = np.asarray(self.data[0])
self.y_ref = np.asarray(self.data[1]).ravel()
# Normalise to 3-D (N, C, T) for aeon
if self.X_ref.ndim == 2: # (N, T) univariate
self._X_ref_3d = self.X_ref[:, np.newaxis, :]
self._n_channels = 1
self._series_length = self.X_ref.shape[1]
elif self.X_ref.ndim == 3: # (N, C, T) multivariate
self._X_ref_3d = self.X_ref
self._n_channels = self.X_ref.shape[1]
self._series_length = self.X_ref.shape[2]
else:
raise ValueError(f"Unsupported X_ref shape: {self.X_ref.shape}")
self._init_label_mapping(self.model, self.y_ref)
def _fit_shapelets(self) -> None:
"""Extract shapelets using aeon's RandomShapeletTransform.
Fits the transform on the 3-D reference data, extracts the raw
shapelet table (values, lengths, dimensions, info gain, class
value), parses each entry into a :class:`ShapeletInfo` named
tuple, and stores the distance matrix for later thresholding.
"""
max_len = self.max_shapelet_length
if max_len is None:
max_len = self._series_length
transformer = RandomShapeletTransform(
n_shapelet_samples=self.n_shapelet_samples,
max_shapelets=self.max_shapelets,
min_shapelet_length=self.min_shapelet_length,
max_shapelet_length=max_len,
time_limit_in_minutes=self.time_limit_in_minutes,
random_state=self.random_state,
n_jobs=self.n_jobs,
)
self._dist_matrix: np.ndarray = transformer.fit_transform(self._X_ref_3d, self.y_ref)
self._transformer = transformer
parsed: list[ShapeletInfo] = []
for i, shp in enumerate(transformer.shapelets):
ig, length, start_pos, channel, series_id, class_val, z_norm = shp
length = int(length)
start_pos = int(start_pos)
channel = int(channel)
series_id = int(series_id)
end = start_pos + length
raw = self._X_ref_3d[series_id, channel, start_pos:end].copy()
parsed.append(
ShapeletInfo(
idx=i,
info_gain=float(ig),
length=length,
start_pos=start_pos,
channel=channel,
series_id=series_id,
class_value=class_val,
z_norm_array=np.asarray(z_norm),
raw_array=raw,
)
)
self._shapelets = parsed
if not parsed:
warnings.warn(
"SETS: No shapelets extracted. Counterfactual generation will "
"return the original instance unchanged.",
UserWarning,
stacklevel=2,
)
def _compute_thresholds(self) -> None:
"""Compute per-shapelet occlusion thresholds from raw distances.
For each shapelet, takes the ``threshold_percentile``-th percentile
of the distance column from the ST distance matrix. This matches
the original SETS implementation which applies a per-shapelet
threshold on the raw sliding-window distances rather than a single
global threshold on normalised distances.
"""
D = self._dist_matrix # (n_cases, n_shapelets)
self._thresholds = {}
for s in self._shapelets:
col = D[:, s.idx]
self._thresholds[s.idx] = float(np.percentile(col, self.threshold_percentile))
def _assign_classes(self) -> None:
"""Filter to class-exclusive shapelets and build class mapping.
Keeps only shapelets that occur exclusively in instances of a
single class. Multi-class shapelets are discarded. Builds the
``_class_shapelets`` mapping from class label to shapelet indices.
"""
class_shapelets: dict[Any, list[int]] = {}
surviving_indices: set[int] = set()
for s in self._shapelets:
# Find training instances where this shapelet "occurs"
threshold = self._thresholds[s.idx]
occ_mask = self._dist_matrix[:, s.idx] <= threshold
if not np.any(occ_mask):
continue # No occurrences — discard
occ_classes = set(self.y_ref[occ_mask].tolist())
if len(occ_classes) != 1:
continue # Multi-class — discard
cls = occ_classes.pop()
class_shapelets.setdefault(cls, []).append(s.idx)
surviving_indices.add(s.idx)
self._class_shapelets = class_shapelets
# Keep only surviving shapelets
self._shapelets = [s for s in self._shapelets if s.idx in surviving_indices]
if not self._shapelets:
warnings.warn(
"SETS: No class-exclusive shapelets found after filtering. "
"Try increasing n_shapelet_samples or threshold_percentile.",
UserWarning,
stacklevel=2,
)
def _build_heat_maps(self) -> None:
"""Build normalised occurrence-position heat maps per shapelet.
For each shapelet, aggregates the best-match start positions
across same-class training instances to produce a probability
distribution over time positions. Used to guide Phase B
shapelet insertion.
"""
for s in self._shapelets:
heat = np.zeros(self._series_length, dtype=np.float64)
threshold = self._thresholds[s.idx]
# Iterate only over training instances of this shapelet's class
class_mask = self.y_ref == s.class_value
n_occ = 0
for i in np.where(class_mask)[0]:
if self._dist_matrix[i, s.idx] > threshold:
continue
# Find occurrence position(s) in this instance
channel_data = self._X_ref_3d[i, s.channel]
positions = self._find_occurrence_positions(
channel_data, s.z_norm_array, s.length, s.idx
)
for p in positions:
end = min(p + s.length, self._series_length)
heat[p:end] += 1.0
n_occ += 1
if n_occ > 0:
heat /= n_occ
self._heat_maps[s.idx] = heat
def _compute_dim_ig(self) -> None:
"""Compute maximum information gain per channel.
For each channel, stores the highest information gain across all
class-exclusive shapelets assigned to that channel. Channels with
no shapelets receive an information gain of ``0.0``. The result
is stored in ``_dim_ig`` and used to order dimensions during
counterfactual generation.
"""
self._dim_ig = {}
for c in range(self._n_channels):
igs = [s.info_gain for s in self._shapelets if s.channel == c]
self._dim_ig[c] = max(igs) if igs else 0.0
@staticmethod
def _sliding_window_distances(
series_channel: np.ndarray,
z_norm_shapelet: np.ndarray,
length: int,
) -> np.ndarray:
"""Z-normalised sliding-window squared Euclidean distance profile.
Parameters
----------
series_channel : np.ndarray
1-D array of shape ``(T,)``.
z_norm_shapelet : np.ndarray
Z-normalised shapelet of shape ``(L,)``.
length : int
Shapelet length.
Returns
-------
np.ndarray
Distance at each valid position, shape ``(T - L + 1,)``.
"""
T = len(series_channel)
n_pos = T - length + 1
dists = np.empty(n_pos, dtype=np.float64)
for p in range(n_pos):
w = series_channel[p : p + length].astype(np.float64)
std = w.std()
w_norm = np.zeros_like(w) if std < 1e-8 else (w - w.mean()) / std
dists[p] = float(np.sum((z_norm_shapelet - w_norm) ** 2)) / length
return dists
def _find_occurrence_positions(
self,
series_channel: np.ndarray,
z_norm_shapelet: np.ndarray,
length: int,
shapelet_idx: int,
) -> list[int]:
"""Return positions where a shapelet occurs below its occlusion threshold.
Parameters
----------
series_channel : np.ndarray
1-D channel data of shape ``(T,)``.
z_norm_shapelet : np.ndarray
Z-normalised shapelet of shape ``(L,)``.
length : int
Shapelet length.
shapelet_idx : int
Index of the shapelet, used to look up its per-shapelet
threshold in ``_thresholds``.
Returns
-------
list[int]
Start positions where the distance is at or below the
shapelet's threshold.
"""
if length > len(series_channel):
return []
dists = self._sliding_window_distances(series_channel, z_norm_shapelet, length)
threshold = self._thresholds[shapelet_idx]
positions: list[int] = np.where(dists <= threshold)[0].tolist()
return positions
@staticmethod
def _rescale_segment(
source: np.ndarray,
target_min: float,
target_max: float,
) -> np.ndarray:
"""Min-max rescale *source* values into ``[target_min, target_max]``.
Parameters
----------
source : np.ndarray
1-D array of values to rescale.
target_min : float
Desired minimum of the output range.
target_max : float
Desired maximum of the output range.
Returns
-------
np.ndarray
Rescaled array with the same shape as *source*.
"""
s_min = float(source.min())
s_max = float(source.max())
if s_max - s_min < 1e-12:
return np.full_like(source, (target_min + target_max) / 2.0)
rescaled = (target_max - target_min) * (source - s_min) / (s_max - s_min) + target_min
return rescaled
@staticmethod
def _best_position_from_heatmap(heat_map: np.ndarray, length: int) -> int:
"""Find insertion position from the center of the heat map's active region.
Matches the original SETS implementation: computes the center of
the non-zero region of the heat map, then positions the shapelet
so that its midpoint aligns with that center.
Parameters
----------
heat_map : np.ndarray
1-D heat map of shape ``(T,)``.
length : int
Shapelet length to insert.
Returns
-------
int
Start position for the shapelet insertion.
"""
T = len(heat_map)
nonzero = np.argwhere(heat_map > 0)
if len(nonzero) == 0:
return 0
first = int(nonzero[0, 0])
last = int(nonzero[-1, 0])
center = (last - first) // 2 + first
start = center - length // 2
end = center + (length - length // 2)
# Boundary adjustments (matching original)
if start < 0:
end = end - start
start = 0
if end > T:
start = start - (end - T)
end = T
start = max(start, 0)
return start
def _find_nun(
self,
x_internal: np.ndarray,
target_class: Any,
) -> tuple[np.ndarray, int]:
"""Find nearest unlike neighbor from *target_class*.
Performs simple kNN search on training instances whose
ground-truth label matches the target class, matching the
original SETS implementation.
Parameters
----------
x_internal : np.ndarray
Instance in ``(C, T)`` shape.
target_class
Target class label (probability index).
Returns
-------
nun : np.ndarray
Nearest unlike neighbor in ``(C, T)`` shape.
nun_idx : int
Index in ``X_ref``.
"""
target_label = self._idx_to_label(target_class)
nuns, indices = find_nearest_unlike_neighbor(
x_internal,
self._X_ref_3d,
self.y_ref,
target_label,
k=1,
)
if nuns:
return nuns[0], indices[0]
# Fallback: no instances of target class in training data
warnings.warn(
f"SETS: No instances of target class {target_label!r} found in "
f"training labels. Using closest instance regardless of class.",
UserWarning,
stacklevel=2,
)
nuns, indices = find_nearest_unlike_neighbor(
x_internal,
self._X_ref_3d,
np.zeros(len(self.y_ref)),
0,
fallback_all=True,
k=1,
)
return nuns[0], indices[0]
def _predict_class_idx(self, x_internal: np.ndarray) -> int:
"""Predict class for a ``(C, T)`` internal representation.
Parameters
----------
x_internal : np.ndarray
Time series in internal ``(C, T)`` format.
Returns
-------
int
Probability column index of the predicted class.
"""
# Convert back to original dimensionality for model
if self._n_channels == 1 and self.X_ref.ndim == 2:
x_model = x_internal[0][np.newaxis, :] # (1, T)
else:
x_model = x_internal[np.newaxis, ...] # (1, C, T)
probs = self.predict_proba(x_model)[0]
return int(np.argmax(probs))
def _generate_cf(
self,
x_internal: np.ndarray,
orig_class: Any,
target_class: Any,
nun: np.ndarray,
dim_order: list[int],
) -> tuple[np.ndarray, bool, dict[str, Any]]:
"""Core generation loop with Phase A/B and dimension combinations.
For each dimension (ordered by information gain), applies Phase A
(remove original-class shapelets by replacing with NUN segments)
and Phase B (insert target-class shapelets at heat-map-guided
positions). After each dimension, if the single-dimension edit
failed, immediately tries combinations of all dimensions processed
so far. This matches the original SETS implementation structure.
Parameters
----------
x_internal : np.ndarray
Input series in ``(C, T)`` shape.
orig_class : int
Probability index of the original predicted class.
target_class : int
Probability index of the target class.
nun : np.ndarray
Nearest unlike neighbor in ``(C, T)`` shape.
dim_order : list[int]
Dimension indices ordered by information gain (descending).
Returns
-------
cf : np.ndarray
Counterfactual in ``(C, T)`` shape.
success : bool
Whether the target class was achieved.
info : dict
Edit information including dimensions modified, phase A/B
edit counts.
"""
per_dim_cfs: dict[int, np.ndarray] = {}
total_a = 0
total_b = 0
# Map probability indices to original class labels for shapelet lookup
orig_label = self._classes[orig_class] if orig_class < len(self._classes) else orig_class
target_label = (
self._classes[target_class] if target_class < len(self._classes) else target_class
)
# Collect shapelets per class, per channel for quick lookup
orig_shps_by_ch: dict[int, list[ShapeletInfo]] = {}
target_shps_by_ch: dict[int, list[ShapeletInfo]] = {}
for s in self._shapelets:
if s.class_value == orig_label and s.channel in dim_order:
orig_shps_by_ch.setdefault(s.channel, []).append(s)
if s.class_value == target_label and s.channel in dim_order:
target_shps_by_ch.setdefault(s.channel, []).append(s)
# Sort shapelets within each channel by info gain (descending)
for ch_list in orig_shps_by_ch.values():
ch_list.sort(key=lambda s: s.info_gain, reverse=True)
for ch_list in target_shps_by_ch.values():
ch_list.sort(key=lambda s: s.info_gain, reverse=True)
for d in dim_order:
working = x_internal.copy()
# Phase A: Remove original-class shapelets in dimension d
for s in orig_shps_by_ch.get(d, []):
positions = self._find_occurrence_positions(
working[d], s.z_norm_array, s.length, s.idx
)
for p in positions:
end = min(p + s.length, self._series_length)
seg_len = end - p
local_min = float(working[d, p:end].min())
local_max = float(working[d, p:end].max())
nun_seg = nun[d, p:end]
working[d, p:end] = self._rescale_segment(
nun_seg[:seg_len], local_min, local_max
)
total_a += 1
if self._predict_class_idx(working) == target_class:
return (
working,
True,
{
"dimensions_modified": [d],
"phase_a_edits": total_a,
"phase_b_edits": total_b,
},
)
# Phase B: Introduce target-class shapelets in dimension d
for s in target_shps_by_ch.get(d, []):
heat_map = self._heat_maps.get(s.idx)
if heat_map is None:
continue
insert_pos = self._best_position_from_heatmap(heat_map, s.length)
end = min(insert_pos + s.length, self._series_length)
seg_len = end - insert_pos
local_min = float(working[d, insert_pos:end].min())
local_max = float(working[d, insert_pos:end].max())
working[d, insert_pos:end] = self._rescale_segment(
s.raw_array[:seg_len], local_min, local_max
)
total_b += 1
if self._predict_class_idx(working) == target_class:
return (
working,
True,
{
"dimensions_modified": [d],
"phase_a_edits": total_a,
"phase_b_edits": total_b,
},
)
per_dim_cfs[d] = working
# After single-dim edit, check if prediction changed
if self._predict_class_idx(working) == target_class:
return (
working,
True,
{
"dimensions_modified": [d],
"phase_a_edits": total_a,
"phase_b_edits": total_b,
},
)
# Try combinations of all dimensions processed so far
available_dims = [dd for dd in dim_order if dd in per_dim_cfs]
max_k = min(self.max_combination_dims, len(available_dims))
for n_dims in range(2, max_k + 1):
for combo in combinations(available_dims, n_dims):
combined = x_internal.copy()
for cd in combo:
combined[cd] = per_dim_cfs[cd][cd]
if self._predict_class_idx(combined) == target_class:
return (
combined,
True,
{
"dimensions_modified": list(combo),
"phase_a_edits": total_a,
"phase_b_edits": total_b,
},
)
# Failed — return the best single-dimension attempt
best_cf = (
per_dim_cfs.get(dim_order[0], x_internal.copy()) if per_dim_cfs else x_internal.copy()
)
return (
best_cf,
False,
{
"dimensions_modified": [],
"phase_a_edits": total_a,
"phase_b_edits": total_b,
},
)
[docs]
def explain(
self,
x: np.ndarray,
y_pred: int | None = None,
*,
class_of_interest: int | None = None,
) -> tuple[np.ndarray, int, dict[str, Any]]:
"""Generate a counterfactual explanation using SETS.
Parameters
----------
x : np.ndarray
Input time series of shape ``(T,)`` for univariate or ``(C, T)``
for multivariate data.
y_pred : int, optional
Base predicted class for ``x``. If ``None``, computed via model.
class_of_interest : int, optional
Target class. If ``None``, uses the highest-probability
alternative to ``y_pred``.
Returns
-------
cf : np.ndarray
Counterfactual time series with the same shape as ``x``.
cf_label : int
Predicted class label for the counterfactual.
meta : dict
Metadata dictionary containing:
- ``method``: ``'sets'``
- ``class_of_interest``: Target class.
- ``nun_index_in_ref``: Index of the NUN used.
- ``dimensions_modified``: Channels edited.
- ``phase_a_edits``: Number of Phase A replacements.
- ``phase_b_edits``: Number of Phase B insertions.
- ``n_class_shapelets``: Total surviving class-exclusive shapelets.
- ``validity``: Whether the target class was achieved.
- ``failure_reason``: ``None`` if successful, description otherwise.
"""
xb, added = ensure_batch_shape(x)
x1 = strip_batch(xb, added)
# Convert to (C, T) internally
was_univariate = x1.ndim == 1
x_internal = x1[np.newaxis, :] if was_univariate else x1.copy()
# Determine base prediction and target class (all in index space)
base_probs = self.predict_proba(xb)[0]
base_idx = int(np.argmax(base_probs)) if y_pred is None else self._label_to_idx(y_pred)
if class_of_interest is not None:
target_idx = self._label_to_idx(class_of_interest)
else:
probs_sorted = np.argsort(-base_probs)
target_idx = int(next(c for c in probs_sorted if c != base_idx))
failure_reason: str | None = None
# Edge case: no class-exclusive shapelets
if not self._shapelets:
failure_reason = "no_class_exclusive_shapelets"
cf = x1.copy()
cf_label = self._idx_to_label(base_idx)
return (
cf,
cf_label,
self._build_meta(
self._idx_to_label(target_idx),
None,
[],
0,
0,
False,
failure_reason,
),
)
# Sort dimensions by max info gain (descending)
dim_order = sorted(
range(self._n_channels),
key=lambda c: self._dim_ig.get(c, 0.0),
reverse=True,
)
# Find nearest unlike neighbor
nun, nun_idx = self._find_nun(x_internal, target_idx)
# Generate counterfactual
cf_internal, success, edit_info = self._generate_cf(
x_internal, base_idx, target_idx, nun, dim_order
)
if not success:
failure_reason = "no_valid_cf_found"
# Convert back to original shape
cf = cf_internal[0] if was_univariate else cf_internal
cf_idx = self._predict_class_idx(cf_internal)
cf_label = self._idx_to_label(cf_idx)
return (
cf,
cf_label,
self._build_meta(
self._idx_to_label(target_idx),
nun_idx,
edit_info.get("dimensions_modified", []),
edit_info.get("phase_a_edits", 0),
edit_info.get("phase_b_edits", 0),
success,
failure_reason,
),
)
[docs]
def explain_k(
self,
x: np.ndarray,
k: int = 5,
y_pred: int | None = None,
*,
class_of_interest: int | None = None,
) -> tuple[np.ndarray, np.ndarray, list[dict[str, Any]]]:
"""Generate k diverse counterfactuals using different NUNs.
SETS supports diverse counterfactual generation by using different
nearest unlike neighbors as the replacement source for Phase A.
Each counterfactual is generated with a different NUN, producing
structurally diverse explanations.
Parameters
----------
x : np.ndarray
Input time series of shape ``(T,)`` or ``(C, T)``.
k : int, default 5
Number of counterfactuals to generate.
y_pred : int, optional
Precomputed predicted label for ``x``.
class_of_interest : int, optional
Target class for counterfactuals.
Returns
-------
cfs : np.ndarray
Array of k counterfactuals with shape ``(k, ...)``.
cf_labels : np.ndarray
Array of k predicted labels.
metas : list[dict]
List of k metadata dictionaries.
"""
xb, added = ensure_batch_shape(x)
x1 = strip_batch(xb, added)
was_univariate = x1.ndim == 1
x_internal = x1[np.newaxis, :] if was_univariate else x1.copy()
# Determine base prediction and target class (all in index space)
base_probs = self.predict_proba(xb)[0]
base_idx = int(np.argmax(base_probs)) if y_pred is None else self._label_to_idx(y_pred)
if class_of_interest is not None:
target_idx = self._label_to_idx(class_of_interest)
else:
probs_sorted = np.argsort(-base_probs)
target_idx = int(next(c for c in probs_sorted if c != base_idx))
target_label = self._idx_to_label(target_idx)
# Edge case: no class-exclusive shapelets
if not self._shapelets:
base_label = self._idx_to_label(base_idx)
cfs_out = np.array([x1.copy() for _ in range(k)])
labels_out = np.array([base_label] * k)
metas_out = [
{
**self._build_meta(
target_label, None, [], 0, 0, False, "no_class_exclusive_shapelets"
),
"k_index": i,
}
for i in range(k)
]
return cfs_out, labels_out, metas_out
dim_order = sorted(
range(self._n_channels),
key=lambda c: self._dim_ig.get(c, 0.0),
reverse=True,
)
# Find k NUNs for diversity
nuns, nun_indices = self._find_k_nuns(x_internal, target_idx, k)
cfs: list[np.ndarray] = []
cf_labels: list[Any] = []
metas: list[dict[str, Any]] = []
for i, (nun, nun_idx) in enumerate(zip(nuns, nun_indices, strict=True)):
cf_internal, success, edit_info = self._generate_cf(
x_internal, base_idx, target_idx, nun, dim_order
)
failure_reason = None if success else "no_valid_cf_found"
cf = cf_internal[0] if was_univariate else cf_internal
cf_idx = self._predict_class_idx(cf_internal)
cf_label = self._idx_to_label(cf_idx)
meta = self._build_meta(
target_label,
nun_idx,
edit_info.get("dimensions_modified", []),
edit_info.get("phase_a_edits", 0),
edit_info.get("phase_b_edits", 0),
success,
failure_reason,
)
meta["k_index"] = i
cfs.append(cf)
cf_labels.append(cf_label)
metas.append(meta)
# Pad with best result if fewer NUNs than k
while len(cfs) < k:
best_idx = 0
cf = cfs[best_idx].copy()
cf_label = cf_labels[best_idx]
new_meta = metas[best_idx].copy()
new_meta["k_index"] = len(cfs)
new_meta["note"] = "duplicated from nearest neighbor"
cfs.append(cf)
cf_labels.append(cf_label)
metas.append(new_meta)
return np.array(cfs), np.array(cf_labels), metas
def _find_k_nuns(
self,
x_internal: np.ndarray,
target_class: Any,
k: int,
) -> tuple[list[np.ndarray], list[int]]:
"""Find k nearest unlike neighbors from *target_class*.
Parameters
----------
x_internal : np.ndarray
Instance in ``(C, T)`` shape.
target_class
Target class label (probability index).
k : int
Number of NUNs to retrieve.
Returns
-------
nuns : list[np.ndarray]
Up to k NUNs in ``(C, T)`` shape.
nun_indices : list[int]
Indices in ``X_ref``.
"""
target_label = self._idx_to_label(target_class)
nuns, indices = find_nearest_unlike_neighbor(
x_internal,
self._X_ref_3d,
self.y_ref,
target_label,
k=k,
)
if nuns:
return nuns, indices
# Fallback: no instances of target class in training data
return find_nearest_unlike_neighbor(
x_internal,
self._X_ref_3d,
np.zeros(len(self.y_ref)),
0,
fallback_all=True,
k=k,
)
def _build_meta(
self,
class_of_interest: int,
nun_idx: int | None,
dims_modified: list[int],
phase_a: int,
phase_b: int,
validity: bool,
failure_reason: str | None,
) -> dict[str, Any]:
"""Build the metadata dictionary returned by ``explain``.
Parameters
----------
class_of_interest : int
Target class label.
nun_idx : int or None
Index of the NUN in ``X_ref``, or ``None`` if unavailable.
dims_modified : list[int]
Channel indices that were edited.
phase_a : int
Number of Phase A (removal) edits applied.
phase_b : int
Number of Phase B (insertion) edits applied.
validity : bool
Whether the target class was achieved.
failure_reason : str or None
Description of the failure, or ``None`` on success.
Returns
-------
dict[str, Any]
Metadata dictionary suitable for the ``explain`` return value.
"""
return {
"method": "sets",
"class_of_interest": class_of_interest,
"nun_index_in_ref": nun_idx,
"dimensions_modified": dims_modified,
"phase_a_edits": phase_a,
"phase_b_edits": phase_b,
"n_class_shapelets": len(self._shapelets),
"validity": validity,
"failure_reason": failure_reason,
}