"""Glacier counterfactual explainer implementation.
This module provides the ``Glacier`` class, an implementation of the Glacier
(Guided Locally Constrained Counterfactual Explanations) algorithm for
generating counterfactual explanations for time series classification.
The algorithm was originally developed by Zhendong Wang, Isak Samsten,
Ioanna Miliou, Rami Mochaourab, and Panagiotis Papapetrou at Stockholm University.
Original implementation: https://github.com/zhendong3wang/learning-time-series-counterfactuals
Classes
-------
Glacier
Glacier counterfactual generator using gradient-based optimization with
guided constraints.
Algorithm Overview
------------------
Glacier generates counterfactuals through gradient-based optimization:
1. Optionally encode the input time series into a latent space using an
autoencoder (CNN or LSTM-based).
2. Compute importance weights using segment-based LIME (local importance),
uniform weights, or unconstrained weights.
3. Optimize a composite loss function that balances:
- **Prediction margin loss**: Drives the counterfactual toward the target class
- **Proximity loss**: Penalizes deviations from the original, weighted by importance
4. Iterate until the classifier predicts the target class with sufficient
confidence or the maximum iterations are reached.
5. If using an autoencoder, decode the optimized latent representation back
to the original time series space.
Examples
--------
>>> from tscf_eval.counterfactuals import Glacier
>>> import numpy as np
>>>
>>> # Assume clf is a trained classifier with predict_proba
>>> glacier = Glacier(
... model=clf,
... data=(X_train, y_train),
... pred_margin_weight=0.5,
... learning_rate=0.01,
... max_iter=100,
... )
>>>
>>> # Generate counterfactual for a test instance
>>> cf, cf_label, meta = glacier.explain(x_test)
>>> print(f"Converged: {meta['converged']}")
>>> print(f"Iterations: {meta['n_iterations']}")
References
----------
.. [glacier1] Wang, Z., Samsten, I., Miliou, I., Mochaourab, R., & Papapetrou, P. (2024).
Glacier: Guided Locally Constrained Counterfactual Explanations for
Time Series Classification. Machine Learning, 113(3).
DOI: 10.1007/s10994-023-06502-x
.. [glacier2] Wang, Z., Samsten, I., Mochaourab, R., & Papapetrou, P. (2021).
Learning Time Series Counterfactuals via Latent Space Representations.
In International Conference on Discovery Science (DS'2021).
Notes
-----
This implementation provides a simplified version of Glacier that works
directly in the original time series space (without autoencoder) for
compatibility with any scikit-learn compatible classifier. The core
gradient-based optimization with weighted proximity constraints is preserved.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Literal
import warnings
import numpy as np
from sklearn.linear_model import Ridge
from .base import Counterfactual
from .utils import (
ensure_batch_shape,
has_expensive_transform,
soft_predict_proba_fn,
strip_batch,
supports_soft_probabilities,
)
from .utils._adam import AdamState
# Optional: stumpy for matrix-profile-based segmentation
try:
import stumpy
STUMPY_AVAILABLE = True
except ImportError: # pragma: no cover
stumpy = None # type: ignore[assignment]
STUMPY_AVAILABLE = False
# Optional: scipy for STFT-based background identification
try:
from scipy import signal as sp_signal
SCIPY_AVAILABLE = True
except ImportError: # pragma: no cover
sp_signal = None # type: ignore[assignment]
SCIPY_AVAILABLE = False
WeightType = Literal["uniform", "local", "unconstrained"]
[docs]
@dataclass
class Glacier(Counterfactual):
"""Glacier counterfactual generator using gradient-based optimization.
Implementation of the Glacier algorithm by Wang et al. (2024) [glacier1]_.
Glacier uses gradient-based optimization with guided constraints to generate
counterfactual explanations. The key innovation is applying importance-based
weights that allow free modification of less-important time series regions
while preserving critical features.
The optimization minimizes a composite loss:
L = w * L_pred + (1-w) * L_proximity
where:
- L_pred: Prediction margin loss (distance to target class probability)
- L_proximity: Weighted distance from original (importance-weighted)
- w: pred_margin_weight parameter
Parameters
----------
model : object
A classifier with a probability estimator (``predict_proba`` or a
compatible interface). Must be differentiable or approximable.
data : tuple (``X_ref``, ``y_ref``)
Reference dataset used for computing feature importance and
normalization statistics.
pred_margin_weight : float, default 0.75
Weight balancing prediction margin loss vs proximity loss.
Higher values prioritize changing the prediction over staying close
to the original. Range: [0, 1]. Values >= 0.75 recommended for
non-neural-network classifiers where finite-difference gradients
are weak relative to the proximity gradient.
learning_rate : float, default 0.01
Step size for Adam optimizer. Internally scaled by data standard
deviation so the effective step adapts to input magnitude.
max_iter : int, default 300
Maximum number of optimization iterations.
tau : float, default 0.5
Decision threshold for target class probability. Optimization stops
when P(target_class) >= tau.
tolerance : float, default 1e-4
Convergence tolerance for prediction margin loss.
weight_type : {'uniform', 'local', 'unconstrained'}, default 'uniform'
Type of importance weighting:
- 'uniform': Equal weights across all timesteps
- 'local': Segment-based LIME importance following the paper.
Uses matrix-profile changepoint segmentation, STFT background
perturbation, and Ridge regression surrogate to compute
per-segment importance, producing binary timestep weights.
Requires ``stumpy`` and ``scipy`` for full functionality
(falls back to uniform segments / mean background otherwise).
- 'unconstrained': No proximity penalty (pure prediction optimization)
random_state : int or None, default 0
PRNG seed for reproducible optimization.
gradient_subsample : int or None, default 50
Number of features to randomly sample for gradient computation each
iteration. Uses stochastic gradient descent when set to a value less
than the total number of features. Set to None to use all features
(full gradient). Lower values speed up computation but may require
more iterations to converge.
temperature : float or None, default None
Temperature scaling for soft probability computation. Higher values
produce smoother gradients by preventing sigmoid saturation when
decision function values are large. If None, auto-calibrates based
on model decision function values (recommended for most use cases).
Increase manually (e.g., 2.0-5.0) if counterfactuals are unchanged
with ROCKET or other margin-based classifiers.
n_segments : int, default 10
Number of changepoints for segment-based local importance
(``weight_type='local'``). Produces ``n_segments + 1`` segments.
Ignored when ``weight_type`` is not ``'local'``.
segment_window : int, default 10
Window size for the matrix-profile segmentation algorithm.
Ignored when ``weight_type`` is not ``'local'``.
n_perturbations : int, default 100
Number of binary perturbation samples for the LIME surrogate model
used in segment-based local importance. Ignored when ``weight_type``
is not ``'local'``.
Attributes
----------
predict_proba : callable
Wrapped probability prediction function.
rng : numpy.random.Generator
Random number generator for reproducibility.
X_ref : np.ndarray
Reference dataset features.
y_ref : np.ndarray
Reference dataset labels.
_mean : np.ndarray
Mean of reference data (for normalization).
_std : np.ndarray
Standard deviation of reference data (for normalization).
References
----------
.. [glacier1] Wang, Z., Samsten, I., Miliou, I., Mochaourab, R., & Papapetrou, P.
(2024). Glacier: Guided Locally Constrained Counterfactual
Explanations for Time Series Classification. Machine Learning, 113(3).
https://github.com/zhendong3wang/learning-time-series-counterfactuals
"""
model: Any
data: tuple[np.ndarray, np.ndarray]
pred_margin_weight: float = 0.75
learning_rate: float = 0.01
max_iter: int = 300
tau: float = 0.5
tolerance: float = 1e-4
weight_type: WeightType = "uniform"
random_state: int | None = 0
gradient_subsample: int | None = 50
temperature: float | None = None
n_segments: int = 10
segment_window: int = 10
n_perturbations: int = 100
# Internal state
_mean: np.ndarray = field(default_factory=lambda: np.array([0.0]), init=False, repr=False)
_std: np.ndarray = field(default_factory=lambda: np.array([1.0]), init=False, repr=False)
[docs]
def __post_init__(self):
"""Initialise probability wrapper, RNG, reference data, and label mapping.
Validates all hyperparameters and computes normalisation statistics
from the reference dataset. Warns if the model is unlikely to work
well with gradient-based optimisation.
"""
# Warn about classifiers that may not work well with gradient optimization
if not supports_soft_probabilities(self.model):
warnings.warn(
f"Glacier is a gradient-based method that may not work well with "
f"{type(self.model).__name__}. Tree-based classifiers return discrete "
f"probabilities that don't respond well to gradient optimization. "
f"Consider using COMTE or NativeGuide instead.",
UserWarning,
stacklevel=2,
)
self.predict_proba = soft_predict_proba_fn(self.model, temperature=self.temperature)
self.rng = np.random.default_rng(self.random_state)
self.X_ref = np.asarray(self.data[0])
self.y_ref = np.asarray(self.data[1]).ravel()
self._init_label_mapping(self.model, self.y_ref)
# Compute normalization statistics from reference data
self._mean = self.X_ref.mean(axis=0)
self._std = self.X_ref.std(axis=0) + 1e-8 # Avoid division by zero
# Validate parameters
if not (0.0 <= self.pred_margin_weight <= 1.0):
raise ValueError("pred_margin_weight must be in [0, 1]")
if self.learning_rate <= 0:
raise ValueError("learning_rate must be > 0")
if self.max_iter < 1:
raise ValueError("max_iter must be >= 1")
if not (0.0 < self.tau <= 1.0):
raise ValueError("tau must be in (0, 1]")
if self.weight_type not in ("uniform", "local", "unconstrained"):
raise ValueError("weight_type must be one of {'uniform', 'local', 'unconstrained'}")
if self.gradient_subsample is not None and self.gradient_subsample < 1:
raise ValueError("gradient_subsample must be >= 1 or None")
if self.n_segments < 1:
raise ValueError("n_segments must be >= 1")
if self.segment_window < 2:
raise ValueError("segment_window must be >= 2")
if self.n_perturbations < 10:
raise ValueError("n_perturbations must be >= 10")
[docs]
def explain(
self,
x: np.ndarray,
y_pred: int | None = None,
*,
class_of_interest: int | None = None,
) -> tuple[np.ndarray, int, dict[str, Any]]:
"""Generate a counterfactual explanation using gradient-based optimization.
Parameters
----------
x : np.ndarray
Input time series of shape ``(T,)`` for univariate or ``(C, T)``
for multivariate data.
y_pred : int, optional
Base predicted class for ``x``. If ``None``, computed via the model.
class_of_interest : int, optional
Target class for the counterfactual. If ``None``, uses the
highest-probability alternative to ``y_pred``.
Returns
-------
cf : np.ndarray
Counterfactual time series with the same shape as ``x``.
cf_label : int
Predicted class label for the counterfactual.
meta : dict
Metadata dictionary containing:
- ``method``: Algorithm identifier (``'glacier'``).
- ``weight_type``: Constraint type used.
- ``class_of_interest``: Target class.
- ``pred_margin_weight``: Weight parameter used.
- ``learning_rate``: Learning rate used.
- ``n_iterations``: Number of iterations performed.
- ``converged``: Whether optimization converged.
- ``final_target_prob``: Final probability of target class.
- ``final_loss``: Final composite loss value.
"""
xb, added = ensure_batch_shape(x)
x1 = strip_batch(xb, added)
# Determine base prediction and target class.
# All internal work uses probability column *indices* (0, 1, ...),
# while y_pred and returned cf_label use actual class *labels*.
base_probs = self.predict_proba(xb)[0]
base_idx = int(np.argmax(base_probs)) if y_pred is None else self._label_to_idx(y_pred)
if class_of_interest is None:
probs_sorted = np.argsort(-base_probs)
target_idx = int(next(c for c in probs_sorted if c != base_idx))
else:
target_idx = self._label_to_idx(class_of_interest)
# Compute importance weights
step_weights = self._compute_weights(x1, base_idx)
# Run gradient-based optimization
cf, n_iter, converged, final_prob, final_loss = self._optimize(
x1, base_idx, target_idx, step_weights
)
# Get final prediction using soft probabilities (for consistency with optimization)
cf_probs = self.predict_proba(cf[None, ...])[0]
cf_idx = int(np.argmax(cf_probs))
# Also get model's actual prediction for metadata
model_cf_pred = self.model.predict(cf[None, ...])[0]
# Return actual class label (not index)
cf_label = self._idx_to_label(cf_idx)
meta: dict[str, Any] = {
"method": "glacier",
"weight_type": self.weight_type,
"class_of_interest": self._idx_to_label(target_idx),
"pred_margin_weight": float(self.pred_margin_weight),
"learning_rate": float(self.learning_rate),
"n_iterations": n_iter,
"converged": converged,
"final_target_prob": float(final_prob),
"final_loss": float(final_loss),
"validity": cf_idx != base_idx,
"model_cf_prediction": model_cf_pred,
}
return cf, cf_label, meta
def _compute_weights(self, x: np.ndarray, base_label: int) -> np.ndarray:
"""Compute importance weights for proximity loss.
Parameters
----------
x : np.ndarray
Original time series of shape ``(T,)`` or ``(C, T)``.
base_label : int
Original predicted class label.
Returns
-------
np.ndarray
Importance weights with the same shape as ``x``.
"""
if self.weight_type == "unconstrained":
# No proximity penalty - all weights zero
return np.zeros_like(x)
if self.weight_type == "uniform":
# Equal weights across all timesteps
return np.ones_like(x)
# Local importance: segment-based LIME following the Glacier paper.
# Uses matrix-profile segmentation + STFT background perturbation
# + Ridge surrogate to compute per-segment importance, then maps
# to binary timestep weights.
return self._compute_local_importance(x, base_label)
# ------------------------------------------------------------------
# Segment-based local importance (following the Glacier paper)
# ------------------------------------------------------------------
def _compute_local_importance(self, x: np.ndarray, base_label: int) -> np.ndarray:
"""Compute local feature importance via segment-based LIME.
Follows the Glacier paper [glacier1]_:
1. Segment the time series using matrix-profile changepoint detection.
2. Perturb segments by replacing them with an STFT-derived background.
3. Fit a weighted Ridge regression as a LIME surrogate model.
4. Threshold segment importance and map to binary timestep weights.
When ``stumpy`` or ``scipy`` are unavailable, falls back to uniform
segmentation and mean-value perturbation respectively.
Parameters
----------
x : np.ndarray
Original time series of shape ``(T,)`` or ``(C, T)``.
base_label : int
Probability column index of the original predicted class.
Returns
-------
np.ndarray
Binary importance weights with the same shape as ``x``.
0 = segment may be freely modified, 1 = segment is protected.
"""
is_multivariate = x.ndim == 2
# Work on the first channel for segmentation; weights are broadcast
x_1d = x[0] if is_multivariate else x
T = len(x_1d)
# 1. Segment the time series
seg_bounds = self._segment_time_series(x_1d)
n_segs = len(seg_bounds) - 1
# 2. Compute background signal for perturbation
background = self._compute_background(x_1d)
# 3. Generate binary perturbation samples and their raw versions
interpretable, raw_samples = self._generate_perturbation_samples(
x, x_1d, background, seg_bounds, n_segs, is_multivariate
)
# 4. Get predictions for all perturbed samples
probs = self.predict_proba(raw_samples)[:, base_label]
# 5. Compute Euclidean distance weights for locality
# Distance between original interpretable repr (all 1s) and each sample
all_on = np.ones(n_segs)
dists = np.linalg.norm(interpretable - all_on, axis=1)
dists_z = (dists - dists.mean()) / dists.std() if dists.std() > 0 else np.zeros_like(dists)
sample_weights = np.exp(-np.abs(dists_z))
# 6. Fit Ridge regression surrogate
ridge = Ridge(alpha=1.0)
ridge.fit(interpretable, probs, sample_weight=sample_weights)
seg_importance = ridge.coef_ # one coefficient per segment
# 7. Threshold and expand to timestep-level binary weights
threshold = np.percentile(seg_importance, 25)
# Segments with LOW importance for the base class -> safe to modify
# (weight 0 means no proximity penalty -> optimizer can change freely)
mask_indices = np.where(seg_importance <= threshold)[0]
weights_1d = np.ones(T)
for idx in mask_indices:
start = seg_bounds[idx]
end = seg_bounds[idx + 1]
weights_1d[start:end] = 0.0
# Broadcast to full shape
if is_multivariate:
return np.broadcast_to(weights_1d[None, :], x.shape).copy()
return weights_1d
def _segment_time_series(self, x_1d: np.ndarray) -> list[int]:
"""Segment a univariate time series via matrix-profile changepoints.
Uses NNSegment-style changepoint detection: compute the matrix profile,
find discontinuities in nearest-neighbour pointers, rank by a
variance-based score, and greedily select non-overlapping changepoints.
Falls back to uniform segmentation when ``stumpy`` is not installed or
the series is too short for the requested window size.
Parameters
----------
x_1d : np.ndarray
Univariate time series of shape ``(T,)``.
Returns
-------
list[int]
Sorted segment boundary indices including ``0`` and ``T``.
"""
T = len(x_1d)
n_cp = self.n_segments
window = min(self.segment_window, T // 2)
if not STUMPY_AVAILABLE or window < 3 or 2 * window > T:
if not STUMPY_AVAILABLE:
warnings.warn(
"stumpy is not installed. Glacier local importance is using "
"uniform segmentation instead of matrix-profile changepoints. "
"Install stumpy for proper segmentation: pip install stumpy",
UserWarning,
stacklevel=3,
)
return self._uniform_segments(T, n_cp)
# Compute the matrix profile
mp = stumpy.stump(x_1d.astype(np.float64), m=window)
nn_indices = mp[:, 1].astype(int)
# Find candidate changepoints: discontinuities in NN pointer
candidates = []
for i in range(len(nn_indices) - 1):
if nn_indices[i + 1] != nn_indices[i] + 1:
candidates.append(i + 1) # boundary is at i+1
if not candidates:
return self._uniform_segments(T, n_cp)
# Score candidates by mean/variance shift
tol = max(window // 2, 1)
scored: list[tuple[float, int]] = []
for idx in candidates:
left_start = max(0, idx - tol)
right_end = min(T, idx + tol)
if idx - left_start < 2 or right_end - idx < 2:
continue
left = x_1d[left_start:idx]
right = x_1d[idx:right_end]
mean_change = abs(float(left.mean() - right.mean()))
std_change = abs(float(left.std() - right.std()))
std_mean = float((left.std() + right.std()) / 2)
score = mean_change * std_change / (std_mean + 1e-10)
scored.append((score, idx))
if not scored:
return self._uniform_segments(T, n_cp)
# Greedy non-overlapping selection
scored.sort(reverse=True)
selected: list[int] = []
for _, idx in scored:
if len(selected) >= n_cp:
break
if all(abs(idx - s) >= tol for s in selected):
selected.append(idx)
if not selected:
return self._uniform_segments(T, n_cp)
selected.sort()
return [0, *selected, T]
@staticmethod
def _uniform_segments(T: int, n_cp: int) -> list[int]:
"""Create uniform segment boundaries as fallback.
Parameters
----------
T : int
Length of the time series.
n_cp : int
Desired number of changepoints.
Returns
-------
list[int]
Sorted segment boundary indices including ``0`` and ``T``.
"""
n_segs = min(n_cp + 1, T)
bounds = np.linspace(0, T, n_segs + 1, dtype=int).tolist()
# Deduplicate (can happen for very short series)
return sorted(set(bounds))
def _compute_background(self, x_1d: np.ndarray) -> np.ndarray:
"""Compute a background signal via STFT (Realistic Background Perturbation).
Isolates the most stable frequency component (highest mean/std ratio
in the STFT) and reconstructs a signal from only that component.
This background signal serves as a realistic replacement when
"turning off" a segment.
Falls back to the global mean when ``scipy`` is not installed.
Parameters
----------
x_1d : np.ndarray
Univariate time series of shape ``(T,)``.
Returns
-------
np.ndarray
Background signal of the same length as ``x_1d``.
"""
if not SCIPY_AVAILABLE:
warnings.warn(
"scipy is not installed. Glacier local importance is using the "
"global mean as background instead of STFT-based background "
"identification. Install scipy for proper background perturbation: "
"pip install scipy",
UserWarning,
stacklevel=3,
)
return np.full_like(x_1d, x_1d.mean())
T = len(x_1d)
nperseg = min(40, T)
_f, _t, Zxx = sp_signal.stft(x_1d.astype(np.float64), fs=1.0, nperseg=nperseg)
# Find the most stable frequency (highest mean/std ratio of magnitude)
magnitudes = np.abs(Zxx)
with np.errstate(divide="ignore", invalid="ignore"):
stds = magnitudes.std(axis=1)
means = magnitudes.mean(axis=1)
stability = np.where(stds > 1e-12, means / stds, 0.0)
best_freq = int(np.argmax(stability))
# Reconstruct using only the most stable frequency
mask = np.zeros_like(Zxx)
mask[best_freq, :] = 1.0
_, bg_raw = sp_signal.istft(Zxx * mask, fs=1.0, nperseg=nperseg)
background: np.ndarray = np.asarray(bg_raw, dtype=np.float64)
# Match length (STFT/ISTFT may produce slightly different length)
if len(background) >= T:
return background[:T]
return np.pad(background, (0, T - len(background)), mode="edge")
def _generate_perturbation_samples(
self,
x: np.ndarray,
x_1d: np.ndarray,
background: np.ndarray,
seg_bounds: list[int],
n_segs: int,
is_multivariate: bool,
) -> tuple[np.ndarray, np.ndarray]:
"""Generate binary perturbation samples and corresponding raw signals.
Each sample is a binary vector (one bit per segment) indicating
whether to keep (1) or replace (0) each segment. When a segment is
"turned off", its timesteps are replaced with the background signal.
Parameters
----------
x : np.ndarray
Original time series of shape ``(T,)`` or ``(C, T)``.
x_1d : np.ndarray
Univariate view of ``x`` (first channel or ``x`` itself).
background : np.ndarray
Background signal of shape ``(T,)``.
seg_bounds : list[int]
Segment boundary indices.
n_segs : int
Number of segments.
is_multivariate : bool
Whether the original input is multivariate.
Returns
-------
interpretable : np.ndarray
Binary matrix of shape ``(n_perturbations, n_segs)``.
raw_samples : np.ndarray
Perturbed time series of shape ``(n_perturbations, *x.shape)``.
"""
interpretable = self.rng.binomial(1, 0.5, size=(self.n_perturbations, n_segs))
raw_samples = np.tile(
x, (self.n_perturbations, 1) if x.ndim == 1 else (self.n_perturbations, 1, 1)
)
for i in range(self.n_perturbations):
for s in range(n_segs):
if interpretable[i, s] == 0:
start = seg_bounds[s]
end = seg_bounds[s + 1]
if is_multivariate:
# Replace all channels with background
raw_samples[i, :, start:end] = background[start:end]
else:
raw_samples[i, start:end] = background[start:end]
return interpretable, raw_samples
def _optimize(
self,
x_original: np.ndarray,
base_label: int,
target_class: int,
step_weights: np.ndarray,
) -> tuple[np.ndarray, int, bool, float, float]:
"""Run gradient-based optimization to find counterfactual.
Uses Adam optimizer with finite-difference gradients, following the
original Glacier implementation which uses ``tf.GradientTape`` + Adam.
Parameters
----------
x_original : np.ndarray
Original time series of shape ``(T,)`` or ``(C, T)``.
base_label : int
Original predicted class label.
target_class : int
Target class for counterfactual.
step_weights : np.ndarray
Importance weights for proximity loss.
Returns
-------
cf : np.ndarray
Optimized counterfactual.
n_iterations : int
Number of iterations performed.
converged : bool
Whether optimization converged.
final_target_prob : float
Final probability of target class.
final_loss : float
Final composite loss value.
"""
# Initialize counterfactual as copy of original
cf = x_original.copy().astype(np.float64)
w = self.pred_margin_weight
converged = False
n_iterations = 0
final_prob = 0.0
final_loss = float("inf")
# Scale learning rate to data magnitude. Adam normalizes gradients so
# effective step ≈ lr regardless of gradient magnitude. For finite
# differences through non-differentiable transforms (ROCKET), the step
# needs to be proportional to data scale to cross the decision boundary.
data_scale = float(self._std.mean())
effective_lr = self.learning_rate * data_scale
adam = AdamState.zeros_like(cf)
for iteration in range(self.max_iter):
n_iterations = iteration + 1
# Compute current prediction
probs = self.predict_proba(cf[None, ...])[0]
target_prob = probs[target_class]
pred_label = int(np.argmax(probs))
# Prediction margin loss: MSE(tau, target_prob) following original
pred_margin_loss = (self.tau - target_prob) ** 2
# Proximity loss: weighted MAE from original (mean, not sum)
diff = cf - x_original
if self.weight_type == "unconstrained":
proximity_loss = 0.0
else:
proximity_loss = float(np.mean(step_weights * np.abs(diff)))
# Composite loss
total_loss = w * pred_margin_loss + (1 - w) * proximity_loss
final_prob = target_prob
final_loss = total_loss
# Check convergence
if pred_margin_loss < self.tolerance and target_prob >= self.tau:
converged = True
break
if pred_label == target_class and target_prob >= self.tau:
converged = True
break
# Compute gradient via finite differences
gradient = self._compute_gradient(cf, x_original, target_class, step_weights, w)
# Adam update
cf = cf - adam.step(gradient, effective_lr)
return cf, n_iterations, converged, final_prob, final_loss
def _compute_gradient(
self,
cf: np.ndarray,
x_original: np.ndarray,
target_class: int,
step_weights: np.ndarray,
w: float,
) -> np.ndarray:
"""Compute gradient of composite loss.
Uses finite differences for the prediction margin loss (which requires
model evaluation) and an analytical gradient for the proximity loss
(which has a simple closed form). This separation avoids the proximity
gradient dominating due to scale differences.
Parameters
----------
cf : np.ndarray
Current counterfactual estimate.
x_original : np.ndarray
Original time series.
target_class : int
Target class for counterfactual.
step_weights : np.ndarray
Importance weights for proximity loss.
w : float
Prediction margin weight.
Returns
-------
np.ndarray
Gradient of composite loss with respect to cf.
"""
# Scale epsilon to data magnitude so gradients are meaningful
# through non-differentiable transforms (e.g. ROCKET random kernels).
epsilon = max(float(self._std.mean()) * 0.01, 1e-4)
flat_cf = cf.flatten()
n_features = len(flat_cf)
# Subsample features for stochastic gradient estimation.
# For classifiers with expensive transforms (ROCKET, RDST), skip the
# n_features // 2 floor so the user's gradient_subsample is respected,
# reducing the number of costly transform calls per iteration.
if self.gradient_subsample is not None and self.gradient_subsample < n_features:
if has_expensive_transform(self.model):
n_sample = self.gradient_subsample
else:
n_sample = max(self.gradient_subsample, n_features // 2)
n_sample = min(n_sample, n_features)
sampled_idx = self.rng.choice(n_features, size=n_sample, replace=False)
else:
n_sample = n_features
sampled_idx = np.arange(n_features)
# --- Prediction margin gradient via finite differences ---
perturbations = np.tile(flat_cf, (2 * n_sample, 1))
for i, feat_idx in enumerate(sampled_idx):
perturbations[i, feat_idx] += epsilon # cf_plus
perturbations[n_sample + i, feat_idx] -= epsilon # cf_minus
perturbations_reshaped = perturbations.reshape(2 * n_sample, *cf.shape)
probs_batch = self.predict_proba(perturbations_reshaped)
target_probs = probs_batch[:, target_class]
# d/dcf [ (tau - p(cf))^2 ] via finite differences
pred_losses_plus = (self.tau - target_probs[:n_sample]) ** 2
pred_losses_minus = (self.tau - target_probs[n_sample:]) ** 2
pred_grad_sampled = (pred_losses_plus - pred_losses_minus) / (2 * epsilon)
pred_gradient = np.zeros(n_features)
pred_gradient[sampled_idx] = pred_grad_sampled
# --- Proximity gradient (analytical) ---
# L_prox = mean(weights * |cf - x_orig|)
# d/dcf L_prox = weights * sign(cf - x_orig) / n_features
if self.weight_type == "unconstrained":
prox_gradient = np.zeros(n_features)
else:
flat_weights = step_weights.flatten()
flat_original = x_original.flatten()
diff = flat_cf - flat_original
prox_gradient = flat_weights * np.sign(diff) / n_features
# Normalize each component so that `w` controls the balance
# regardless of raw gradient magnitudes. Without this, the
# proximity gradient (order ~0.04) dominates the prediction
# gradient (order ~0.0001) through non-differentiable classifiers.
# When a component norm is negligible (e.g. prediction gradient is
# zero for tree-based classifiers), zero it out instead of
# amplifying floating-point noise to unit norm.
min_norm = 1e-10
pred_norm = float(np.linalg.norm(pred_gradient))
prox_norm = float(np.linalg.norm(prox_gradient))
pred_component = (
w * pred_gradient / pred_norm if pred_norm > min_norm else np.zeros_like(pred_gradient)
)
prox_component = (
(1 - w) * prox_gradient / prox_norm
if prox_norm > min_norm
else np.zeros_like(prox_gradient)
)
full_gradient = pred_component + prox_component
return full_gradient.reshape(cf.shape)