Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions docs/api/interpret.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,18 @@ New to interpretability in PyHealth? Check out these complete examples:
- Compare different baseline strategies for background sample generation
- Decode attributions to human-readable medical codes and lab measurements

**ViT/Chefer Attribution Example:**

- ``examples/covid19_cxr_tutorial.py`` - Demonstrates Chefer's attention-based attribution for Vision Transformers:

- Train a ViT model on COVID-19 chest X-ray classification
- Use CheferRelevance for gradient-weighted attention attribution
- Visualize which image patches contribute to predictions

These examples provide end-to-end workflows from loading data to interpreting and evaluating attributions.

Available Methods
-----------------
Attribution Methods
-------------------

.. toctree::
:maxdepth: 4
Expand All @@ -64,4 +72,15 @@ Available Methods
interpret/pyhealth.interpret.methods.deeplift
interpret/pyhealth.interpret.methods.integrated_gradients
interpret/pyhealth.interpret.methods.shap


Visualization Utilities
-----------------------

The ``pyhealth.interpret.utils`` module provides visualization functions for
creating attribution overlays, heatmaps, and publication-ready figures.
Includes specialized support for Vision Transformer (ViT) attribution visualization.

.. toctree::
:maxdepth: 4

interpret/pyhealth.interpret.utils
100 changes: 100 additions & 0 deletions docs/api/interpret/pyhealth.interpret.utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
pyhealth.interpret.utils
========================

.. automodule:: pyhealth.interpret.utils
:members:
:undoc-members:
:show-inheritance:

Overview
--------

The ``pyhealth.interpret.utils`` module provides visualization utilities for
interpretability methods in PyHealth. These functions help create visual
explanations of model predictions, particularly useful for medical imaging.

Core Functions
--------------

**Overlay Visualization**

- :func:`show_cam_on_image` - Overlay a CAM/attribution map on an image
- :func:`visualize_attribution_on_image` - Generate complete attribution visualization

**Normalization & Processing**

- :func:`normalize_attribution` - Normalize attribution values for visualization
- :func:`interpolate_attribution_map` - Resize attribution to match image dimensions

**Figure Generation**

- :func:`create_attribution_figure` - Create publication-ready figure with overlays

ViT-Specific Functions
----------------------

These functions are specifically designed for Vision Transformer (ViT) models
using attention-based interpretability methods like :class:`~pyhealth.interpret.methods.CheferRelevance`.

- :func:`generate_vit_visualization` - Generate visualization components for ViT attribution
- :func:`create_vit_attribution_figure` - Create complete ViT attribution figure
- :func:`reshape_vit_attribution` - Reshape flat patch attribution to 2D spatial map

Example: Basic Attribution Visualization
----------------------------------------

.. code-block:: python

import numpy as np
from pyhealth.interpret.utils import show_cam_on_image, normalize_attribution

# Assume we have image and attribution from an interpreter
image = np.random.rand(224, 224, 3) # RGB image in [0, 1]
attribution = np.random.rand(224, 224) # Raw attribution values

# Normalize and overlay
attr_normalized = normalize_attribution(attribution)
overlay = show_cam_on_image(image, attr_normalized)

Example: ViT Attribution with CheferRelevance
---------------------------------------------

.. code-block:: python

from pyhealth.models import TorchvisionModel
from pyhealth.interpret.methods import CheferRelevance
from pyhealth.interpret.utils import (
generate_vit_visualization,
create_vit_attribution_figure,
)
import matplotlib.pyplot as plt

# Initialize ViT model and interpreter
model = TorchvisionModel(dataset, "vit_b_16", {"weights": "DEFAULT"})
# ... train model ...

interpreter = CheferRelevance(model)

# Generate visualization components
image, attr_map, overlay = generate_vit_visualization(
interpreter=interpreter,
**test_batch
)

# Or create a complete figure
fig = create_vit_attribution_figure(
interpreter=interpreter,
class_names={0: "Normal", 1: "COVID", 2: "Pneumonia"},
save_path="vit_attribution.png",
**test_batch
)

See Also
--------

- :mod:`pyhealth.interpret.methods` - Attribution methods (DeepLift, IntegratedGradients, CheferRelevance, etc.)
- :class:`pyhealth.interpret.methods.CheferRelevance` - Attention-based interpretability for Transformers
- :class:`pyhealth.models.TorchvisionModel` - ViT and other vision models



108 changes: 75 additions & 33 deletions docs/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,41 +62,47 @@ The ``examples/`` directory contains additional code examples demonstrating vari
Mortality Prediction
--------------------

