Source code for tscf_eval.counterfactuals.latent_cf

"""LatentCF++ counterfactual explainer implementation.

This module provides the ``LatentCF`` class, an implementation of the LatentCF++
algorithm for generating counterfactual explanations for time series classification.

The algorithm was originally developed by Zhendong Wang, Isak Samsten,
Rami Mochaourab, and Panagiotis Papapetrou at Stockholm University, with
multivariate extensions by Stella Gerantoni.

Original implementations:
- LatentCF++: https://github.com/zhendong3wang/learning-time-series-counterfactuals
- Multivariate: https://github.com/stellagerantoni/LatentCfMultivariate

Classes
-------
LatentCF
    LatentCF++ counterfactual generator using gradient-based optimization
    in latent space representations.

Algorithm Overview
------------------
LatentCF++ generates counterfactuals through latent space optimization:

1. Optionally encode the input time series into a latent representation using
   an autoencoder (user-provided or None for direct optimization).
2. Compute importance weights for each timestep using:
   - 'uniform': Equal weights across all timesteps
   - 'local': Per-sample importance computed via perturbation-based sensitivity
   - 'global': Dataset-level importance computed across reference samples
3. Optimize a composite loss function:

   - **Prediction margin loss**: Drives the sample toward target class probability
   - **Weighted proximity loss**: Penalizes deviations, weighted by importance

4. Iterate until the target probability is reached or max iterations exhausted.
5. If using an autoencoder, decode the optimized latent representation.

Examples
--------
>>> from tscf_eval.counterfactuals import LatentCF
>>> import numpy as np
>>>
>>> # Assume clf is a trained classifier with predict_proba
>>> latent_cf = LatentCF(
...     model=clf,
...     data=(X_train, y_train),
...     pred_margin_weight=1.0,
...     learning_rate=0.0001,
...     max_iter=100,
... )
>>>
>>> # Generate counterfactual for a test instance
>>> cf, cf_label, meta = latent_cf.explain(x_test)
>>> print(f"Converged: {meta['converged']}")
>>> print(f"Iterations: {meta['n_iterations']}")

References
----------
.. [latentcf1] Wang, Z., Samsten, I., Mochaourab, R., & Papapetrou, P. (2021).
       Learning Time Series Counterfactuals via Latent Space Representations.
       In International Conference on Discovery Science (DS 2021),
       Lecture Notes in Computer Science, vol 12986, pp. 369-384. Springer.
       DOI: 10.1007/978-3-030-88942-5_29

Notes
-----
This implementation provides a NumPy-based version of LatentCF++ that works
directly in the original time series space for compatibility with any
scikit-learn compatible classifier. For TensorFlow/Keras-based models with
autoencoders, the original implementation is recommended.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Literal
import warnings

import numpy as np

from .base import Counterfactual
from .utils import (
    ensure_batch_shape,
    has_expensive_transform,
    soft_predict_proba_fn,
    strip_batch,
    supports_soft_probabilities,
)
from .utils._adam import AdamState

WeightStrategy = Literal["uniform", "local", "global"]


[docs] @dataclass class LatentCF(Counterfactual): """LatentCF++ counterfactual generator using gradient-based optimization. Implementation of the LatentCF++ algorithm by Wang et al. (2021) [latentcf1]_. LatentCF++ generates counterfactuals by optimizing in the latent space (or directly in input space when no autoencoder is provided). The algorithm balances prediction margin loss (driving toward target class) with weighted proximity loss (staying close to original, prioritizing less important regions). The optimization minimizes a composite loss: L = w * L_pred + (1-w) * L_proximity where: - L_pred: Mean squared error between desired probability (1.0) and current - L_proximity: Weighted mean absolute error from original - w: pred_margin_weight parameter Parameters ---------- model : object A classifier with a probability estimator (``predict_proba`` or a compatible interface). data : tuple (``X_ref``, ``y_ref``) Reference dataset used for computing feature importance (for 'global' weight strategy) and normalization statistics. probability : float, default 0.5 Target probability threshold. Optimization aims for P(target) >= probability. tolerance : float, default 1e-6 Convergence tolerance. Optimization stops when prediction margin loss is below tolerance AND target probability is reached. max_iter : int, default 300 Maximum number of optimization iterations. learning_rate : float, default 0.01 Step size for Adam optimizer. Internally scaled by data standard deviation so the effective step adapts to input magnitude. pred_margin_weight : float, default 0.75 Weight balancing prediction margin loss vs proximity loss. Range: [0, 1]. Higher values prioritize changing the prediction. Values >= 0.75 recommended for non-neural-network classifiers. step_weights : {'uniform', 'local', 'global'}, default 'uniform' Strategy for computing importance weights: - 'uniform': Equal weights across all timesteps - 'local': Per-sample importance via perturbation-based sensitivity - 'global': Dataset-level importance computed across reference samples random_state : int or None, default 0 PRNG seed for reproducible optimization. gradient_subsample : int or None, default 50 Number of features to randomly sample for gradient computation each iteration. Uses stochastic gradient descent when set to a value less than the total number of features. Set to None to use all features (full gradient). Lower values speed up computation but may require more iterations to converge. temperature : float or None, default None Temperature scaling for soft probability computation. Higher values produce smoother gradients by preventing sigmoid saturation when decision function values are large. If None, auto-calibrates based on model decision function values (recommended for most use cases). Increase manually (e.g., 2.0-5.0) if counterfactuals are unchanged with ROCKET or other margin-based classifiers. Attributes ---------- predict_proba : callable Wrapped probability prediction function. rng : numpy.random.Generator Random number generator for reproducibility. X_ref : np.ndarray Reference dataset features. y_ref : np.ndarray Reference dataset labels. _global_weights : np.ndarray or None Precomputed global weights (cached after first use). References ---------- .. [latentcf1] Wang, Z., Samsten, I., Mochaourab, R., & Papapetrou, P. (2021). Learning Time Series Counterfactuals via Latent Space Representations. In International Conference on Discovery Science (DS 2021). https://github.com/zhendong3wang/learning-time-series-counterfactuals """ model: Any data: tuple[np.ndarray, np.ndarray] probability: float = 0.5 tolerance: float = 1e-6 max_iter: int = 300 learning_rate: float = 0.01 pred_margin_weight: float = 0.75 step_weights: WeightStrategy = "uniform" random_state: int | None = 0 gradient_subsample: int | None = 50 temperature: float | None = None # Internal state _global_weights: np.ndarray | None = field(default=None, init=False, repr=False)
[docs] def __post_init__(self): """Initialise probability wrapper, RNG, reference data, and label mapping. Validates all hyperparameters. Warns if the model is unlikely to work well with gradient-based optimisation. """ # Warn about classifiers that may not work well with gradient optimization if not supports_soft_probabilities(self.model): warnings.warn( f"LatentCF is a gradient-based method that may not work well with " f"{type(self.model).__name__}. Tree-based classifiers return discrete " f"probabilities that don't respond well to gradient optimization. " f"Consider using COMTE or NativeGuide instead.", UserWarning, stacklevel=2, ) self.predict_proba = soft_predict_proba_fn(self.model, temperature=self.temperature) self.rng = np.random.default_rng(self.random_state) self.X_ref = np.asarray(self.data[0]) self.y_ref = np.asarray(self.data[1]).ravel() self._init_label_mapping(self.model, self.y_ref) # Validate parameters if not (0.0 < self.probability <= 1.0): raise ValueError("probability must be in (0, 1]") if self.tolerance <= 0: raise ValueError("tolerance must be > 0") if self.max_iter < 1: raise ValueError("max_iter must be >= 1") if self.learning_rate <= 0: raise ValueError("learning_rate must be > 0") if not (0.0 <= self.pred_margin_weight <= 1.0): raise ValueError("pred_margin_weight must be in [0, 1]") if self.step_weights not in ("uniform", "local", "global"): raise ValueError("step_weights must be one of {'uniform', 'local', 'global'}") if self.gradient_subsample is not None and self.gradient_subsample < 1: raise ValueError("gradient_subsample must be >= 1 or None")
[docs] def explain( self, x: np.ndarray, y_pred: int | None = None, *, class_of_interest: int | None = None, ) -> tuple[np.ndarray, int, dict[str, Any]]: """Generate a counterfactual explanation using LatentCF++ optimization. Parameters ---------- x : np.ndarray Input time series of shape ``(T,)`` for univariate or ``(C, T)`` for multivariate data. y_pred : int, optional Base predicted class for ``x``. If ``None``, computed via the model. class_of_interest : int, optional Target class for the counterfactual. If ``None``, uses the highest-probability alternative to ``y_pred``. Returns ------- cf : np.ndarray Counterfactual time series with the same shape as ``x``. cf_label : int Predicted class label for the counterfactual. meta : dict Metadata dictionary containing: - ``method``: Algorithm identifier (``'latent_cf'``). - ``step_weights``: Weight strategy used. - ``class_of_interest``: Target class. - ``pred_margin_weight``: Weight parameter used. - ``learning_rate``: Learning rate used. - ``n_iterations``: Number of iterations performed. - ``converged``: Whether optimization converged. - ``final_target_prob``: Final probability of target class. - ``final_loss``: Final composite loss value. - ``validity``: Whether counterfactual changed prediction. """ xb, added = ensure_batch_shape(x) x1 = strip_batch(xb, added) # Determine base prediction and target class (all in index space) base_probs = self.predict_proba(xb)[0] base_idx = int(np.argmax(base_probs)) if y_pred is None else self._label_to_idx(y_pred) if class_of_interest is not None: target_idx = self._label_to_idx(class_of_interest) else: probs_sorted = np.argsort(-base_probs) target_idx = int(next(c for c in probs_sorted if c != base_idx)) # Compute importance weights weights = self._compute_weights(x1, base_idx) # Run gradient-based optimization cf, n_iter, converged, final_prob, final_loss = self._optimize( x1, base_idx, target_idx, weights ) # Get final prediction using soft probabilities (for consistency with optimization) cf_probs = self.predict_proba(cf[None, ...])[0] cf_idx = int(np.argmax(cf_probs)) # Also get model's actual prediction for metadata model_cf_pred = self.model.predict(cf[None, ...])[0] meta: dict[str, Any] = { "method": "latent_cf", "step_weights": self.step_weights, "class_of_interest": self._idx_to_label(target_idx), "pred_margin_weight": float(self.pred_margin_weight), "learning_rate": float(self.learning_rate), "probability_threshold": float(self.probability), "n_iterations": n_iter, "converged": converged, "final_target_prob": float(final_prob), "final_loss": float(final_loss), "validity": cf_idx != base_idx, "model_cf_prediction": model_cf_pred, } return cf, self._idx_to_label(cf_idx), meta
def _compute_weights(self, x: np.ndarray, base_label: int) -> np.ndarray: """Compute importance weights for proximity loss. Parameters ---------- x : np.ndarray Original time series of shape ``(T,)`` or ``(C, T)``. base_label : int Original predicted class label. Returns ------- np.ndarray Importance weights with the same shape as ``x``. Lower weight = more freedom to modify. """ if self.step_weights == "uniform": return np.ones_like(x) if self.step_weights == "local": return self._compute_local_importance(x, base_label) # Global weights return self._compute_global_importance(x) def _compute_local_importance(self, x: np.ndarray, base_label: int) -> np.ndarray: """Compute local feature importance via batched perturbation. Uses batch prediction to compute all perturbations at once, measuring how much the prediction changes when each timestep is perturbed. Parameters ---------- x : np.ndarray Original time series of shape ``(T,)`` or ``(C, T)``. base_label : int Original predicted class label. Returns ------- np.ndarray Importance weights (higher = more important = less modifiable). Low-importance regions (below 25th percentile) are set to 0. """ # Compute standard deviation for perturbation scaling x_std = np.std(x) if np.std(x) > 0 else 1.0 epsilon = 0.01 * x_std flat_x = x.flatten() n_features = len(flat_x) # Sample subset for efficiency on long series n_samples = min(n_features, 100) indices = self.rng.choice(n_features, size=n_samples, replace=False) # Create all perturbed versions at once: 2 * n_samples samples perturbations = np.tile(flat_x, (2 * n_samples, 1)) for j, i in enumerate(indices): perturbations[j, i] += epsilon # x_plus perturbations[n_samples + j, i] -= epsilon # x_minus # Reshape to match expected input shape perturbations_reshaped = perturbations.reshape(2 * n_samples, *x.shape) # Single batch prediction for all perturbations probs_batch = self.predict_proba(perturbations_reshaped) probs_base = probs_batch[:, base_label] # Extract probabilities for plus and minus perturbations probs_plus = probs_base[:n_samples] probs_minus = probs_base[n_samples:] # Gradient magnitude as importance for sampled indices flat_importance = np.zeros(n_features) sampled_importance = np.abs(probs_plus - probs_minus) / (2 * epsilon) for j, i in enumerate(indices): flat_importance[i] = sampled_importance[j] # Interpolate non-sampled indices if n_samples < n_features: sampled_mask = np.zeros(n_features, dtype=bool) sampled_mask[indices] = True for i in range(n_features): if not sampled_mask[i]: # Find nearest sampled neighbors left = i - 1 right = i + 1 while left >= 0 and not sampled_mask[left]: left -= 1 while right < n_features and not sampled_mask[right]: right += 1 if left >= 0 and right < n_features: flat_importance[i] = (flat_importance[left] + flat_importance[right]) / 2 elif left >= 0: flat_importance[i] = flat_importance[left] elif right < n_features: flat_importance[i] = flat_importance[right] importance = flat_importance.reshape(x.shape) # Normalize to [0, 1] if importance.max() > importance.min(): importance = (importance - importance.min()) / (importance.max() - importance.min()) # Mask low-importance regions (below 25th percentile) threshold = np.percentile(importance, 25) weights = np.where(importance <= threshold, 0.0, importance) return weights def _compute_global_importance(self, x: np.ndarray) -> np.ndarray: """Compute global feature importance from reference dataset. Computes importance across all samples in the reference set using perturbation-based sensitivity analysis, then thresholds at the 75th percentile. Parameters ---------- x : np.ndarray Original time series (used for shape reference). Returns ------- np.ndarray Importance weights matching shape of ``x``. """ # Use cached global weights if available and shape matches if self._global_weights is not None and self._global_weights.shape == x.shape: return self._global_weights.copy() # Compute global importance from reference set n_ref = min(len(self.X_ref), 20) # Sample for efficiency ref_indices = self.rng.choice(len(self.X_ref), size=n_ref, replace=False) all_importance = [] for idx in ref_indices: x_ref = self.X_ref[idx] # Reshape if needed to match expected shape if x_ref.shape != x.shape: if x_ref.ndim == 1 and x.ndim == 1: # Both univariate but different lengths - skip continue elif x_ref.ndim == 2 and x.ndim == 2: if x_ref.shape != x.shape: continue else: continue # Get the probability index for this reference sample # (use soft probabilities to get the predicted class index) ref_probs = self.predict_proba(x_ref[None, ...])[0] ref_label_idx = int(np.argmax(ref_probs)) importance = self._compute_local_importance(x_ref, ref_label_idx) all_importance.append(importance) if not all_importance: return np.ones_like(x) # Average importance across reference samples global_importance = np.mean(all_importance, axis=0) # Threshold at 75th percentile (high importance = protected) threshold = np.percentile(global_importance, 75) weights = np.where(global_importance >= threshold, global_importance, 0.0) # Cache for reuse self._global_weights = weights.copy() return weights def _optimize( self, x_original: np.ndarray, base_label: int, target_class: int, step_weights: np.ndarray, ) -> tuple[np.ndarray, int, bool, float, float]: """Run gradient-based optimization to find counterfactual. Uses Adam optimizer with finite-difference gradients, following the original LatentCF++ implementation which uses ``tf.GradientTape`` + Adam. Parameters ---------- x_original : np.ndarray Original time series of shape ``(T,)`` or ``(C, T)``. base_label : int Original predicted class label. target_class : int Target class for counterfactual. step_weights : np.ndarray Importance weights for proximity loss. Returns ------- cf : np.ndarray Optimized counterfactual. n_iterations : int Number of iterations performed. converged : bool Whether optimization converged. final_target_prob : float Final probability of target class. final_loss : float Final composite loss value. """ # Initialize counterfactual as copy of original cf = x_original.copy().astype(np.float64) w = self.pred_margin_weight converged = False n_iterations = 0 final_prob = 0.0 final_loss = float("inf") # Scale learning rate to data magnitude. Adam normalizes gradients so # effective step ≈ lr regardless of gradient magnitude. For finite # differences through non-differentiable transforms (ROCKET), the step # needs to be proportional to data scale to cross the decision boundary. data_scale = float(np.std(self.X_ref)) effective_lr = self.learning_rate * data_scale adam = AdamState.zeros_like(cf) for iteration in range(self.max_iter): n_iterations = iteration + 1 # Compute current prediction probs = self.predict_proba(cf[None, ...])[0] target_prob = probs[target_class] # Prediction margin loss: MSE between desired probability and current pred_margin_loss = (self.probability - target_prob) ** 2 # Proximity loss: weighted MAE from original diff = cf - x_original proximity_loss = float(np.mean(step_weights * np.abs(diff))) # Composite loss total_loss = w * pred_margin_loss + (1 - w) * proximity_loss final_prob = target_prob final_loss = total_loss # Check convergence: loss below tolerance AND target prob reached if pred_margin_loss < self.tolerance and target_prob >= self.probability: converged = True break # Compute gradient via finite differences gradient = self._compute_gradient(cf, x_original, target_class, step_weights, w) # Adam update cf = cf - adam.step(gradient, effective_lr) return cf, n_iterations, converged, final_prob, final_loss def _compute_gradient( self, cf: np.ndarray, x_original: np.ndarray, target_class: int, step_weights: np.ndarray, w: float, ) -> np.ndarray: """Compute gradient of composite loss. Uses finite differences for the prediction margin loss (which requires model evaluation) and an analytical gradient for the proximity loss (which has a simple closed form). This separation avoids the proximity gradient dominating due to scale differences. Parameters ---------- cf : np.ndarray Current counterfactual estimate. x_original : np.ndarray Original time series. target_class : int Target class for counterfactual. step_weights : np.ndarray Importance weights for proximity loss. w : float Prediction margin weight. Returns ------- np.ndarray Gradient of composite loss with respect to cf. """ # Scale epsilon to data magnitude so gradients are meaningful # through non-differentiable transforms (e.g. ROCKET random kernels). x_std = float(np.std(self.X_ref)) epsilon = max(x_std * 0.01, 1e-4) flat_cf = cf.flatten() n_features = len(flat_cf) # Subsample features for stochastic gradient estimation. # For classifiers with expensive transforms (ROCKET, RDST), skip the # n_features // 2 floor so the user's gradient_subsample is respected, # reducing the number of costly transform calls per iteration. if self.gradient_subsample is not None and self.gradient_subsample < n_features: if has_expensive_transform(self.model): n_sample = self.gradient_subsample else: n_sample = max(self.gradient_subsample, n_features // 2) n_sample = min(n_sample, n_features) sampled_idx = self.rng.choice(n_features, size=n_sample, replace=False) else: n_sample = n_features sampled_idx = np.arange(n_features) # --- Prediction margin gradient via finite differences --- perturbations = np.tile(flat_cf, (2 * n_sample, 1)) for i, feat_idx in enumerate(sampled_idx): perturbations[i, feat_idx] += epsilon # cf_plus perturbations[n_sample + i, feat_idx] -= epsilon # cf_minus perturbations_reshaped = perturbations.reshape(2 * n_sample, *cf.shape) probs_batch = self.predict_proba(perturbations_reshaped) target_probs = probs_batch[:, target_class] # d/dcf [ (probability - p(cf))^2 ] via finite differences pred_losses_plus = (self.probability - target_probs[:n_sample]) ** 2 pred_losses_minus = (self.probability - target_probs[n_sample:]) ** 2 pred_grad_sampled = (pred_losses_plus - pred_losses_minus) / (2 * epsilon) pred_gradient = np.zeros(n_features) pred_gradient[sampled_idx] = pred_grad_sampled # --- Proximity gradient (analytical) --- # L_prox = mean(weights * |cf - x_orig|) # d/dcf L_prox = weights * sign(cf - x_orig) / n_features flat_weights = step_weights.flatten() flat_original = x_original.flatten() diff = flat_cf - flat_original prox_gradient = flat_weights * np.sign(diff) / n_features # Normalize each component so that `w` controls the balance # regardless of raw gradient magnitudes. Without this, the # proximity gradient (order ~0.04) dominates the prediction # gradient (order ~0.0001) through non-differentiable classifiers. pred_norm = float(np.linalg.norm(pred_gradient)) + 1e-30 prox_norm = float(np.linalg.norm(prox_gradient)) + 1e-30 full_gradient = w * pred_gradient / pred_norm + (1 - w) * prox_gradient / prox_norm result: np.ndarray = full_gradient.reshape(cf.shape) return result