"""NativeGuide counterfactual explainer implementation.
This module provides the ``NativeGuide`` class, an implementation of the
Native Guide algorithm for generating counterfactual explanations for
time series classification.
The algorithm was originally developed by Eoin Delaney, Derek Greene, and
Mark T. Keane at University College Dublin's Insight Centre for Data Analytics.
Original implementation: https://github.com/e-delaney/Instance-Based_CFE_TSC
Classes
-------
NativeGuide
NativeGuide counterfactual generator using nearest-unlike-neighbor guidance.
Algorithm Overview
------------------
NativeGuide generates counterfactuals through instance-based reasoning:
1. Find the nearest unlike neighbor (NUN) - the closest instance in the
reference set that is predicted as a different class.
2. Generate a counterfactual by blending the query with the NUN using one
of several methods:
- **blend**: Weighted DTW barycenter averaging, incrementally increasing
the NUN's influence until the prediction flips (original paper method).
- **ng**: Copy a contiguous window from the NUN, growing until flip.
- **dtw_dba**: Like 'ng' but uses a DTW-DBA barycenter of k unlike neighbors.
- **cam**: Like 'ng' but uses CAM importance to select the window location.
Examples
--------
>>> from tscf_eval.counterfactuals import NativeGuide
>>> import numpy as np
>>>
>>> # Assume clf is a trained classifier
>>> ng = NativeGuide(
... model=clf,
... data=(X_train, y_train),
... method="blend", # Original paper method
... distance="dtw",
... )
>>>
>>> # Generate counterfactual for a test instance
>>> cf, cf_label, meta = ng.explain(x_test)
>>> print(f"Beta (blend weight): {meta['beta']}")
>>> print(f"NUN index: {meta['nun_index_in_X']}")
References
----------
.. [ng1] Delaney, E., Greene, D., & Keane, M. T. (2021).
Instance-Based Counterfactual Explanations for Time Series Classification.
In Case-Based Reasoning Research and Development (ICCBR 2021),
pp. 32-47. Springer International Publishing.
DOI: 10.1007/978-3-030-86957-1_3
.. [ng2] Hollig, J., Kulbach, C., & Thoma, S. (2023).
TSInterpret: A Python Package for the Interpretability of Time Series
Classification. Journal of Open Source Software, 8(85), 5220.
https://doi.org/10.21105/joss.05220
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal
import warnings
import numpy as np
if TYPE_CHECKING:
from collections.abc import Callable
from .base import Counterfactual
from .utils import (
dba_barycenter_multich,
dtw_distance_vec_multich,
ensure_batch_shape,
euclidean_cdist_flat,
soft_predict_proba_fn,
strip_batch,
weighted_dba_multich,
)
try:
from tslearn.metrics import dtw as _tslearn_dtw # noqa: F401
from tslearn.neighbors import KNeighborsTimeSeries
TSLEARN_AVAILABLE = True
except ImportError: # pragma: no cover
TSLEARN_AVAILABLE = False
KNeighborsTimeSeries = None # type: ignore
[docs]
@dataclass
class NativeGuide(Counterfactual):
"""NativeGuide counterfactual generator for time-series.
Implementation of the NativeGuide algorithm by Delaney et al. (2021) [ng1]_.
The algorithm retrieves a "native guide" (nearest-unlike neighbor, NUN)
from a reference set. Depending on the method, it either:
- **'blend'** (original paper): Blends the query with the NUN using
weighted DTW barycenter averaging, incrementally increasing the
guide's influence until prediction flips.
- **'ng'**: Copies a contiguous window from the NUN into the query,
growing the window until prediction flips.
- **'dtw_dba'**: Like 'ng' but uses a DTW-DBA barycenter of k unlike
neighbors as the guide.
- **'cam'**: Like 'ng' but uses a CAM importance function to select
the discriminative window.
Parameters
----------
model : object
A classifier-like object that exposes a probability estimator. The
internal helper ``predict_proba_fn`` adapts common interfaces (e.g.
scikit-learn, aeon).
data : tuple
A tuple ``(X_ref, y_ref)`` containing the reference dataset used to
select distractors. ``X_ref`` can have shape ``(N, T)`` or ``(N, C, T)``.
method : {'blend', 'ng', 'dtw_dba', 'cam'}, default 'blend'
Strategy for counterfactual generation:
- 'blend': Original paper method. Weighted averaging of query and NUN
using DTW barycenter, incrementally increasing NUN influence.
- 'ng': Window replacement using nearest-unlike neighbor.
- 'dtw_dba': Window replacement using DTW-DBA barycenter of k neighbors.
- 'cam': Window replacement guided by CAM importance function.
distance : {'euclidean', 'dtw'}, default 'dtw'
Distance metric used to rank distractors when searching the reference set.
- ``'euclidean'``: Euclidean distance on flattened vectors. Faster but
ignores temporal alignment.
- ``'dtw'``: Dynamic Time Warping distance (per-channel, averaged).
Respects temporal shifts and is recommended for time series.
k_unlike : int, default 5
Number of unlike neighbors to consider when computing a DTW-DBA guide.
random_state : int or None, default 0
PRNG seed for deterministic behaviour where applicable.
beta_step : float, default 0.01
For ``method='blend'``: increment for the blending weight beta at each
iteration (original paper uses 0.01).
target_prob : float, default 0.5
For ``method='blend'``: target probability threshold for the counterfactual
class (original paper uses 0.5).
cam_importance_fn : callable or None
When ``method=='cam'``, a function with signature
``(series, y_pred) -> np.ndarray`` that returns an importance map of
shape ``(T,)`` or ``(C, T)``.
Notes
-----
The public API is ``explain(x, y_pred=None) -> (cf, cf_label, meta)``.
The returned ``meta`` dictionary contains keys such as
``nun_index_in_X``, ``neighbor_indices``, ``neighbor_distance``,
``window_start``, ``window_len``, and ``beta`` (for blend method).
References
----------
.. [ng1] Delaney, E., Greene, D., & Keane, M. T. (2021).
Instance-Based Counterfactual Explanations for Time Series
Classification. ICCBR 2021.
https://github.com/e-delaney/Instance-Based_CFE_TSC
"""
model: Any
data: tuple[np.ndarray, np.ndarray]
method: Literal["blend", "ng", "dtw_dba", "cam"] = "blend"
distance: Literal["euclidean", "dtw"] = "dtw"
k_unlike: int = 5
random_state: int | None = 0
# Blend method hyperparameters (original paper)
beta_step: float = 0.01 # increment for blending weight
target_prob: float = 0.5 # target probability threshold
# Only used when method="cam": importance fn(series, y_pred) -> (T,) or (C,T)
cam_importance_fn: Callable[[np.ndarray, int], np.ndarray] | None = None
[docs]
def __post_init__(self):
"""Initialise probability wrapper, RNG, reference data, and label mapping.
Validates all hyperparameters, pre-computes reference-set
predictions, and checks method-specific requirements (e.g.
``cam_importance_fn`` when ``method='cam'``).
Raises
------
ValueError
If ``X`` and ``y`` have mismatched sample counts, ``method`` or
``distance`` is not in the allowed set, ``beta_step`` or
``target_prob`` is outside ``(0, 1]``, or ``method='cam'``
without a ``cam_importance_fn``.
"""
X_ref, y_ref = self.data
self.X_ref = np.asarray(X_ref)
self.y_ref = np.asarray(y_ref).ravel()
if self.X_ref.shape[0] != self.y_ref.shape[0]:
raise ValueError("X and y must have the same number of samples.")
self.predict_proba = soft_predict_proba_fn(self.model)
self.rng = np.random.default_rng(self.random_state)
self._init_label_mapping(self.model, self.y_ref)
# Pre-compute reference set predictions to avoid redundant calls
self._ref_probs = self.predict_proba(self.X_ref)
# Store as probability column indices (consistent with internal index space)
self._ref_yhat = np.argmax(self._ref_probs, axis=1)
if self.method not in {"blend", "ng", "dtw_dba", "cam"}:
raise ValueError("method must be one of {'blend', 'ng', 'dtw_dba', 'cam'}")
if self.distance not in {"euclidean", "dtw"}:
raise ValueError("distance must be one of {'euclidean', 'dtw'}")
if self.method == "dtw_dba" and self.k_unlike < 2:
self.k_unlike = 2
if self.method == "cam" and self.cam_importance_fn is None:
raise ValueError("cam_importance_fn must be provided when method='cam'.")
# sanity checks for blend method
if not (0.0 < self.beta_step <= 1.0):
raise ValueError("beta_step must be in (0, 1].")
if not (0.0 < self.target_prob <= 1.0):
raise ValueError("target_prob must be in (0, 1].")
[docs]
def explain(
self, x: np.ndarray, y_pred: int | None = None
) -> tuple[np.ndarray, int, dict[str, Any]]:
"""Generate a counterfactual explanation for a time series instance.
Parameters
----------
x : np.ndarray
Input time series of shape ``(T,)`` for univariate or ``(C, T)``
for multivariate data.
y_pred : int, optional
Precomputed predicted class for ``x``. If ``None``, computed
via the model.
Returns
-------
cf : np.ndarray
Counterfactual time series with the same shape as ``x``.
cf_label : int
Predicted class label for the counterfactual.
meta : dict
Metadata dictionary containing:
- ``method``: Algorithm variant used.
- ``distance``: Distance metric used.
- ``nun_index_in_X``: Index of nearest unlike neighbor.
- ``neighbor_indices``: Indices of neighbors (for dtw_dba).
- ``neighbor_distance``: Distance to nearest unlike neighbor.
- ``beta``: Blending weight (for blend method, else ``None``).
- ``window_start``: Start of replacement window (else ``None``).
- ``window_len``: Length of replacement window (else ``None``).
"""
xb, added = ensure_batch_shape(x)
x1 = strip_batch(xb, added)
if y_pred is None:
base_idx = int(np.argmax(self.predict_proba(xb)[0]))
else:
base_idx = self._label_to_idx(y_pred)
# Step 1: Retrieve the native guide
if self.method == "dtw_dba":
guide, guide_meta = self._build_dba_guide(x1, base_idx)
else:
guide, guide_meta = self._find_nearest_unlike_neighbor(x1, base_idx)
# Step 2: Generate counterfactual using the chosen strategy
if self.method == "blend":
cf, cf_idx, beta = self._blend_query_with_guide(x1, guide, base_idx)
window_start, window_len = None, None
else:
importance = self._compute_cam_importance(guide) if self.method == "cam" else None
window_start, window_len, cf, cf_idx = self._grow_window_until_flip(
x1, guide, base_idx, importance
)
beta = None
# Step 3: Assemble result
cf_label = self._idx_to_label(cf_idx)
meta = self._build_meta(
guide_meta, beta=beta, window_start=window_start, window_len=window_len
)
return cf, cf_label, meta
[docs]
def explain_k(
self,
x: np.ndarray,
k: int = 5,
y_pred: int | None = None,
) -> tuple[np.ndarray, np.ndarray, list[dict[str, Any]]]:
"""Generate k diverse counterfactuals using different unlike neighbors.
NativeGuide naturally supports diverse counterfactual generation by
using different unlike neighbors as guides. Each counterfactual is
generated using a different neighbor, producing structurally diverse
explanations.
Parameters
----------
x : np.ndarray
Input time series of shape ``(T,)`` or ``(C, T)``.
k : int, default 5
Number of counterfactuals to generate.
y_pred : int, optional
Precomputed predicted label for ``x``.
Returns
-------
cfs : np.ndarray
Array of k counterfactuals with shape ``(k, ...)``.
cf_labels : np.ndarray
Array of k predicted labels.
metas : list[dict]
List of k metadata dictionaries.
"""
xb, added = ensure_batch_shape(x)
x1 = strip_batch(xb, added)
if y_pred is None:
base_idx = int(np.argmax(self.predict_proba(xb)[0]))
else:
base_idx = self._label_to_idx(y_pred)
# Step 1: Find k unlike neighbors to use as diverse guides
result = self._find_unlike_neighbors(x1, base_idx, k=k)
if result is None:
return super().explain_k(x, k=k, y_pred=y_pred)
distances, indices_in_unlike, ref_indices, X_unlike = result
# Step 2: Generate a counterfactual for each unlike neighbor
cfs: list[np.ndarray] = []
cf_labels: list[int] = []
metas: list[dict[str, Any]] = []
for i in range(len(indices_in_unlike)):
nun_in_unlike = int(indices_in_unlike[i])
guide = X_unlike[nun_in_unlike]
guide_meta = {
"nun_index_in_X": int(ref_indices[nun_in_unlike]),
"neighbor_indices": None,
"neighbor_distance": float(distances[i]),
}
if self.method == "blend":
cf, cf_idx, beta = self._blend_query_with_guide(x1, guide, base_idx)
window_start, window_len = None, None
else:
importance = self._compute_cam_importance(guide) if self.method == "cam" else None
window_start, window_len, cf, cf_idx = self._grow_window_until_flip(
x1, guide, base_idx, importance
)
beta = None
cf_label = self._idx_to_label(cf_idx)
meta = self._build_meta(
guide_meta,
beta=beta,
window_start=window_start,
window_len=window_len,
k_index=i,
)
cfs.append(cf)
cf_labels.append(cf_label)
metas.append(meta)
# Step 3: Pad with nearest neighbor result if fewer than k available
while len(cfs) < k:
best_idx = 0
cf = cfs[best_idx].copy()
cf_label = cf_labels[best_idx]
new_meta: dict[str, Any] = metas[best_idx].copy()
new_meta["k_index"] = len(cfs)
new_meta["note"] = "duplicated from nearest neighbor"
cfs.append(cf)
cf_labels.append(cf_label)
metas.append(new_meta)
return np.array(cfs), np.array(cf_labels), metas
def _find_nearest_unlike_neighbor(
self, x: np.ndarray, base_idx: int
) -> tuple[np.ndarray, dict[str, Any]]:
"""Find the single nearest unlike neighbor (NUN) from the reference set.
Used by the 'blend', 'ng', and 'cam' methods.
Parameters
----------
x : np.ndarray
Query time series of shape ``(T,)`` or ``(C, T)``.
base_idx : int
Probability column index of the original predicted class.
Returns
-------
nun : np.ndarray
Nearest unlike neighbor with the same shape as ``x``.
metadata : dict
Dictionary with ``'nun_index_in_X'``, ``'neighbor_indices'``,
and ``'neighbor_distance'``.
"""
result = self._find_unlike_neighbors(x, base_idx, k=1)
if result is None:
guide = self._fallback_global_mean(self.X_ref)
return guide, {
"nun_index_in_X": None,
"neighbor_indices": None,
"neighbor_distance": float(np.nan),
"failure_reason": "no_unlike_neighbors",
}
distances, indices_in_unlike, ref_indices, X_unlike = result
nun_in_unlike = int(indices_in_unlike[0])
nun = X_unlike[nun_in_unlike]
return nun, {
"nun_index_in_X": int(ref_indices[nun_in_unlike]),
"neighbor_indices": None,
"neighbor_distance": float(distances[0]),
}
def _build_dba_guide(self, x: np.ndarray, base_idx: int) -> tuple[np.ndarray, dict[str, Any]]:
"""Build a DTW-DBA barycenter guide from k unlike neighbors.
Used by the 'dtw_dba' method.
Parameters
----------
x : np.ndarray
Query time series of shape ``(T,)`` or ``(C, T)``.
base_idx : int
Probability column index of the original predicted class.
Returns
-------
guide : np.ndarray
DBA barycenter of k unlike neighbors, same shape as ``x``.
metadata : dict
Dictionary with ``'nun_index_in_X'``, ``'neighbor_indices'``,
and ``'neighbor_distance'``.
"""
result = self._find_unlike_neighbors(x, base_idx, k=self.k_unlike)
if result is None:
guide = self._fallback_global_mean(self.X_ref)
return guide, {
"nun_index_in_X": None,
"neighbor_indices": None,
"neighbor_distance": float(np.nan),
"failure_reason": "no_unlike_neighbors",
}
distances, indices_in_unlike, ref_indices, X_unlike = result
nbrs = X_unlike[indices_in_unlike]
guide = dba_barycenter_multich(nbrs)
nun_in_unlike = int(indices_in_unlike[0])
return guide, {
"nun_index_in_X": int(ref_indices[nun_in_unlike]),
"neighbor_indices": ref_indices[indices_in_unlike].tolist(),
"neighbor_distance": float(distances[0]),
}
def _find_unlike_neighbors(
self, x: np.ndarray, base_idx: int, k: int
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] | None:
"""Find k unlike neighbors from the reference set.
Shared helper for both single-NUN and multi-neighbor retrieval.
Parameters
----------
x : np.ndarray
Query time series of shape ``(T,)`` or ``(C, T)``.
base_idx : int
Probability column index of the original predicted class.
k : int
Number of unlike neighbors to retrieve.
Returns
-------
tuple or None
``(distances, indices_in_unlike, ref_indices, X_unlike)`` if
unlike neighbors exist, otherwise ``None`` (with a warning).
"""
yhat = self._ref_yhat
unlike_mask = yhat != base_idx
if not np.any(unlike_mask):
warnings.warn(
f"NativeGuide: No unlike neighbors found in reference set. "
f"The classifier predicts all {len(yhat)} reference samples as class "
f"{self._idx_to_label(base_idx)}. Falling back to global mean, which "
f"may not produce a valid counterfactual. Consider using a different "
f"dataset or classifier with more diverse predictions.",
UserWarning,
stacklevel=3,
)
return None
X_unlike = self.X_ref[unlike_mask]
ref_indices = np.flatnonzero(unlike_mask)
k_actual = min(k, len(X_unlike))
if TSLEARN_AVAILABLE:
distances, indices_in_unlike = self._find_k_neighbors_tslearn(x, X_unlike, k=k_actual)
else:
warnings.warn(
"tslearn is not installed. NativeGuide is using a manual fallback "
"for neighbor search instead of KNeighborsTimeSeries. "
"Install tslearn for the original algorithm: pip install tslearn",
UserWarning,
stacklevel=3,
)
dvec = self._compute_distances(x, X_unlike)
order = np.argsort(dvec)[:k_actual]
indices_in_unlike = order
distances = dvec[order]
return distances, indices_in_unlike, ref_indices, X_unlike
def _blend_query_with_guide(
self, x: np.ndarray, guide: np.ndarray, base_idx: int
) -> tuple[np.ndarray, int, float]:
"""Blend query with guide via weighted DTW barycenter averaging.
Incrementally increases the guide's influence (beta) until the model's
predicted class flips or the target probability is reached.
Parameters
----------
x : np.ndarray
Query series of shape ``(T,)`` or ``(C, T)``.
guide : np.ndarray
Native guide (NUN) of same shape as ``x``.
base_idx : int
Probability column index of the original predicted class.
Returns
-------
cf : np.ndarray
Counterfactual series with the same shape as ``x``.
cf_idx : int
Probability column index for the counterfactual prediction.
beta : float
Final blending weight in ``[0, 1]``.
"""
beta = 0.0
cf = x.copy()
cf_idx, alt_prob = self._predict_idx_and_max_alt_prob(cf, base_idx)
while alt_prob < self.target_prob and beta < 1.0:
beta = min(beta + self.beta_step, 1.0)
cf = weighted_dba_multich(x, guide, beta)
cf_idx, alt_prob = self._predict_idx_and_max_alt_prob(cf, base_idx)
if cf_idx != base_idx:
break
return cf, cf_idx, beta
def _compute_cam_importance(self, guide: np.ndarray) -> np.ndarray:
"""Compute the CAM importance map from the guide's class activation.
Predicts the guide's class internally and passes it to the user-provided
``cam_importance_fn``. This matches the original paper where the NUN's
CAM weights identify which region to swap.
Parameters
----------
guide : np.ndarray
The native guide (NUN) of shape ``(T,)`` or ``(C, T)``.
Returns
-------
np.ndarray
Importance map of shape ``(T,)``.
Raises
------
ValueError
If the shape returned by ``cam_importance_fn`` does not match
``(T,)`` for univariate input or ``(T,)`` / ``(C, T)`` for
multivariate input.
"""
assert self.cam_importance_fn is not None
guide_b, _ = ensure_batch_shape(guide)
guide_pred_idx = int(np.argmax(self.predict_proba(guide_b)[0]))
guide_label = self._idx_to_label(guide_pred_idx)
imp = np.asarray(self.cam_importance_fn(guide, guide_label))
if guide.ndim == 1:
if imp.ndim != 1 or imp.shape[0] != guide.shape[0]:
raise ValueError("cam_importance_fn must return shape (T,) for univariate input.")
return imp
# multivariate input: accept (T,) or (C, T)
if imp.ndim == 1:
if imp.shape[0] != guide.shape[1]:
raise ValueError("When returning (T,), length must match time length.")
return imp
if imp.ndim == 2 and imp.shape == guide.shape:
sum_result: np.ndarray = imp.sum(axis=0)
return sum_result
raise ValueError("cam_importance_fn must return (T,) or (C,T) for multivariate input.")
def _grow_window_until_flip(
self,
x: np.ndarray,
guide: np.ndarray,
base_idx: int,
importance: np.ndarray | None,
) -> tuple[int, int, np.ndarray, int]:
"""Grow a window from the guide into the query until the prediction flips.
Iteratively increases the window length, positioning it at the most
important region, and copies guide content into the query.
Parameters
----------
x : np.ndarray
Query time series of shape ``(T,)`` or ``(C, T)``.
guide : np.ndarray
Native guide series with the same shape as ``x``.
base_idx : int
Probability column index of the original predicted class.
importance : np.ndarray or None
Importance map of shape ``(T,)``. If ``None``, uses
``|guide - x|`` as a heuristic.
Returns
-------
start : int
Start index of the replacement window.
length : int
Length of the replacement window.
cf : np.ndarray
Counterfactual series with the same shape as ``x``.
cf_idx : int
Probability column index for the counterfactual prediction.
Raises
------
ValueError
If ``importance`` is not ``None`` and its shape is not ``(T,)``.
"""
T = x.shape[-1] if x.ndim == 2 else x.shape[0]
# Compute importance if not provided (heuristic: |guide - x|)
if importance is None:
importance = np.abs(guide - x) if x.ndim == 1 else np.abs(guide - x).sum(axis=0)
elif importance.ndim != 1 or importance.shape[0] != T:
raise ValueError("Internal: importance must be (T,) at this point.")
# Precompute cumulative sum for fast window-start queries
cumsum = np.concatenate([[0.0], np.cumsum(importance)])
# Grow window from length 1 to T, checking for flip at each step
length = 1
while length <= T:
start = self._best_window_start(cumsum, length)
cf = self._splice_guide_segment(x, guide, start, length)
cf_idx = self._predict_class_idx(cf)
if cf_idx != base_idx:
return start, length, cf, cf_idx
length += 1
# Worst case: replace entire series
cf = self._splice_guide_segment(x, guide, 0, T)
cf_idx = self._predict_class_idx(cf)
return 0, T, cf, cf_idx
def _predict_class_idx(self, arr: np.ndarray) -> int:
"""Return the predicted probability column index for a series.
Parameters
----------
arr : np.ndarray
Time series of shape ``(T,)`` or ``(C, T)``.
Returns
-------
int
Argmax index of the model's probability vector.
"""
cb, _ = ensure_batch_shape(arr)
return int(np.argmax(self.predict_proba(cb)[0]))
def _predict_idx_and_max_alt_prob(self, arr: np.ndarray, base_idx: int) -> tuple[int, float]:
"""Predict class index and highest non-base probability.
Parameters
----------
arr : np.ndarray
Time series of shape ``(T,)`` or ``(C, T)``.
base_idx : int
Probability column index of the base (original) class.
Returns
-------
pred_idx : int
Argmax index of the model's probability vector.
alt_prob : float
Highest probability among classes other than ``base_idx``.
"""
cb, _ = ensure_batch_shape(arr)
probs = self.predict_proba(cb)[0]
pred_idx = int(np.argmax(probs))
alt_prob = max((p for i, p in enumerate(probs) if i != base_idx), default=0.0)
return pred_idx, alt_prob
@staticmethod
def _best_window_start(cumsum: np.ndarray, length: int) -> int:
"""Return the start index maximising importance over a window.
Uses a precomputed cumulative sum for O(1) window-sum queries.
Parameters
----------
cumsum : np.ndarray
Cumulative sum of the importance map, of length ``T + 1``
(prepended with 0).
length : int
Window length to evaluate.
Returns
-------
int
Start index of the window with the highest total importance.
"""
return int(np.argmax(cumsum[length:] - cumsum[:-length]))
def _splice_guide_segment(
self, x: np.ndarray, guide: np.ndarray, start: int, length: int
) -> np.ndarray:
"""Copy a contiguous segment from guide into x.
Parameters
----------
x : np.ndarray
Original time series of shape ``(T,)`` or ``(C, T)``.
guide : np.ndarray
Guide series with the same shape as ``x``.
start : int
Start index of the segment to replace.
length : int
Length of the segment to replace.
Returns
-------
np.ndarray
Modified series with ``[start, start+length)`` replaced from guide.
"""
end = min(start + length, x.shape[-1] if x.ndim == 2 else x.shape[0])
out = x.copy()
if x.ndim == 1:
out[start:end] = guide[start:end]
else:
out[:, start:end] = guide[:, start:end]
return out
def _fallback_global_mean(self, X: np.ndarray) -> np.ndarray:
"""Compute the global mean of the reference set (fallback guide).
Used when no unlike neighbors exist in the reference set.
Parameters
----------
X : np.ndarray
Reference set of shape ``(N, T)`` or ``(N, C, T)``.
Returns
-------
np.ndarray
Mean series of shape ``(T,)`` or ``(C, T)``.
Raises
------
ValueError
If ``X`` has fewer than 2 or more than 3 dimensions.
"""
if X.ndim in (2, 3):
result: np.ndarray = X.mean(axis=0)
return result
raise ValueError(f"Unsupported X shape for mean: {X.shape}")
def _compute_distances(self, x: np.ndarray, X_candidates: np.ndarray) -> np.ndarray:
"""Compute distances from query to candidate set.
Parameters
----------
x : np.ndarray
Query time series of shape ``(T,)`` or ``(C, T)``.
X_candidates : np.ndarray
Candidate set of shape ``(N, T)`` or ``(N, C, T)``.
Returns
-------
np.ndarray
1-D array of distances of length ``N``.
"""
if self.distance != "euclidean" and not TSLEARN_AVAILABLE:
warnings.warn(
f"NativeGuide: distance='{self.distance}' was requested but tslearn "
f"is not installed. Falling back to Euclidean distance. "
f"Install tslearn for DTW support: pip install tslearn",
UserWarning,
stacklevel=2,
)
if self.distance == "euclidean" or not TSLEARN_AVAILABLE:
xb, _ = ensure_batch_shape(x)
return euclidean_cdist_flat(xb, X_candidates).ravel()
return dtw_distance_vec_multich(x, X_candidates)
def _find_k_neighbors_tslearn(
self, x: np.ndarray, X_candidates: np.ndarray, k: int
) -> tuple[np.ndarray, np.ndarray]:
"""Find k nearest neighbors using tslearn's KNeighborsTimeSeries.
Handles shape conversion between codebase convention ``(N, C, T)``
and tslearn convention ``(N, T, C)``.
Parameters
----------
x : np.ndarray
Query series of shape ``(T,)`` or ``(C, T)``.
X_candidates : np.ndarray
Candidate set of shape ``(N, T)`` or ``(N, C, T)``.
k : int
Number of neighbors to find.
Returns
-------
distances : np.ndarray
1-D array of k distances.
indices : np.ndarray
1-D array of k indices into ``X_candidates``.
"""
if X_candidates.ndim == 2:
X_tl = X_candidates[:, :, np.newaxis]
else:
X_tl = np.transpose(X_candidates, (0, 2, 1))
q_tl = x[np.newaxis, :, np.newaxis] if x.ndim == 1 else x.T[np.newaxis, :, :]
knn = KNeighborsTimeSeries(n_neighbors=k, metric=self.distance)
knn.fit(X_tl)
dists, inds = knn.kneighbors(q_tl, return_distance=True)
return dists[0], inds[0]
def _build_meta(
self,
guide_meta: dict[str, Any],
*,
beta: float | None = None,
window_start: int | None = None,
window_len: int | None = None,
k_index: int | None = None,
) -> dict[str, Any]:
"""Build the metadata dictionary for an explanation result.
Parameters
----------
guide_meta : dict
Metadata from guide retrieval, containing ``'nun_index_in_X'``,
``'neighbor_indices'``, and ``'neighbor_distance'``.
beta : float or None
Blending weight (only for ``method='blend'``).
window_start : int or None
Start index of the replacement window (window-based methods).
window_len : int or None
Length of the replacement window (window-based methods).
k_index : int or None
Index of this result within a ``explain_k`` batch. Omitted
from the dict when ``None``.
Returns
-------
dict
Metadata dictionary with keys ``method``, ``distance``,
``nun_index_in_X``, ``neighbor_indices``, ``neighbor_distance``,
``beta``, ``window_start``, and ``window_len`` (plus ``k_index``
when applicable).
"""
meta: dict[str, Any] = {
"method": self.method,
"distance": self.distance,
"nun_index_in_X": guide_meta.get("nun_index_in_X"),
"neighbor_indices": guide_meta.get("neighbor_indices"),
"neighbor_distance": guide_meta.get("neighbor_distance"),
"beta": beta,
"window_start": window_start,
"window_len": window_len,
}
if k_index is not None:
meta["k_index"] = k_index
return meta