These examples are located in ``examples/mortality_prediction/``.

.. list-table::
:widths: 50 50
:header-rows: 1

* - Example File
- Description
* - ``mortality_mimic3_rnn.py``
* - ``mortality_prediction/mortality_mimic3_rnn.py``
- RNN for mortality prediction on MIMIC-III
* - ``mortality_mimic3_stagenet.py``
* - ``mortality_prediction/mortality_mimic3_stagenet.py``
- StageNet for mortality prediction on MIMIC-III
* - ``mortality_mimic3_adacare.py``
- AdaCare for mortality prediction on MIMIC-III
* - ``mortality_mimic3_agent.py``
* - ``mortality_prediction/mortality_mimic3_adacare.ipynb``
- AdaCare for mortality prediction on MIMIC-III (notebook)
* - ``mortality_prediction/mortality_mimic3_agent.py``
- Agent model for mortality prediction on MIMIC-III
* - ``mortality_mimic3_concare.py``
* - ``mortality_prediction/mortality_mimic3_concare.py``
- ConCare for mortality prediction on MIMIC-III
* - ``mortality_mimic3_grasp.py``
* - ``mortality_prediction/mortality_mimic3_grasp.py``
- GRASP for mortality prediction on MIMIC-III
* - ``mortality_mimic3_tcn.py``
* - ``mortality_prediction/mortality_mimic3_tcn.py``
- Temporal Convolutional Network for mortality prediction
* - ``mortality_mimic4_stagenet_v2.py``
* - ``mortality_prediction/mortality_mimic4_stagenet_v2.py``
- StageNet for mortality prediction on MIMIC-IV (v2)
* - ``mortality_prediction/timeseries_mimic4.py``
- Time series analysis on MIMIC-IV

Readmission Prediction
----------------------

These examples are located in ``examples/readmission/``.

.. list-table::
:widths: 50 50
:header-rows: 1

* - Example File
- Description
* - ``readmission_mimic3_rnn.py``
* - ``readmission/readmission_mimic3_rnn.py``
- RNN for readmission prediction on MIMIC-III
* - ``readmission_mimic3_fairness.py``
* - ``readmission/readmission_mimic3_fairness.py``
- Fairness-aware readmission prediction on MIMIC-III

Survival Prediction
Expand All @@ -114,27 +120,29 @@ Survival Prediction
Drug Recommendation
-------------------

These examples are located in ``examples/drug_recommendation/``.

.. list-table::
:widths: 50 50
:header-rows: 1

* - Example File
- Description
* - ``drug_recommendation_mimic3_safedrug.py``
* - ``drug_recommendation/drug_recommendation_mimic3_safedrug.py``
- SafeDrug for drug recommendation on MIMIC-III
* - ``drug_recommendation_mimic3_molerec.py``
* - ``drug_recommendation/drug_recommendation_mimic3_molerec.py``
- MoleRec for drug recommendation on MIMIC-III
* - ``drug_recommendation_mimic3_gamenet.py``
* - ``drug_recommendation/drug_recommendation_mimic3_gamenet.py``
- GAMENet for drug recommendation on MIMIC-III
* - ``drug_recommendation_mimic3_transformer.py``
* - ``drug_recommendation/drug_recommendation_mimic3_transformer.py``
- Transformer for drug recommendation on MIMIC-III
* - ``drug_recommendation_mimic3_micron.py``
* - ``drug_recommendation/drug_recommendation_mimic3_micron.py``
- MICRON for drug recommendation on MIMIC-III
* - ``drug_recommendation_mimic4_gamenet.py``
* - ``drug_recommendation/drug_recommendation_mimic4_gamenet.py``
- GAMENet for drug recommendation on MIMIC-IV
* - ``drug_recommendation_mimic4_retain.py``
* - ``drug_recommendation/drug_recommendation_mimic4_retain.py``
- RETAIN for drug recommendation on MIMIC-IV
* - ``drug_recommendation_eICU_transformer.py``
* - ``drug_recommendation/drug_recommendation_eICU_transformer.py``
- Transformer for drug recommendation on eICU

EEG and Sleep Analysis
Expand All @@ -159,27 +167,39 @@ EEG and Sleep Analysis
* - ``cardiology_detection_isAR_SparcNet.py``
- SparcNet for cardiology arrhythmia detection

Image Analysis
--------------
Image Analysis (Chest X-Ray)
----------------------------

These examples are located in ``examples/cxr/``.

.. list-table::
:widths: 50 50
:header-rows: 1

