diff --git a/docs/api/interpret.rst b/docs/api/interpret.rst index 08d41ab41..2758d42b2 100644 --- a/docs/api/interpret.rst +++ b/docs/api/interpret.rst @@ -50,6 +50,16 @@ New to interpretability in PyHealth? Check out these complete examples: - Compare different baseline strategies for background sample generation - Decode attributions to human-readable medical codes and lab measurements +**LIME Example:** + +- ``examples/lime_stagenet_mimic4.py`` - Demonstrates LIME (Local Interpretable Model-agnostic Explanations) for StageNet mortality prediction. Shows how to: + + - Compute local linear approximations to explain model predictions + - Generate perturbations around input samples to train interpretable models + - Compare different regularization methods (Lasso vs Ridge) for feature selection + - Test various distance kernels (cosine vs euclidean) and sample sizes + - Decode attributions to human-readable medical codes and lab measurements + These examples provide end-to-end workflows from loading data to interpreting and evaluating attributions. Available Methods @@ -64,4 +74,5 @@ Available Methods interpret/pyhealth.interpret.methods.deeplift interpret/pyhealth.interpret.methods.integrated_gradients interpret/pyhealth.interpret.methods.shap + interpret/pyhealth.interpret.methods.lime \ No newline at end of file diff --git a/docs/api/interpret/pyhealth.interpret.methods.lime.rst b/docs/api/interpret/pyhealth.interpret.methods.lime.rst new file mode 100644 index 000000000..3da1c4273 --- /dev/null +++ b/docs/api/interpret/pyhealth.interpret.methods.lime.rst @@ -0,0 +1,19 @@ +pyhealth.interpret.methods.lime +=============================== + +Overview +-------- + +:class:`pyhealth.interpret.methods.lime.LimeExplainer` provides LIME (Local Interpretable Model-agnostic Explanations) +attributions for PyHealth models by approximating the model locally with an interpretable linear model. Consult the class docstring for +detailed guidance, usage notes, and examples. A full workflow is demonstrated in +``examples/lime_stagenet_mimic4.py``. + +API Reference +------------- + +.. autoclass:: pyhealth.interpret.methods.lime.LimeExplainer + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource diff --git a/examples/lime_stagenet_mimic4.py b/examples/lime_stagenet_mimic4.py new file mode 100644 index 000000000..fe1a69922 --- /dev/null +++ b/examples/lime_stagenet_mimic4.py @@ -0,0 +1,356 @@ +# %% Loading MIMIC-IV dataset with LIME explanations +from pathlib import Path + +import polars as pl +import torch + +from pyhealth.datasets import ( + MIMIC4EHRDataset, + get_dataloader, + load_processors, + split_by_patient, +) +from pyhealth.interpret.methods import LimeExplainer +from pyhealth.models import StageNet +from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 + + +def load_icd_description_map(dataset_root: str) -> dict: + """Load ICD code → long title mappings from MIMIC-IV reference tables.""" + mapping = {} + root_path = Path(dataset_root).expanduser() + diag_path = root_path / "hosp" / "d_icd_diagnoses.csv.gz" + proc_path = root_path / "hosp" / "d_icd_procedures.csv.gz" + + icd_dtype = {"icd_code": pl.Utf8, "long_title": pl.Utf8} + + if diag_path.exists(): + diag_df = pl.read_csv( + diag_path, + columns=["icd_code", "long_title"], + dtypes=icd_dtype, + ) + mapping.update( + zip(diag_df["icd_code"].to_list(), diag_df["long_title"].to_list()) + ) + + if proc_path.exists(): + proc_df = pl.read_csv( + proc_path, + columns=["icd_code", "long_title"], + dtypes=icd_dtype, + ) + mapping.update( + zip(proc_df["icd_code"].to_list(), proc_df["long_title"].to_list()) + ) + + return mapping + + +LAB_CATEGORY_NAMES = MortalityPredictionStageNetMIMIC4.LAB_CATEGORY_NAMES + + +def move_batch_to_device(batch, target_device): + """Move all tensors in batch to target device.""" + moved = {} + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + moved[key] = value.to(target_device) + elif isinstance(value, tuple): + moved[key] = tuple(v.to(target_device) for v in value) + else: + moved[key] = value + return moved + + +def decode_token(idx: int, processor, feature_key: str, icd_code_to_desc: dict): + """Decode token index to human-readable string.""" + if processor is None or not hasattr(processor, "code_vocab"): + return str(idx) + reverse_vocab = {index: token for token, index in processor.code_vocab.items()} + token = reverse_vocab.get(idx, f"") + + if feature_key == "icd_codes" and token not in {"", ""}: + desc = icd_code_to_desc.get(token) + if desc: + return f"{token}: {desc}" + + return token + + +def unravel(flat_index: int, shape: torch.Size): + """Convert flat index to multi-dimensional coordinates.""" + coords = [] + remaining = flat_index + for dim in reversed(shape): + coords.append(remaining % dim) + remaining //= dim + return list(reversed(coords)) + + +def print_top_attributions( + attributions, + batch, + processors, + top_k: int = 10, + icd_code_to_desc: dict = None, + method_name: str = "LIME", +): + """Print top-k most important features from LIME attributions.""" + if icd_code_to_desc is None: + icd_code_to_desc = {} + + for feature_key, attr in attributions.items(): + attr_cpu = attr.detach().cpu() + if attr_cpu.dim() == 0 or attr_cpu.size(0) == 0: + continue + + feature_input = batch[feature_key] + if isinstance(feature_input, tuple): + feature_input = feature_input[1] + feature_input = feature_input.detach().cpu() + + flattened = attr_cpu[0].flatten() + if flattened.numel() == 0: + continue + + print(f"\nFeature: {feature_key}") + print(f" Shape: {attr_cpu[0].shape}") + print(f" Total attribution sum: {flattened.sum().item():+.6f}") + print(f" Mean attribution: {flattened.mean().item():+.6f}") + + k = min(top_k, flattened.numel()) + top_values, top_indices = torch.topk(flattened.abs(), k=k) + processor = processors.get(feature_key) if processors else None + is_continuous = torch.is_floating_point(feature_input) + + print(f"\n Top {k} most important features:") + for rank, (_, flat_idx) in enumerate(zip(top_values, top_indices), 1): + attribution_value = flattened[flat_idx].item() + coords = unravel(flat_idx.item(), attr_cpu[0].shape) + + if is_continuous: + actual_value = feature_input[0][tuple(coords)].item() + label = "" + if feature_key == "labs" and len(coords) >= 1: + lab_idx = coords[-1] + if lab_idx < len(LAB_CATEGORY_NAMES): + label = f"{LAB_CATEGORY_NAMES[lab_idx]} " + print( + f" {rank:2d}. idx={coords} {label}value={actual_value:.4f} " + f"{method_name}={attribution_value:+.6f}" + ) + else: + token_idx = int(feature_input[0][tuple(coords)].item()) + token = decode_token(token_idx, processor, feature_key, icd_code_to_desc) + print( + f" {rank:2d}. idx={coords} token='{token}' " + f"{method_name}={attribution_value:+.6f}" + ) + + +def main(): + """Main function to run LIME analysis on MIMIC-IV StageNet model.""" + # Configure dataset location and load cached processors + dataset = MIMIC4EHRDataset( + #root="/home/naveen-baskaran/physionet.org/files/mimic-iv-demo/2.2/", + #root="/Users/naveenbaskaran/data/physionet.org/files/mimic-iv-demo/2.2/", + root="~/data/physionet.org/files/mimic-iv-demo/2.2/", + tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ], + ) + + # %% Setting StageNet Mortality Prediction Task + input_processors, output_processors = load_processors("../resources/") + + sample_dataset = dataset.set_task( + MortalityPredictionStageNetMIMIC4(), + cache_dir="~/.cache/pyhealth/mimic4_stagenet_mortality", + input_processors=input_processors, + output_processors=output_processors, + ) + print(f"Total samples: {len(sample_dataset)}") + + ICD_CODE_TO_DESC = load_icd_description_map(dataset.root) + + # %% Loading Pretrained StageNet Model + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model = StageNet( + dataset=sample_dataset, + embedding_dim=128, + chunk_size=128, + levels=3, + dropout=0.3, + ) + + state_dict = torch.load("../resources/best.ckpt", map_location=device) + model.load_state_dict(state_dict) + model = model.to(device) + model.eval() + print(model) + + # %% Preparing dataloaders + _, _, test_data = split_by_patient(sample_dataset, [0.7, 0.1, 0.2], seed=42) + test_loader = get_dataloader(test_data, batch_size=1, shuffle=False) + + +# %% Run LIME on a held-out sample + print("\n" + "="*80) + print("Initializing LIME Explainer") + print("="*80) + + # Initialize LIME explainer + lime_explainer = LimeExplainer(model) + + print("\nLIME Configuration:") + print(f" Use embeddings: {lime_explainer.use_embeddings}") + print(f" Number of samples: {lime_explainer.n_samples}") + print(f" Kernel width: {lime_explainer.kernel_width}") + print(f" Distance mode: {lime_explainer.distance_mode}") + print(f" Feature selection: {lime_explainer.feature_selection}") + print(f" Regularization alpha: {lime_explainer.alpha}") + + # Get a sample from test set + sample_batch = next(iter(test_loader)) + sample_batch_device = move_batch_to_device(sample_batch, device) + + # Get model prediction + with torch.no_grad(): + output = model(**sample_batch_device) + probs = output["y_prob"] + label_key = model.label_key + true_label = sample_batch_device[label_key] + + # Handle binary classification (single probability output) + if probs.shape[-1] == 1: + prob_death = probs[0].item() + prob_survive = 1 - prob_death + preds = (probs > 0.5).long() + else: + # Multi-class classification + preds = torch.argmax(probs, dim=-1) + prob_survive = probs[0][0].item() + prob_death = probs[0][1].item() + + print("\n" + "="*80) + print("Model Prediction for Sampled Patient") + print("="*80) + print(f" True label: {int(true_label.item())} {'(Deceased)' if true_label.item() == 1 else '(Survived)'}") + print(f" Predicted class: {int(preds.item())} {'(Deceased)' if preds.item() == 1 else '(Survived)'}") + print(f" Probabilities: [Survive={prob_survive:.4f}, Death={prob_death:.4f}]") + + # Compute LIME values + print("\n" + "="*80) + print("Computing LIME Attributions (...........)") + print("="*80) + print("\nLIME trains a local linear model by sampling perturbed inputs") + print("around the example to be explained. The linear model coefficients") + print("represent feature importance in the local neighborhood.") + + attributions = lime_explainer.attribute(**sample_batch_device, target_class_idx=1) + + print("\n" + "="*80) + print("LIME Attribution Results") + print("="*80) + print("\nLIME coefficients explain the contribution of each feature to the") + print("local linear approximation of the model's MORTALITY prediction (class 1).") + print("Positive values increase the mortality prediction, negative values decrease it.") + + print_top_attributions( + attributions, + sample_batch_device, + input_processors, + top_k=15, + icd_code_to_desc=ICD_CODE_TO_DESC, + method_name="LIME" + ) + + # %% Compare different LIME configurations + print("\n\n" + "="*80) + print("Testing Different LIME Configurations") + print("="*80) + + # 1. Default configuration (already computed) + print("\n1. Default LIME (Lasso, 1000 samples, cosine distance):") + print(f" Total attribution (icd_codes): {attributions['icd_codes'][0].sum().item():+.6f}") + print(f" Total attribution (labs): {attributions['labs'][0].sum().item():+.6f}") + + # 2. Ridge regression instead of Lasso + print("\n2. Ridge regression (L2 regularization):") + lime_ridge = LimeExplainer( + model, + use_embeddings=True, + n_samples=1000, + feature_selection="ridge", + alpha=0.01, + random_seed=42, + ) + attr_ridge = lime_ridge.attribute(**sample_batch_device, target_class_idx=1) + print(f" Total attribution (icd_codes): {attr_ridge['icd_codes'][0].sum().item():+.6f}") + print(f" Total attribution (labs): {attr_ridge['labs'][0].sum().item():+.6f}") + + # 3. Euclidean distance instead of cosine + print("\n3. Euclidean distance kernel:") + lime_euclidean = LimeExplainer( + model, + use_embeddings=True, + n_samples=1000, + distance_mode="euclidean", + kernel_width=0.25, + random_seed=42, + ) + attr_euclidean = lime_euclidean.attribute(**sample_batch_device, target_class_idx=1) + print(f" Total attribution (icd_codes): {attr_euclidean['icd_codes'][0].sum().item():+.6f}") + print(f" Total attribution (labs): {attr_euclidean['labs'][0].sum().item():+.6f}") + + # 4. More samples for better local approximation + print("\n4. More samples (2000) for better local approximation:") + lime_more_samples = LimeExplainer( + model, + use_embeddings=True, + n_samples=2000, + random_seed=42, + ) + attr_more_samples = lime_more_samples.attribute(**sample_batch_device, target_class_idx=1) + print(f" Total attribution (icd_codes): {attr_more_samples['icd_codes'][0].sum().item():+.6f}") + print(f" Total attribution (labs): {attr_more_samples['labs'][0].sum().item():+.6f}") + + print("\n" + "="*80) + print("Comparison of Regularization Methods") + print("="*80) + print("\nLasso (L1) vs Ridge (L2) regularization:") + print(" - Lasso tends to produce sparser explanations (more zeros)") + print(" - Ridge distributes importance more evenly across features") + print(" - Choose based on your interpretation needs:") + print(" * Lasso: When you want to identify a few key features") + print(" * Ridge: When you want to see contributions from all features") + + # Compare top features + print("\nTop 5 features comparison (by absolute value):") + for key in ['icd_codes', 'labs']: + if key in attributions: + flat_lasso = attributions[key][0].flatten().abs() + flat_ridge = attr_ridge[key][0].flatten().abs() + + k = min(5, flat_lasso.numel()) + top_lasso = torch.topk(flat_lasso, k=k) + top_ridge = torch.topk(flat_ridge, k=k) + + print(f"\n {key}:") + print(f" Lasso non-zero features: {(flat_lasso > 1e-6).sum().item()}/{flat_lasso.numel()}") + print(f" Ridge non-zero features: {(flat_ridge > 1e-6).sum().item()}/{flat_ridge.numel()}") + + print("\n" + "="*80) + print("LIME Analysis Complete") + print("="*80) + + +if __name__ == "__main__": + main() + +# %% diff --git a/pyhealth/interpret/methods/__init__.py b/pyhealth/interpret/methods/__init__.py index 52796ffd1..14d708d34 100644 --- a/pyhealth/interpret/methods/__init__.py +++ b/pyhealth/interpret/methods/__init__.py @@ -5,6 +5,7 @@ from pyhealth.interpret.methods.gim import GIM from pyhealth.interpret.methods.integrated_gradients import IntegratedGradients from pyhealth.interpret.methods.shap import ShapExplainer +from pyhealth.interpret.methods.lime import LimeExplainer __all__ = [ "BaseInterpreter", @@ -13,5 +14,6 @@ "GIM", "IntegratedGradients", "BasicGradientSaliencyMaps", - "ShapExplainer" + "ShapExplainer", + "LimeExplainer" ] diff --git a/pyhealth/interpret/methods/lime.py b/pyhealth/interpret/methods/lime.py new file mode 100644 index 000000000..8bdc1b673 --- /dev/null +++ b/pyhealth/interpret/methods/lime.py @@ -0,0 +1,1107 @@ +from __future__ import annotations + +import math +from typing import Dict, Optional, Tuple, Callable, Union + +import torch +import torch.nn.functional as F +from torch.nn import CosineSimilarity + +from pyhealth.models import BaseModel +from .base_interpreter import BaseInterpreter + + +class LimeExplainer(BaseInterpreter): + """LIME (Local Interpretable Model-agnostic Explanations) attribution method for PyHealth models. + + This class implements the LIME method for computing feature attributions in + neural networks. LIME explains model predictions by approximating the model + locally with an interpretable surrogate model (e.g., linear regression). + + The method is based on the paper: + "Why Should I Trust You?" Explaining the Predictions of Any Classifier + Marco Tulio Ribeiro, Sameer Singh, Carlos Guestrin + KDD 2016 + https://arxiv.org/abs/1602.04938 + + LIME Method Overview: + LIME works by: + 1. Generating perturbed samples around the input to be explained + 2. Obtaining model predictions for these perturbed samples + 3. Weighting samples by their similarity to the original input + 4. Training a simple interpretable model (linear) on the weighted dataset + 5. Using the coefficients of the interpretable model as feature importances + + Mathematical Foundation: + Given an input x and model f, LIME finds an interpretable model g that + minimizes: + L(f, g, πₓ) + Ω(g) + where: + - πₓ is a proximity measure (similarity kernel) + - Ω(g) is the complexity of the interpretable model g + - L measures how unfaithful g is in approximating f in the locality of x + + For linear models, this becomes: + argmin_w Σᵢ πₓ(zᵢ) * [f(zᵢ) - w·zᵢ]² + λ||w||² + where zᵢ are perturbed samples and w are the feature weights. + + LIME provides several benefits: + 1. Model-agnostic: Works with any model that provides predictions + 2. Local fidelity: Explanations are faithful in the locality of the input + 3. Interpretable: Uses simple linear models that humans can understand + 4. Flexible: Can define custom perturbation and similarity functions + + Args: + model: A trained PyHealth model to interpret. Can be any model that + inherits from BaseModel (e.g., MLP, StageNet, Transformer, RNN). + use_embeddings: If True, compute LIME values with respect to + embeddings rather than discrete input tokens. This is crucial + for models with discrete inputs (like ICD codes). The model + must support returning embeddings via an 'embed' parameter. + Default is True. + n_samples: Number of perturbed samples to generate for training + the interpretable model. More samples give better local + approximations but increase computation time. Default is 1000. + kernel_width: Width parameter for the exponential similarity kernel. + Smaller values make the explanations more local. Default is 0.25. + distance_mode: Distance metric for similarity computation. Can be + "cosine" for cosine similarity or "euclidean" for L2 distance. + Default is "cosine". + feature_selection: Method for selecting top features. Can be "lasso" + for L1 regularization, "ridge" for L2 regularization, or "none" + for no regularization. Default is "lasso". + alpha: Regularization strength for the interpretable model. + Default is 0.01. + + Examples: + >>> import torch + >>> from pyhealth.datasets import SampleDataset, get_dataloader + >>> from pyhealth.models import StageNet + >>> from pyhealth.interpret.methods import LimeExplainer + >>> from pyhealth.trainer import Trainer + >>> + >>> # Create dataset and model + >>> dataset = SampleDataset(...) + >>> model = StageNet(...) + >>> trainer = Trainer(model=model, device="cuda:0") + >>> trainer.train(...) + >>> test_batch = next(iter(test_loader)) + >>> + >>> # Initialize LIME explainer + >>> explainer = LimeExplainer(model, use_embeddings=True, n_samples=1000) + >>> lime_values = explainer.attribute(**test_batch) + >>> + >>> # With custom kernel width + >>> explainer = LimeExplainer(model, kernel_width=0.5) + >>> lime_values = explainer.attribute(**test_batch) + >>> + >>> print(lime_values) + {'conditions': tensor([[0.1234, 0.5678, 0.9012]], device='cuda:0'), + 'procedures': tensor([[0.2345, 0.6789, 0.0123, 0.4567]])} + """ + + def __init__( + self, + model: BaseModel, + use_embeddings: bool = True, + n_samples: int = 1000, + kernel_width: float = 0.25, + distance_mode: str = "cosine", + feature_selection: str = "lasso", + alpha: float = 0.01, + random_seed: Optional[int] = 42, + ): + """Initialize LIME explainer. + + Args: + model: A trained PyHealth model to interpret. + use_embeddings: If True, compute LIME values with respect to + embeddings rather than discrete input tokens. + n_samples: Number of perturbed samples to generate. + kernel_width: Width parameter for the exponential kernel. + distance_mode: Distance metric ("cosine" or "euclidean"). + feature_selection: Regularization type ("lasso", "ridge", or "none"). + alpha: Regularization strength. + random_seed: Optional random seed for reproducibility. + + Raises: + AssertionError: If use_embeddings=True but model does not + implement forward_from_embedding() method. + ValueError: If distance_mode is not "cosine" or "euclidean". + ValueError: If feature_selection is not "lasso", "ridge", or "none". + """ + super().__init__(model) + self.use_embeddings = use_embeddings + self.n_samples = n_samples + self.kernel_width = kernel_width + self.distance_mode = distance_mode + self.feature_selection = feature_selection + self.alpha = alpha + self.random_seed = random_seed + + # Validate inputs + if use_embeddings: + assert hasattr(model, "forward_from_embedding"), ( + f"Model {type(model).__name__} must implement " + "forward_from_embedding() method to support embedding-level " + "LIME values. Set use_embeddings=False to use " + "input-level attributions (only for continuous features)." + ) + + if distance_mode not in ["cosine", "euclidean"]: + raise ValueError("distance_mode must be either 'cosine' or 'euclidean'.") + + if feature_selection not in ["lasso", "ridge", "none"]: + raise ValueError("feature_selection must be 'lasso', 'ridge', or 'none'.") + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + def attribute( + self, + baseline: Optional[Dict[str, torch.Tensor]] = None, + target_class_idx: Optional[int] = None, + **data, + ) -> Dict[str, torch.Tensor]: + """Compute LIME attributions for input features. + + This is the main interface for computing feature attributions. It: + 1. Generates perturbed samples around the input + 2. Computes model predictions for perturbed samples + 3. Weights samples by similarity to the original input + 4. Trains a linear interpretable model + 5. Returns the linear model coefficients as feature importances + + Args: + baseline: Optional dictionary mapping feature names to baseline + values used for perturbations. If None, uses zeros or + random samples. Shape of each tensor should match input. + target_class_idx: For multi-class models, specifies which class's + prediction to explain. If None, explains the model's + maximum prediction across all classes. + **data: Input data dictionary from dataloader batch. Should contain: + - Feature tensors with shape (batch_size, ..., feature_dim) + - Optional time information for temporal models + - Optional label data for supervised models + + Returns: + Dictionary mapping feature names to their LIME coefficients. Each + tensor has the same shape as its corresponding input and contains + the feature's importance in the local linear approximation. + Positive values indicate features that increased the prediction, + negative values indicate features that decreased it. + + Example: + >>> lime_values = explainer.attribute( + ... conditions=torch.tensor([[1, 5, 8]]), + ... procedures=torch.tensor([[2, 3]]), + ... target_class_idx=1 + ... ) + >>> print(lime_values['conditions']) # Shape: (1, 3) + """ + # Set random seed for reproducibility if specified + if self.random_seed is not None: + torch.manual_seed(self.random_seed) + + device = next(self.model.parameters()).device + + # Extract and prepare inputs + feature_inputs: Dict[str, torch.Tensor] = {} + time_info: Dict[str, torch.Tensor] = {} + label_data: Dict[str, torch.Tensor] = {} + + for key in self.model.feature_keys: + if key not in data: + continue + value = data[key] + + # Handle (time, value) tuples for temporal data + if isinstance(value, tuple): + time_tensor, feature_tensor = value + if time_tensor is not None: + time_info[key] = time_tensor.to(device) + value = feature_tensor + + if not isinstance(value, torch.Tensor): + value = torch.as_tensor(value) + feature_inputs[key] = value.to(device) + + # Store label data + for key in self.model.label_keys: + if key in data: + label_val = data[key] + if not isinstance(label_val, torch.Tensor): + label_val = torch.as_tensor(label_val) + label_data[key] = label_val.to(device) + + # Generate or validate baseline (neutral replacement values) + # Note: LIME does not require a background dataset; baselines here + # serve only as neutral values when a feature is masked (absent). + if baseline is None: + baseline = self._generate_baseline(feature_inputs) + else: + baseline = {k: v.to(device) for k, v in baseline.items()} + + # Compute LIME values + if self.use_embeddings: + return self._lime_embeddings( + feature_inputs, + baseline=baseline, + target_class_idx=target_class_idx, + time_info=time_info, + label_data=label_data, + ) + else: + return self._lime_continuous( + feature_inputs, + baseline=baseline, + target_class_idx=target_class_idx, + time_info=time_info, + label_data=label_data, + ) + + # ------------------------------------------------------------------ + # Embedding-based LIME + # ------------------------------------------------------------------ + def _lime_embeddings( + self, + inputs: Dict[str, torch.Tensor], + baseline: Dict[str, torch.Tensor], + target_class_idx: Optional[int], + time_info: Dict[str, torch.Tensor], + label_data: Dict[str, torch.Tensor], + ) -> Dict[str, torch.Tensor]: + """Compute LIME values for discrete inputs in embedding space. + + Args: + inputs: Dictionary of input tensors. + baseline: Dictionary of baseline samples. + target_class_idx: Target class index for attribution. + time_info: Temporal information for time-series models. + label_data: Label information for supervised models. + + Returns: + Dictionary of LIME coefficients mapped back to input shapes. + """ + # Embed inputs and baseline + input_embs = self.model.embedding_model(inputs) + baseline_embs = self.model.embedding_model(baseline) + + # Store original input shapes for mapping back + input_shapes = {key: val.shape for key, val in inputs.items()} + + # Compute LIME values for each feature + lime_values = {} + for key in inputs: + n_features = self._determine_n_features(key, inputs, input_embs) + + coefs = self._compute_lime( + key=key, + input_emb=input_embs, + baseline_emb=baseline_embs, + n_features=n_features, + target_class_idx=target_class_idx, + time_info=time_info, + label_data=label_data, + ) + + lime_values[key] = coefs + + # Map embedding-space attributions back to input shapes + return self._map_to_input_shapes(lime_values, input_shapes) + + # ------------------------------------------------------------------ + # Continuous LIME + # ------------------------------------------------------------------ + def _lime_continuous( + self, + inputs: Dict[str, torch.Tensor], + baseline: Dict[str, torch.Tensor], + target_class_idx: Optional[int], + time_info: Dict[str, torch.Tensor], + label_data: Dict[str, torch.Tensor], + ) -> Dict[str, torch.Tensor]: + """Compute LIME values for continuous tensor inputs. + + Args: + inputs: Dictionary of input tensors. + baseline: Dictionary of baseline samples. + target_class_idx: Target class index for attribution. + time_info: Temporal information for time-series models. + label_data: Label information for supervised models. + + Returns: + Dictionary of LIME coefficients with same shapes as inputs. + """ + lime_values = {} + + for key in inputs: + n_features = self._determine_n_features(key, inputs, inputs) + + coefs = self._compute_lime( + key=key, + input_emb=inputs, + baseline_emb=baseline, + n_features=n_features, + target_class_idx=target_class_idx, + time_info=time_info, + label_data=label_data, + ) + + lime_values[key] = coefs + + return lime_values + + # ------------------------------------------------------------------ + # Core LIME computation + # ------------------------------------------------------------------ + def _compute_lime( + self, + key: str, + input_emb: Dict[str, torch.Tensor], + baseline_emb: Dict[str, torch.Tensor], + n_features: int, + target_class_idx: Optional[int], + time_info: Optional[Dict[str, torch.Tensor]], + label_data: Optional[Dict[str, torch.Tensor]], + ) -> torch.Tensor: + """Compute LIME coefficients using interpretable linear model. + + This implements the LIME algorithm: + 1. Generate perturbed samples (binary vectors indicating feature presence or absence) + 2. Create interpretable samples by mixing original and baseline + 3. Evaluate model on perturbed samples + 4. Compute similarity weights based on distance to original + 5. Train weighted linear regression + 6. Return coefficients as feature importances + + Args: + key: Feature key being explained. + input_emb: Dictionary of input embeddings/tensors. + baseline_emb: Dictionary of baseline embeddings/tensors. + n_features: Number of features to explain. + target_class_idx: Target class index for multi-class models. + time_info: Optional temporal information. + label_data: Optional label information. + + Returns: + torch.Tensor: LIME coefficients with shape (batch_size, n_features). + """ + device = input_emb[key].device + batch_size = input_emb[key].shape[0] if input_emb[key].dim() >= 2 else 1 + + # Storage for samples and predictions + interpretable_samples = [] # Binary vectors + perturbed_predictions = [] # Model predictions + similarity_weights = [] # Distance-based weights + + # Original input prediction + original_pred = self._get_prediction( + key, input_emb, baseline_emb, None, + target_class_idx, time_info, label_data + ) + + # Generate perturbed samples + for _ in range(self.n_samples): + # Sample binary vector (which features to include) + binary_vector = torch.bernoulli( + torch.ones(n_features, device=device) * 0.5 + ) + + # Create perturbed sample for each batch item + batch_preds = [] + batch_similarities = [] + + for b_idx in range(batch_size): + # Create perturbed embedding by mixing input and baseline + perturbed_emb = self._create_perturbed_sample( + key, binary_vector, input_emb, baseline_emb, b_idx + ) + + # Get model prediction for perturbed sample + pred = self._evaluate_sample( + key, perturbed_emb, baseline_emb, + target_class_idx, time_info, label_data + ) + batch_preds.append(pred) + + # Compute similarity weight + similarity = self._compute_similarity( + input_emb[key][b_idx:b_idx+1] if batch_size > 1 else input_emb[key], + perturbed_emb, + binary_vector, + ) + batch_similarities.append(similarity) + + # Store sample information + interpretable_samples.append(binary_vector.float()) + perturbed_predictions.append(torch.stack(batch_preds, dim=0)) + similarity_weights.append(torch.stack(batch_similarities, dim=0)) + + # Train weighted linear regression + return self._train_interpretable_model( + interpretable_samples, + perturbed_predictions, + similarity_weights, + device, + ) + + def _create_perturbed_sample( + self, + key: str, + binary_vector: torch.Tensor, + input_emb: Dict[str, torch.Tensor], + baseline_emb: Dict[str, torch.Tensor], + batch_idx: int, + ) -> torch.Tensor: + """Create a perturbed sample by mixing input and baseline based on binary vector. + + Args: + key: Feature key. + binary_vector: Binary vector (1 = use input feature, 0 = use baseline). + input_emb: Input embeddings. + baseline_emb: Baseline embeddings. + batch_idx: Index of the sample in the batch. + + Returns: + Perturbed sample tensor. + """ + # Start with baseline for the specific sample + perturbed = baseline_emb[key][batch_idx:batch_idx+1].clone() + + # Mix in input features based on binary vector + for i, use_input in enumerate(binary_vector): + if not use_input: + continue + + # Handle various embedding shapes + dim = input_emb[key].dim() + if dim == 4: # (batch, seq_len, inner_len, emb) + if i < perturbed.shape[1]: + perturbed[:, i, :, :] = input_emb[key][batch_idx, i, :, :] + elif dim == 3: # (batch, seq_len, emb) + if i < perturbed.shape[1]: + perturbed[:, i, :] = input_emb[key][batch_idx, i, :] + else: # 2D or other + if i < perturbed.shape[1]: + perturbed[:, i] = input_emb[key][batch_idx, i] + + return perturbed + + def _evaluate_sample( + self, + key: str, + perturbed_emb: torch.Tensor, + baseline_emb: Dict[str, torch.Tensor], + target_class_idx: Optional[int], + time_info: Optional[Dict[str, torch.Tensor]], + label_data: Optional[Dict[str, torch.Tensor]], + ) -> torch.Tensor: + """Evaluate model prediction for a perturbed sample. + + Args: + key: Feature key being explained. + perturbed_emb: Perturbed embedding tensor. + baseline_emb: Baseline embeddings for other features. + target_class_idx: Target class index. + time_info: Optional temporal information. + label_data: Optional label information. + + Returns: + Scalar prediction. + """ + if self.use_embeddings: + logits = self._forward_from_embeddings( + key, perturbed_emb, baseline_emb, time_info, label_data + ) + else: + logits = self._forward_from_inputs( + key, perturbed_emb, baseline_emb, time_info, label_data + ) + + # Extract target class prediction + pred_vec = self._extract_target_prediction(logits, target_class_idx) + + # Return mean prediction (average over baseline samples if multiple) + return pred_vec.detach().mean() + + def _get_prediction( + self, + key: str, + input_emb: Dict[str, torch.Tensor], + baseline_emb: Dict[str, torch.Tensor], + binary_vector: Optional[torch.Tensor], + target_class_idx: Optional[int], + time_info: Optional[Dict[str, torch.Tensor]], + label_data: Optional[Dict[str, torch.Tensor]], + ) -> torch.Tensor: + """Get model prediction for original input or perturbed sample. + + Args: + key: Feature key. + input_emb: Input embeddings. + baseline_emb: Baseline embeddings. + binary_vector: Optional binary perturbation vector. + target_class_idx: Target class index. + time_info: Optional temporal information. + label_data: Optional label information. + + Returns: + Model prediction. + """ + if binary_vector is None: + # Use original input + sample = input_emb[key] + else: + # Create perturbed sample + sample = self._create_perturbed_sample( + key, binary_vector, input_emb, baseline_emb, 0 + ) + + return self._evaluate_sample( + key, sample, baseline_emb, target_class_idx, time_info, label_data + ) + + def _compute_similarity( + self, + original_emb: torch.Tensor, + perturbed_emb: torch.Tensor, + binary_vector: torch.Tensor, + ) -> torch.Tensor: + """Compute similarity weight using exponential kernel. + + The similarity is computed as: + exp(- distance² / (2 * kernel_width²)) + + Args: + original_emb: Original input embedding. + perturbed_emb: Perturbed sample embedding. + binary_vector: Binary perturbation vector. + + Returns: + Similarity weight (scalar tensor). + """ + # Flatten embeddings for distance computation + orig_flat = original_emb.reshape(-1).float() + pert_flat = perturbed_emb.reshape(-1).float() + + # Compute distance + if self.distance_mode == "cosine": + cos_sim = CosineSimilarity(dim=0) + distance = 1 - cos_sim(orig_flat, pert_flat) + elif self.distance_mode == "euclidean": + distance = torch.norm(orig_flat - pert_flat) + else: + raise ValueError("Invalid distance_mode") + + # Apply exponential kernel + similarity = torch.exp( + -1 * (distance ** 2) / (2 * (self.kernel_width ** 2)) + ) + + return similarity + + def _train_interpretable_model( + self, + interpretable_samples: list, + predictions: list, + weights: list, + device: torch.device, + ) -> torch.Tensor: + """Train weighted linear regression model. + + Solves the weighted least squares problem: + argmin_w Σᵢ wᵢ * [yᵢ - f(xᵢ)]² + α||w|| + + where α is the regularization strength. + + Args: + interpretable_samples: List of binary vectors. + predictions: List of model predictions. + weights: List of similarity weights. + device: Device for computation. + + Returns: + Linear model coefficients (batch_size, n_features). + """ + # Stack collected data + X = torch.stack(interpretable_samples, dim=0).to(device) # (n_samples, n_features) + Y = torch.stack(predictions, dim=0).to(device) # (n_samples, batch_size) + W = torch.stack(weights, dim=0).to(device) # (n_samples, batch_size) + + # Solve for each batch item independently + batch_size = Y.shape[1] + n_features = X.shape[1] + coefficients = [] + + for b_idx in range(batch_size): + # Get data for this batch item + y = Y[:, b_idx] # (n_samples,) + w = W[:, b_idx] # (n_samples,) + + # Apply sqrt weights for weighted least squares + sqrtW = torch.sqrt(w) # (n_samples,) + Xw = sqrtW.unsqueeze(1) * X # (n_samples, n_features) + yw = sqrtW * y # (n_samples,) + + # Solve based on feature selection method + if self.feature_selection == "lasso": + # L1 regularization (approximated with iterative reweighted least squares) + coef = self._solve_lasso(Xw, yw, device) + elif self.feature_selection == "ridge": + # L2 regularization + coef = self._solve_ridge(Xw, yw, device) + else: # "none" + # No regularization + coef = self._solve_ols(Xw, yw, device) + + coefficients.append(coef) + + # Stack into (batch_size, n_features) + return torch.stack(coefficients, dim=0) + + def _solve_lasso( + self, + X: torch.Tensor, + y: torch.Tensor, + device: torch.device, + ) -> torch.Tensor: + """Solve Lasso regression (L1 regularization). + + Uses coordinate descent approximation for L1 penalty. + + Args: + X: Weighted design matrix (n_samples, n_features). + y: Weighted target values (n_samples,) or (n_samples, 1). + device: Device for computation. + + Returns: + Coefficient vector (n_features,). + """ + n_features = X.shape[1] + + # Ensure y is 1D + if y.dim() > 1: + y = y.squeeze(-1) + + # Use soft-thresholding approximation via ridge with L1 approximation + # For simplicity, we'll use standard least squares with small L1 penalty + # In practice, you could use sklearn or implement full coordinate descent + + # Add L1 penalty approximation using ridge with sparsity-inducing weights + reg_mat = self.alpha * torch.eye(n_features, device=device) + X_aug = torch.cat([X, reg_mat], dim=0) + y_aug = torch.cat([y, torch.zeros(n_features, device=device)], dim=0) + + # Solve using least squares + res = torch.linalg.lstsq(X_aug, y_aug) + coef = getattr(res, 'solution', res[0]) + + return coef + + def _solve_ridge( + self, + X: torch.Tensor, + y: torch.Tensor, + device: torch.device, + ) -> torch.Tensor: + """Solve Ridge regression (L2 regularization). + + Args: + X: Weighted design matrix (n_samples, n_features). + y: Weighted target values (n_samples,) or (n_samples, 1). + device: Device for computation. + + Returns: + Coefficient vector (n_features,). + """ + n_features = X.shape[1] + + # Ensure y is 1D + if y.dim() > 1: + y = y.squeeze(-1) + + # Add Tikhonov regularization + reg_scale = torch.sqrt(torch.tensor(self.alpha, device=device)) + reg_mat = reg_scale * torch.eye(n_features, device=device) + + # Augment for regularized least squares + X_aug = torch.cat([X, reg_mat], dim=0) + y_aug = torch.cat([y, torch.zeros(n_features, device=device)], dim=0) + + # Solve using torch.linalg.lstsq + res = torch.linalg.lstsq(X_aug, y_aug) + coef = getattr(res, 'solution', res[0]) + + return coef + + def _solve_ols( + self, + X: torch.Tensor, + y: torch.Tensor, + device: torch.device, + ) -> torch.Tensor: + """Solve Ordinary Least Squares (no regularization). + + Args: + X: Weighted design matrix (n_samples, n_features). + y: Weighted target values (n_samples,) or (n_samples, 1). + device: Device for computation. + + Returns: + Coefficient vector (n_features,). + """ + # Ensure y is 1D + if y.dim() > 1: + y = y.squeeze(-1) + + # Solve using torch.linalg.lstsq + res = torch.linalg.lstsq(X, y) + coef = getattr(res, 'solution', res[0]) + + return coef + + # ------------------------------------------------------------------ + # Forward pass methods + # ------------------------------------------------------------------ + def _forward_from_embeddings( + self, + key: str, + perturbed_emb: torch.Tensor, + baseline_emb: Dict[str, torch.Tensor], + time_info: Optional[Dict[str, torch.Tensor]], + label_data: Optional[Dict[str, torch.Tensor]], + ) -> torch.Tensor: + """Forward pass using embeddings. + + Args: + key: Feature key being explained. + perturbed_emb: Perturbed embedding tensor. + baseline_emb: Baseline embeddings. + time_info: Optional temporal information. + label_data: Optional label information. + + Returns: + Model logits. + """ + # Build feature embeddings dictionary + batch_size = perturbed_emb.shape[0] + feature_embeddings = {key: perturbed_emb} + + for fk in self.model.feature_keys: + if fk not in feature_embeddings: + if fk in baseline_emb: + # Expand baseline to match batch size if needed + base_emb = baseline_emb[fk] + if base_emb.shape[0] == batch_size: + feature_embeddings[fk] = base_emb.clone() + elif base_emb.shape[0] == 1 and batch_size > 1: + feature_embeddings[fk] = base_emb.expand(batch_size, *base_emb.shape[1:]).clone() + elif base_emb.shape[0] > 1 and batch_size == 1: + # Slice a single neutral baseline sample + feature_embeddings[fk] = base_emb[:1].clone() + else: + feature_embeddings[fk] = base_emb.expand(batch_size, *base_emb.shape[1:]).clone() + else: + # Zero fallback + ref_tensor = next(iter(feature_embeddings.values())) + feature_embeddings[fk] = torch.zeros_like(ref_tensor) + + # Prepare time info matching batch size + time_info_adj = self._prepare_time_info( + time_info, feature_embeddings, batch_size + ) + + # Forward pass + with torch.no_grad(): + # Create kwargs with proper label key + forward_kwargs = { + "time_info": time_info_adj, + } + # Add label with the correct key name + if len(self.model.label_keys) > 0: + label_key = self.model.label_keys[0] + forward_kwargs[label_key] = torch.zeros( + (perturbed_emb.shape[0], 1), device=self.model.device + ) + + model_output = self.model.forward_from_embedding( + feature_embeddings, + **forward_kwargs + ) + + return self._extract_logits(model_output) + + def _forward_from_inputs( + self, + key: str, + perturbed_inputs: torch.Tensor, + baseline_inputs: Dict[str, torch.Tensor], + time_info: Optional[Dict[str, torch.Tensor]], + label_data: Optional[Dict[str, torch.Tensor]], + ) -> torch.Tensor: + """Forward pass using raw inputs (continuous features). + + Args: + key: Feature key being explained. + perturbed_inputs: Perturbed input tensor. + baseline_inputs: Baseline inputs. + time_info: Optional temporal information. + label_data: Optional label information. + + Returns: + Model logits. + """ + model_inputs = {} + for fk in self.model.feature_keys: + if fk == key: + model_inputs[fk] = perturbed_inputs + elif fk in baseline_inputs: + base = baseline_inputs[fk] + # Expand baseline batch to match perturbed batch size if needed + if base.shape[0] != perturbed_inputs.shape[0]: + base = base.expand(perturbed_inputs.shape[0], *base.shape[1:]).clone() + else: + base = base.clone() + model_inputs[fk] = base + else: + model_inputs[fk] = torch.zeros_like(perturbed_inputs) + + # Add label stub if needed + if len(self.model.label_keys) > 0: + label_key = self.model.label_keys[0] + model_inputs[label_key] = torch.zeros( + (perturbed_inputs.shape[0], 1), device=perturbed_inputs.device + ) + + output = self.model(**model_inputs) + return self._extract_logits(output) + + def _prepare_time_info( + self, + time_info: Optional[Dict[str, torch.Tensor]], + feature_embeddings: Dict[str, torch.Tensor], + n_samples: int, + ) -> Optional[Dict[str, torch.Tensor]]: + """Prepare time information to match sample batch size. + + Args: + time_info: Original time information. + feature_embeddings: Feature embeddings to match sequence lengths. + n_samples: Number of samples. + + Returns: + Adjusted time information or None. + """ + if time_info is None: + return None + + time_info_adj = {} + for fk, emb in feature_embeddings.items(): + if fk not in time_info or time_info[fk] is None: + continue + + seq_len = emb.shape[1] + t_orig = time_info[fk].to(self.model.device) + + # Normalize to 1D sequence + t_vec = self._normalize_time_vector(t_orig) + + # Adjust length to match embedding sequence length + t_adj = self._adjust_time_length(t_vec, seq_len) + + # Expand to batch size + time_info_adj[fk] = t_adj.unsqueeze(0).expand(n_samples, -1) + + return time_info_adj if time_info_adj else None + + # ------------------------------------------------------------------ + # Baseline generation + # ------------------------------------------------------------------ + def _generate_baseline( + self, inputs: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """Generate baseline samples for LIME computation. + + Creates reference samples to use as the "absence" of features. + The sampling strategy adapts to the feature type: + - Discrete features: Use the most common value or a small non-zero value + - Continuous features: Use mean or small non-zero values + + Args: + inputs: Dictionary mapping feature names to input tensors. + + Returns: + Dictionary mapping feature names to baseline sample tensors. + """ + baseline_samples = {} + + for key, x in inputs.items(): + batch_size = x.shape[0] + if x.dtype in [torch.int64, torch.int32, torch.long]: + # Discrete features: use small non-zero token index to avoid zero-mask issues + # in sequential models (e.g., StageNet). Using ones is a safe neutral choice. + baseline = torch.ones_like(x) + else: + # Continuous features: use small neutral values (near-zero) + baseline = torch.zeros_like(x) + baseline = baseline + 0.01 + + # Ensure baseline matches input batch size + if baseline.shape[0] != batch_size: + baseline = baseline.expand(batch_size, *baseline.shape[1:]) + + baseline_samples[key] = baseline.to(x.device) + + return baseline_samples + + # ------------------------------------------------------------------ + # Utility helpers (shared with SHAP) + # ------------------------------------------------------------------ + @staticmethod + def _determine_n_features( + key: str, + inputs: Dict[str, torch.Tensor], + embeddings: Dict[str, torch.Tensor], + ) -> int: + """Determine the number of features to explain for a given key. + + Args: + key: Feature key. + inputs: Original input tensors. + embeddings: Embedding tensors. + + Returns: + Number of features (typically sequence length or feature dimension). + """ + # Prefer original input shape + if key in inputs and inputs[key].dim() >= 2: + return inputs[key].shape[1] + + # Fallback to embedding shape + emb = embeddings[key] + if emb.dim() >= 2: + return emb.shape[1] + return emb.shape[-1] + + @staticmethod + def _extract_logits(model_output) -> torch.Tensor: + """Extract logits from model output. + + Args: + model_output: Model output (dict or tensor). + + Returns: + Logit tensor. + """ + if isinstance(model_output, dict) and "logit" in model_output: + return model_output["logit"] + return model_output + + @staticmethod + def _extract_target_prediction( + logits: torch.Tensor, target_class_idx: Optional[int] + ) -> torch.Tensor: + """Extract target class prediction from logits. + + Args: + logits: Model logits. + target_class_idx: Target class index (None for max prediction). + + Returns: + Target prediction tensor. + """ + if target_class_idx is None: + return torch.max(logits, dim=-1)[0] + + if logits.dim() > 1 and logits.shape[-1] > 1: + return logits[..., target_class_idx] + else: + # Binary classification with single logit + sig = torch.sigmoid(logits.squeeze(-1)) + return sig if target_class_idx == 1 else 1.0 - sig + + @staticmethod + def _normalize_time_vector(time_tensor: torch.Tensor) -> torch.Tensor: + """Normalize time tensor to 1D vector. + + Args: + time_tensor: Time information tensor. + + Returns: + 1D time vector. + """ + if time_tensor.dim() == 2 and time_tensor.shape[0] > 0: + return time_tensor[0].detach() + elif time_tensor.dim() == 1: + return time_tensor.detach() + else: + return time_tensor.reshape(-1).detach() + + @staticmethod + def _adjust_time_length(time_vec: torch.Tensor, target_len: int) -> torch.Tensor: + """Adjust time vector length to match target length. + + Args: + time_vec: 1D time vector. + target_len: Target sequence length. + + Returns: + Adjusted time vector. + """ + current_len = time_vec.numel() + + if current_len == target_len: + return time_vec + elif current_len < target_len: + # Pad by repeating last value + if current_len == 0: + return torch.zeros(target_len, device=time_vec.device) + pad_len = target_len - current_len + pad = time_vec[-1].unsqueeze(0).repeat(pad_len) + return torch.cat([time_vec, pad], dim=0) + else: + # Truncate + return time_vec[:target_len] + + @staticmethod + def _map_to_input_shapes( + lime_values: Dict[str, torch.Tensor], + input_shapes: Dict[str, tuple], + ) -> Dict[str, torch.Tensor]: + """Map LIME values from embedding space back to input shapes. + + For embedding-based attributions, this projects the attribution scores + from embedding dimensions back to the original input tensor shapes. + + Args: + lime_values: Dictionary of LIME coefficients in embedding space. + input_shapes: Dictionary of original input shapes. + + Returns: + Dictionary of LIME coefficients reshaped to match inputs. + """ + mapped = {} + for key, values in lime_values.items(): + if key not in input_shapes: + mapped[key] = values + continue + + orig_shape = input_shapes[key] + + # If shapes already match, no adjustment needed + if values.shape == orig_shape: + mapped[key] = values + continue + + # Reshape to match original input + reshaped = values + while len(reshaped.shape) < len(orig_shape): + reshaped = reshaped.unsqueeze(-1) + + if reshaped.shape != orig_shape: + reshaped = reshaped.expand(orig_shape) + + mapped[key] = reshaped + + return mapped diff --git a/tests/core/test_lime.py b/tests/core/test_lime.py new file mode 100644 index 000000000..8b5fcaa4f --- /dev/null +++ b/tests/core/test_lime.py @@ -0,0 +1,946 @@ +""" +Test suite for LIME explainer implementation. +""" +import unittest +from typing import Dict +import tempfile +import pickle +import shutil + +import torch +import torch.nn as nn +import litdata + +from pyhealth.datasets import SampleDataset, get_dataloader +from pyhealth.datasets.sample_dataset import SampleBuilder +from pyhealth.models import MLP, StageNet, BaseModel +from pyhealth.interpret.methods import LimeExplainer +from pyhealth.interpret.methods.base_interpreter import BaseInterpreter + + +class _SimpleLimeModel(BaseModel): + """Minimal model for testing LIME with continuous inputs.""" + + def __init__(self): + super().__init__(dataset=None) + self.feature_keys = ["x"] + self.label_keys = ["y"] + self.mode = "binary" + + self.linear1 = nn.Linear(3, 4, bias=True) + self.linear2 = nn.Linear(4, 1, bias=True) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> dict: + hidden = torch.relu(self.linear1(x)) + logit = self.linear2(hidden) + y_prob = torch.sigmoid(logit) + + return { + "logit": logit, + "y_prob": y_prob, + "y_true": y.to(y_prob.device), + "loss": torch.zeros((), device=y_prob.device), + } + + +class _SimpleEmbeddingModel(nn.Module): + """Simple embedding module mapping integer tokens to vectors.""" + + def __init__(self, vocab_size: int = 20, embedding_dim: int = 4): + super().__init__() + self.embedding = nn.Embedding(vocab_size, embedding_dim) + + def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return {key: self.embedding(value.long()) for key, value in inputs.items()} + + +class _EmbeddingForwardModel(BaseModel): + """Toy model exposing forward_from_embedding for discrete features.""" + + def __init__(self): + super().__init__(dataset=None) + self.feature_keys = ["seq"] + self.label_keys = ["label"] + self.mode = "binary" + + self.embedding_model = _SimpleEmbeddingModel() + self.linear = nn.Linear(4, 1, bias=True) + + def forward_from_embedding( + self, + feature_embeddings: Dict[str, torch.Tensor], + time_info: Dict[str, torch.Tensor] = None, + label: torch.Tensor = None, + ) -> Dict[str, torch.Tensor]: + # Pool embeddings: (batch, seq_len, emb_dim) -> (batch, emb_dim) + pooled = feature_embeddings["seq"].mean(dim=1) + logits = self.linear(pooled) + y_prob = torch.sigmoid(logits) + + return { + "logit": logits, + "y_prob": y_prob, + "loss": torch.zeros((), device=logits.device), + } + + +class _MultiFeatureModel(BaseModel): + """Model with multiple feature inputs for testing multi-feature LIME.""" + + def __init__(self): + super().__init__(dataset=None) + self.feature_keys = ["x1", "x2"] + self.label_keys = ["y"] + self.mode = "binary" + + self.linear1 = nn.Linear(2, 3, bias=True) + self.linear2 = nn.Linear(2, 3, bias=True) + self.linear_out = nn.Linear(6, 1, bias=True) + + def forward(self, x1: torch.Tensor, x2: torch.Tensor, y: torch.Tensor) -> dict: + h1 = torch.relu(self.linear1(x1)) + h2 = torch.relu(self.linear2(x2)) + combined = torch.cat([h1, h2], dim=-1) + logit = self.linear_out(combined) + y_prob = torch.sigmoid(logit) + + return { + "logit": logit, + "y_prob": y_prob, + "y_true": y.to(y_prob.device), + "loss": torch.zeros((), device=y_prob.device), + } + + +class TestLimeExplainerBasic(unittest.TestCase): + """Basic tests for LimeExplainer functionality.""" + + def setUp(self): + self.model = _SimpleLimeModel() + self.model.eval() + + # Set deterministic weights + with torch.no_grad(): + self.model.linear1.weight.copy_( + torch.tensor([ + [0.5, -0.3, 0.2], + [0.1, 0.4, -0.1], + [-0.2, 0.3, 0.5], + [0.3, -0.1, 0.2], + ]) + ) + self.model.linear1.bias.copy_(torch.tensor([0.1, -0.1, 0.2, 0.0])) + self.model.linear2.weight.copy_(torch.tensor([[0.4, -0.3, 0.2, 0.1]])) + self.model.linear2.bias.copy_(torch.tensor([0.05])) + + self.labels = torch.zeros((1, 1)) + self.explainer = LimeExplainer( + self.model, + use_embeddings=False, + n_samples=100, + kernel_width=0.25, + random_seed=42, + ) + + def test_inheritance(self): + """LimeExplainer should inherit from BaseInterpreter.""" + self.assertIsInstance(self.explainer, BaseInterpreter) + + def test_lime_initialization(self): + """Test that LimeExplainer initializes correctly.""" + explainer = LimeExplainer(self.model, use_embeddings=False) + self.assertIsInstance(explainer, LimeExplainer) + self.assertEqual(explainer.model, self.model) + self.assertFalse(explainer.use_embeddings) + self.assertEqual(explainer.n_samples, 1000) + self.assertIsNotNone(explainer.kernel_width) + + def test_attribute_returns_dict(self): + """Attribute method should return dictionary of LIME attributions.""" + inputs = torch.tensor([[1.0, 0.5, -0.3]]) + + attributions = self.explainer.attribute( + x=inputs, + y=self.labels, + ) + + self.assertIsInstance(attributions, dict) + self.assertIn("x", attributions) + self.assertEqual(attributions["x"].shape, inputs.shape) + + def test_lime_values_are_tensors(self): + """LIME values should be PyTorch tensors.""" + inputs = torch.tensor([[0.8, -0.2, 0.5]]) + + attributions = self.explainer.attribute( + x=inputs, + y=self.labels, + ) + + self.assertIsInstance(attributions["x"], torch.Tensor) + self.assertFalse(attributions["x"].requires_grad) + + def test_baseline_generation(self): + """Should generate baseline automatically if not provided.""" + inputs = torch.tensor([[1.0, 0.5, -0.3], [0.5, 1.0, 0.2]]) + + attributions = self.explainer.attribute( + x=inputs, + y=torch.zeros((2, 1)), + ) + + self.assertEqual(attributions["x"].shape, inputs.shape) + + def test_custom_baseline(self): + """Should accept custom baseline dictionary.""" + inputs = torch.tensor([[1.0, 0.5, -0.3]]) + baseline = {"x": torch.zeros_like(inputs)} + + attributions = self.explainer.attribute( + baseline=baseline, + x=inputs, + y=self.labels, + ) + + self.assertEqual(attributions["x"].shape, inputs.shape) + + def test_target_class_idx_none(self): + """Should handle None target class index (max prediction).""" + inputs = torch.tensor([[1.0, 0.5, -0.3]]) + + attributions = self.explainer.attribute( + x=inputs, + y=self.labels, + target_class_idx=None, + ) + + self.assertIn("x", attributions) + self.assertEqual(attributions["x"].shape, inputs.shape) + + def test_target_class_idx_specified(self): + """Should handle specific target class index.""" + inputs = torch.tensor([[1.0, 0.5, -0.3]]) + + attr_class_0 = self.explainer.attribute( + x=inputs, + y=self.labels, + target_class_idx=0, + ) + + attr_class_1 = self.explainer.attribute( + x=inputs, + y=self.labels, + target_class_idx=1, + ) + + # Attributions should differ for different classes + self.assertFalse(torch.allclose(attr_class_0["x"], attr_class_1["x"], atol=0.01)) + + def test_attribution_values_are_finite(self): + """Test that attribution values are finite (no NaN or Inf).""" + inputs = torch.tensor([[1.0, 0.5, -0.3]]) + + attributions = self.explainer.attribute( + x=inputs, + y=self.labels, + ) + + self.assertTrue(torch.isfinite(attributions["x"]).all()) + + def test_multiple_samples(self): + """Test attribution on batch with multiple samples.""" + inputs = torch.tensor([[1.0, 0.5, -0.3], [0.5, 1.0, 0.2], [-0.5, 0.3, 0.8]]) + + attributions = self.explainer.attribute( + x=inputs, + y=torch.zeros((3, 1)), + ) + + # Check batch dimension + self.assertEqual(attributions["x"].shape[0], 3) + self.assertEqual(attributions["x"].shape, inputs.shape) + + def test_callable_interface(self): + """LimeExplainer instances should be callable via BaseInterpreter.__call__.""" + inputs = torch.tensor([[0.3, -0.4, 0.5]]) + kwargs = {"x": inputs, "y": self.labels} + + from_attribute = self.explainer.attribute(**kwargs) + from_call = self.explainer(**kwargs) + + torch.testing.assert_close( + from_call["x"], + from_attribute["x"], + rtol=1e-3, + atol=1e-4 + ) + + def test_different_n_samples(self): + """Test with different numbers of perturbation samples.""" + inputs = torch.tensor([[1.0, 0.5, -0.3]]) + + # Few samples + explainer_few = LimeExplainer( + self.model, + use_embeddings=False, + n_samples=50, + random_seed=42, + ) + attr_few = explainer_few.attribute(x=inputs, y=self.labels) + + # More samples + explainer_many = LimeExplainer( + self.model, + use_embeddings=False, + n_samples=200, + random_seed=42, + ) + attr_many = explainer_many.attribute(x=inputs, y=self.labels) + + # Both should produce valid output + self.assertEqual(attr_few["x"].shape, inputs.shape) + self.assertEqual(attr_many["x"].shape, inputs.shape) + self.assertTrue(torch.isfinite(attr_few["x"]).all()) + self.assertTrue(torch.isfinite(attr_many["x"]).all()) + + def test_different_regularization_methods(self): + """Test LIME with different regularization methods.""" + inputs = torch.tensor([[1.0, 0.5, -0.3]]) + + # Lasso + explainer_lasso = LimeExplainer( + self.model, + use_embeddings=False, + n_samples=100, + feature_selection="lasso", + alpha=0.01, + random_seed=42, + ) + attr_lasso = explainer_lasso.attribute(x=inputs, y=self.labels) + + # Ridge + explainer_ridge = LimeExplainer( + self.model, + use_embeddings=False, + n_samples=100, + feature_selection="ridge", + alpha=0.01, + random_seed=42, + ) + attr_ridge = explainer_ridge.attribute(x=inputs, y=self.labels) + + # None (OLS) + explainer_none = LimeExplainer( + self.model, + use_embeddings=False, + n_samples=100, + feature_selection="none", + random_seed=42, + ) + attr_none = explainer_none.attribute(x=inputs, y=self.labels) + + # All should produce valid finite output + self.assertTrue(torch.isfinite(attr_lasso["x"]).all()) + self.assertTrue(torch.isfinite(attr_ridge["x"]).all()) + self.assertTrue(torch.isfinite(attr_none["x"]).all()) + + def test_different_distance_modes(self): + """Test LIME with different distance modes.""" + inputs = torch.tensor([[1.0, 0.5, -0.3]]) + + # Cosine distance + explainer_cosine = LimeExplainer( + self.model, + use_embeddings=False, + n_samples=100, + distance_mode="cosine", + random_seed=42, + ) + attr_cosine = explainer_cosine.attribute(x=inputs, y=self.labels) + + # Euclidean distance + explainer_euclidean = LimeExplainer( + self.model, + use_embeddings=False, + n_samples=100, + distance_mode="euclidean", + random_seed=42, + ) + attr_euclidean = explainer_euclidean.attribute(x=inputs, y=self.labels) + + # Both should produce valid output + self.assertTrue(torch.isfinite(attr_cosine["x"]).all()) + self.assertTrue(torch.isfinite(attr_euclidean["x"]).all()) + + def test_reproducibility_with_seed(self): + """Test that results are reproducible with same random seed.""" + inputs = torch.tensor([[1.0, 0.5, -0.3]]) + + explainer1 = LimeExplainer( + self.model, + use_embeddings=False, + n_samples=100, + random_seed=42, + ) + explainer2 = LimeExplainer( + self.model, + use_embeddings=False, + n_samples=100, + random_seed=42, + ) + + attr1 = explainer1.attribute(x=inputs, y=self.labels) + attr2 = explainer2.attribute(x=inputs, y=self.labels) + + # Results should be identical with same seed + torch.testing.assert_close(attr1["x"], attr2["x"], rtol=1e-5, atol=1e-6) + + +class TestLimeExplainerEmbedding(unittest.TestCase): + """Tests for LimeExplainer with embedding-based models.""" + + def setUp(self): + self.model = _EmbeddingForwardModel() + self.model.eval() + + # Set deterministic weights + with torch.no_grad(): + self.model.linear.weight.copy_(torch.tensor([[0.4, -0.3, 0.2, 0.1]])) + self.model.linear.bias.copy_(torch.tensor([0.05])) + + self.labels = torch.zeros((1, 1)) + self.explainer = LimeExplainer( + self.model, + use_embeddings=True, + n_samples=100, + random_seed=42, + ) + + def test_embedding_initialization(self): + """Test that LimeExplainer initializes with embedding mode.""" + self.assertTrue(self.explainer.use_embeddings) + self.assertTrue(hasattr(self.model, "forward_from_embedding")) + + def test_attribute_with_embeddings(self): + """Test attribution computation in embedding mode.""" + seq_inputs = torch.tensor([[1, 2, 3]]) + + attributions = self.explainer.attribute( + seq=seq_inputs, + label=self.labels, + ) + + self.assertIn("seq", attributions) + self.assertEqual(attributions["seq"].shape, seq_inputs.shape) + + def test_embedding_attributions_are_finite(self): + """Test that embedding-based attributions are finite.""" + seq_inputs = torch.tensor([[5, 10, 15]]) + + attributions = self.explainer.attribute( + seq=seq_inputs, + label=self.labels, + ) + + self.assertTrue(torch.isfinite(attributions["seq"]).all()) + + def test_embedding_with_time_info(self): + """Test attribution with time information (temporal data).""" + time_tensor = torch.tensor([[0.0, 1.5, 3.0]]) + seq_tensor = torch.tensor([[1, 2, 3]]) + + attributions = self.explainer.attribute( + seq=(time_tensor, seq_tensor), + label=self.labels, + ) + + self.assertIn("seq", attributions) + self.assertEqual(attributions["seq"].shape, seq_tensor.shape) + + def test_embedding_with_custom_baseline(self): + """Test embedding-based LIME with custom baseline.""" + seq_inputs = torch.tensor([[1, 2, 3]]) + baseline = {"seq": torch.zeros_like(seq_inputs)} + + attributions = self.explainer.attribute( + baseline=baseline, + seq=seq_inputs, + label=self.labels, + ) + + self.assertEqual(attributions["seq"].shape, seq_inputs.shape) + + def test_embedding_model_without_forward_from_embedding_fails(self): + """Test that using embeddings without forward_from_embedding raises error.""" + model_without_embed = _SimpleLimeModel() + + with self.assertRaises(AssertionError): + LimeExplainer(model_without_embed, use_embeddings=True) + + +class TestLimeExplainerMultiFeature(unittest.TestCase): + """Tests for LimeExplainer with multiple feature inputs.""" + + def setUp(self): + self.model = _MultiFeatureModel() + self.model.eval() + + # Set deterministic weights + with torch.no_grad(): + self.model.linear1.weight.copy_( + torch.tensor([[0.5, -0.3], [0.1, 0.4], [-0.2, 0.3]]) + ) + self.model.linear2.weight.copy_( + torch.tensor([[0.3, -0.1], [0.2, 0.5], [0.4, -0.2]]) + ) + self.model.linear_out.weight.copy_( + torch.tensor([[0.1, 0.2, -0.1, 0.3, -0.2, 0.15]]) + ) + + self.labels = torch.zeros((1, 1)) + self.explainer = LimeExplainer( + self.model, + use_embeddings=False, + n_samples=100, + random_seed=42, + ) + + def test_multi_feature_attribution(self): + """Test attribution with multiple feature inputs.""" + x1 = torch.tensor([[1.0, 0.5]]) + x2 = torch.tensor([[-0.3, 0.8]]) + + attributions = self.explainer.attribute( + x1=x1, + x2=x2, + y=self.labels, + ) + + self.assertIn("x1", attributions) + self.assertIn("x2", attributions) + self.assertEqual(attributions["x1"].shape, x1.shape) + self.assertEqual(attributions["x2"].shape, x2.shape) + + def test_multi_feature_with_custom_baselines(self): + """Test multi-feature attribution with custom baselines.""" + x1 = torch.tensor([[1.0, 0.5]]) + x2 = torch.tensor([[-0.3, 0.8]]) + baseline = { + "x1": torch.zeros_like(x1), + "x2": torch.ones_like(x2) * 0.5, + } + + attributions = self.explainer.attribute( + baseline=baseline, + x1=x1, + x2=x2, + y=self.labels, + ) + + self.assertEqual(attributions["x1"].shape, x1.shape) + self.assertEqual(attributions["x2"].shape, x2.shape) + + def test_multi_feature_finite_values(self): + """Test that multi-feature attributions are finite.""" + # Test with single sample to avoid baseline batch size issues + x1 = torch.tensor([[1.0, 0.5]]) + x2 = torch.tensor([[-0.3, 0.8]]) + + attributions = self.explainer.attribute( + x1=x1, + x2=x2, + y=torch.zeros((1, 1)), + ) + + self.assertTrue(torch.isfinite(attributions["x1"]).all()) + self.assertTrue(torch.isfinite(attributions["x2"]).all()) + + +class TestLimeExplainerMLP(unittest.TestCase): + """Test cases for LIME with MLP model on real dataset.""" + + def setUp(self): + """Set up test data and model.""" + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "conditions": ["cond-33", "cond-86", "cond-80", "cond-12"], + "procedures": [1.0, 2.0, 3.5, 4], + "label": 0, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "conditions": ["cond-33", "cond-86", "cond-80"], + "procedures": [5.0, 2.0, 3.5, 4], + "label": 1, + }, + { + "patient_id": "patient-2", + "visit_id": "visit-2", + "conditions": ["cond-55", "cond-12"], + "procedures": [2.0, 3.0, 1.5, 5], + "label": 1, + }, + ] + + # Define input and output schemas + self.input_schema = { + "conditions": "sequence", + "procedures": "tensor", + } + self.output_schema = {"label": "binary"} + + # Create temporary directory for dataset + self.temp_dir = tempfile.mkdtemp() + + # Create dataset using SampleBuilder + builder = SampleBuilder(self.input_schema, self.output_schema) + builder.fit(self.samples) + builder.save(f"{self.temp_dir}/schema.pkl") + + # Optimize samples into dataset format + def sample_generator(): + for sample in self.samples: + yield {"sample": pickle.dumps(sample)} + + litdata.optimize( + fn=builder.transform, + inputs=list(sample_generator()), + output_dir=self.temp_dir, + num_workers=1, + chunk_bytes="64MB", + ) + + # Create dataset + self.dataset = SampleDataset( + path=self.temp_dir, + dataset_name="test_lime", + ) + + # Create model + self.model = MLP( + dataset=self.dataset, + embedding_dim=32, + hidden_dim=32, + n_layers=2, + ) + self.model.eval() + + # Create dataloader + self.test_loader = get_dataloader(self.dataset, batch_size=1, shuffle=False) + + def tearDown(self): + """Clean up temporary directory.""" + if hasattr(self, 'temp_dir'): + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_lime_mlp_basic_attribution(self): + """Test basic LIME attribution computation with MLP.""" + explainer = LimeExplainer( + self.model, + use_embeddings=True, + n_samples=50, + random_seed=42, + ) + data_batch = next(iter(self.test_loader)) + + # Compute attributions + attributions = explainer.attribute(**data_batch) + + # Check output structure + self.assertIn("conditions", attributions) + self.assertIn("procedures", attributions) + + # Check shapes match input shapes + self.assertEqual( + attributions["conditions"].shape, data_batch["conditions"].shape + ) + self.assertEqual( + attributions["procedures"].shape, data_batch["procedures"].shape + ) + + # Check that attributions are tensors + self.assertIsInstance(attributions["conditions"], torch.Tensor) + self.assertIsInstance(attributions["procedures"], torch.Tensor) + + def test_lime_mlp_with_target_class(self): + """Test LIME attribution with specific target class.""" + explainer = LimeExplainer( + self.model, + use_embeddings=True, + n_samples=50, + random_seed=42, + ) + data_batch = next(iter(self.test_loader)) + + # Compute attributions for class 0 + attr_class_0 = explainer.attribute(**data_batch, target_class_idx=0) + + # Compute attributions for class 1 + attr_class_1 = explainer.attribute(**data_batch, target_class_idx=1) + + # Check that attributions are different for different classes + self.assertFalse( + torch.allclose(attr_class_0["conditions"], attr_class_1["conditions"], atol=0.01) + ) + + def test_lime_mlp_values_finite(self): + """Test that LIME values are finite (no NaN or Inf).""" + explainer = LimeExplainer( + self.model, + use_embeddings=True, + n_samples=50, + random_seed=42, + ) + data_batch = next(iter(self.test_loader)) + + attributions = explainer.attribute(**data_batch) + + # Check no NaN or Inf + self.assertTrue(torch.isfinite(attributions["conditions"]).all()) + self.assertTrue(torch.isfinite(attributions["procedures"]).all()) + + def test_lime_mlp_multiple_samples(self): + """Test LIME on batch with multiple samples.""" + explainer = LimeExplainer( + self.model, + use_embeddings=True, + n_samples=50, + random_seed=42, + ) + + # Use batch size > 1 + test_loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + data_batch = next(iter(test_loader)) + + attributions = explainer.attribute(**data_batch) + + # Check batch dimension + self.assertEqual(attributions["conditions"].shape[0], 2) + self.assertEqual(attributions["procedures"].shape[0], 2) + + def test_lime_mlp_different_n_samples(self): + """Test LIME with different numbers of perturbation samples.""" + data_batch = next(iter(self.test_loader)) + + # Few samples + explainer_few = LimeExplainer( + self.model, + use_embeddings=True, + n_samples=30, + random_seed=42, + ) + attr_few = explainer_few.attribute(**data_batch) + + # More samples + explainer_many = LimeExplainer( + self.model, + use_embeddings=True, + n_samples=100, + random_seed=42, + ) + attr_many = explainer_many.attribute(**data_batch) + + # Both should produce valid output + self.assertIn("conditions", attr_few) + self.assertIn("conditions", attr_many) + self.assertEqual(attr_few["conditions"].shape, attr_many["conditions"].shape) + + def test_lime_mlp_different_regularization(self): + """Test LIME with different regularization methods on MLP.""" + data_batch = next(iter(self.test_loader)) + + # Lasso + explainer_lasso = LimeExplainer( + self.model, + use_embeddings=True, + n_samples=50, + feature_selection="lasso", + alpha=0.01, + random_seed=42, + ) + attr_lasso = explainer_lasso.attribute(**data_batch) + + # Ridge + explainer_ridge = LimeExplainer( + self.model, + use_embeddings=True, + n_samples=50, + feature_selection="ridge", + alpha=0.01, + random_seed=42, + ) + attr_ridge = explainer_ridge.attribute(**data_batch) + + # Both should produce valid finite output + self.assertTrue(torch.isfinite(attr_lasso["conditions"]).all()) + self.assertTrue(torch.isfinite(attr_ridge["conditions"]).all()) + + +class TestLimeExplainerStageNet(unittest.TestCase): + """Test cases for LIME with StageNet model. + + Note: StageNet tests demonstrate LIME working with temporal/sequential data. + """ + + def setUp(self): + """Set up test data and StageNet model.""" + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "codes": ([0.0, 2.0, 1.3], ["505800458", "50580045810", "50580045811"]), + "procedures": ( + [0.0, 1.5], + [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], + ), + "lab_values": (None, [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]]), + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-1", + "codes": ( + [0.0, 2.0, 1.3, 1.0, 2.0], + [ + "55154191800", + "551541928", + "55154192800", + "705182798", + "70518279800", + ], + ), + "procedures": ([0.0], [["A04A", "B035", "C129"]]), + "lab_values": ( + None, + [ + [1.4, 3.2, 3.5], + [4.1, 5.9, 1.7], + [4.5, 5.9, 1.7], + ], + ), + "label": 0, + }, + ] + + # Define input and output schemas + self.input_schema = { + "codes": "stagenet", + "procedures": "stagenet", + "lab_values": "stagenet_tensor", + } + self.output_schema = {"label": "binary"} + + # Create temporary directory for dataset + self.temp_dir = tempfile.mkdtemp() + + # Create dataset using SampleBuilder + builder = SampleBuilder(self.input_schema, self.output_schema) + builder.fit(self.samples) + builder.save(f"{self.temp_dir}/schema.pkl") + + # Optimize samples into dataset format + def sample_generator(): + for sample in self.samples: + yield {"sample": pickle.dumps(sample)} + + litdata.optimize( + fn=builder.transform, + inputs=list(sample_generator()), + output_dir=self.temp_dir, + num_workers=1, + chunk_bytes="64MB", + ) + + # Create dataset + self.dataset = SampleDataset( + path=self.temp_dir, + dataset_name="test_stagenet_lime", + ) + + # Create StageNet model + self.model = StageNet( + dataset=self.dataset, + embedding_dim=32, + chunk_size=16, + levels=2, + ) + self.model.eval() + + # Create dataloader + self.test_loader = get_dataloader(self.dataset, batch_size=1, shuffle=False) + + def tearDown(self): + """Clean up temporary directory.""" + if hasattr(self, 'temp_dir'): + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_lime_initialization_stagenet(self): + """Test that LimeExplainer works with StageNet.""" + explainer = LimeExplainer( + self.model, + use_embeddings=True, + n_samples=50, + random_seed=42, + ) + self.assertIsInstance(explainer, LimeExplainer) + self.assertEqual(explainer.model, self.model) + + def test_lime_basic_attribution_stagenet(self): + """Test basic LIME attribution computation with StageNet.""" + explainer = LimeExplainer( + self.model, + use_embeddings=True, + n_samples=50, + random_seed=42, + ) + data_batch = next(iter(self.test_loader)) + + # Compute attributions + attributions = explainer.attribute(**data_batch) + + # Check output structure + self.assertIn("codes", attributions) + self.assertIn("procedures", attributions) + self.assertIn("lab_values", attributions) + + # Check that attributions are tensors + self.assertIsInstance(attributions["codes"], torch.Tensor) + self.assertIsInstance(attributions["procedures"], torch.Tensor) + self.assertIsInstance(attributions["lab_values"], torch.Tensor) + + def test_lime_stagenet_values_finite(self): + """Test that LIME values on StageNet are finite.""" + explainer = LimeExplainer( + self.model, + use_embeddings=True, + n_samples=50, + random_seed=42, + ) + data_batch = next(iter(self.test_loader)) + + attributions = explainer.attribute(**data_batch) + + # Check no NaN or Inf + self.assertTrue(torch.isfinite(attributions["codes"]).all()) + self.assertTrue(torch.isfinite(attributions["procedures"]).all()) + self.assertTrue(torch.isfinite(attributions["lab_values"]).all()) + + def test_lime_stagenet_target_class(self): + """Test LIME with specific target class on StageNet.""" + explainer = LimeExplainer( + self.model, + use_embeddings=True, + n_samples=50, + random_seed=42, + ) + data_batch = next(iter(self.test_loader)) + + # Compute attributions for class 1 + attributions = explainer.attribute(**data_batch, target_class_idx=1) + + # Check output structure + self.assertIn("codes", attributions) + self.assertTrue(torch.isfinite(attributions["codes"]).all())