* - Example File
- Description
* - ``covid19cxr_conformal.py``
* - ``cxr/covid19cxr_tutorial.py``
- ViT training, conformal prediction & interpretability for COVID-19 CXR
* - ``cxr/covid19cxr_conformal.py``
- Conformal prediction for COVID-19 CXR classification
* - ``cnn_cxr.ipynb``
* - ``cxr/cnn_cxr.ipynb``
- CNN for chest X-ray classification (notebook)
* - ``chestXray_image_generation_VAE.py``
* - ``cxr/chestxray14_binary_classification.ipynb``
- Binary classification on ChestX-ray14 dataset (notebook)
* - ``cxr/chestxray14_multilabel_classification.ipynb``
- Multi-label classification on ChestX-ray14 dataset (notebook)
* - ``cxr/ChestXrayClassificationWithSaliency.ipynb``
- Chest X-ray classification with saliency maps (notebook)
* - ``cxr/chextXray_image_generation_VAE.py``
- VAE for chest X-ray image generation
* - ``ChestXray-image-generation-GAN.ipynb``
* - ``cxr/ChestXray-image-generation-GAN.ipynb``
- GAN for chest X-ray image generation (notebook)

Interpretability
----------------

These examples are located in ``examples/interpretability/``.

.. list-table::
:widths: 50 50
:header-rows: 1
Expand All @@ -188,12 +208,20 @@ Interpretability
- Description
* - ``integrated_gradients_mortality_mimic4_stagenet.py``
- Integrated Gradients for StageNet interpretability
* - ``deeplift_stagenet_mimic4.py``
* - ``interpretability/deeplift_stagenet_mimic4.py``
- DeepLift attributions for StageNet on MIMIC-IV
* - ``interpretability_metrics.py``
* - ``interpretability/gim_stagenet_mimic4.py``
- GIM attributions for StageNet on MIMIC-IV
* - ``interpretability/gim_transformer_mimic4.py``
- GIM attributions for Transformer on MIMIC-IV
* - ``interpretability/shap_stagenet_mimic4.py``
- SHAP attributions for StageNet on MIMIC-IV
* - ``interpretability/interpretability_metrics.py``
- Evaluating attribution methods with metrics
* - ``interpret_demo.ipynb``
* - ``interpretability/interpret_demo.ipynb``
- Interactive interpretability demonstrations (notebook)
* - ``interpretability/shap_stagenet_mimic4.ipynb``
- SHAP attributions for StageNet (notebook)

Patient Linkage
---------------
Expand All @@ -207,6 +235,22 @@ Patient Linkage
* - ``patient_linkage_mimic3_medlink.py``
- MedLink for patient record linkage on MIMIC-III

Length of Stay
--------------

These examples are located in ``examples/length_of_stay/``.

.. list-table::
:widths: 50 50
:header-rows: 1

* - Example File
- Description
* - ``length_of_stay/length_of_stay_mimic3_rnn.py``
- RNN for length of stay prediction on MIMIC-III
* - ``length_of_stay/length_of_stay_mimic4_rnn.py``
- RNN for length of stay prediction on MIMIC-IV

Advanced Topics
---------------

Expand All @@ -216,13 +260,11 @@ Advanced Topics

* - Example File
- Description
* - ``length_of_stay_mimic3_rnn.py``
- RNN for length of stay prediction
* - ``omop_dataset_demo.py``
- Working with OMOP Common Data Model
* - ``medcode.py``
- Medical code vocabulary and mappings
* - ``benchmark_ehrshot.ipynb``
* - ``benchmark_ehrshot_xgboost.ipynb``
- EHRShot benchmark with XGBoost (notebook)

Notebooks (Interactive)
Expand All @@ -238,7 +280,7 @@ Notebooks (Interactive)
- Comprehensive StageNet tutorial
* - ``mimic3_mortality_prediction_cached.ipynb``
- Cached mortality prediction workflow
* - ``timeseries_mimic4.ipynb``
* - ``mortality_prediction/timeseries_mimic4.ipynb``
- Time series analysis on MIMIC-IV
* - ``transformer_mimic4.ipynb``
- Transformer models on MIMIC-IV
Expand All @@ -252,7 +294,7 @@ Notebooks (Interactive)
- SafeDrug interactive notebook
* - ``molerec_mimic3.ipynb``
- MoleRec interactive notebook
* - ``drug_recommendation_mimic3_micron.ipynb``
* - ``drug_recommendation/drug_recommendation_mimic3_micron.ipynb``
- MICRON interactive notebook
* - ``kg_embedding.ipynb``
- Knowledge graph embeddings
Expand Down
Loading