diff --git a/docs/api/interpret.rst b/docs/api/interpret.rst index 08d41ab41..409f9a5c3 100644 --- a/docs/api/interpret.rst +++ b/docs/api/interpret.rst @@ -50,10 +50,18 @@ New to interpretability in PyHealth? Check out these complete examples: - Compare different baseline strategies for background sample generation - Decode attributions to human-readable medical codes and lab measurements +**ViT/Chefer Attribution Example:** + +- ``examples/covid19_cxr_tutorial.py`` - Demonstrates Chefer's attention-based attribution for Vision Transformers: + + - Train a ViT model on COVID-19 chest X-ray classification + - Use CheferRelevance for gradient-weighted attention attribution + - Visualize which image patches contribute to predictions + These examples provide end-to-end workflows from loading data to interpreting and evaluating attributions. -Available Methods ------------------ +Attribution Methods +------------------- .. toctree:: :maxdepth: 4 @@ -64,4 +72,15 @@ Available Methods interpret/pyhealth.interpret.methods.deeplift interpret/pyhealth.interpret.methods.integrated_gradients interpret/pyhealth.interpret.methods.shap - \ No newline at end of file + +Visualization Utilities +----------------------- + +The ``pyhealth.interpret.utils`` module provides visualization functions for +creating attribution overlays, heatmaps, and publication-ready figures. +Includes specialized support for Vision Transformer (ViT) attribution visualization. + +.. toctree:: + :maxdepth: 4 + + interpret/pyhealth.interpret.utils diff --git a/docs/api/interpret/pyhealth.interpret.utils.rst b/docs/api/interpret/pyhealth.interpret.utils.rst new file mode 100644 index 000000000..ad480415d --- /dev/null +++ b/docs/api/interpret/pyhealth.interpret.utils.rst @@ -0,0 +1,100 @@ +pyhealth.interpret.utils +======================== + +.. automodule:: pyhealth.interpret.utils + :members: + :undoc-members: + :show-inheritance: + +Overview +-------- + +The ``pyhealth.interpret.utils`` module provides visualization utilities for +interpretability methods in PyHealth. These functions help create visual +explanations of model predictions, particularly useful for medical imaging. + +Core Functions +-------------- + +**Overlay Visualization** + +- :func:`show_cam_on_image` - Overlay a CAM/attribution map on an image +- :func:`visualize_attribution_on_image` - Generate complete attribution visualization + +**Normalization & Processing** + +- :func:`normalize_attribution` - Normalize attribution values for visualization +- :func:`interpolate_attribution_map` - Resize attribution to match image dimensions + +**Figure Generation** + +- :func:`create_attribution_figure` - Create publication-ready figure with overlays + +ViT-Specific Functions +---------------------- + +These functions are specifically designed for Vision Transformer (ViT) models +using attention-based interpretability methods like :class:`~pyhealth.interpret.methods.CheferRelevance`. + +- :func:`generate_vit_visualization` - Generate visualization components for ViT attribution +- :func:`create_vit_attribution_figure` - Create complete ViT attribution figure +- :func:`reshape_vit_attribution` - Reshape flat patch attribution to 2D spatial map + +Example: Basic Attribution Visualization +---------------------------------------- + +.. code-block:: python + + import numpy as np + from pyhealth.interpret.utils import show_cam_on_image, normalize_attribution + + # Assume we have image and attribution from an interpreter + image = np.random.rand(224, 224, 3) # RGB image in [0, 1] + attribution = np.random.rand(224, 224) # Raw attribution values + + # Normalize and overlay + attr_normalized = normalize_attribution(attribution) + overlay = show_cam_on_image(image, attr_normalized) + +Example: ViT Attribution with CheferRelevance +--------------------------------------------- + +.. code-block:: python + + from pyhealth.models import TorchvisionModel + from pyhealth.interpret.methods import CheferRelevance + from pyhealth.interpret.utils import ( + generate_vit_visualization, + create_vit_attribution_figure, + ) + import matplotlib.pyplot as plt + + # Initialize ViT model and interpreter + model = TorchvisionModel(dataset, "vit_b_16", {"weights": "DEFAULT"}) + # ... train model ... + + interpreter = CheferRelevance(model) + + # Generate visualization components + image, attr_map, overlay = generate_vit_visualization( + interpreter=interpreter, + **test_batch + ) + + # Or create a complete figure + fig = create_vit_attribution_figure( + interpreter=interpreter, + class_names={0: "Normal", 1: "COVID", 2: "Pneumonia"}, + save_path="vit_attribution.png", + **test_batch + ) + +See Also +-------- + +- :mod:`pyhealth.interpret.methods` - Attribution methods (DeepLift, IntegratedGradients, CheferRelevance, etc.) +- :class:`pyhealth.interpret.methods.CheferRelevance` - Attention-based interpretability for Transformers +- :class:`pyhealth.models.TorchvisionModel` - ViT and other vision models + + + diff --git a/docs/tutorials.rst b/docs/tutorials.rst index e21ee0002..bdedb189a 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -62,41 +62,47 @@ The ``examples/`` directory contains additional code examples demonstrating vari Mortality Prediction -------------------- +These examples are located in ``examples/mortality_prediction/``. + .. list-table:: :widths: 50 50 :header-rows: 1 * - Example File - Description - * - ``mortality_mimic3_rnn.py`` + * - ``mortality_prediction/mortality_mimic3_rnn.py`` - RNN for mortality prediction on MIMIC-III - * - ``mortality_mimic3_stagenet.py`` + * - ``mortality_prediction/mortality_mimic3_stagenet.py`` - StageNet for mortality prediction on MIMIC-III - * - ``mortality_mimic3_adacare.py`` - - AdaCare for mortality prediction on MIMIC-III - * - ``mortality_mimic3_agent.py`` + * - ``mortality_prediction/mortality_mimic3_adacare.ipynb`` + - AdaCare for mortality prediction on MIMIC-III (notebook) + * - ``mortality_prediction/mortality_mimic3_agent.py`` - Agent model for mortality prediction on MIMIC-III - * - ``mortality_mimic3_concare.py`` + * - ``mortality_prediction/mortality_mimic3_concare.py`` - ConCare for mortality prediction on MIMIC-III - * - ``mortality_mimic3_grasp.py`` + * - ``mortality_prediction/mortality_mimic3_grasp.py`` - GRASP for mortality prediction on MIMIC-III - * - ``mortality_mimic3_tcn.py`` + * - ``mortality_prediction/mortality_mimic3_tcn.py`` - Temporal Convolutional Network for mortality prediction - * - ``mortality_mimic4_stagenet_v2.py`` + * - ``mortality_prediction/mortality_mimic4_stagenet_v2.py`` - StageNet for mortality prediction on MIMIC-IV (v2) + * - ``mortality_prediction/timeseries_mimic4.py`` + - Time series analysis on MIMIC-IV Readmission Prediction ---------------------- +These examples are located in ``examples/readmission/``. + .. list-table:: :widths: 50 50 :header-rows: 1 * - Example File - Description - * - ``readmission_mimic3_rnn.py`` + * - ``readmission/readmission_mimic3_rnn.py`` - RNN for readmission prediction on MIMIC-III - * - ``readmission_mimic3_fairness.py`` + * - ``readmission/readmission_mimic3_fairness.py`` - Fairness-aware readmission prediction on MIMIC-III Survival Prediction @@ -114,27 +120,29 @@ Survival Prediction Drug Recommendation ------------------- +These examples are located in ``examples/drug_recommendation/``. + .. list-table:: :widths: 50 50 :header-rows: 1 * - Example File - Description - * - ``drug_recommendation_mimic3_safedrug.py`` + * - ``drug_recommendation/drug_recommendation_mimic3_safedrug.py`` - SafeDrug for drug recommendation on MIMIC-III - * - ``drug_recommendation_mimic3_molerec.py`` + * - ``drug_recommendation/drug_recommendation_mimic3_molerec.py`` - MoleRec for drug recommendation on MIMIC-III - * - ``drug_recommendation_mimic3_gamenet.py`` + * - ``drug_recommendation/drug_recommendation_mimic3_gamenet.py`` - GAMENet for drug recommendation on MIMIC-III - * - ``drug_recommendation_mimic3_transformer.py`` + * - ``drug_recommendation/drug_recommendation_mimic3_transformer.py`` - Transformer for drug recommendation on MIMIC-III - * - ``drug_recommendation_mimic3_micron.py`` + * - ``drug_recommendation/drug_recommendation_mimic3_micron.py`` - MICRON for drug recommendation on MIMIC-III - * - ``drug_recommendation_mimic4_gamenet.py`` + * - ``drug_recommendation/drug_recommendation_mimic4_gamenet.py`` - GAMENet for drug recommendation on MIMIC-IV - * - ``drug_recommendation_mimic4_retain.py`` + * - ``drug_recommendation/drug_recommendation_mimic4_retain.py`` - RETAIN for drug recommendation on MIMIC-IV - * - ``drug_recommendation_eICU_transformer.py`` + * - ``drug_recommendation/drug_recommendation_eICU_transformer.py`` - Transformer for drug recommendation on eICU EEG and Sleep Analysis @@ -159,8 +167,10 @@ EEG and Sleep Analysis * - ``cardiology_detection_isAR_SparcNet.py`` - SparcNet for cardiology arrhythmia detection -Image Analysis --------------- +Image Analysis (Chest X-Ray) +---------------------------- + +These examples are located in ``examples/cxr/``. .. list-table:: :widths: 50 50 @@ -168,18 +178,28 @@ Image Analysis * - Example File - Description - * - ``covid19cxr_conformal.py`` + * - ``cxr/covid19cxr_tutorial.py`` + - ViT training, conformal prediction & interpretability for COVID-19 CXR + * - ``cxr/covid19cxr_conformal.py`` - Conformal prediction for COVID-19 CXR classification - * - ``cnn_cxr.ipynb`` + * - ``cxr/cnn_cxr.ipynb`` - CNN for chest X-ray classification (notebook) - * - ``chestXray_image_generation_VAE.py`` + * - ``cxr/chestxray14_binary_classification.ipynb`` + - Binary classification on ChestX-ray14 dataset (notebook) + * - ``cxr/chestxray14_multilabel_classification.ipynb`` + - Multi-label classification on ChestX-ray14 dataset (notebook) + * - ``cxr/ChestXrayClassificationWithSaliency.ipynb`` + - Chest X-ray classification with saliency maps (notebook) + * - ``cxr/chextXray_image_generation_VAE.py`` - VAE for chest X-ray image generation - * - ``ChestXray-image-generation-GAN.ipynb`` + * - ``cxr/ChestXray-image-generation-GAN.ipynb`` - GAN for chest X-ray image generation (notebook) Interpretability ---------------- +These examples are located in ``examples/interpretability/``. + .. list-table:: :widths: 50 50 :header-rows: 1 @@ -188,12 +208,20 @@ Interpretability - Description * - ``integrated_gradients_mortality_mimic4_stagenet.py`` - Integrated Gradients for StageNet interpretability - * - ``deeplift_stagenet_mimic4.py`` + * - ``interpretability/deeplift_stagenet_mimic4.py`` - DeepLift attributions for StageNet on MIMIC-IV - * - ``interpretability_metrics.py`` + * - ``interpretability/gim_stagenet_mimic4.py`` + - GIM attributions for StageNet on MIMIC-IV + * - ``interpretability/gim_transformer_mimic4.py`` + - GIM attributions for Transformer on MIMIC-IV + * - ``interpretability/shap_stagenet_mimic4.py`` + - SHAP attributions for StageNet on MIMIC-IV + * - ``interpretability/interpretability_metrics.py`` - Evaluating attribution methods with metrics - * - ``interpret_demo.ipynb`` + * - ``interpretability/interpret_demo.ipynb`` - Interactive interpretability demonstrations (notebook) + * - ``interpretability/shap_stagenet_mimic4.ipynb`` + - SHAP attributions for StageNet (notebook) Patient Linkage --------------- @@ -207,6 +235,22 @@ Patient Linkage * - ``patient_linkage_mimic3_medlink.py`` - MedLink for patient record linkage on MIMIC-III +Length of Stay +-------------- + +These examples are located in ``examples/length_of_stay/``. + +.. list-table:: + :widths: 50 50 + :header-rows: 1 + + * - Example File + - Description + * - ``length_of_stay/length_of_stay_mimic3_rnn.py`` + - RNN for length of stay prediction on MIMIC-III + * - ``length_of_stay/length_of_stay_mimic4_rnn.py`` + - RNN for length of stay prediction on MIMIC-IV + Advanced Topics --------------- @@ -216,13 +260,11 @@ Advanced Topics * - Example File - Description - * - ``length_of_stay_mimic3_rnn.py`` - - RNN for length of stay prediction * - ``omop_dataset_demo.py`` - Working with OMOP Common Data Model * - ``medcode.py`` - Medical code vocabulary and mappings - * - ``benchmark_ehrshot.ipynb`` + * - ``benchmark_ehrshot_xgboost.ipynb`` - EHRShot benchmark with XGBoost (notebook) Notebooks (Interactive) @@ -238,7 +280,7 @@ Notebooks (Interactive) - Comprehensive StageNet tutorial * - ``mimic3_mortality_prediction_cached.ipynb`` - Cached mortality prediction workflow - * - ``timeseries_mimic4.ipynb`` + * - ``mortality_prediction/timeseries_mimic4.ipynb`` - Time series analysis on MIMIC-IV * - ``transformer_mimic4.ipynb`` - Transformer models on MIMIC-IV @@ -252,7 +294,7 @@ Notebooks (Interactive) - SafeDrug interactive notebook * - ``molerec_mimic3.ipynb`` - MoleRec interactive notebook - * - ``drug_recommendation_mimic3_micron.ipynb`` + * - ``drug_recommendation/drug_recommendation_mimic3_micron.ipynb`` - MICRON interactive notebook * - ``kg_embedding.ipynb`` - Knowledge graph embeddings diff --git a/examples/benchmark_perf/benchmark_pandas_los.py b/examples/benchmark_perf/benchmark_pandas_los.py new file mode 100644 index 000000000..dd72f0a71 --- /dev/null +++ b/examples/benchmark_perf/benchmark_pandas_los.py @@ -0,0 +1,394 @@ +""" +Benchmark script for MIMIC-IV length of stay prediction using pandas +(analogous to PyHealth LengthOfStayPredictionMIMIC4 task). + +This benchmark mimics the LengthOfStayPredictionMIMIC4 task: +1. Creates visit-level samples for each admission +2. For each visit, extracts conditions, procedures, and drugs +3. Calculates length of stay from admission to discharge +4. Categorizes LOS into 10 categories (0-9) + +Length of Stay Categories: +- 0: < 1 day +- 1-7: 1-7 days (each day is its own category) +- 8: 8-14 days (over one week, less than two) +- 9: > 14 days (over two weeks) +""" + +import argparse +import time +import os +import threading +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple + +import pandas as pd +import psutil + + +PEAK_MEM_USAGE = 0 +SELF_PROC = psutil.Process(os.getpid()) +STOP_TRACKING = False + + +def track_mem(): + """Background thread to track peak memory usage.""" + global PEAK_MEM_USAGE + while not STOP_TRACKING: + m = SELF_PROC.memory_info().rss + if m > PEAK_MEM_USAGE: + PEAK_MEM_USAGE = m + time.sleep(0.1) + + +def format_size(size_bytes: int) -> str: + """Format bytes to human-readable string.""" + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} PB" + + +def categorize_los(days: int) -> int: + """Categorizes length of stay into 10 categories. + + One for ICU stays shorter than a day, seven day-long categories for each day of + the first week, one for stays of over one week but less than two, + and one for stays of over two weeks. + + Args: + days: int, length of stay in days + + Returns: + category: int, category of length of stay (0-9) + """ + # ICU stays shorter than a day + if days < 1: + return 0 + # each day of the first week + elif 1 <= days <= 7: + return days + # stays of over one week but less than two + elif 7 < days <= 14: + return 8 + # stays of over two weeks + else: + return 9 + + +def process_patient_length_of_stay( + subject_id: int, + admissions_df: pd.DataFrame, + diagnoses_df: pd.DataFrame, + procedures_df: pd.DataFrame, + prescriptions_df: pd.DataFrame, +) -> List[Dict[str, Any]]: + """Process a single patient for length of stay prediction task. + + Creates visit-level samples with conditions, procedures, drugs, and LOS label. + + Args: + subject_id: Patient ID + admissions_df: Admission records (pre-filtered for this patient) + diagnoses_df: Diagnosis ICD codes (pre-filtered for this patient) + procedures_df: Procedure ICD codes (pre-filtered for this patient) + prescriptions_df: Prescription records (pre-filtered for this patient) + + Returns: + List of sample dictionaries, or empty list if patient doesn't qualify + """ + samples = [] + + # Get all admissions for this patient + patient_admissions = admissions_df[admissions_df["subject_id"] == subject_id] + + if len(patient_admissions) == 0: + return [] + + # Process each admission + for _, admission in patient_admissions.iterrows(): + hadm_id = admission["hadm_id"] + + # Get diagnosis codes for this admission + visit_diagnoses = diagnoses_df[diagnoses_df["hadm_id"] == hadm_id] + # Combine ICD version with code (e.g., "10_A123" or "9_456") + conditions = [] + for _, row in visit_diagnoses.iterrows(): + if pd.notna(row.get("icd_code")) and pd.notna(row.get("icd_version")): + conditions.append(f"{int(row['icd_version'])}_{row['icd_code']}") + + # Get procedure codes for this admission + visit_procedures = procedures_df[procedures_df["hadm_id"] == hadm_id] + procedures = [] + for _, row in visit_procedures.iterrows(): + if pd.notna(row.get("icd_code")) and pd.notna(row.get("icd_version")): + procedures.append(f"{int(row['icd_version'])}_{row['icd_code']}") + + # Get prescriptions for this admission + visit_prescriptions = prescriptions_df[prescriptions_df["hadm_id"] == hadm_id] + drugs = [] + for _, row in visit_prescriptions.iterrows(): + ndc = row.get("ndc") + if pd.notna(ndc) and ndc: + drugs.append(str(ndc)) + + # Exclude visits without condition, procedure, or drug code + if len(conditions) == 0 or len(procedures) == 0 or len(drugs) == 0: + continue + + # Calculate length of stay + admittime = admission["admittime"] + dischtime = admission["dischtime"] + + if pd.isna(admittime) or pd.isna(dischtime): + continue + + # Calculate LOS in days + los_days = (dischtime - admittime).days + los_category = categorize_los(los_days) + + samples.append({ + "visit_id": hadm_id, + "patient_id": subject_id, + "conditions": conditions, + "procedures": procedures, + "drugs": drugs, + "los": los_category, + "los_days": los_days, # Also store raw days for debugging + }) + + return samples + + +def benchmark_length_of_stay( + admissions_df: pd.DataFrame, + diagnoses_df: pd.DataFrame, + procedures_df: pd.DataFrame, + prescriptions_df: pd.DataFrame, + n_patients: Optional[int] = None, +) -> Tuple[List[Dict[str, Any]], float]: + """ + Benchmark MIMIC-IV length of stay processing. + + Args: + admissions_df: Admissions dataframe + diagnoses_df: Diagnoses dataframe + procedures_df: Procedures dataframe + prescriptions_df: Prescriptions dataframe + n_patients: Number of patients to process (None = all patients) + + Returns: + Tuple of (list of samples, processing time in seconds) + """ + print("=" * 80) + print("BENCHMARK: Pandas Length of Stay Prediction (MIMIC-IV format)") + print("=" * 80) + + # Get unique patients + all_patients = admissions_df["subject_id"].unique().tolist() + + if n_patients is None: + patients_to_process = all_patients + print(f"Processing all {len(patients_to_process)} patients...") + else: + patients_to_process = all_patients[:n_patients] + print(f"Processing first {len(patients_to_process)} patients...") + + # Parse datetime columns + admissions_df = admissions_df.copy() + admissions_df["admittime"] = pd.to_datetime(admissions_df["admittime"]) + admissions_df["dischtime"] = pd.to_datetime(admissions_df["dischtime"]) + + # Start processing timer + start_time = time.perf_counter() + + samples = [] + processed_patients = 0 + patients_with_samples = 0 + + # Track LOS distribution + los_distribution = {i: 0 for i in range(10)} + + for subject_id in patients_to_process: + patient_samples = process_patient_length_of_stay( + subject_id, + admissions_df, + diagnoses_df, + procedures_df, + prescriptions_df, + ) + + if patient_samples: + samples.extend(patient_samples) + patients_with_samples += 1 + # Update LOS distribution + for sample in patient_samples: + los_distribution[sample["los"]] += 1 + + processed_patients += 1 + if processed_patients % 1000 == 0: + print(f"Processed {processed_patients} patients, " + f"{len(samples)} samples so far...") + + # End processing timer + processing_time = time.perf_counter() - start_time + + print("\nCompleted processing:") + print(f" - Total patients processed: {processed_patients}") + print(f" - Patients with valid samples: {patients_with_samples}") + print(f" - Total samples created: {len(samples)}") + print(f" - Processing time: {processing_time:.2f}s") + print("\nLOS Category Distribution:") + for cat, count in los_distribution.items(): + pct = (count / len(samples) * 100) if samples else 0 + label = { + 0: "<1 day", + 1: "1 day", 2: "2 days", 3: "3 days", 4: "4 days", + 5: "5 days", 6: "6 days", 7: "7 days", + 8: "8-14 days", + 9: ">14 days", + }.get(cat, str(cat)) + print(f" Category {cat} ({label}): {count} ({pct:.1f}%)") + print("=" * 80) + + return samples, processing_time + + +def load_mimic_data( + data_root: str = "/srv/local/data/physionet.org/files/mimiciv/2.2/hosp", +) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """Load MIMIC-IV tables needed for length of stay prediction. + + Args: + data_root: Root directory for MIMIC-IV hosp data + + Returns: + Tuple of dataframes: (admissions, diagnoses, procedures, prescriptions) + """ + print("Loading MIMIC-IV data tables...") + load_start = time.perf_counter() + + admissions_df = pd.read_csv(f"{data_root}/admissions.csv") + diagnoses_df = pd.read_csv(f"{data_root}/diagnoses_icd.csv.gz") + procedures_df = pd.read_csv(f"{data_root}/procedures_icd.csv.gz") + prescriptions_df = pd.read_csv(f"{data_root}/prescriptions.csv.gz", low_memory=False) + + load_time = time.perf_counter() - load_start + print(f"Data loaded in {load_time:.2f}s") + print(f" - Admissions: {len(admissions_df):,}") + print(f" - Diagnoses: {len(diagnoses_df):,}") + print(f" - Procedures: {len(procedures_df):,}") + print(f" - Prescriptions: {len(prescriptions_df):,}") + print() + + return ( + admissions_df, + diagnoses_df, + procedures_df, + prescriptions_df, + ) + + +def main(): + """Main function to run the benchmark.""" + global STOP_TRACKING + + parser = argparse.ArgumentParser( + description="Benchmark MIMIC-IV length of stay prediction with pandas" + ) + parser.add_argument( + "--data-root", + type=str, + default="/srv/local/data/physionet.org/files/mimiciv/2.2/hosp", + help="Path to MIMIC-IV hosp directory", + ) + parser.add_argument( + "--n-patients", + type=int, + default=None, + help="Number of patients to process (default: all)", + ) + parser.add_argument( + "--output", + type=str, + default="benchmark_results_pandas_los.txt", + help="Output file for results", + ) + args = parser.parse_args() + + # Start memory tracking thread + mem_thread = threading.Thread(target=track_mem, daemon=True) + mem_thread.start() + + # Load data + total_start = time.perf_counter() + ( + admissions_df, + diagnoses_df, + procedures_df, + prescriptions_df, + ) = load_mimic_data(args.data_root) + load_time = time.perf_counter() - total_start + + # Run benchmark + samples, processing_time = benchmark_length_of_stay( + admissions_df, + diagnoses_df, + procedures_df, + prescriptions_df, + n_patients=args.n_patients, + ) + + total_time = time.perf_counter() - total_start + + # Stop memory tracking + STOP_TRACKING = True + time.sleep(0.2) # Allow final memory sample + + # Get peak memory + peak_mem = PEAK_MEM_USAGE + + # Print summary + print("\n" + "=" * 80) + print("FINAL SUMMARY") + print("=" * 80) + print(f"Data loading time: {load_time:.2f}s") + print(f"Processing time: {processing_time:.2f}s") + print(f"Total time: {total_time:.2f}s") + print(f"Total samples: {len(samples)}") + print(f"Peak memory usage: {format_size(peak_mem)}") + print("=" * 80) + + # Save results + with open(args.output, "w") as f: + f.write("BENCHMARK RESULTS: Pandas Length of Stay Prediction (MIMIC-IV)\n") + f.write("=" * 80 + "\n") + f.write(f"Data root: {args.data_root}\n") + f.write(f"N patients: {args.n_patients or 'all'}\n") + f.write("-" * 80 + "\n") + f.write(f"Data loading time: {load_time:.2f}s\n") + f.write(f"Processing time: {processing_time:.2f}s\n") + f.write(f"Total time: {total_time:.2f}s\n") + f.write(f"Total samples: {len(samples)}\n") + f.write(f"Peak memory usage: {format_size(peak_mem)}\n") + f.write("=" * 80 + "\n") + + print(f"\n✓ Results saved to {args.output}") + + # Show example sample + if samples: + print("\nExample sample (first sample):") + first_sample = samples[0] + print(f" Patient ID: {first_sample['patient_id']}") + print(f" Visit ID: {first_sample['visit_id']}") + print(f" Conditions: {first_sample['conditions'][:5]}..." if len(first_sample['conditions']) > 5 else f" Conditions: {first_sample['conditions']}") + print(f" Procedures: {first_sample['procedures'][:3]}..." if len(first_sample['procedures']) > 3 else f" Procedures: {first_sample['procedures']}") + print(f" Drugs: {first_sample['drugs'][:5]}..." if len(first_sample['drugs']) > 5 else f" Drugs: {first_sample['drugs']}") + print(f" LOS (days): {first_sample['los_days']}") + print(f" LOS (category): {first_sample['los']}") + + +if __name__ == "__main__": + main() + diff --git a/examples/benchmark_perf/benchmark_workers_n_length_of_stay.py b/examples/benchmark_perf/benchmark_workers_n_length_of_stay.py new file mode 100644 index 000000000..b8603e01f --- /dev/null +++ b/examples/benchmark_perf/benchmark_workers_n_length_of_stay.py @@ -0,0 +1,400 @@ +"""Benchmark script for MIMIC-IV length of stay prediction across multiple num_workers. + +This benchmark measures: +1. Time to load the base dataset (once) +2. Time to process the task for each num_workers value (optionally repeated) +3. Cache sizes for base dataset and each task run +4. Peak memory usage (RSS, includes child processes) + +Typical usage: + python benchmark_workers_n_length_of_stay.py + python benchmark_workers_n_length_of_stay.py --workers 1,4,8,12,16 --repeats 3 + python benchmark_workers_n_length_of_stay.py --dev --workers 1,2,4 + +Notes: +- The task cache directory is recreated for each run and deleted after measuring size. +- The base dataset cache is also deleted before and after each run, so every run + measures full dataset + task processing time from scratch. +- Peak memory is sampled in a background thread; it reports total RSS of the current + process plus all child processes. +""" + +from __future__ import annotations + +import argparse +import csv +import os +import shutil +import threading +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Iterable + +import psutil + +from pyhealth.datasets import MIMIC4Dataset +from pyhealth.tasks import LengthOfStayPredictionMIMIC4 + +try: + import resource + + HAS_RESOURCE = True +except ImportError: + HAS_RESOURCE = False + + +@dataclass +class RunResult: + num_workers: int + repeat_index: int + dataset_load_s: float + task_process_s: float + total_s: float + base_cache_bytes: int + task_cache_bytes: int + peak_rss_bytes: int + + +def format_size(size_bytes: int) -> str: + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} PB" + + +def get_directory_size(path: str | Path) -> int: + total = 0 + p = Path(path) + if not p.exists(): + return 0 + try: + for entry in p.rglob("*"): + if entry.is_file(): + try: + total += entry.stat().st_size + except FileNotFoundError: + # File might disappear if something is concurrently modifying cache. + pass + except Exception as e: + print(f"Error calculating size for {p}: {e}") + return total + + +def set_memory_limit(max_memory_gb: float) -> None: + """Set a hard virtual memory limit for the process.""" + if not HAS_RESOURCE: + print( + "Warning: resource module not available (Windows?). " + "Memory limit not enforced." + ) + return + + max_memory_bytes = int(max_memory_gb * 1024**3) + try: + resource.setrlimit(resource.RLIMIT_AS, (max_memory_bytes, max_memory_bytes)) + print(f"✓ Memory limit set to {max_memory_gb} GB") + except Exception as e: + print(f"Warning: Failed to set memory limit: {e}") + + +class PeakMemoryTracker: + """Tracks peak RSS for current process + children.""" + + def __init__(self, poll_interval_s: float = 0.1) -> None: + self._proc = psutil.Process(os.getpid()) + self._poll_interval_s = poll_interval_s + self._stop = threading.Event() + self._lock = threading.Lock() + self._peak = 0 + self._thread = threading.Thread(target=self._run, daemon=True) + + def start(self) -> None: + self._thread.start() + + def reset(self) -> None: + with self._lock: + self._peak = 0 + + def stop(self) -> None: + self._stop.set() + + def peak_bytes(self) -> int: + with self._lock: + return self._peak + + def _total_rss_bytes(self) -> int: + total = 0 + try: + total += self._proc.memory_info().rss + for child in self._proc.children(recursive=True): + try: + total += child.memory_info().rss + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + return total + + def _run(self) -> None: + while not self._stop.is_set(): + rss = self._total_rss_bytes() + with self._lock: + if rss > self._peak: + self._peak = rss + time.sleep(self._poll_interval_s) + + +def parse_workers(value: str) -> list[int]: + parts = [p.strip() for p in value.split(",") if p.strip()] + workers: list[int] = [] + for p in parts: + w = int(p) + if w <= 0: + raise argparse.ArgumentTypeError("All worker counts must be > 0") + workers.append(w) + if not workers: + raise argparse.ArgumentTypeError("No workers provided") + return workers + + +def ensure_empty_dir(path: str | Path) -> None: + p = Path(path) + if p.exists(): + shutil.rmtree(p) + p.mkdir(parents=True, exist_ok=True) + + +def remove_dir(path: str | Path, retries: int = 3, delay: float = 1.0) -> None: + """Remove a directory with retry logic for busy file handles.""" + p = Path(path) + if not p.exists(): + return + for attempt in range(retries): + try: + shutil.rmtree(p) + return + except OSError as e: + if attempt < retries - 1: + time.sleep(delay) + else: + print(f"Warning: Failed to delete {p} after {retries} attempts: {e}") + + +def median(values: Iterable[float]) -> float: + xs = sorted(values) + if not xs: + return 0.0 + mid = len(xs) // 2 + if len(xs) % 2 == 1: + return xs[mid] + return (xs[mid - 1] + xs[mid]) / 2.0 + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Benchmark MIMIC-IV length of stay prediction over multiple num_workers" + ) + parser.add_argument( + "--workers", + type=parse_workers, + default=[1, 4, 8, 12, 16], + help="Comma-separated list of num_workers values (default: 1,4,8,12,16)", + ) + parser.add_argument( + "--repeats", + type=int, + default=1, + help="Number of repeats per worker setting (default: 1)", + ) + parser.add_argument( + "--dev", + action="store_true", + help="Use dev mode dataset loading (smaller subset)", + ) + parser.add_argument( + "--ehr-root", + type=str, + default="/srv/local/data/physionet.org/files/mimiciv/2.2/", + help="Path to MIMIC-IV root directory", + ) + parser.add_argument( + "--cache-root", + type=str, + default="/shared/eng/pyhealth/", + help="Root directory for benchmark caches (default: /shared/eng/pyhealth/)", + ) + parser.add_argument( + "--enable-memory-limit", + action="store_true", + help="Enforce a hard memory limit via resource.setrlimit (Unix only)", + ) + parser.add_argument( + "--max-memory-gb", + type=float, + default=None, + help=( + "Hard memory limit in GB (only used if --enable-memory-limit is set). " + "If omitted, no memory limit is applied by default." + ), + ) + parser.add_argument( + "--output-csv", + type=str, + default="../benchmark_results_los_workers_sweep.csv", + help="Where to write per-run results as CSV", + ) + args = parser.parse_args() + + if args.repeats <= 0: + raise SystemExit("--repeats must be > 0") + + if args.enable_memory_limit: + if args.max_memory_gb is None: + raise SystemExit( + "When using --enable-memory-limit, you must also pass " + "--max-memory-gb (e.g., --max-memory-gb 32)." + ) + set_memory_limit(args.max_memory_gb) + + tracker = PeakMemoryTracker(poll_interval_s=0.1) + tracker.start() + + print("=" * 80) + print("BENCHMARK: Length of Stay Prediction num_workers sweep") + print(f"workers={args.workers} repeats={args.repeats} dev={args.dev}") + if args.enable_memory_limit: + print(f"Memory Limit: {args.max_memory_gb} GB (ENFORCED)") + else: + print("Memory Limit: None (unrestricted)") + print(f"ehr_root: {args.ehr_root}") + print(f"cache_root: {args.cache_root}") + print("=" * 80) + + cache_root = Path(args.cache_root) + base_cache_dir = cache_root / ( + "base_dataset_los_dev" if args.dev else "base_dataset_los" + ) + + total_start = time.time() + + results: list[RunResult] = [] + + print("\n[1/1] Sweeping num_workers (each run reloads dataset + task)...") + for w in args.workers: + for r in range(args.repeats): + task_cache_dir = cache_root / f"task_samples_los_w{w}" + + # Ensure no cache artifacts before this run. + remove_dir(base_cache_dir) + ensure_empty_dir(task_cache_dir) + + tracker.reset() + run_start = time.time() + + dataset_start = time.time() + base_dataset = MIMIC4Dataset( + ehr_root=args.ehr_root, + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "prescriptions", + ], + dev=args.dev, + cache_dir=str(base_cache_dir), + ) + dataset_load_s = time.time() - dataset_start + base_cache_bytes = get_directory_size(base_cache_dir) + + task_start = time.time() + sample_dataset = base_dataset.set_task( + LengthOfStayPredictionMIMIC4(), + num_workers=w, + cache_dir=str(task_cache_dir), + ) + + task_process_s = time.time() - task_start + total_s = time.time() - run_start + peak_rss_bytes = tracker.peak_bytes() + task_cache_bytes = get_directory_size(task_cache_dir) + + # Capture sample count BEFORE cleaning up the cache (litdata needs it). + num_samples = len(sample_dataset) + + # Release the dataset reference to free file handles before cleanup. + del sample_dataset + del base_dataset + + # Clean up to avoid disk growth across an overnight sweep. + remove_dir(task_cache_dir) + remove_dir(base_cache_dir) + + results.append( + RunResult( + num_workers=w, + repeat_index=r, + dataset_load_s=dataset_load_s, + task_process_s=task_process_s, + total_s=total_s, + base_cache_bytes=base_cache_bytes, + task_cache_bytes=task_cache_bytes, + peak_rss_bytes=peak_rss_bytes, + ) + ) + + print( + "✓ " + f"workers={w:>2} repeat={r+1:>2}/{args.repeats} " + f"samples={num_samples} " + f"dataset={dataset_load_s:.2f}s " + f"task={task_process_s:.2f}s " + f"total={total_s:.2f}s " + f"peak_rss={format_size(peak_rss_bytes)} " + f"base_cache={format_size(base_cache_bytes)} " + f"task_cache={format_size(task_cache_bytes)}" + ) + + total_sweep_s = time.time() - total_start + + # Write CSV + out_csv = Path(args.output_csv) + out_csv.parent.mkdir(parents=True, exist_ok=True) + with out_csv.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=list(asdict(results[0]).keys())) + writer.writeheader() + for rr in results: + writer.writerow(asdict(rr)) + + # Print a compact summary per worker (median across repeats) + print("\n" + "=" * 80) + print("SUMMARY (median across repeats)") + print("=" * 80) + + for w in args.workers: + wrs = [rr for rr in results if rr.num_workers == w] + med_task = median([rr.task_process_s for rr in wrs]) + med_total = median([rr.total_s for rr in wrs]) + med_peak = median([float(rr.peak_rss_bytes) for rr in wrs]) + med_cache = median([float(rr.task_cache_bytes) for rr in wrs]) + print( + f"workers={w:>2} " + f"task_med={med_task:>8.2f}s " + f"total_med={med_total:>8.2f}s " + f"peak_rss_med={format_size(int(med_peak)):>10} " + f"task_cache_med={format_size(int(med_cache)):>10}" + ) + + print("\nArtifacts:") + print(f" - CSV: {out_csv}") + print(f" - Cache root: {cache_root}") + print("\nTotals:") + print(f" - Sweep wall time: {total_sweep_s:.2f}s") + print("=" * 80) + + +if __name__ == "__main__": + main() + diff --git a/examples/benchmark_perf/legacy_ver/benchmark_legacy_los.py b/examples/benchmark_perf/legacy_ver/benchmark_legacy_los.py new file mode 100644 index 000000000..7d036f40b --- /dev/null +++ b/examples/benchmark_perf/legacy_ver/benchmark_legacy_los.py @@ -0,0 +1,473 @@ +"""Legacy PyHealth 1.1.6 Benchmark script for MIMIC-IV length of stay prediction. + +This benchmark measures performance across multiple worker counts (via pandarallel): +1. Time to load the base dataset +2. Time to process the task for each num_workers value +3. Cache sizes for base dataset +4. Peak memory usage (RSS, includes child processes) + +Typical usage: + # First, install the legacy version: + pip install pyhealth==1.1.6 + + # Then run the benchmark: + python benchmark_legacy_los.py + python benchmark_legacy_los.py --workers 1,4,8,12,16 --repeats 3 + python benchmark_legacy_los.py --dev --workers 1,2,4 + +API differences from PyHealth 2.0: +- Uses `root` instead of `ehr_root` +- Uses `tables` instead of `ehr_tables` +- Uses `refresh_cache` instead of `cache_dir` +- set_task() takes a `task_fn` function instead of a task class +- Parallelization via pandarallel (not num_workers in set_task) + +Notes: +- This uses the PyHealth 1.1.6 legacy API +- Uses the built-in length_of_stay_prediction_mimic4_fn task function +- Pandarallel is re-initialized for each worker count +- Cache is cleared before each run to ensure fresh timing data +- Peak memory is sampled in a background thread; it reports total RSS of the current + process plus all child processes. +""" + +from __future__ import annotations + +import argparse +import csv +import os +import shutil +import threading +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Iterable + +import psutil + +# Import pandarallel - will be initialized before each run +from pandarallel import pandarallel + +# Global variable to store desired worker count for monkey-patching +_DESIRED_NB_WORKERS: int = 16 + +# Store the original pandarallel.initialize function +_original_pandarallel_initialize = pandarallel.initialize + + +def _patched_pandarallel_initialize(*args, **kwargs): + """Patched pandarallel.initialize that enforces our worker count. + + The legacy PyHealth code calls pandarallel.initialize() without nb_workers, + which defaults to all CPUs. This patch ensures our desired worker count is used. + """ + # Override nb_workers with our desired value + kwargs['nb_workers'] = _DESIRED_NB_WORKERS + return _original_pandarallel_initialize(*args, **kwargs) + + +# Apply the monkey-patch +pandarallel.initialize = _patched_pandarallel_initialize + +# Legacy PyHealth 1.1.6 imports (AFTER monkey-patching) +from pyhealth.datasets import MIMIC4Dataset +from pyhealth.datasets.utils import MODULE_CACHE_PATH +from pyhealth.tasks import length_of_stay_prediction_mimic4_fn + +try: + import resource + + HAS_RESOURCE = True +except ImportError: + HAS_RESOURCE = False + + +# ============================================================================= +# Benchmark Infrastructure +# ============================================================================= + + +@dataclass +class RunResult: + num_workers: int + repeat_index: int + dataset_load_s: float + task_process_s: float + total_s: float + base_cache_bytes: int + peak_rss_bytes: int + num_samples: int + + +def format_size(size_bytes: int) -> str: + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} PB" + + +def get_directory_size(path: str | Path) -> int: + total = 0 + p = Path(path) + if not p.exists(): + return 0 + try: + for entry in p.rglob("*"): + if entry.is_file(): + try: + total += entry.stat().st_size + except FileNotFoundError: + pass + except Exception as e: + print(f"Error calculating size for {p}: {e}") + return total + + +def clear_pyhealth_cache(verbose: bool = True) -> int: + """Clear all PyHealth cache files. + + PyHealth 1.1.6 stores cache as .pkl files in MODULE_CACHE_PATH. + This function deletes all .pkl files in that directory. + + Args: + verbose: Whether to print information about deleted files. + + Returns: + Number of cache files deleted. + """ + cache_path = Path(MODULE_CACHE_PATH) + if not cache_path.exists(): + return 0 + + deleted_count = 0 + total_size = 0 + + # Find all .pkl cache files + for cache_file in cache_path.glob("*.pkl"): + try: + file_size = cache_file.stat().st_size + cache_file.unlink() + deleted_count += 1 + total_size += file_size + except OSError as e: + if verbose: + print(f" Warning: Could not delete {cache_file}: {e}") + + if verbose and deleted_count > 0: + print(f" Cleared {deleted_count} cache files ({format_size(total_size)})") + + return deleted_count + + +def set_memory_limit(max_memory_gb: float) -> None: + """Set a hard virtual memory limit for the process.""" + if not HAS_RESOURCE: + print( + "Warning: resource module not available (Windows?). " + "Memory limit not enforced." + ) + return + + max_memory_bytes = int(max_memory_gb * 1024**3) + try: + resource.setrlimit(resource.RLIMIT_AS, (max_memory_bytes, max_memory_bytes)) + print(f"✓ Memory limit set to {max_memory_gb} GB") + except Exception as e: + print(f"Warning: Failed to set memory limit: {e}") + + +class PeakMemoryTracker: + """Tracks peak RSS for current process + children.""" + + def __init__(self, poll_interval_s: float = 0.1) -> None: + self._proc = psutil.Process(os.getpid()) + self._poll_interval_s = poll_interval_s + self._stop = threading.Event() + self._lock = threading.Lock() + self._peak = 0 + self._thread = threading.Thread(target=self._run, daemon=True) + + def start(self) -> None: + self._thread.start() + + def reset(self) -> None: + with self._lock: + self._peak = 0 + + def stop(self) -> None: + self._stop.set() + + def peak_bytes(self) -> int: + with self._lock: + return self._peak + + def _total_rss_bytes(self) -> int: + total = 0 + try: + total += self._proc.memory_info().rss + for child in self._proc.children(recursive=True): + try: + total += child.memory_info().rss + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + return total + + def _run(self) -> None: + while not self._stop.is_set(): + rss = self._total_rss_bytes() + with self._lock: + if rss > self._peak: + self._peak = rss + time.sleep(self._poll_interval_s) + + +def remove_dir(path: str | Path, retries: int = 3, delay: float = 1.0) -> None: + """Remove a directory with retry logic for busy file handles.""" + p = Path(path) + if not p.exists(): + return + for attempt in range(retries): + try: + shutil.rmtree(p) + return + except OSError as e: + if attempt < retries - 1: + time.sleep(delay) + else: + print(f"Warning: Failed to delete {p} after {retries} attempts: {e}") + + +def parse_workers(value: str) -> list[int]: + """Parse comma-separated list of worker counts.""" + parts = [p.strip() for p in value.split(",") if p.strip()] + workers: list[int] = [] + for p in parts: + w = int(p) + if w <= 0: + raise argparse.ArgumentTypeError("All worker counts must be > 0") + workers.append(w) + if not workers: + raise argparse.ArgumentTypeError("No workers provided") + return workers + + +def median(values: Iterable[float]) -> float: + xs = sorted(values) + if not xs: + return 0.0 + mid = len(xs) // 2 + if len(xs) % 2 == 1: + return xs[mid] + return (xs[mid - 1] + xs[mid]) / 2.0 + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Legacy PyHealth 1.1.6 Benchmark for MIMIC-IV length of stay prediction" + ) + parser.add_argument( + "--workers", + type=parse_workers, + default=[1, 4, 8, 12, 16], + help="Comma-separated list of num_workers values (default: 1,4,8,12,16)", + ) + parser.add_argument( + "--repeats", + type=int, + default=1, + help="Number of repeats per worker setting (default: 1)", + ) + parser.add_argument( + "--dev", + action="store_true", + help="Use dev mode dataset loading (smaller subset)", + ) + parser.add_argument( + "--root", + type=str, + default="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp", + help="Path to MIMIC-IV hosp directory (legacy uses single root)", + ) + parser.add_argument( + "--enable-memory-limit", + action="store_true", + help="Enforce a hard memory limit via resource.setrlimit (Unix only)", + ) + parser.add_argument( + "--max-memory-gb", + type=float, + default=None, + help=( + "Hard memory limit in GB (only used if --enable-memory-limit is set). " + "If omitted, no memory limit is applied by default." + ), + ) + parser.add_argument( + "--output-csv", + type=str, + default="benchmark_legacy_los_workers_sweep.csv", + help="Where to write per-run results as CSV", + ) + parser.add_argument( + "--no-clear-cache", + action="store_true", + help="Do not clear PyHealth cache before each run (default: clear cache)", + ) + args = parser.parse_args() + + if args.repeats <= 0: + raise SystemExit("--repeats must be > 0") + + if args.enable_memory_limit: + if args.max_memory_gb is None: + raise SystemExit( + "When using --enable-memory-limit, you must also pass " + "--max-memory-gb (e.g., --max-memory-gb 32)." + ) + set_memory_limit(args.max_memory_gb) + + tracker = PeakMemoryTracker(poll_interval_s=0.1) + tracker.start() + + print("=" * 80) + print("LEGACY BENCHMARK: PyHealth 1.1.6 API - Length of Stay Prediction (Worker Sweep)") + print(f"workers={args.workers} repeats={args.repeats} dev={args.dev}") + print(f"clear_cache={not args.no_clear_cache}") + if args.enable_memory_limit: + print(f"Memory Limit: {args.max_memory_gb} GB (ENFORCED)") + else: + print("Memory Limit: None (unrestricted)") + print(f"root: {args.root}") + print(f"cache_path: {MODULE_CACHE_PATH}") + print("=" * 80) + + # Determine cache directory based on PyHealth's default location + cache_dir = Path(MODULE_CACHE_PATH) + + total_start = time.time() + results: list[RunResult] = [] + + print("\n[1/1] Sweeping num_workers (pandarallel)...") + + for w in args.workers: + for r in range(args.repeats): + # Clear cache before each run to ensure fresh timing data + if not args.no_clear_cache: + print(f"\n Clearing PyHealth cache...") + clear_pyhealth_cache(verbose=True) + + # Set the desired worker count for pandarallel + # The monkey-patched initialize() will enforce this when PyHealth calls it + global _DESIRED_NB_WORKERS + _DESIRED_NB_WORKERS = w + print(f" Set pandarallel worker count to {w} (will be enforced via monkey-patch)") + + tracker.reset() + run_start = time.time() + + # Step 1: Load base dataset using legacy API + print(f" workers={w} repeat={r + 1}/{args.repeats}: Loading dataset...") + dataset_start = time.time() + + # Legacy PyHealth 1.1.6 API: + # - Uses `root` instead of `ehr_root` + # - Uses `tables` instead of `ehr_tables` + # - Always use refresh_cache=True to force reprocessing (for accurate timing) + # Length of stay uses: diagnoses_icd, procedures_icd, prescriptions + # The task requires drugs from prescriptions table + base_dataset = MIMIC4Dataset( + root=args.root, + tables=["diagnoses_icd", "procedures_icd", "prescriptions"], + dev=args.dev, + code_mapping={"ICD10PROC": "CCSPROC", "NDC": "ATC"}, + refresh_cache=True, # Always refresh to measure processing time + ) + dataset_load_s = time.time() - dataset_start + base_cache_bytes = get_directory_size(cache_dir) + + # Step 2: Set task using legacy API (built-in length_of_stay_prediction_mimic4_fn) + print(" Processing task...") + task_start = time.time() + + sample_dataset = base_dataset.set_task( + task_fn=length_of_stay_prediction_mimic4_fn + ) + + task_process_s = time.time() - task_start + total_s = time.time() - run_start + peak_rss_bytes = tracker.peak_bytes() + + # Get sample count + num_samples = len(sample_dataset.samples) + + results.append( + RunResult( + num_workers=w, + repeat_index=r, + dataset_load_s=dataset_load_s, + task_process_s=task_process_s, + total_s=total_s, + base_cache_bytes=base_cache_bytes, + peak_rss_bytes=peak_rss_bytes, + num_samples=num_samples, + ) + ) + + print( + f" ✓ workers={w:>2} repeat={r + 1:>2}/{args.repeats} " + f"samples={num_samples} " + f"dataset={dataset_load_s:.2f}s " + f"task={task_process_s:.2f}s " + f"total={total_s:.2f}s " + f"peak_rss={format_size(peak_rss_bytes)} " + f"cache={format_size(base_cache_bytes)}" + ) + + # Clean up references + del sample_dataset + del base_dataset + + total_sweep_s = time.time() - total_start + + # Write CSV + out_csv = Path(args.output_csv) + out_csv.parent.mkdir(parents=True, exist_ok=True) + with out_csv.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=list(asdict(results[0]).keys())) + writer.writeheader() + for rr in results: + writer.writerow(asdict(rr)) + + # Print a compact summary per worker (median across repeats) + print("\n" + "=" * 80) + print("SUMMARY (median across repeats)") + print("=" * 80) + + for w in args.workers: + wrs = [rr for rr in results if rr.num_workers == w] + med_dataset = median([rr.dataset_load_s for rr in wrs]) + med_task = median([rr.task_process_s for rr in wrs]) + med_total = median([rr.total_s for rr in wrs]) + med_peak = median([float(rr.peak_rss_bytes) for rr in wrs]) + print( + f"workers={w:>2} " + f"dataset_med={med_dataset:>8.2f}s " + f"task_med={med_task:>8.2f}s " + f"total_med={med_total:>8.2f}s " + f"peak_rss_med={format_size(int(med_peak)):>10}" + ) + + print("\nArtifacts:") + print(f" - CSV: {out_csv}") + print(f" - Cache dir: {cache_dir}") + print("\nTotals:") + print(f" - Sweep wall time: {total_sweep_s:.2f}s") + print("=" * 80) + + +if __name__ == "__main__": + main() + diff --git a/examples/benchmark_perf/meds_reader_ver/benchmark_meds_reader_drug_rec.py b/examples/benchmark_perf/meds_reader_ver/benchmark_meds_reader_drug_rec.py new file mode 100644 index 000000000..b81a6d3e2 --- /dev/null +++ b/examples/benchmark_perf/meds_reader_ver/benchmark_meds_reader_drug_rec.py @@ -0,0 +1,847 @@ +"""Benchmark script for MIMIC-IV drug recommendation using meds_reader. + +This benchmark measures performance across multiple thread counts: +1. Time for MEDS ETL conversion (MIMIC-IV -> MEDS format) +2. Time for meds_reader database conversion (MEDS -> meds_reader format) +3. Time to process the task +4. Peak memory usage (RSS, includes child processes) +5. Number of samples generated + +IMPORTANT: For fair comparison with PyHealth, conversion time MUST be included. +PyHealth's dataset loading includes parsing raw MIMIC-IV CSVs, so we must +account for the equivalent preprocessing time in meds_reader. + +This script uses meds_etl for data conversion: +- Converts MIMIC-IV directly to MEDS format via meds_etl_mimic +- Runs meds_reader_convert to prepare the database +- Then runs the benchmark + +Typical usage: + # First install dependencies: + pip install meds_etl meds_reader + + # Run benchmark (includes conversion time by default): + python benchmark_meds_reader_drug_rec.py + python benchmark_meds_reader_drug_rec.py --threads 1,4,8,12,16 --repeats 3 + + # Skip conversion (only for debugging, not fair benchmarking): + python benchmark_meds_reader_drug_rec.py --skip-conversion +""" + +from __future__ import annotations + +import argparse +import collections +import csv +import os +import shutil +import subprocess +import threading +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, Iterator, List + +import psutil +import torch +from torch.utils.data import Dataset + +try: + import meds_reader +except ImportError: + raise ImportError( + "meds_reader not found. Install with: pip install meds_reader\n" + "Or from source: pip install -e /path/to/meds_reader" + ) + + +# ============================================================================= +# PyTorch Dataset Wrapper +# ============================================================================= + +class MedsReaderSampleDataset(Dataset): + """PyTorch Dataset wrapper for meds_reader samples. + + Provides a standard PyTorch Dataset interface for model training. + """ + + def __init__( + self, + samples: List[Dict[str, Any]], + input_schema: Dict[str, str], + output_schema: Dict[str, str], + input_processors: Dict[str, Any], + output_processors: Dict[str, Any], + dataset_name: str = "", + task_name: str = "", + ): + self.samples = samples + self.input_schema = input_schema + self.output_schema = output_schema + self.input_processors = input_processors + self.output_processors = output_processors + self.dataset_name = dataset_name + self.task_name = task_name + + self.patient_to_index: Dict[Any, List[int]] = collections.defaultdict(list) + self.record_to_index: Dict[Any, List[int]] = collections.defaultdict(list) + + for idx, sample in enumerate(samples): + if "patient_id" in sample: + self.patient_to_index[sample["patient_id"]].append(idx) + if "visit_id" in sample: + self.record_to_index[sample["visit_id"]].append(idx) + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, index: int) -> Dict[str, Any]: + return self.samples[index] + + def __repr__(self) -> str: + return f"MedsReaderSampleDataset({self.dataset_name}, {self.task_name}, n={len(self)})" + + +# ============================================================================= +# Processor Classes (matching PyHealth's SequenceProcessor for fair comparison) +# ============================================================================= + +class SequenceProcessor: + """Matches PyHealth's SequenceProcessor for vocabulary building and tokenization.""" + + def __init__(self): + self.code_vocab = {"": 0} + self._next_index = 1 + + def fit(self, samples, field): + """Build vocabulary from all samples (first pass through data).""" + for sample in samples: + if field not in sample: + continue + # Handle nested sequences (list of lists) + values = sample[field] + if values and isinstance(values[0], list): + for visit_values in values: + for token in visit_values: + if token is None: + continue + if token not in self.code_vocab: + self.code_vocab[token] = self._next_index + self._next_index += 1 + else: + for token in values: + if token is None: + continue + if token not in self.code_vocab: + self.code_vocab[token] = self._next_index + self._next_index += 1 + self.code_vocab[""] = len(self.code_vocab) + + def process(self, value): + """Convert code strings to tensor of indices.""" + indices = [] + for token in value: + if token in self.code_vocab: + indices.append(self.code_vocab[token]) + else: + indices.append(self.code_vocab[""]) + return torch.tensor(indices, dtype=torch.long) + + def process_nested(self, values): + """Convert nested sequences to list of tensors.""" + return [self.process(v) for v in values] + + def size(self): + return len(self.code_vocab) + + +class MultilabelProcessor: + """Processor for multilabel outputs (matching PyHealth's MultiLabelProcessor).""" + + def __init__(self): + self.label_vocab = {} + + def fit(self, samples, field): + """Build vocabulary from all label values.""" + for sample in samples: + if field in sample: + for val in sample[field]: + if val not in self.label_vocab: + self.label_vocab[val] = len(self.label_vocab) + + def process(self, value): + """Convert label list to multi-hot tensor.""" + multi_hot = torch.zeros(len(self.label_vocab), dtype=torch.float32) + for val in value: + if val in self.label_vocab: + multi_hot[self.label_vocab[val]] = 1.0 + return multi_hot + + def size(self): + return len(self.label_vocab) + + +try: + import resource + HAS_RESOURCE = True +except ImportError: + HAS_RESOURCE = False + + +# ============================================================================= +# Data Conversion (MIMIC-IV -> MEDS -> meds_reader via meds_etl) +# ============================================================================= + +def run_meds_etl_mimic( + src_mimic: str, + output_dir: str, + num_shards: int = 100, + num_proc: int = 1, + backend: str = "polars", +) -> float: + """Run meds_etl_mimic to convert MIMIC-IV to MEDS format. + + Args: + src_mimic: Path to MIMIC-IV root (containing 2.2/ subdirectory) + output_dir: Path to output MEDS dataset + num_shards: Number of shards for processing + num_proc: Number of processes to use + backend: Backend to use (polars or cpp) + + Returns: + Time taken in seconds + """ + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + print(f" Running meds_etl_mimic (shards={num_shards}, proc={num_proc}, backend={backend})...") + print(f" Source: {src_mimic}") + print(f" Destination: {output_dir}") + + start = time.time() + result = subprocess.run( + [ + "meds_etl_mimic", + src_mimic, + output_dir, + "--num_shards", str(num_shards), + "--num_proc", str(num_proc), + "--backend", backend, + ], + capture_output=True, + text=True, + ) + elapsed = time.time() - start + + if result.returncode != 0: + print(f" STDOUT: {result.stdout}") + print(f" STDERR: {result.stderr}") + raise RuntimeError(f"meds_etl_mimic failed with code {result.returncode}") + + print(f" meds_etl_mimic completed in {elapsed:.2f}s") + return elapsed + + +def run_meds_reader_convert(input_dir: str, output_dir: str, num_threads: int = 10) -> float: + """Run meds_reader_convert CLI tool. Returns time taken.""" + print(f" Running meds_reader_convert (threads={num_threads})...") + print(f" {input_dir} -> {output_dir}") + + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + start = time.time() + try: + result = subprocess.run( + ["meds_reader_convert", input_dir, output_dir, "--num_threads", str(num_threads)], + capture_output=True, + text=True, + check=True, + ) + elapsed = time.time() - start + print(f" meds_reader_convert completed in {elapsed:.2f}s") + return elapsed + except subprocess.CalledProcessError as e: + print(f" ERROR: meds_reader_convert failed:") + print(f" stdout: {e.stdout}") + print(f" stderr: {e.stderr}") + raise + except FileNotFoundError: + print(f" ERROR: meds_reader_convert not found in PATH") + raise + + +@dataclass +class ConversionResult: + """Holds timing information for the MEDS conversion process.""" + meds_etl_s: float + meds_reader_convert_s: float + total_conversion_s: float + was_cached: bool # True if conversion was skipped due to existing cache + + +def run_meds_conversion( + mimic_root: str, + meds_dir: str, + meds_reader_dir: str, + num_shards: int, + num_proc: int, + backend: str, + force_reconvert: bool, + skip_conversion: bool, +) -> ConversionResult: + """Run MEDS conversion and return timing information. + + Args: + mimic_root: Path to MIMIC-IV root directory + meds_dir: Path for intermediate MEDS output + meds_reader_dir: Path for final meds_reader database + num_shards: Number of shards for meds_etl + num_proc: Number of processes for meds_etl + backend: Backend for meds_etl (polars or cpp) + force_reconvert: If True, always reconvert even if cache exists + skip_conversion: If True, skip conversion (for debugging only) + + Returns: + ConversionResult with timing information + """ + # Check if we should skip conversion + if skip_conversion: + if not Path(meds_reader_dir).exists(): + raise SystemExit( + f"Cannot skip conversion: MEDS database does not exist at {meds_reader_dir}\n" + "Run without --skip-conversion first." + ) + print(f"✓ Skipping conversion (using cached MEDS database: {meds_reader_dir})") + print(" WARNING: For fair benchmarking, conversion time should be included!") + return ConversionResult( + meds_etl_s=0.0, + meds_reader_convert_s=0.0, + total_conversion_s=0.0, + was_cached=True, + ) + + # Check if we can reuse existing cache + if Path(meds_reader_dir).exists() and not force_reconvert: + print(f"✓ MEDS database exists: {meds_reader_dir}") + print(" NOTE: Using cached data. Use --force-reconvert for fresh timing.") + return ConversionResult( + meds_etl_s=0.0, + meds_reader_convert_s=0.0, + total_conversion_s=0.0, + was_cached=True, + ) + + print(f"\n{'='*60}") + print(f"Converting MIMIC-IV to MEDS format") + print(f"{'='*60}") + + # Clear existing cache directories to avoid interference + if Path(meds_dir).exists(): + print(f" Clearing existing MEDS cache: {meds_dir}") + shutil.rmtree(meds_dir) + if Path(meds_reader_dir).exists(): + print(f" Clearing existing meds_reader cache: {meds_reader_dir}") + shutil.rmtree(meds_reader_dir) + + # Verify MIMIC-IV structure + mimic_version_path = os.path.join(mimic_root, "2.2") + if not os.path.exists(mimic_version_path): + raise SystemExit( + f"ERROR: Expected MIMIC-IV version directory not found: {mimic_version_path}\n" + f"meds_etl_mimic expects the MIMIC-IV data to be in {{mimic_root}}/2.2/" + ) + + # Step 1: Convert MIMIC-IV -> MEDS using meds_etl + print(f"\n[Step 1/2] Converting MIMIC-IV to MEDS format using meds_etl...") + meds_etl_s = run_meds_etl_mimic( + src_mimic=mimic_root, + output_dir=meds_dir, + num_shards=num_shards, + num_proc=num_proc, + backend=backend, + ) + + # Step 2: Run meds_reader_convert + print(f"\n[Step 2/2] Running meds_reader_convert...") + meds_reader_convert_s = run_meds_reader_convert( + meds_dir, meds_reader_dir, num_threads=num_proc + ) + + total_conversion_s = meds_etl_s + meds_reader_convert_s + print(f"\n✓ MEDS database ready: {meds_reader_dir}") + print(f" Total conversion time: {total_conversion_s:.2f}s") + + return ConversionResult( + meds_etl_s=meds_etl_s, + meds_reader_convert_s=meds_reader_convert_s, + total_conversion_s=total_conversion_s, + was_cached=False, + ) + + +# ============================================================================= +# Task Function - Drug Recommendation +# ============================================================================= + +def get_drug_rec_samples(subjects: Iterator[meds_reader.Subject]): + """Process subjects for drug recommendation task. + + Uses MEDS-ETL code conventions: + - Admission codes are like "MIMIC_IV_Admission/..." + - Diagnosis codes are like "ICD10CM/..." or "ICD9CM/..." + - Procedure codes are like "ICD10PCS/..." or "ICD9Proc/..." + - Prescriptions are like "NDC/..." or "MIMIC_IV_Drug/..." + + Drug recommendation predicts drugs for current visit based on + cumulative history of conditions, procedures, and past drugs. + """ + samples = [] + + for subject in subjects: + # Collect all admissions with their data + admissions = {} # visit_id -> {time, conditions, procedures, drugs} + + # First pass: identify admissions + for event in subject.events: + if event.code.startswith("MIMIC_IV_Admission/"): + visit_id = getattr(event, 'visit_id', None) + if visit_id is not None and event.time is not None: + admissions[visit_id] = { + 'time': event.time, + 'conditions': set(), + 'procedures': set(), + 'drugs': set(), + } + + # Second pass: collect features per admission + for event in subject.events: + visit_id = getattr(event, 'visit_id', None) + if visit_id is None or visit_id not in admissions: + continue + + code = event.code + if code.startswith("ICD"): # ICD9CM, ICD10CM, ICD9Proc, ICD10PCS + if "CM" in code: + admissions[visit_id]['conditions'].add(code) + else: + admissions[visit_id]['procedures'].add(code) + elif code.startswith("NDC/") or code.startswith("MIMIC_IV_Drug/"): + admissions[visit_id]['drugs'].add(code) + + # Sort admissions by time + sorted_visits = sorted( + [(vid, data) for vid, data in admissions.items()], + key=lambda x: x[1]['time'] + ) + + # Filter to visits with complete data + valid_visits = [ + (vid, data) for vid, data in sorted_visits + if len(data['conditions']) > 0 and len(data['procedures']) > 0 and len(data['drugs']) > 0 + ] + + # Need at least 2 visits for drug recommendation + if len(valid_visits) < 2: + continue + + # Create samples with cumulative history + for i, (visit_id, data) in enumerate(valid_visits): + # Build cumulative history + conditions_hist = [] + procedures_hist = [] + drugs_hist = [] + + for j in range(i + 1): + _, hist_data = valid_visits[j] + conditions_hist.append(list(hist_data['conditions'])) + procedures_hist.append(list(hist_data['procedures'])) + # For drugs_hist, current visit's drugs are emptied (target) + if j < i: + drugs_hist.append(list(hist_data['drugs'])) + else: + drugs_hist.append([]) # Empty for current visit + + samples.append({ + "visit_id": visit_id, + "patient_id": subject.subject_id, + "conditions": conditions_hist, # Nested: [[codes_v1], [codes_v2], ...] + "procedures": procedures_hist, # Nested + "drugs_hist": drugs_hist, # Nested (current visit empty) + "drugs": list(data['drugs']), # Target drugs (flat list) + }) + + return samples + + +# ============================================================================= +# Benchmark Infrastructure +# ============================================================================= + +@dataclass +class RunResult: + num_threads: int + repeat_index: int + meds_etl_s: float # Time for MIMIC-IV -> MEDS conversion + meds_reader_convert_s: float # Time for MEDS -> meds_reader conversion + task_process_s: float # Time to run the ML task + total_s: float # Total time (conversion + task) + peak_rss_bytes: int + num_samples: int + conversion_cached: bool # True if conversion was skipped + + +def format_size(size_bytes: int) -> str: + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} PB" + + +def set_memory_limit(max_memory_gb: float) -> None: + if not HAS_RESOURCE: + print("Warning: resource module not available. Memory limit not enforced.") + return + max_memory_bytes = int(max_memory_gb * 1024**3) + try: + resource.setrlimit(resource.RLIMIT_AS, (max_memory_bytes, max_memory_bytes)) + print(f"✓ Memory limit set to {max_memory_gb} GB") + except Exception as e: + print(f"Warning: Failed to set memory limit: {e}") + + +class PeakMemoryTracker: + def __init__(self, poll_interval_s: float = 0.1) -> None: + self._proc = psutil.Process(os.getpid()) + self._poll_interval_s = poll_interval_s + self._stop = threading.Event() + self._lock = threading.Lock() + self._peak = 0 + self._thread = threading.Thread(target=self._run, daemon=True) + + def start(self) -> None: + self._thread.start() + + def reset(self) -> None: + with self._lock: + self._peak = 0 + + def stop(self) -> None: + self._stop.set() + + def peak_bytes(self) -> int: + with self._lock: + return self._peak + + def _total_rss_bytes(self) -> int: + total = 0 + try: + total += self._proc.memory_info().rss + for child in self._proc.children(recursive=True): + try: + total += child.memory_info().rss + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + return total + + def _run(self) -> None: + while not self._stop.is_set(): + rss = self._total_rss_bytes() + with self._lock: + if rss > self._peak: + self._peak = rss + time.sleep(self._poll_interval_s) + + +def parse_threads(value: str) -> list[int]: + parts = [p.strip() for p in value.split(",") if p.strip()] + threads = [] + for p in parts: + t = int(p) + if t <= 0: + raise argparse.ArgumentTypeError("All thread counts must be > 0") + threads.append(t) + if not threads: + raise argparse.ArgumentTypeError("No threads provided") + return threads + + +def median(values: Iterable[float]) -> float: + xs = sorted(values) + if not xs: + return 0.0 + mid = len(xs) // 2 + if len(xs) % 2 == 1: + return xs[mid] + return (xs[mid - 1] + xs[mid]) / 2.0 + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Benchmark meds_reader for MIMIC-IV drug recommendation" + ) + parser.add_argument( + "--threads", type=parse_threads, default=[1, 4, 8, 12, 16], + help="Comma-separated list of num_threads values (default: 1,4,8,12,16)", + ) + parser.add_argument( + "--repeats", type=int, default=1, + help="Number of repeats per thread setting (default: 1)", + ) + parser.add_argument( + "--mimic-root", type=str, + default="/srv/local/data/physionet.org/files/mimiciv", + help="Path to MIMIC-IV root directory (containing 2.2/ subdirectory)", + ) + parser.add_argument( + "--cache-dir", type=str, default="/srv/local/data/johnwu3/meds_reader", + help="Directory for MEDS cache", + ) + parser.add_argument( + "--num-shards", type=int, default=100, + help="Number of shards for meds_etl_mimic (default: 100)", + ) + parser.add_argument( + "--num-proc", type=int, default=8, + help="Number of processes for meds_etl_mimic (default: 8)", + ) + parser.add_argument( + "--backend", type=str, default="polars", choices=["polars", "cpp"], + help="Backend for meds_etl_mimic (default: polars)", + ) + parser.add_argument( + "--force-reconvert", action="store_true", + help="Force reconversion even if MEDS database exists (recommended for benchmarking)", + ) + parser.add_argument( + "--skip-conversion", action="store_true", + help="Skip conversion entirely (for debugging only - NOT fair benchmarking)", + ) + parser.add_argument( + "--enable-memory-limit", action="store_true", + help="Enforce a hard memory limit via resource.setrlimit (Unix only)", + ) + parser.add_argument( + "--max-memory-gb", type=float, default=None, + help="Hard memory limit in GB (only used if --enable-memory-limit is set)", + ) + parser.add_argument( + "--output-csv", type=str, + default="benchmark_meds_reader_drug_rec_threads_sweep.csv", + help="Where to write per-run results as CSV", + ) + args = parser.parse_args() + + if args.repeats <= 0: + raise SystemExit("--repeats must be > 0") + + if args.enable_memory_limit: + if args.max_memory_gb is None: + raise SystemExit( + "When using --enable-memory-limit, you must also pass --max-memory-gb" + ) + set_memory_limit(args.max_memory_gb) + + # MEDS paths + # Use task-specific cache directories to avoid interference between tasks + meds_dir = f"{args.cache_dir}/mimic4_meds_drug_rec" + meds_reader_dir = f"{args.cache_dir}/mimic4_meds_reader_drug_rec" + + print("=" * 80) + print("BENCHMARK: meds_reader - Drug Recommendation (Thread Sweep)") + print(f"threads={args.threads} repeats={args.repeats}") + print(f"mimic_root: {args.mimic_root}") + print(f"backend: {args.backend}, num_proc: {args.num_proc}, num_shards: {args.num_shards}") + if args.skip_conversion: + print("WARNING: --skip-conversion is set. Conversion time will NOT be included.") + print(" This is NOT a fair comparison with PyHealth!") + print("=" * 80) + + tracker = PeakMemoryTracker(poll_interval_s=0.1) + tracker.start() + + total_start = time.time() + results: list[RunResult] = [] + + print(f"\n{'='*60}") + print("Running benchmark...") + print(f"{'='*60}") + + for t in args.threads: + for r in range(args.repeats): + tracker.reset() + run_start = time.time() + + # Step 0: Convert MIMIC-IV to MEDS format (part of total time) + # For fair comparison with PyHealth, we must include this conversion time + # since PyHealth's dataset loading includes parsing raw MIMIC-IV CSVs. + conversion = run_meds_conversion( + mimic_root=args.mimic_root, + meds_dir=meds_dir, + meds_reader_dir=meds_reader_dir, + num_shards=args.num_shards, + num_proc=args.num_proc, + backend=args.backend, + force_reconvert=args.force_reconvert and r == 0, # Only reconvert on first repeat + skip_conversion=args.skip_conversion or r > 0, # Reuse on subsequent repeats + ) + + print(f"\n threads={t} repeat={r + 1}/{args.repeats}: Processing task...") + task_start = time.time() + + # Step 1: Extract samples using meds_reader (parallel) + samples = [] + with meds_reader.SubjectDatabase(meds_reader_dir, num_threads=t) as database: + for s in database.map(get_drug_rec_samples): + samples.extend(s) + + # Step 2: Build vocabularies (matching PyHealth's processor.fit()) + conditions_processor = SequenceProcessor() + procedures_processor = SequenceProcessor() + drugs_processor = SequenceProcessor() + drugs_label_processor = MultilabelProcessor() + + conditions_processor.fit(samples, "conditions") # Nested + procedures_processor.fit(samples, "procedures") # Nested + drugs_processor.fit(samples, "drugs_hist") # Nested + drugs_processor.fit(samples, "drugs") # Flat (adds more vocab) + drugs_label_processor.fit(samples, "drugs") # For multilabel output + + # Step 3: Tokenize samples (matching PyHealth's processor.process()) + processed_samples = [] + for sample in samples: + processed_sample = { + "visit_id": sample["visit_id"], + "patient_id": sample["patient_id"], + "conditions": conditions_processor.process_nested(sample["conditions"]), + "procedures": procedures_processor.process_nested(sample["procedures"]), + "drugs_hist": drugs_processor.process_nested(sample["drugs_hist"]), + "drugs": drugs_label_processor.process(sample["drugs"]), # Multilabel + } + processed_samples.append(processed_sample) + + # Step 4: Wrap in PyTorch Dataset for model training compatibility + dataset = MedsReaderSampleDataset( + samples=processed_samples, + input_schema={ + "conditions": "sequence", + "procedures": "sequence", + "drugs_hist": "sequence", + }, + output_schema={"drugs": "multilabel"}, + input_processors={ + "conditions": conditions_processor, + "procedures": procedures_processor, + "drugs_hist": drugs_processor, + }, + output_processors={"drugs": drugs_label_processor}, + dataset_name="MIMIC-IV", + task_name="DrugRecommendation", + ) + + task_process_s = time.time() - task_start + total_s = time.time() - run_start + peak_rss_bytes = tracker.peak_bytes() + num_samples = len(dataset) + + results.append( + RunResult( + num_threads=t, + repeat_index=r, + meds_etl_s=conversion.meds_etl_s, + meds_reader_convert_s=conversion.meds_reader_convert_s, + task_process_s=task_process_s, + total_s=total_s, + peak_rss_bytes=peak_rss_bytes, + num_samples=num_samples, + conversion_cached=conversion.was_cached, + ) + ) + + # Build output message + timing_str = f"task={task_process_s:.2f}s" + if not conversion.was_cached: + timing_str = ( + f"meds_etl={conversion.meds_etl_s:.2f}s " + f"convert={conversion.meds_reader_convert_s:.2f}s " + + timing_str + f" total={total_s:.2f}s" + ) + + print( + f" ✓ threads={t:>2} repeat={r + 1:>2}/{args.repeats} " + f"samples={num_samples} " + f"{timing_str} " + f"peak_rss={format_size(peak_rss_bytes)} " + f"vocab_sizes=({conditions_processor.size()},{procedures_processor.size()},{drugs_processor.size()})" + ) + + total_sweep_s = time.time() - total_start + + # Write CSV + out_csv = Path(args.output_csv) + out_csv.parent.mkdir(parents=True, exist_ok=True) + with out_csv.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=list(asdict(results[0]).keys())) + writer.writeheader() + for rr in results: + writer.writerow(asdict(rr)) + + # Print summary + print("\n" + "=" * 80) + print("SUMMARY (median across repeats)") + print("=" * 80) + + # Check if any results have conversion times + has_conversion = any(not rr.conversion_cached for rr in results) + + if has_conversion: + print("\n NOTE: Conversion time included for fair comparison with PyHealth.") + print(" PyHealth's dataset_load_s ≈ meds_etl_s + meds_reader_convert_s") + else: + print("\n WARNING: Conversion was cached. For fair benchmarking, use --force-reconvert") + + print() + for t in args.threads: + trs = [rr for rr in results if rr.num_threads == t] + med_task = median([rr.task_process_s for rr in trs]) + med_total = median([rr.total_s for rr in trs]) + med_peak = median([float(rr.peak_rss_bytes) for rr in trs]) + + # Get conversion times (from first repeat which has them if --force-reconvert) + first_run = [rr for rr in trs if rr.repeat_index == 0][0] + + if not first_run.conversion_cached: + print( + f"threads={t:>2} " + f"meds_etl={first_run.meds_etl_s:>7.2f}s " + f"convert={first_run.meds_reader_convert_s:>7.2f}s " + f"task_med={med_task:>7.2f}s " + f"total={med_total:>7.2f}s " + f"peak_rss={format_size(int(med_peak)):>10}" + ) + else: + print( + f"threads={t:>2} " + f"task_med={med_task:>8.2f}s " + f"(conversion cached) " + f"peak_rss_med={format_size(int(med_peak)):>10}" + ) + + print("\nArtifacts:") + print(f" - CSV: {out_csv}") + print(f" - MEDS database: {meds_reader_dir}") + print("\nTotals:") + print(f" - Sweep wall time: {total_sweep_s:.2f}s") + + # Print comparison note + print("\nFor comparison with PyHealth:") + print(" PyHealth total_s = dataset_load_s + task_process_s") + print(" meds_reader total_s = meds_etl_s + meds_reader_convert_s + task_process_s") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/benchmark_perf/meds_reader_ver/benchmark_meds_reader_drug_rec_pyhealth_etl.py b/examples/benchmark_perf/meds_reader_ver/benchmark_meds_reader_drug_rec_pyhealth_etl.py new file mode 100644 index 000000000..c0f283897 --- /dev/null +++ b/examples/benchmark_perf/meds_reader_ver/benchmark_meds_reader_drug_rec_pyhealth_etl.py @@ -0,0 +1,744 @@ +"""Benchmark script for MIMIC-IV drug recommendation using meds_reader. + +This is a FALLBACK version that uses PyHealth 1.1.6 for ETL instead of meds_etl_mimic. +Use this if meds_etl_mimic fails to run properly. + +Pipeline: +1. Load MIMIC-IV data using PyHealth 1.1.6 (MIMIC4Dataset) +2. Convert PyHealth data structures to MEDS format (parquet files) +3. Run meds_reader_convert to create meds_reader database +4. Process the task using meds_reader +5. Return samples in a PyTorch-compatible Dataset + +IMPORTANT: For fair comparison with PyHealth, conversion time MUST be included. + +Typical usage: + # First install dependencies: + pip install pyhealth==1.1.6 meds_reader pyarrow + + # Run benchmark: + python benchmark_meds_reader_drug_rec_pyhealth_etl.py + python benchmark_meds_reader_drug_rec_pyhealth_etl.py --threads 1,4,8,12,16 +""" + +from __future__ import annotations + +import argparse +import collections +import csv +import datetime +import os +import shutil +import subprocess +import threading +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, Iterator, List + +import numpy as np +import psutil +import pyarrow as pa +import pyarrow.parquet as pq +import torch +from torch.utils.data import Dataset + +try: + import meds_reader +except ImportError: + raise ImportError( + "meds_reader not found. Install with: pip install meds_reader\n" + "Or from source: pip install -e /path/to/meds_reader" + ) + +# Import PyHealth 1.1.6 +try: + from pyhealth.datasets import MIMIC4Dataset +except ImportError: + raise ImportError( + "PyHealth not found. Install with: pip install pyhealth==1.1.6" + ) + + +# ============================================================================= +# PyTorch Dataset Wrapper +# ============================================================================= + +class MedsReaderSampleDataset(Dataset): + """PyTorch Dataset wrapper for meds_reader samples.""" + + def __init__( + self, + samples: List[Dict[str, Any]], + input_schema: Dict[str, str], + output_schema: Dict[str, str], + input_processors: Dict[str, Any], + output_processors: Dict[str, Any], + dataset_name: str = "", + task_name: str = "", + ): + self.samples = samples + self.input_schema = input_schema + self.output_schema = output_schema + self.input_processors = input_processors + self.output_processors = output_processors + self.dataset_name = dataset_name + self.task_name = task_name + + self.patient_to_index: Dict[Any, List[int]] = collections.defaultdict(list) + self.record_to_index: Dict[Any, List[int]] = collections.defaultdict(list) + + for idx, sample in enumerate(samples): + if "patient_id" in sample: + self.patient_to_index[sample["patient_id"]].append(idx) + if "visit_id" in sample: + self.record_to_index[sample["visit_id"]].append(idx) + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, index: int) -> Dict[str, Any]: + return self.samples[index] + + def __repr__(self) -> str: + return f"MedsReaderSampleDataset({self.dataset_name}, {self.task_name}, n={len(self)})" + + +# ============================================================================= +# Processor Classes +# ============================================================================= + +class SequenceProcessor: + """Matches PyHealth's SequenceProcessor for vocabulary building.""" + + def __init__(self): + self.code_vocab = {"": 0} + self._next_index = 1 + + def fit(self, samples, field): + for sample in samples: + if field not in sample: + continue + values = sample[field] + if values and isinstance(values[0], list): + for visit_values in values: + for token in visit_values: + if token is None: + continue + if token not in self.code_vocab: + self.code_vocab[token] = self._next_index + self._next_index += 1 + else: + for token in values: + if token is None: + continue + if token not in self.code_vocab: + self.code_vocab[token] = self._next_index + self._next_index += 1 + self.code_vocab[""] = len(self.code_vocab) + + def process(self, value): + indices = [] + for token in value: + if token in self.code_vocab: + indices.append(self.code_vocab[token]) + else: + indices.append(self.code_vocab[""]) + return torch.tensor(indices, dtype=torch.long) + + def process_nested(self, values): + return [self.process(v) for v in values] + + def size(self): + return len(self.code_vocab) + + +class MultilabelProcessor: + """Processor for multilabel outputs.""" + + def __init__(self): + self.label_vocab = {} + + def fit(self, samples, field): + for sample in samples: + if field in sample: + for val in sample[field]: + if val not in self.label_vocab: + self.label_vocab[val] = len(self.label_vocab) + + def process(self, value): + multi_hot = torch.zeros(len(self.label_vocab), dtype=torch.float32) + for val in value: + if val in self.label_vocab: + multi_hot[self.label_vocab[val]] = 1.0 + return multi_hot + + def size(self): + return len(self.label_vocab) + + +# ============================================================================= +# Data Conversion (PyHealth 1.1.6 -> MEDS -> meds_reader) +# ============================================================================= + +def pyhealth_to_meds( + pyhealth_root: str, + output_dir: str, + tables: List[str], + dev: bool = False, + num_shards: int = 100, +) -> float: + """Convert MIMIC-IV data via PyHealth 1.1.6 to MEDS format.""" + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + print(" Loading MIMIC-IV via PyHealth 1.1.6...") + print(f" Root: {pyhealth_root}") + print(f" Tables: {tables}") + print(f" Dev mode: {dev}") + + start = time.time() + + dataset = MIMIC4Dataset( + root=pyhealth_root, + tables=tables, + dev=dev, + refresh_cache=True, + ) + + pyhealth_load_time = time.time() - start + print(f" PyHealth load completed in {pyhealth_load_time:.2f}s") + + print(" Converting to MEDS format...") + convert_start = time.time() + + results = collections.defaultdict(list) + + for patient_id, patient in dataset.patients.items(): + subject_id = int(patient_id) + + # Birth event + if patient.birth_datetime is not None: + birth_obj = { + 'subject_id': subject_id, + 'code': 'meds/birth', + 'time': patient.birth_datetime, + } + if hasattr(patient, 'gender') and patient.gender: + birth_obj['gender'] = patient.gender + if hasattr(patient, 'ethnicity') and patient.ethnicity: + birth_obj['ethnicity'] = patient.ethnicity + results[subject_id].append(birth_obj) + + # Death event + if patient.death_datetime is not None: + results[subject_id].append({ + 'subject_id': subject_id, + 'code': 'meds/death', + 'time': patient.death_datetime, + }) + + # Process visits + for visit_id, visit in patient.visits.items(): + visit_id_int = int(visit_id) + + visit_event = { + 'subject_id': subject_id, + 'code': 'MIMIC_IV_Admission/unknown', + 'time': visit.encounter_time, + 'visit_id': visit_id_int, + } + if visit.discharge_time: + visit_event['end'] = visit.discharge_time + if hasattr(visit, 'discharge_status'): + visit_event['discharge_status'] = visit.discharge_status + + results[subject_id].append(visit_event) + + for table in visit.available_tables: + for event in visit.get_event_list(table): + event_obj = { + 'subject_id': subject_id, + 'visit_id': visit_id_int, + 'code': f'{event.vocabulary}/{event.code}', + 'time': event.timestamp or visit.discharge_time, + } + + if hasattr(event, 'attr_dict') and event.attr_dict: + for k, v in event.attr_dict.items(): + if v == v: # Skip NaN + event_obj[k] = v + + results[subject_id].append(event_obj) + + results[subject_id].sort( + key=lambda a: a['time'] if a['time'] else datetime.datetime.min + ) + + # Write to parquet shards + os.makedirs(output_dir, exist_ok=True) + os.makedirs(f"{output_dir}/metadata", exist_ok=True) + os.makedirs(f"{output_dir}/data", exist_ok=True) + + all_subjects = list(results.keys()) + subject_ids_per_shard = np.array_split(all_subjects, num_shards) + + attr_map = { + str: pa.string(), + int: pa.int64(), + np.int64: pa.int64(), + float: pa.float64(), + datetime.datetime: pa.timestamp('us'), + } + + attr_schema = {} + for subject_values in results.values(): + for row in subject_values: + for k, v in row.items(): + if k not in {'subject_id', 'time'} and v is not None: + pa_type = attr_map.get(type(v), pa.string()) + if k not in attr_schema: + attr_schema[k] = pa_type + + schema = pa.schema([ + ('subject_id', pa.int64()), + ('time', pa.timestamp('us')), + ] + [(k, v) for k, v in sorted(attr_schema.items())]) + + for i, subject_ids in enumerate(subject_ids_per_shard): + if len(subject_ids) == 0: + continue + rows = [v for subject_id in subject_ids for v in results[subject_id]] + if rows: + table = pa.Table.from_pylist(rows, schema=schema) + pq.write_table(table, f"{output_dir}/data/{i}.parquet") + + convert_time = time.time() - convert_start + total_time = time.time() - start + + print(f" MEDS conversion completed in {convert_time:.2f}s") + print(f" Total PyHealth ETL time: {total_time:.2f}s") + + return total_time + + +def run_meds_reader_convert( + input_dir: str, output_dir: str, num_threads: int = 10 +) -> float: + """Run meds_reader_convert CLI tool.""" + print(f" Running meds_reader_convert (threads={num_threads})...") + print(f" {input_dir} -> {output_dir}") + + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + start = time.time() + try: + subprocess.run( + ["meds_reader_convert", input_dir, output_dir, + "--num_threads", str(num_threads)], + capture_output=True, + text=True, + check=True, + ) + elapsed = time.time() - start + print(f" meds_reader_convert completed in {elapsed:.2f}s") + return elapsed + except subprocess.CalledProcessError as e: + print(" ERROR: meds_reader_convert failed:") + print(f" stdout: {e.stdout}") + print(f" stderr: {e.stderr}") + raise + except FileNotFoundError: + print(" ERROR: meds_reader_convert not found in PATH") + raise + + +@dataclass +class ConversionResult: + """Holds timing information for the conversion process.""" + pyhealth_etl_s: float + meds_reader_convert_s: float + total_conversion_s: float + was_cached: bool + + +def run_pyhealth_meds_conversion( + pyhealth_root: str, + meds_dir: str, + meds_reader_dir: str, + tables: List[str], + dev: bool, + num_shards: int, + num_threads: int, + force_reconvert: bool, + skip_conversion: bool, +) -> ConversionResult: + """Run PyHealth-based MEDS conversion.""" + + if skip_conversion: + if not Path(meds_reader_dir).exists(): + raise SystemExit( + f"Cannot skip conversion: MEDS database does not exist at " + f"{meds_reader_dir}\nRun without --skip-conversion first." + ) + print("✓ Skipping conversion (using cached MEDS database)") + return ConversionResult(0.0, 0.0, 0.0, True) + + if Path(meds_reader_dir).exists() and not force_reconvert: + print(f"✓ MEDS database exists: {meds_reader_dir}") + return ConversionResult(0.0, 0.0, 0.0, True) + + print("\n" + "=" * 60) + print("Converting MIMIC-IV to MEDS format via PyHealth 1.1.6") + print("=" * 60) + + if Path(meds_dir).exists(): + print(f" Clearing existing MEDS cache: {meds_dir}") + shutil.rmtree(meds_dir) + if Path(meds_reader_dir).exists(): + print(f" Clearing existing meds_reader cache: {meds_reader_dir}") + shutil.rmtree(meds_reader_dir) + + print("\n[Step 1/2] Loading via PyHealth and converting to MEDS...") + pyhealth_etl_s = pyhealth_to_meds( + pyhealth_root=pyhealth_root, + output_dir=meds_dir, + tables=tables, + dev=dev, + num_shards=num_shards, + ) + + print("\n[Step 2/2] Running meds_reader_convert...") + meds_reader_convert_s = run_meds_reader_convert( + meds_dir, meds_reader_dir, num_threads=num_threads + ) + + total = pyhealth_etl_s + meds_reader_convert_s + print(f"\n✓ MEDS database ready. Total conversion: {total:.2f}s") + + return ConversionResult(pyhealth_etl_s, meds_reader_convert_s, total, False) + + +# ============================================================================= +# Task Function - Drug Recommendation +# ============================================================================= + +def get_drug_rec_samples(subjects: Iterator[meds_reader.Subject]): + """Process subjects for drug recommendation task.""" + samples = [] + + for subject in subjects: + admissions = {} + + for event in subject.events: + if event.code.startswith("MIMIC_IV_Admission/"): + visit_id = getattr(event, 'visit_id', None) + if visit_id is not None and event.time is not None: + admissions[visit_id] = { + 'time': event.time, + 'conditions': set(), + 'procedures': set(), + 'drugs': set(), + } + + for event in subject.events: + visit_id = getattr(event, 'visit_id', None) + if visit_id is None or visit_id not in admissions: + continue + + code = event.code + if code.startswith("ICD"): + if "CM" in code: + admissions[visit_id]['conditions'].add(code) + else: + admissions[visit_id]['procedures'].add(code) + elif code.startswith("NDC/") or code.startswith("MIMIC_IV_Drug/"): + admissions[visit_id]['drugs'].add(code) + + sorted_visits = sorted( + [(vid, data) for vid, data in admissions.items()], + key=lambda x: x[1]['time'] + ) + + if len(sorted_visits) < 2: + continue + + for i, (visit_id, current) in enumerate(sorted_visits): + conditions = list(current['conditions']) + procedures = list(current['procedures']) + drugs = list(current['drugs']) + + if len(conditions) == 0 or len(procedures) == 0 or len(drugs) == 0: + continue + + conditions_hist = [] + procedures_hist = [] + drugs_hist = [] + + for j in range(i + 1): + conditions_hist.append(list(sorted_visits[j][1]['conditions'])) + procedures_hist.append(list(sorted_visits[j][1]['procedures'])) + if j < i: + drugs_hist.append(list(sorted_visits[j][1]['drugs'])) + else: + drugs_hist.append([]) + + samples.append({ + "visit_id": visit_id, + "patient_id": subject.subject_id, + "conditions": conditions_hist, + "procedures": procedures_hist, + "drugs_hist": drugs_hist, + "drugs": drugs, + }) + + return samples + + +# ============================================================================= +# Benchmark Infrastructure +# ============================================================================= + +@dataclass +class RunResult: + num_threads: int + repeat_index: int + pyhealth_etl_s: float + meds_reader_convert_s: float + task_process_s: float + total_s: float + peak_rss_bytes: int + num_samples: int + conversion_cached: bool + + +def format_size(size_bytes: int) -> str: + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} PB" + + +class PeakMemoryTracker: + def __init__(self, poll_interval_s: float = 0.1): + self._proc = psutil.Process(os.getpid()) + self._poll_interval_s = poll_interval_s + self._stop = threading.Event() + self._lock = threading.Lock() + self._peak = 0 + self._thread = threading.Thread(target=self._run, daemon=True) + + def start(self): + self._thread.start() + + def reset(self): + with self._lock: + self._peak = 0 + + def peak_bytes(self) -> int: + with self._lock: + return self._peak + + def _total_rss_bytes(self) -> int: + total = 0 + try: + total += self._proc.memory_info().rss + for child in self._proc.children(recursive=True): + try: + total += child.memory_info().rss + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + return total + + def _run(self): + while not self._stop.is_set(): + rss = self._total_rss_bytes() + with self._lock: + if rss > self._peak: + self._peak = rss + time.sleep(self._poll_interval_s) + + +def parse_threads(value: str) -> List[int]: + parts = [p.strip() for p in value.split(",") if p.strip()] + return [int(p) for p in parts if int(p) > 0] + + +def median(values: Iterable[float]) -> float: + xs = sorted(values) + if not xs: + return 0.0 + mid = len(xs) // 2 + return xs[mid] if len(xs) % 2 == 1 else (xs[mid - 1] + xs[mid]) / 2.0 + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark meds_reader Drug Rec using PyHealth 1.1.6 ETL" + ) + parser.add_argument( + "--threads", type=parse_threads, default=[1, 4, 8, 12, 16], + help="Comma-separated list of thread counts", + ) + parser.add_argument("--repeats", type=int, default=1) + parser.add_argument( + "--pyhealth-root", type=str, + default="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp", + help="Path to MIMIC-IV hosp directory (for PyHealth 1.1.6)", + ) + parser.add_argument("--cache-dir", type=str, default="/srv/local/data/johnwu3/meds_reader") + parser.add_argument("--num-shards", type=int, default=100) + parser.add_argument("--num-threads", type=int, default=8) + parser.add_argument("--dev", action="store_true") + parser.add_argument("--force-reconvert", action="store_true") + parser.add_argument("--skip-conversion", action="store_true") + parser.add_argument( + "--output-csv", type=str, + default="benchmark_meds_reader_drug_rec_pyhealth_etl.csv", + ) + args = parser.parse_args() + + meds_dir = f"{args.cache_dir}/mimic4_meds_drug_rec_pyhealth" + meds_reader_dir = f"{args.cache_dir}/mimic4_meds_reader_drug_rec_pyhealth" + + print("=" * 80) + print("BENCHMARK: meds_reader Drug Rec (PyHealth 1.1.6 ETL - Fallback)") + print(f"threads={args.threads} repeats={args.repeats} dev={args.dev}") + print(f"pyhealth_root: {args.pyhealth_root}") + print("=" * 80) + + tracker = PeakMemoryTracker() + tracker.start() + + total_start = time.time() + results: List[RunResult] = [] + + # Tables needed for drug recommendation task + tables = ["diagnoses_icd", "procedures_icd", "prescriptions"] + + for t in args.threads: + for r in range(args.repeats): + tracker.reset() + run_start = time.time() + + conversion = run_pyhealth_meds_conversion( + pyhealth_root=args.pyhealth_root, + meds_dir=meds_dir, + meds_reader_dir=meds_reader_dir, + tables=tables, + dev=args.dev, + num_shards=args.num_shards, + num_threads=args.num_threads, + force_reconvert=args.force_reconvert and r == 0, + skip_conversion=args.skip_conversion or r > 0, + ) + + print(f"\n threads={t} repeat={r + 1}/{args.repeats}: Processing...") + task_start = time.time() + + samples = [] + with meds_reader.SubjectDatabase(meds_reader_dir, num_threads=t) as db: + for s in db.map(get_drug_rec_samples): + samples.extend(s) + + conditions_proc = SequenceProcessor() + procedures_proc = SequenceProcessor() + drugs_proc = SequenceProcessor() + drugs_label_proc = MultilabelProcessor() + + conditions_proc.fit(samples, "conditions") + procedures_proc.fit(samples, "procedures") + drugs_proc.fit(samples, "drugs_hist") + drugs_proc.fit(samples, "drugs") + drugs_label_proc.fit(samples, "drugs") + + processed = [] + for sample in samples: + processed.append({ + "visit_id": sample["visit_id"], + "patient_id": sample["patient_id"], + "conditions": conditions_proc.process_nested(sample["conditions"]), + "procedures": procedures_proc.process_nested(sample["procedures"]), + "drugs_hist": drugs_proc.process_nested(sample["drugs_hist"]), + "drugs": drugs_label_proc.process(sample["drugs"]), + }) + + dataset = MedsReaderSampleDataset( + samples=processed, + input_schema={ + "conditions": "sequence", + "procedures": "sequence", + "drugs_hist": "sequence", + }, + output_schema={"drugs": "multilabel"}, + input_processors={ + "conditions": conditions_proc, + "procedures": procedures_proc, + "drugs_hist": drugs_proc, + }, + output_processors={"drugs": drugs_label_proc}, + dataset_name="MIMIC-IV", + task_name="DrugRecommendation", + ) + + task_process_s = time.time() - task_start + total_s = time.time() - run_start + peak_rss = tracker.peak_bytes() + + results.append(RunResult( + num_threads=t, + repeat_index=r, + pyhealth_etl_s=conversion.pyhealth_etl_s, + meds_reader_convert_s=conversion.meds_reader_convert_s, + task_process_s=task_process_s, + total_s=total_s, + peak_rss_bytes=peak_rss, + num_samples=len(dataset), + conversion_cached=conversion.was_cached, + )) + + timing = f"task={task_process_s:.2f}s" + if not conversion.was_cached: + timing = (f"pyhealth_etl={conversion.pyhealth_etl_s:.2f}s " + f"convert={conversion.meds_reader_convert_s:.2f}s " + f"{timing} total={total_s:.2f}s") + + print(f" ✓ threads={t:>2} samples={len(dataset)} {timing} " + f"peak_rss={format_size(peak_rss)}") + + total_sweep_s = time.time() - total_start + + out_csv = Path(args.output_csv) + out_csv.parent.mkdir(parents=True, exist_ok=True) + with out_csv.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=list(asdict(results[0]).keys())) + writer.writeheader() + for rr in results: + writer.writerow(asdict(rr)) + + print("\n" + "=" * 80) + print("SUMMARY") + print("=" * 80) + for t in args.threads: + trs = [rr for rr in results if rr.num_threads == t] + med_task = median([rr.task_process_s for rr in trs]) + first = [rr for rr in trs if rr.repeat_index == 0][0] + if not first.conversion_cached: + print(f"threads={t:>2} pyhealth_etl={first.pyhealth_etl_s:.2f}s " + f"convert={first.meds_reader_convert_s:.2f}s " + f"task_med={med_task:.2f}s") + else: + print(f"threads={t:>2} task_med={med_task:.2f}s (cached)") + + print(f"\nSweep time: {total_sweep_s:.2f}s") + print(f"CSV: {out_csv}") + print("=" * 80) + + +if __name__ == "__main__": + main() + diff --git a/examples/benchmark_perf/meds_reader_ver/benchmark_meds_reader_los.py b/examples/benchmark_perf/meds_reader_ver/benchmark_meds_reader_los.py new file mode 100644 index 000000000..4735b123d --- /dev/null +++ b/examples/benchmark_perf/meds_reader_ver/benchmark_meds_reader_los.py @@ -0,0 +1,839 @@ +"""Benchmark script for MIMIC-IV length of stay prediction using meds_reader. + +This benchmark measures performance across multiple thread counts: +1. Time for MEDS ETL conversion (MIMIC-IV -> MEDS format) +2. Time for meds_reader database conversion (MEDS -> meds_reader format) +3. Time to process the task +4. Peak memory usage (RSS, includes child processes) +5. Number of samples generated + +IMPORTANT: For fair comparison with PyHealth, conversion time MUST be included. +PyHealth's dataset loading includes parsing raw MIMIC-IV CSVs, so we must +account for the equivalent preprocessing time in meds_reader. + +This script uses meds_etl for data conversion: +- Converts MIMIC-IV directly to MEDS format via meds_etl_mimic +- Runs meds_reader_convert to prepare the database +- Then runs the benchmark + +Typical usage: + # First install dependencies: + pip install meds_etl meds_reader + + # Run benchmark (includes conversion time by default): + python benchmark_meds_reader_los.py + python benchmark_meds_reader_los.py --threads 1,4,8,12,16 --repeats 3 + + # Skip conversion (only for debugging, not fair benchmarking): + python benchmark_meds_reader_los.py --skip-conversion +""" + +from __future__ import annotations + +import argparse +import collections +import csv +import os +import shutil +import subprocess +import threading +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, Iterator, List + +import psutil +import torch +from torch.utils.data import Dataset + +try: + import meds_reader +except ImportError: + raise ImportError( + "meds_reader not found. Install with: pip install meds_reader\n" + "Or from source: pip install -e /path/to/meds_reader" + ) + + +# ============================================================================= +# PyTorch Dataset Wrapper +# ============================================================================= + +class MedsReaderSampleDataset(Dataset): + """PyTorch Dataset wrapper for meds_reader samples. + + This provides a standard PyTorch Dataset interface for the processed samples, + making them compatible with PyTorch DataLoader for model training. + + Attributes: + samples: List of processed sample dictionaries + input_schema: Schema describing input features + output_schema: Schema describing output features + input_processors: Fitted processors for input features + output_processors: Fitted processors for output features + """ + + def __init__( + self, + samples: List[Dict[str, Any]], + input_schema: Dict[str, str], + output_schema: Dict[str, str], + input_processors: Dict[str, Any], + output_processors: Dict[str, Any], + dataset_name: str = "", + task_name: str = "", + ): + self.samples = samples + self.input_schema = input_schema + self.output_schema = output_schema + self.input_processors = input_processors + self.output_processors = output_processors + self.dataset_name = dataset_name + self.task_name = task_name + + # Build patient and record indices for train/val/test splitting + self.patient_to_index: Dict[Any, List[int]] = collections.defaultdict(list) + self.record_to_index: Dict[Any, List[int]] = collections.defaultdict(list) + + for idx, sample in enumerate(samples): + if "patient_id" in sample: + self.patient_to_index[sample["patient_id"]].append(idx) + if "visit_id" in sample: + self.record_to_index[sample["visit_id"]].append(idx) + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, index: int) -> Dict[str, Any]: + return self.samples[index] + + def __repr__(self) -> str: + return ( + f"MedsReaderSampleDataset(dataset={self.dataset_name}, " + f"task={self.task_name}, n_samples={len(self)})" + ) + + def get_all_tokens(self, key: str) -> set: + """Get all unique tokens for a given key across all samples.""" + tokens = set() + for sample in self.samples: + if key in sample: + val = sample[key] + if isinstance(val, torch.Tensor): + tokens.update(val.tolist()) + elif isinstance(val, list): + tokens.update(val) + return tokens + + +# ============================================================================= +# Processor Classes (matching PyHealth's SequenceProcessor for fair comparison) +# ============================================================================= + +class SequenceProcessor: + """Matches PyHealth's SequenceProcessor for vocabulary building and tokenization.""" + + def __init__(self): + self.code_vocab = {"": 0} + self._next_index = 1 + + def fit(self, samples, field): + """Build vocabulary from all samples (first pass through data).""" + for sample in samples: + if field not in sample: + continue + for token in sample[field]: + if token is None: + continue + if token not in self.code_vocab: + self.code_vocab[token] = self._next_index + self._next_index += 1 + self.code_vocab[""] = len(self.code_vocab) + + def process(self, value): + """Convert code strings to tensor of indices.""" + indices = [] + for token in value: + if token in self.code_vocab: + indices.append(self.code_vocab[token]) + else: + indices.append(self.code_vocab[""]) + return torch.tensor(indices, dtype=torch.long) + + def size(self): + return len(self.code_vocab) + + +class MulticlassLabelProcessor: + """Processor for multiclass labels (matching PyHealth's MultiClassLabelProcessor).""" + + def __init__(self): + self.label_vocab = {} + + def fit(self, samples, field): + """Build vocabulary from all label values.""" + for sample in samples: + if field in sample: + val = sample[field] + if val not in self.label_vocab: + self.label_vocab[val] = len(self.label_vocab) + + def process(self, value): + """Convert label to tensor.""" + return torch.tensor(self.label_vocab.get(value, 0), dtype=torch.long) + + def size(self): + return len(self.label_vocab) + + +try: + import resource + HAS_RESOURCE = True +except ImportError: + HAS_RESOURCE = False + + +# ============================================================================= +# Data Conversion (MIMIC-IV -> MEDS -> meds_reader via meds_etl) +# ============================================================================= + +def run_meds_etl_mimic( + src_mimic: str, + output_dir: str, + num_shards: int = 100, + num_proc: int = 1, + backend: str = "polars", +) -> float: + """Run meds_etl_mimic to convert MIMIC-IV to MEDS format. + + Args: + src_mimic: Path to MIMIC-IV root (containing 2.2/ subdirectory) + output_dir: Path to output MEDS dataset + num_shards: Number of shards for processing + num_proc: Number of processes to use + backend: Backend to use (polars or cpp) + + Returns: + Time taken in seconds + """ + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + print(f" Running meds_etl_mimic (shards={num_shards}, proc={num_proc}, backend={backend})...") + print(f" Source: {src_mimic}") + print(f" Destination: {output_dir}") + + start = time.time() + result = subprocess.run( + [ + "meds_etl_mimic", + src_mimic, + output_dir, + "--num_shards", str(num_shards), + "--num_proc", str(num_proc), + "--backend", backend, + ], + capture_output=True, + text=True, + ) + elapsed = time.time() - start + + if result.returncode != 0: + print(f" STDOUT: {result.stdout}") + print(f" STDERR: {result.stderr}") + raise RuntimeError(f"meds_etl_mimic failed with code {result.returncode}") + + print(f" meds_etl_mimic completed in {elapsed:.2f}s") + return elapsed + + +def run_meds_reader_convert(input_dir: str, output_dir: str, num_threads: int = 10) -> float: + """Run meds_reader_convert CLI tool. Returns time taken.""" + print(f" Running meds_reader_convert (threads={num_threads})...") + print(f" {input_dir} -> {output_dir}") + + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + start = time.time() + try: + result = subprocess.run( + ["meds_reader_convert", input_dir, output_dir, "--num_threads", str(num_threads)], + capture_output=True, + text=True, + check=True, + ) + elapsed = time.time() - start + print(f" meds_reader_convert completed in {elapsed:.2f}s") + return elapsed + except subprocess.CalledProcessError as e: + print(f" ERROR: meds_reader_convert failed:") + print(f" stdout: {e.stdout}") + print(f" stderr: {e.stderr}") + raise + except FileNotFoundError: + print(f" ERROR: meds_reader_convert not found in PATH") + raise + + +@dataclass +class ConversionResult: + """Holds timing information for the MEDS conversion process.""" + meds_etl_s: float + meds_reader_convert_s: float + total_conversion_s: float + was_cached: bool # True if conversion was skipped due to existing cache + + +def run_meds_conversion( + mimic_root: str, + meds_dir: str, + meds_reader_dir: str, + num_shards: int, + num_proc: int, + backend: str, + force_reconvert: bool, + skip_conversion: bool, +) -> ConversionResult: + """Run MEDS conversion and return timing information. + + Args: + mimic_root: Path to MIMIC-IV root directory + meds_dir: Path for intermediate MEDS output + meds_reader_dir: Path for final meds_reader database + num_shards: Number of shards for meds_etl + num_proc: Number of processes for meds_etl + backend: Backend for meds_etl (polars or cpp) + force_reconvert: If True, always reconvert even if cache exists + skip_conversion: If True, skip conversion (for debugging only) + + Returns: + ConversionResult with timing information + """ + # Check if we should skip conversion + if skip_conversion: + if not Path(meds_reader_dir).exists(): + raise SystemExit( + f"Cannot skip conversion: MEDS database does not exist at {meds_reader_dir}\n" + "Run without --skip-conversion first." + ) + print(f"✓ Skipping conversion (using cached MEDS database: {meds_reader_dir})") + print(" WARNING: For fair benchmarking, conversion time should be included!") + return ConversionResult( + meds_etl_s=0.0, + meds_reader_convert_s=0.0, + total_conversion_s=0.0, + was_cached=True, + ) + + # Check if we can reuse existing cache + if Path(meds_reader_dir).exists() and not force_reconvert: + print(f"✓ MEDS database exists: {meds_reader_dir}") + print(" NOTE: Using cached data. Use --force-reconvert for fresh timing.") + return ConversionResult( + meds_etl_s=0.0, + meds_reader_convert_s=0.0, + total_conversion_s=0.0, + was_cached=True, + ) + + print(f"\n{'='*60}") + print(f"Converting MIMIC-IV to MEDS format") + print(f"{'='*60}") + + # Clear existing cache directories to avoid interference + if Path(meds_dir).exists(): + print(f" Clearing existing MEDS cache: {meds_dir}") + shutil.rmtree(meds_dir) + if Path(meds_reader_dir).exists(): + print(f" Clearing existing meds_reader cache: {meds_reader_dir}") + shutil.rmtree(meds_reader_dir) + + # Verify MIMIC-IV structure + mimic_version_path = os.path.join(mimic_root, "2.2") + if not os.path.exists(mimic_version_path): + raise SystemExit( + f"ERROR: Expected MIMIC-IV version directory not found: {mimic_version_path}\n" + f"meds_etl_mimic expects the MIMIC-IV data to be in {{mimic_root}}/2.2/" + ) + + # Step 1: Convert MIMIC-IV -> MEDS using meds_etl + print(f"\n[Step 1/2] Converting MIMIC-IV to MEDS format using meds_etl...") + meds_etl_s = run_meds_etl_mimic( + src_mimic=mimic_root, + output_dir=meds_dir, + num_shards=num_shards, + num_proc=num_proc, + backend=backend, + ) + + # Step 2: Run meds_reader_convert + print(f"\n[Step 2/2] Running meds_reader_convert...") + meds_reader_convert_s = run_meds_reader_convert( + meds_dir, meds_reader_dir, num_threads=num_proc + ) + + total_conversion_s = meds_etl_s + meds_reader_convert_s + print(f"\n✓ MEDS database ready: {meds_reader_dir}") + print(f" Total conversion time: {total_conversion_s:.2f}s") + + return ConversionResult( + meds_etl_s=meds_etl_s, + meds_reader_convert_s=meds_reader_convert_s, + total_conversion_s=total_conversion_s, + was_cached=False, + ) + + +# ============================================================================= +# Task Function - Length of Stay Prediction +# ============================================================================= + +def get_los_samples(subjects: Iterator[meds_reader.Subject]): + """Process subjects for length of stay prediction task. + + Uses MEDS-ETL code conventions: + - Admission codes are like "MIMIC_IV_Admission/..." + - Diagnosis codes are like "ICD10CM/..." or "ICD9CM/..." + - Procedure codes are like "ICD10PCS/..." or "ICD9Proc/..." + - Prescriptions are like "NDC/..." or "MIMIC_IV_Drug/..." + """ + samples = [] + + for subject in subjects: + admission_data = {} + + # First pass: identify admissions and their discharge times + for event in subject.events: + if event.code.startswith("MIMIC_IV_Admission/"): + # Get admission metadata + visit_id = getattr(event, 'visit_id', None) + end_time = getattr(event, 'end', None) + + if visit_id is not None and event.time is not None and end_time is not None: + los_days = (end_time - event.time).days + + # Categorize LOS (matching PyHealth's categorization) + if los_days < 1: los_label = 0 + elif los_days <= 7: los_label = los_days + elif los_days <= 14: los_label = 8 + else: los_label = 9 + + admission_data[visit_id] = { + 'start': event.time, + 'los_days': los_days, + 'label': los_label, + 'conditions': set(), + 'procedures': set(), + 'drugs': set(), + } + + # Second pass: collect features per admission + for event in subject.events: + visit_id = getattr(event, 'visit_id', None) + if visit_id is None or visit_id not in admission_data: + continue + + code = event.code + if code.startswith("ICD"): # ICD9CM, ICD10CM, ICD9Proc, ICD10PCS + if "CM" in code: + admission_data[visit_id]['conditions'].add(code) + else: + admission_data[visit_id]['procedures'].add(code) + elif code.startswith("NDC/") or code.startswith("MIMIC_IV_Drug/"): + admission_data[visit_id]['drugs'].add(code) + + # Create samples for admissions with sufficient data + for visit_id, data in admission_data.items(): + conditions = list(data['conditions']) + procedures = list(data['procedures']) + drugs = list(data['drugs']) + + # Match PyHealth's filtering: require conditions, procedures, and drugs + if len(conditions) == 0 or len(procedures) == 0 or len(drugs) == 0: + continue + + samples.append({ + "visit_id": visit_id, + "patient_id": subject.subject_id, + "conditions": conditions, + "procedures": procedures, + "drugs": drugs, + "label": data['label'], + "los_days": data['los_days'], + }) + + return samples + + +# ============================================================================= +# Benchmark Infrastructure +# ============================================================================= + +@dataclass +class RunResult: + num_threads: int + repeat_index: int + meds_etl_s: float # Time for MIMIC-IV -> MEDS conversion + meds_reader_convert_s: float # Time for MEDS -> meds_reader conversion + task_process_s: float # Time to run the ML task + total_s: float # Total time (conversion + task) + peak_rss_bytes: int + num_samples: int + conversion_cached: bool # True if conversion was skipped + + +def format_size(size_bytes: int) -> str: + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} PB" + + +def set_memory_limit(max_memory_gb: float) -> None: + if not HAS_RESOURCE: + print("Warning: resource module not available. Memory limit not enforced.") + return + max_memory_bytes = int(max_memory_gb * 1024**3) + try: + resource.setrlimit(resource.RLIMIT_AS, (max_memory_bytes, max_memory_bytes)) + print(f"✓ Memory limit set to {max_memory_gb} GB") + except Exception as e: + print(f"Warning: Failed to set memory limit: {e}") + + +class PeakMemoryTracker: + def __init__(self, poll_interval_s: float = 0.1) -> None: + self._proc = psutil.Process(os.getpid()) + self._poll_interval_s = poll_interval_s + self._stop = threading.Event() + self._lock = threading.Lock() + self._peak = 0 + self._thread = threading.Thread(target=self._run, daemon=True) + + def start(self) -> None: + self._thread.start() + + def reset(self) -> None: + with self._lock: + self._peak = 0 + + def stop(self) -> None: + self._stop.set() + + def peak_bytes(self) -> int: + with self._lock: + return self._peak + + def _total_rss_bytes(self) -> int: + total = 0 + try: + total += self._proc.memory_info().rss + for child in self._proc.children(recursive=True): + try: + total += child.memory_info().rss + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + return total + + def _run(self) -> None: + while not self._stop.is_set(): + rss = self._total_rss_bytes() + with self._lock: + if rss > self._peak: + self._peak = rss + time.sleep(self._poll_interval_s) + + +def parse_threads(value: str) -> list[int]: + parts = [p.strip() for p in value.split(",") if p.strip()] + threads = [] + for p in parts: + t = int(p) + if t <= 0: + raise argparse.ArgumentTypeError("All thread counts must be > 0") + threads.append(t) + if not threads: + raise argparse.ArgumentTypeError("No threads provided") + return threads + + +def median(values: Iterable[float]) -> float: + xs = sorted(values) + if not xs: + return 0.0 + mid = len(xs) // 2 + if len(xs) % 2 == 1: + return xs[mid] + return (xs[mid - 1] + xs[mid]) / 2.0 + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Benchmark meds_reader for MIMIC-IV length of stay prediction" + ) + parser.add_argument( + "--threads", type=parse_threads, default=[1, 4, 8, 12, 16], + help="Comma-separated list of num_threads values (default: 1,4,8,12,16)", + ) + parser.add_argument( + "--repeats", type=int, default=1, + help="Number of repeats per thread setting (default: 1)", + ) + parser.add_argument( + "--mimic-root", type=str, + default="/srv/local/data/physionet.org/files/mimiciv", + help="Path to MIMIC-IV root directory (containing 2.2/ subdirectory)", + ) + parser.add_argument( + "--cache-dir", type=str, default="/srv/local/data/johnwu3/meds_reader", + help="Directory for MEDS cache", + ) + parser.add_argument( + "--num-shards", type=int, default=100, + help="Number of shards for meds_etl_mimic (default: 100)", + ) + parser.add_argument( + "--num-proc", type=int, default=8, + help="Number of processes for meds_etl_mimic (default: 8)", + ) + parser.add_argument( + "--backend", type=str, default="polars", choices=["polars", "cpp"], + help="Backend for meds_etl_mimic (default: polars)", + ) + parser.add_argument( + "--force-reconvert", action="store_true", + help="Force reconversion even if MEDS database exists (recommended for benchmarking)", + ) + parser.add_argument( + "--skip-conversion", action="store_true", + help="Skip conversion entirely (for debugging only - NOT fair benchmarking)", + ) + parser.add_argument( + "--enable-memory-limit", action="store_true", + help="Enforce a hard memory limit via resource.setrlimit (Unix only)", + ) + parser.add_argument( + "--max-memory-gb", type=float, default=None, + help="Hard memory limit in GB (only used if --enable-memory-limit is set)", + ) + parser.add_argument( + "--output-csv", type=str, + default="benchmark_meds_reader_los_threads_sweep.csv", + help="Where to write per-run results as CSV", + ) + args = parser.parse_args() + + if args.repeats <= 0: + raise SystemExit("--repeats must be > 0") + + if args.enable_memory_limit: + if args.max_memory_gb is None: + raise SystemExit( + "When using --enable-memory-limit, you must also pass --max-memory-gb" + ) + set_memory_limit(args.max_memory_gb) + + # MEDS paths + # Use task-specific cache directories to avoid interference between tasks + meds_dir = f"{args.cache_dir}/mimic4_meds_los" + meds_reader_dir = f"{args.cache_dir}/mimic4_meds_reader_los" + + print("=" * 80) + print("BENCHMARK: meds_reader - Length of Stay Prediction (Thread Sweep)") + print(f"threads={args.threads} repeats={args.repeats}") + print(f"mimic_root: {args.mimic_root}") + print(f"backend: {args.backend}, num_proc: {args.num_proc}, num_shards: {args.num_shards}") + if args.skip_conversion: + print("WARNING: --skip-conversion is set. Conversion time will NOT be included.") + print(" This is NOT a fair comparison with PyHealth!") + print("=" * 80) + + tracker = PeakMemoryTracker(poll_interval_s=0.1) + tracker.start() + + total_start = time.time() + results: list[RunResult] = [] + + print(f"\n{'='*60}") + print("Running benchmark...") + print(f"{'='*60}") + + for t in args.threads: + for r in range(args.repeats): + tracker.reset() + run_start = time.time() + + # Step 0: Convert MIMIC-IV to MEDS format (part of total time) + # For fair comparison with PyHealth, we must include this conversion time + # since PyHealth's dataset loading includes parsing raw MIMIC-IV CSVs. + conversion = run_meds_conversion( + mimic_root=args.mimic_root, + meds_dir=meds_dir, + meds_reader_dir=meds_reader_dir, + num_shards=args.num_shards, + num_proc=args.num_proc, + backend=args.backend, + force_reconvert=args.force_reconvert and r == 0, # Only reconvert on first repeat + skip_conversion=args.skip_conversion or r > 0, # Reuse on subsequent repeats + ) + + print(f"\n threads={t} repeat={r + 1}/{args.repeats}: Processing task...") + task_start = time.time() + + # Step 1: Extract samples using meds_reader (parallel) + samples = [] + with meds_reader.SubjectDatabase(meds_reader_dir, num_threads=t) as database: + for s in database.map(get_los_samples): + samples.extend(s) + + # Step 2: Build vocabularies (matching PyHealth's processor.fit()) + conditions_processor = SequenceProcessor() + procedures_processor = SequenceProcessor() + drugs_processor = SequenceProcessor() + label_processor = MulticlassLabelProcessor() + + conditions_processor.fit(samples, "conditions") + procedures_processor.fit(samples, "procedures") + drugs_processor.fit(samples, "drugs") + label_processor.fit(samples, "label") + + # Step 3: Tokenize samples (matching PyHealth's processor.process()) + processed_samples = [] + for sample in samples: + processed_sample = { + "visit_id": sample["visit_id"], + "patient_id": sample["patient_id"], + "conditions": conditions_processor.process(sample["conditions"]), + "procedures": procedures_processor.process(sample["procedures"]), + "drugs": drugs_processor.process(sample["drugs"]), + "label": label_processor.process(sample["label"]), + "los_days": sample["los_days"], + } + processed_samples.append(processed_sample) + + # Step 4: Wrap in PyTorch Dataset for model training compatibility + dataset = MedsReaderSampleDataset( + samples=processed_samples, + input_schema={ + "conditions": "sequence", + "procedures": "sequence", + "drugs": "sequence", + }, + output_schema={"label": "multiclass"}, + input_processors={ + "conditions": conditions_processor, + "procedures": procedures_processor, + "drugs": drugs_processor, + }, + output_processors={"label": label_processor}, + dataset_name="MIMIC-IV", + task_name="LengthOfStayPrediction", + ) + + task_process_s = time.time() - task_start + total_s = time.time() - run_start + peak_rss_bytes = tracker.peak_bytes() + num_samples = len(dataset) + + results.append( + RunResult( + num_threads=t, + repeat_index=r, + meds_etl_s=conversion.meds_etl_s, + meds_reader_convert_s=conversion.meds_reader_convert_s, + task_process_s=task_process_s, + total_s=total_s, + peak_rss_bytes=peak_rss_bytes, + num_samples=num_samples, + conversion_cached=conversion.was_cached, + ) + ) + + # Build output message + timing_str = f"task={task_process_s:.2f}s" + if not conversion.was_cached: + timing_str = ( + f"meds_etl={conversion.meds_etl_s:.2f}s " + f"convert={conversion.meds_reader_convert_s:.2f}s " + + timing_str + f" total={total_s:.2f}s" + ) + + print( + f" ✓ threads={t:>2} repeat={r + 1:>2}/{args.repeats} " + f"samples={num_samples} " + f"{timing_str} " + f"peak_rss={format_size(peak_rss_bytes)} " + f"vocab_sizes=({conditions_processor.size()},{procedures_processor.size()},{drugs_processor.size()})" + ) + + total_sweep_s = time.time() - total_start + + # Write CSV + out_csv = Path(args.output_csv) + out_csv.parent.mkdir(parents=True, exist_ok=True) + with out_csv.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=list(asdict(results[0]).keys())) + writer.writeheader() + for rr in results: + writer.writerow(asdict(rr)) + + # Print summary + print("\n" + "=" * 80) + print("SUMMARY (median across repeats)") + print("=" * 80) + + # Check if any results have conversion times + has_conversion = any(not rr.conversion_cached for rr in results) + + if has_conversion: + print("\n NOTE: Conversion time included for fair comparison with PyHealth.") + print(" PyHealth's dataset_load_s ≈ meds_etl_s + meds_reader_convert_s") + else: + print("\n WARNING: Conversion was cached. For fair benchmarking, use --force-reconvert") + + print() + for t in args.threads: + trs = [rr for rr in results if rr.num_threads == t] + med_task = median([rr.task_process_s for rr in trs]) + med_total = median([rr.total_s for rr in trs]) + med_peak = median([float(rr.peak_rss_bytes) for rr in trs]) + + # Get conversion times (from first repeat which has them if --force-reconvert) + first_run = [rr for rr in trs if rr.repeat_index == 0][0] + + if not first_run.conversion_cached: + print( + f"threads={t:>2} " + f"meds_etl={first_run.meds_etl_s:>7.2f}s " + f"convert={first_run.meds_reader_convert_s:>7.2f}s " + f"task_med={med_task:>7.2f}s " + f"total={med_total:>7.2f}s " + f"peak_rss={format_size(int(med_peak)):>10}" + ) + else: + print( + f"threads={t:>2} " + f"task_med={med_task:>8.2f}s " + f"(conversion cached) " + f"peak_rss_med={format_size(int(med_peak)):>10}" + ) + + print("\nArtifacts:") + print(f" - CSV: {out_csv}") + print(f" - MEDS database: {meds_reader_dir}") + print("\nTotals:") + print(f" - Sweep wall time: {total_sweep_s:.2f}s") + + # Print comparison note + print("\nFor comparison with PyHealth:") + print(" PyHealth total_s = dataset_load_s + task_process_s") + print(" meds_reader total_s = meds_etl_s + meds_reader_convert_s + task_process_s") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/benchmark_perf/meds_reader_ver/benchmark_meds_reader_los_pyhealth_etl.py b/examples/benchmark_perf/meds_reader_ver/benchmark_meds_reader_los_pyhealth_etl.py new file mode 100644 index 000000000..b8fa94ee1 --- /dev/null +++ b/examples/benchmark_perf/meds_reader_ver/benchmark_meds_reader_los_pyhealth_etl.py @@ -0,0 +1,780 @@ +"""Benchmark script for MIMIC-IV length of stay prediction using meds_reader. + +This is a FALLBACK version that uses PyHealth 1.1.6 for ETL instead of meds_etl_mimic. +Use this if meds_etl_mimic fails to run properly. + +Pipeline: +1. Load MIMIC-IV data using PyHealth 1.1.6 (MIMIC4Dataset) +2. Convert PyHealth data structures to MEDS format (parquet files) +3. Run meds_reader_convert to create meds_reader database +4. Process the task using meds_reader +5. Return samples in a PyTorch-compatible Dataset + +IMPORTANT: For fair comparison with PyHealth, conversion time MUST be included. + +Typical usage: + # First install dependencies: + pip install pyhealth==1.1.6 meds_reader pyarrow + + # Run benchmark: + python benchmark_meds_reader_los_pyhealth_etl.py + python benchmark_meds_reader_los_pyhealth_etl.py --threads 1,4,8,12,16 --repeats 3 +""" + +from __future__ import annotations + +import argparse +import collections +import csv +import datetime +import os +import shutil +import subprocess +import threading +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, Iterator, List + +import numpy as np +import psutil +import pyarrow as pa +import pyarrow.parquet as pq +import torch +from torch.utils.data import Dataset + +try: + import meds_reader +except ImportError: + raise ImportError( + "meds_reader not found. Install with: pip install meds_reader\n" + "Or from source: pip install -e /path/to/meds_reader" + ) + +# Import PyHealth 1.1.6 +try: + from pyhealth.datasets import MIMIC4Dataset +except ImportError: + raise ImportError( + "PyHealth not found. Install with: pip install pyhealth==1.1.6" + ) + + +# ============================================================================= +# PyTorch Dataset Wrapper +# ============================================================================= + +class MedsReaderSampleDataset(Dataset): + """PyTorch Dataset wrapper for meds_reader samples. + + This provides a standard PyTorch Dataset interface for the processed samples, + making them compatible with PyTorch DataLoader for model training. + + Attributes: + samples: List of processed sample dictionaries + input_schema: Schema describing input features + output_schema: Schema describing output features + input_processors: Fitted processors for input features + output_processors: Fitted processors for output features + """ + + def __init__( + self, + samples: List[Dict[str, Any]], + input_schema: Dict[str, str], + output_schema: Dict[str, str], + input_processors: Dict[str, Any], + output_processors: Dict[str, Any], + dataset_name: str = "", + task_name: str = "", + ): + self.samples = samples + self.input_schema = input_schema + self.output_schema = output_schema + self.input_processors = input_processors + self.output_processors = output_processors + self.dataset_name = dataset_name + self.task_name = task_name + + # Build patient and record indices for splitting + self.patient_to_index: Dict[Any, List[int]] = collections.defaultdict(list) + self.record_to_index: Dict[Any, List[int]] = collections.defaultdict(list) + + for idx, sample in enumerate(samples): + if "patient_id" in sample: + self.patient_to_index[sample["patient_id"]].append(idx) + if "visit_id" in sample: + self.record_to_index[sample["visit_id"]].append(idx) + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, index: int) -> Dict[str, Any]: + return self.samples[index] + + def __repr__(self) -> str: + return f"MedsReaderSampleDataset({self.dataset_name}, {self.task_name}, n={len(self)})" + + def get_all_tokens(self, key: str) -> set: + """Get all unique tokens for a given key across all samples.""" + tokens = set() + for sample in self.samples: + if key in sample: + val = sample[key] + if isinstance(val, torch.Tensor): + tokens.update(val.tolist()) + elif isinstance(val, list): + tokens.update(val) + return tokens + + +# ============================================================================= +# Processor Classes (matching PyHealth's SequenceProcessor for fair comparison) +# ============================================================================= + +class SequenceProcessor: + """Matches PyHealth's SequenceProcessor for vocabulary building and tokenization.""" + + def __init__(self): + self.code_vocab = {"": 0} + self._next_index = 1 + + def fit(self, samples, field): + """Build vocabulary from all samples.""" + for sample in samples: + if field not in sample: + continue + for token in sample[field]: + if token is None: + continue + if token not in self.code_vocab: + self.code_vocab[token] = self._next_index + self._next_index += 1 + self.code_vocab[""] = len(self.code_vocab) + + def process(self, value): + """Convert code strings to tensor of indices.""" + indices = [] + for token in value: + if token in self.code_vocab: + indices.append(self.code_vocab[token]) + else: + indices.append(self.code_vocab[""]) + return torch.tensor(indices, dtype=torch.long) + + def size(self): + return len(self.code_vocab) + + +class MulticlassLabelProcessor: + """Processor for multiclass labels (matching PyHealth).""" + + def __init__(self): + self.label_vocab = {} + + def fit(self, samples, field): + """Build vocabulary from all samples.""" + for sample in samples: + if field in sample: + val = sample[field] + if val not in self.label_vocab: + self.label_vocab[val] = len(self.label_vocab) + + def process(self, value): + """Convert label to tensor.""" + return torch.tensor(self.label_vocab.get(value, 0), dtype=torch.long) + + def size(self): + return len(self.label_vocab) + + +# ============================================================================= +# Data Conversion (PyHealth 1.1.6 -> MEDS -> meds_reader) +# ============================================================================= + +def pyhealth_to_meds( + pyhealth_root: str, + output_dir: str, + tables: List[str], + dev: bool = False, + num_shards: int = 100, +) -> float: + """Convert MIMIC-IV data via PyHealth 1.1.6 to MEDS format. + + This loads the data using PyHealth's MIMIC4Dataset, then converts + the internal data structures to MEDS parquet format. + + Args: + pyhealth_root: Path to MIMIC-IV hosp directory + output_dir: Path to output MEDS dataset + tables: List of table names to load + dev: Whether to use dev mode (smaller subset) + num_shards: Number of output shards + + Returns: + Time taken in seconds + """ + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + print(" Loading MIMIC-IV via PyHealth 1.1.6...") + print(f" Root: {pyhealth_root}") + print(f" Tables: {tables}") + print(f" Dev mode: {dev}") + + start = time.time() + + # Step 1: Load data using PyHealth 1.1.6 + dataset = MIMIC4Dataset( + root=pyhealth_root, + tables=tables, + dev=dev, + refresh_cache=True, # Always refresh for timing + ) + + pyhealth_load_time = time.time() - start + print(f" PyHealth load completed in {pyhealth_load_time:.2f}s") + + # Step 2: Convert to MEDS format + print(" Converting to MEDS format...") + convert_start = time.time() + + results = collections.defaultdict(list) + + for patient_id, patient in dataset.patients.items(): + subject_id = int(patient_id) + + # Birth event + if patient.birth_datetime is not None: + birth_obj = { + 'subject_id': subject_id, + 'code': 'meds/birth', + 'time': patient.birth_datetime, + } + if hasattr(patient, 'gender') and patient.gender: + birth_obj['gender'] = patient.gender + if hasattr(patient, 'ethnicity') and patient.ethnicity: + birth_obj['ethnicity'] = patient.ethnicity + results[subject_id].append(birth_obj) + + # Death event + if patient.death_datetime is not None: + results[subject_id].append({ + 'subject_id': subject_id, + 'code': 'meds/death', + 'time': patient.death_datetime, + }) + + # Process visits + for visit_id, visit in patient.visits.items(): + visit_id_int = int(visit_id) + + # Visit/Admission event + visit_event = { + 'subject_id': subject_id, + 'code': f'MIMIC_IV_Admission/{visit.visit_type if hasattr(visit, "visit_type") else "unknown"}', + 'time': visit.encounter_time, + 'visit_id': visit_id_int, + } + if visit.discharge_time: + visit_event['end'] = visit.discharge_time + if hasattr(visit, 'discharge_status'): + visit_event['discharge_status'] = visit.discharge_status + + results[subject_id].append(visit_event) + + # Process events from each table + for table in visit.available_tables: + for event in visit.get_event_list(table): + event_obj = { + 'subject_id': subject_id, + 'visit_id': visit_id_int, + 'code': f'{event.vocabulary}/{event.code}', + 'time': event.timestamp or visit.discharge_time, + } + + # Add any extra attributes + if hasattr(event, 'attr_dict') and event.attr_dict: + for k, v in event.attr_dict.items(): + if v == v: # Skip NaN + event_obj[k] = v + + results[subject_id].append(event_obj) + + # Sort events by time + results[subject_id].sort(key=lambda a: a['time'] if a['time'] else datetime.datetime.min) + + # Write to parquet shards + os.makedirs(output_dir, exist_ok=True) + os.makedirs(f"{output_dir}/metadata", exist_ok=True) + os.makedirs(f"{output_dir}/data", exist_ok=True) + + all_subjects = list(results.keys()) + subject_ids_per_shard = np.array_split(all_subjects, num_shards) + + # Build schema dynamically + attr_map = { + str: pa.string(), + int: pa.int64(), + np.int64: pa.int64(), + float: pa.float64(), + datetime.datetime: pa.timestamp('us'), + } + + # Collect all attribute types + attr_schema = {} + for subject_values in results.values(): + for row in subject_values: + for k, v in row.items(): + if k not in {'subject_id', 'time'} and v is not None: + pa_type = attr_map.get(type(v), pa.string()) + if k not in attr_schema: + attr_schema[k] = pa_type + + schema = pa.schema([ + ('subject_id', pa.int64()), + ('time', pa.timestamp('us')), + ] + [(k, v) for k, v in sorted(attr_schema.items())]) + + for i, subject_ids in enumerate(subject_ids_per_shard): + if len(subject_ids) == 0: + continue + rows = [v for subject_id in subject_ids for v in results[subject_id]] + if rows: + table = pa.Table.from_pylist(rows, schema=schema) + pq.write_table(table, f"{output_dir}/data/{i}.parquet") + + convert_time = time.time() - convert_start + total_time = time.time() - start + + print(f" MEDS conversion completed in {convert_time:.2f}s") + print(f" Total PyHealth ETL time: {total_time:.2f}s") + + return total_time + + +def run_meds_reader_convert(input_dir: str, output_dir: str, num_threads: int = 10) -> float: + """Run meds_reader_convert CLI tool. Returns time taken.""" + print(f" Running meds_reader_convert (threads={num_threads})...") + print(f" {input_dir} -> {output_dir}") + + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + start = time.time() + try: + subprocess.run( + ["meds_reader_convert", input_dir, output_dir, + "--num_threads", str(num_threads)], + capture_output=True, + text=True, + check=True, + ) + elapsed = time.time() - start + print(f" meds_reader_convert completed in {elapsed:.2f}s") + return elapsed + except subprocess.CalledProcessError as e: + print(" ERROR: meds_reader_convert failed:") + print(f" stdout: {e.stdout}") + print(f" stderr: {e.stderr}") + raise + except FileNotFoundError: + print(" ERROR: meds_reader_convert not found in PATH") + raise + + +@dataclass +class ConversionResult: + """Holds timing information for the conversion process.""" + pyhealth_etl_s: float # Time for PyHealth load + MEDS conversion + meds_reader_convert_s: float + total_conversion_s: float + was_cached: bool + + +def run_pyhealth_meds_conversion( + pyhealth_root: str, + meds_dir: str, + meds_reader_dir: str, + tables: List[str], + dev: bool, + num_shards: int, + num_threads: int, + force_reconvert: bool, + skip_conversion: bool, +) -> ConversionResult: + """Run PyHealth-based MEDS conversion and return timing information.""" + + if skip_conversion: + if not Path(meds_reader_dir).exists(): + raise SystemExit( + f"Cannot skip conversion: MEDS database does not exist at {meds_reader_dir}\n" + "Run without --skip-conversion first." + ) + print("✓ Skipping conversion (using cached MEDS database)") + return ConversionResult(0.0, 0.0, 0.0, True) + + if Path(meds_reader_dir).exists() and not force_reconvert: + print(f"✓ MEDS database exists: {meds_reader_dir}") + return ConversionResult(0.0, 0.0, 0.0, True) + + print(f"\n{'='*60}") + print("Converting MIMIC-IV to MEDS format via PyHealth 1.1.6") + print(f"{'='*60}") + + # Clear existing caches + if Path(meds_dir).exists(): + print(f" Clearing existing MEDS cache: {meds_dir}") + shutil.rmtree(meds_dir) + if Path(meds_reader_dir).exists(): + print(f" Clearing existing meds_reader cache: {meds_reader_dir}") + shutil.rmtree(meds_reader_dir) + + # Step 1: PyHealth ETL + print("\n[Step 1/2] Loading via PyHealth and converting to MEDS...") + pyhealth_etl_s = pyhealth_to_meds( + pyhealth_root=pyhealth_root, + output_dir=meds_dir, + tables=tables, + dev=dev, + num_shards=num_shards, + ) + + # Step 2: meds_reader_convert + print("\n[Step 2/2] Running meds_reader_convert...") + meds_reader_convert_s = run_meds_reader_convert( + meds_dir, meds_reader_dir, num_threads=num_threads + ) + + total = pyhealth_etl_s + meds_reader_convert_s + print(f"\n✓ MEDS database ready. Total conversion: {total:.2f}s") + + return ConversionResult(pyhealth_etl_s, meds_reader_convert_s, total, False) + + +# ============================================================================= +# Task Function - Length of Stay Prediction +# ============================================================================= + +def get_los_samples(subjects: Iterator[meds_reader.Subject]): + """Process subjects for length of stay prediction task.""" + samples = [] + + for subject in subjects: + admission_data = {} + + for event in subject.events: + if event.code.startswith("MIMIC_IV_Admission/"): + visit_id = getattr(event, 'visit_id', None) + end_time = getattr(event, 'end', None) + + if visit_id is not None and event.time is not None and end_time is not None: + los_days = (end_time - event.time).days + + # Categorize LOS (matching PyHealth) + if los_days < 1: + los_label = 0 + elif los_days <= 7: + los_label = los_days + elif los_days <= 14: + los_label = 8 + else: + los_label = 9 + + admission_data[visit_id] = { + 'start': event.time, + 'los_days': los_days, + 'label': los_label, + 'conditions': set(), + 'procedures': set(), + 'drugs': set(), + } + + for event in subject.events: + visit_id = getattr(event, 'visit_id', None) + if visit_id is None or visit_id not in admission_data: + continue + + code = event.code + if code.startswith("ICD"): + if "CM" in code: + admission_data[visit_id]['conditions'].add(code) + else: + admission_data[visit_id]['procedures'].add(code) + elif code.startswith("NDC/") or code.startswith("MIMIC_IV_Drug/"): + admission_data[visit_id]['drugs'].add(code) + + for visit_id, data in admission_data.items(): + conditions = list(data['conditions']) + procedures = list(data['procedures']) + drugs = list(data['drugs']) + + if len(conditions) == 0 or len(procedures) == 0 or len(drugs) == 0: + continue + + samples.append({ + "visit_id": visit_id, + "patient_id": subject.subject_id, + "conditions": conditions, + "procedures": procedures, + "drugs": drugs, + "label": data['label'], + "los_days": data['los_days'], + }) + + return samples + + +# ============================================================================= +# Benchmark Infrastructure +# ============================================================================= + +@dataclass +class RunResult: + num_threads: int + repeat_index: int + pyhealth_etl_s: float + meds_reader_convert_s: float + task_process_s: float + total_s: float + peak_rss_bytes: int + num_samples: int + conversion_cached: bool + + +def format_size(size_bytes: int) -> str: + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} PB" + + +class PeakMemoryTracker: + def __init__(self, poll_interval_s: float = 0.1): + self._proc = psutil.Process(os.getpid()) + self._poll_interval_s = poll_interval_s + self._stop = threading.Event() + self._lock = threading.Lock() + self._peak = 0 + self._thread = threading.Thread(target=self._run, daemon=True) + + def start(self): + self._thread.start() + + def reset(self): + with self._lock: + self._peak = 0 + + def peak_bytes(self) -> int: + with self._lock: + return self._peak + + def _total_rss_bytes(self) -> int: + total = 0 + try: + total += self._proc.memory_info().rss + for child in self._proc.children(recursive=True): + try: + total += child.memory_info().rss + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + return total + + def _run(self): + while not self._stop.is_set(): + rss = self._total_rss_bytes() + with self._lock: + if rss > self._peak: + self._peak = rss + time.sleep(self._poll_interval_s) + + +def parse_threads(value: str) -> List[int]: + parts = [p.strip() for p in value.split(",") if p.strip()] + return [int(p) for p in parts if int(p) > 0] + + +def median(values: Iterable[float]) -> float: + xs = sorted(values) + if not xs: + return 0.0 + mid = len(xs) // 2 + return xs[mid] if len(xs) % 2 == 1 else (xs[mid - 1] + xs[mid]) / 2.0 + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark meds_reader LOS using PyHealth 1.1.6 ETL (fallback)" + ) + parser.add_argument( + "--threads", type=parse_threads, default=[1, 4, 8, 12, 16], + help="Comma-separated list of thread counts", + ) + parser.add_argument("--repeats", type=int, default=1) + parser.add_argument( + "--pyhealth-root", type=str, + default="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp", + help="Path to MIMIC-IV hosp directory (for PyHealth 1.1.6)", + ) + parser.add_argument("--cache-dir", type=str, default="/srv/local/data/johnwu3/meds_reader") + parser.add_argument("--num-shards", type=int, default=100) + parser.add_argument("--num-threads", type=int, default=8) + parser.add_argument("--dev", action="store_true") + parser.add_argument("--force-reconvert", action="store_true") + parser.add_argument("--skip-conversion", action="store_true") + parser.add_argument( + "--output-csv", type=str, + default="benchmark_meds_reader_los_pyhealth_etl.csv", + ) + args = parser.parse_args() + + # Task-specific cache directories + meds_dir = f"{args.cache_dir}/mimic4_meds_los_pyhealth" + meds_reader_dir = f"{args.cache_dir}/mimic4_meds_reader_los_pyhealth" + + print("=" * 80) + print("BENCHMARK: meds_reader LOS (PyHealth 1.1.6 ETL - Fallback)") + print(f"threads={args.threads} repeats={args.repeats} dev={args.dev}") + print(f"pyhealth_root: {args.pyhealth_root}") + print("=" * 80) + + tracker = PeakMemoryTracker() + tracker.start() + + total_start = time.time() + results: List[RunResult] = [] + + # Tables needed for LOS task + tables = ["diagnoses_icd", "procedures_icd", "prescriptions"] + + for t in args.threads: + for r in range(args.repeats): + tracker.reset() + run_start = time.time() + + conversion = run_pyhealth_meds_conversion( + pyhealth_root=args.pyhealth_root, + meds_dir=meds_dir, + meds_reader_dir=meds_reader_dir, + tables=tables, + dev=args.dev, + num_shards=args.num_shards, + num_threads=args.num_threads, + force_reconvert=args.force_reconvert and r == 0, + skip_conversion=args.skip_conversion or r > 0, + ) + + print(f"\n threads={t} repeat={r + 1}/{args.repeats}: Processing task...") + task_start = time.time() + + # Extract samples + samples = [] + with meds_reader.SubjectDatabase(meds_reader_dir, num_threads=t) as db: + for s in db.map(get_los_samples): + samples.extend(s) + + # Build processors + conditions_proc = SequenceProcessor() + procedures_proc = SequenceProcessor() + drugs_proc = SequenceProcessor() + label_proc = MulticlassLabelProcessor() + + conditions_proc.fit(samples, "conditions") + procedures_proc.fit(samples, "procedures") + drugs_proc.fit(samples, "drugs") + label_proc.fit(samples, "label") + + # Process samples + processed = [] + for sample in samples: + processed.append({ + "visit_id": sample["visit_id"], + "patient_id": sample["patient_id"], + "conditions": conditions_proc.process(sample["conditions"]), + "procedures": procedures_proc.process(sample["procedures"]), + "drugs": drugs_proc.process(sample["drugs"]), + "label": label_proc.process(sample["label"]), + "los_days": sample["los_days"], + }) + + # Create PyTorch Dataset + dataset = MedsReaderSampleDataset( + samples=processed, + input_schema={ + "conditions": "sequence", + "procedures": "sequence", + "drugs": "sequence", + }, + output_schema={"label": "multiclass"}, + input_processors={ + "conditions": conditions_proc, + "procedures": procedures_proc, + "drugs": drugs_proc, + }, + output_processors={"label": label_proc}, + dataset_name="MIMIC-IV", + task_name="LengthOfStayPrediction", + ) + + task_process_s = time.time() - task_start + total_s = time.time() - run_start + peak_rss = tracker.peak_bytes() + + results.append(RunResult( + num_threads=t, + repeat_index=r, + pyhealth_etl_s=conversion.pyhealth_etl_s, + meds_reader_convert_s=conversion.meds_reader_convert_s, + task_process_s=task_process_s, + total_s=total_s, + peak_rss_bytes=peak_rss, + num_samples=len(dataset), + conversion_cached=conversion.was_cached, + )) + + timing = f"task={task_process_s:.2f}s" + if not conversion.was_cached: + timing = (f"pyhealth_etl={conversion.pyhealth_etl_s:.2f}s " + f"convert={conversion.meds_reader_convert_s:.2f}s " + f"{timing} total={total_s:.2f}s") + + print(f" ✓ threads={t:>2} samples={len(dataset)} {timing} " + f"peak_rss={format_size(peak_rss)}") + + total_sweep_s = time.time() - total_start + + # Write CSV + out_csv = Path(args.output_csv) + out_csv.parent.mkdir(parents=True, exist_ok=True) + with out_csv.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=list(asdict(results[0]).keys())) + writer.writeheader() + for rr in results: + writer.writerow(asdict(rr)) + + # Summary + print("\n" + "=" * 80) + print("SUMMARY") + print("=" * 80) + for t in args.threads: + trs = [rr for rr in results if rr.num_threads == t] + med_task = median([rr.task_process_s for rr in trs]) + first = [rr for rr in trs if rr.repeat_index == 0][0] + if not first.conversion_cached: + print(f"threads={t:>2} pyhealth_etl={first.pyhealth_etl_s:.2f}s " + f"convert={first.meds_reader_convert_s:.2f}s " + f"task_med={med_task:.2f}s") + else: + print(f"threads={t:>2} task_med={med_task:.2f}s (cached)") + + print(f"\nSweep time: {total_sweep_s:.2f}s") + print(f"CSV: {out_csv}") + print("=" * 80) + + +if __name__ == "__main__": + main() + diff --git a/examples/benchmark_perf/meds_reader_ver/benchmark_meds_reader_mortality.py b/examples/benchmark_perf/meds_reader_ver/benchmark_meds_reader_mortality.py new file mode 100644 index 000000000..dcff6c460 --- /dev/null +++ b/examples/benchmark_perf/meds_reader_ver/benchmark_meds_reader_mortality.py @@ -0,0 +1,850 @@ +"""Benchmark script for MIMIC-IV mortality prediction using meds_reader. + +This benchmark measures performance across multiple thread counts: +1. Time for MEDS ETL conversion (MIMIC-IV -> MEDS format) +2. Time for meds_reader database conversion (MEDS -> meds_reader format) +3. Time to process the task +4. Peak memory usage (RSS, includes child processes) +5. Number of samples generated + +IMPORTANT: For fair comparison with PyHealth, conversion time MUST be included. +PyHealth's dataset loading includes parsing raw MIMIC-IV CSVs, so we must +account for the equivalent preprocessing time in meds_reader. + +This script uses meds_etl for data conversion: +- Converts MIMIC-IV directly to MEDS format via meds_etl_mimic +- Runs meds_reader_convert to prepare the database +- Then runs the benchmark + +Typical usage: + # First install dependencies: + pip install meds_etl meds_reader + + # Run benchmark (includes conversion time by default): + python benchmark_meds_reader_mortality.py + python benchmark_meds_reader_mortality.py --threads 1,4,8,12,16 --repeats 3 + + # Skip conversion (only for debugging, not fair benchmarking): + python benchmark_meds_reader_mortality.py --skip-conversion +""" + +from __future__ import annotations + +import argparse +import collections +import csv +import os +import shutil +import subprocess +import threading +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, Iterator, List + +import psutil +import torch +from torch.utils.data import Dataset + +try: + import meds_reader +except ImportError: + raise ImportError( + "meds_reader not found. Install with: pip install meds_reader\n" + "Or from source: pip install -e /path/to/meds_reader" + ) + + +# ============================================================================= +# PyTorch Dataset Wrapper +# ============================================================================= + +class MedsReaderSampleDataset(Dataset): + """PyTorch Dataset wrapper for meds_reader samples. + + Provides a standard PyTorch Dataset interface for model training. + """ + + def __init__( + self, + samples: List[Dict[str, Any]], + input_schema: Dict[str, str], + output_schema: Dict[str, str], + input_processors: Dict[str, Any], + output_processors: Dict[str, Any], + dataset_name: str = "", + task_name: str = "", + ): + self.samples = samples + self.input_schema = input_schema + self.output_schema = output_schema + self.input_processors = input_processors + self.output_processors = output_processors + self.dataset_name = dataset_name + self.task_name = task_name + + self.patient_to_index: Dict[Any, List[int]] = collections.defaultdict(list) + self.record_to_index: Dict[Any, List[int]] = collections.defaultdict(list) + + for idx, sample in enumerate(samples): + if "patient_id" in sample: + self.patient_to_index[sample["patient_id"]].append(idx) + if "visit_id" in sample: + self.record_to_index[sample["visit_id"]].append(idx) + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, index: int) -> Dict[str, Any]: + return self.samples[index] + + def __repr__(self) -> str: + return f"MedsReaderSampleDataset({self.dataset_name}, {self.task_name}, n={len(self)})" + + +# ============================================================================= +# Processor Classes (matching PyHealth's SequenceProcessor for fair comparison) +# ============================================================================= + +class SequenceProcessor: + """Matches PyHealth's SequenceProcessor for vocabulary building and tokenization.""" + + def __init__(self): + self.code_vocab = {"": 0} + self._next_index = 1 + + def fit(self, samples, field): + """Build vocabulary from all samples.""" + for sample in samples: + if field not in sample: + continue + for token in sample[field]: + if token is None: + continue + if token not in self.code_vocab: + self.code_vocab[token] = self._next_index + self._next_index += 1 + self.code_vocab[""] = len(self.code_vocab) + + def process(self, value): + """Convert code strings to tensor of indices.""" + indices = [] + for token in value: + if token in self.code_vocab: + indices.append(self.code_vocab[token]) + else: + indices.append(self.code_vocab[""]) + return torch.tensor(indices, dtype=torch.long) + + def size(self): + return len(self.code_vocab) + + +class BinaryLabelProcessor: + """Processor for binary labels (matching PyHealth's BinaryLabelProcessor).""" + + def __init__(self): + self.label_vocab = {0: 0, 1: 1} + + def fit(self, samples, field): + """Build vocabulary from all label values.""" + for sample in samples: + if field in sample: + val = sample[field] + if val not in self.label_vocab: + self.label_vocab[val] = len(self.label_vocab) + + def process(self, value): + """Convert label to tensor.""" + return torch.tensor([self.label_vocab.get(value, 0)], dtype=torch.float32) + + def size(self): + return 1 + + +try: + import resource + HAS_RESOURCE = True +except ImportError: + HAS_RESOURCE = False + + +# Lab item IDs for StageNet (matching PyHealth's implementation) +LAB_ITEM_IDS = { + "50824", "52455", "50983", "52623", # Sodium + "50822", "52452", "50971", "52610", # Potassium + "50806", "52434", "50902", "52535", # Chloride + "50803", "50804", # Bicarbonate + "50809", "52027", "50931", "52569", # Glucose + "50808", "51624", # Calcium + "50960", # Magnesium + "50868", "52500", # Anion Gap + "52031", "50964", "51701", # Osmolality + "50970", # Phosphate +} + + +# ============================================================================= +# Data Conversion (MIMIC-IV -> MEDS -> meds_reader via meds_etl) +# ============================================================================= + +def run_meds_etl_mimic( + src_mimic: str, + output_dir: str, + num_shards: int = 100, + num_proc: int = 1, + backend: str = "polars", +) -> float: + """Run meds_etl_mimic to convert MIMIC-IV to MEDS format. + + Args: + src_mimic: Path to MIMIC-IV root (containing 2.2/ subdirectory) + output_dir: Path to output MEDS dataset + num_shards: Number of shards for processing + num_proc: Number of processes to use + backend: Backend to use (polars or cpp) + + Returns: + Time taken in seconds + """ + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + print(f" Running meds_etl_mimic (shards={num_shards}, proc={num_proc}, backend={backend})...") + print(f" Source: {src_mimic}") + print(f" Destination: {output_dir}") + + start = time.time() + result = subprocess.run( + [ + "meds_etl_mimic", + src_mimic, + output_dir, + "--num_shards", str(num_shards), + "--num_proc", str(num_proc), + "--backend", backend, + ], + capture_output=True, + text=True, + ) + elapsed = time.time() - start + + if result.returncode != 0: + print(f" STDOUT: {result.stdout}") + print(f" STDERR: {result.stderr}") + raise RuntimeError(f"meds_etl_mimic failed with code {result.returncode}") + + print(f" meds_etl_mimic completed in {elapsed:.2f}s") + return elapsed + + +def run_meds_reader_convert(input_dir: str, output_dir: str, num_threads: int = 10) -> float: + """Run meds_reader_convert CLI tool. Returns time taken.""" + print(f" Running meds_reader_convert (threads={num_threads})...") + print(f" {input_dir} -> {output_dir}") + + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + start = time.time() + try: + result = subprocess.run( + ["meds_reader_convert", input_dir, output_dir, "--num_threads", str(num_threads)], + capture_output=True, + text=True, + check=True, + ) + elapsed = time.time() - start + print(f" meds_reader_convert completed in {elapsed:.2f}s") + return elapsed + except subprocess.CalledProcessError as e: + print(f" ERROR: meds_reader_convert failed:") + print(f" stdout: {e.stdout}") + print(f" stderr: {e.stderr}") + raise + except FileNotFoundError: + print(f" ERROR: meds_reader_convert not found in PATH") + raise + + +@dataclass +class ConversionResult: + """Holds timing information for the MEDS conversion process.""" + meds_etl_s: float + meds_reader_convert_s: float + total_conversion_s: float + was_cached: bool # True if conversion was skipped due to existing cache + + +def run_meds_conversion( + mimic_root: str, + meds_dir: str, + meds_reader_dir: str, + num_shards: int, + num_proc: int, + backend: str, + force_reconvert: bool, + skip_conversion: bool, +) -> ConversionResult: + """Run MEDS conversion and return timing information. + + Args: + mimic_root: Path to MIMIC-IV root directory + meds_dir: Path for intermediate MEDS output + meds_reader_dir: Path for final meds_reader database + num_shards: Number of shards for meds_etl + num_proc: Number of processes for meds_etl + backend: Backend for meds_etl (polars or cpp) + force_reconvert: If True, always reconvert even if cache exists + skip_conversion: If True, skip conversion (for debugging only) + + Returns: + ConversionResult with timing information + """ + # Check if we should skip conversion + if skip_conversion: + if not Path(meds_reader_dir).exists(): + raise SystemExit( + f"Cannot skip conversion: MEDS database does not exist at {meds_reader_dir}\n" + "Run without --skip-conversion first." + ) + print(f"✓ Skipping conversion (using cached MEDS database: {meds_reader_dir})") + print(" WARNING: For fair benchmarking, conversion time should be included!") + return ConversionResult( + meds_etl_s=0.0, + meds_reader_convert_s=0.0, + total_conversion_s=0.0, + was_cached=True, + ) + + # Check if we can reuse existing cache + if Path(meds_reader_dir).exists() and not force_reconvert: + print(f"✓ MEDS database exists: {meds_reader_dir}") + print(" NOTE: Using cached data. Use --force-reconvert for fresh timing.") + return ConversionResult( + meds_etl_s=0.0, + meds_reader_convert_s=0.0, + total_conversion_s=0.0, + was_cached=True, + ) + + print(f"\n{'='*60}") + print(f"Converting MIMIC-IV to MEDS format") + print(f"{'='*60}") + + # Clear existing cache directories to avoid interference + if Path(meds_dir).exists(): + print(f" Clearing existing MEDS cache: {meds_dir}") + shutil.rmtree(meds_dir) + if Path(meds_reader_dir).exists(): + print(f" Clearing existing meds_reader cache: {meds_reader_dir}") + shutil.rmtree(meds_reader_dir) + + # Verify MIMIC-IV structure + mimic_version_path = os.path.join(mimic_root, "2.2") + if not os.path.exists(mimic_version_path): + raise SystemExit( + f"ERROR: Expected MIMIC-IV version directory not found: {mimic_version_path}\n" + f"meds_etl_mimic expects the MIMIC-IV data to be in {{mimic_root}}/2.2/" + ) + + # Step 1: Convert MIMIC-IV -> MEDS using meds_etl + print(f"\n[Step 1/2] Converting MIMIC-IV to MEDS format using meds_etl...") + meds_etl_s = run_meds_etl_mimic( + src_mimic=mimic_root, + output_dir=meds_dir, + num_shards=num_shards, + num_proc=num_proc, + backend=backend, + ) + + # Step 2: Run meds_reader_convert + print(f"\n[Step 2/2] Running meds_reader_convert...") + meds_reader_convert_s = run_meds_reader_convert( + meds_dir, meds_reader_dir, num_threads=num_proc + ) + + total_conversion_s = meds_etl_s + meds_reader_convert_s + print(f"\n✓ MEDS database ready: {meds_reader_dir}") + print(f" Total conversion time: {total_conversion_s:.2f}s") + + return ConversionResult( + meds_etl_s=meds_etl_s, + meds_reader_convert_s=meds_reader_convert_s, + total_conversion_s=total_conversion_s, + was_cached=False, + ) + + +# ============================================================================= +# Task Function - Mortality Prediction +# ============================================================================= + +def get_mortality_samples(subjects: Iterator[meds_reader.Subject]): + """Process subjects for mortality prediction with lab events. + + Uses MEDS-ETL code conventions: + - Admission codes are like "MIMIC_IV_Admission/..." + - Diagnosis codes are like "ICD10CM/..." or "ICD9CM/..." + - Procedure codes are like "ICD10PCS/..." or "ICD9Proc/..." + - Lab events are like "MIMIC_IV_LABITEM/..." with itemid + + Mortality prediction predicts death at next visit based on + current visit's conditions, procedures, and labs. + """ + samples = [] + + for subject in subjects: + # Collect all admissions with their data + admissions = {} # visit_id -> {time, conditions, procedures, labs, discharge_status} + + # Track death events + death_time = None + for event in subject.events: + if event.code == "meds/death": + death_time = event.time + break + + # First pass: identify admissions + for event in subject.events: + if event.code.startswith("MIMIC_IV_Admission/"): + visit_id = getattr(event, 'visit_id', None) + end_time = getattr(event, 'end', None) + if visit_id is not None and event.time is not None: + # Check if patient died during this admission + discharge_status = 0 # Alive + if death_time is not None and end_time is not None: + if death_time <= end_time: + discharge_status = 1 # Died + + admissions[visit_id] = { + 'time': event.time, + 'end': end_time, + 'conditions': set(), + 'procedures': set(), + 'labs': set(), + 'discharge_status': discharge_status, + } + + # Second pass: collect features per admission + for event in subject.events: + visit_id = getattr(event, 'visit_id', None) + if visit_id is None or visit_id not in admissions: + continue + + code = event.code + if code.startswith("ICD"): # ICD9CM, ICD10CM, ICD9Proc, ICD10PCS + if "CM" in code: + admissions[visit_id]['conditions'].add(code) + else: + admissions[visit_id]['procedures'].add(code) + elif code.startswith("MIMIC_IV_LABITEM/"): + # Extract itemid and check if it's in our StageNet lab set + item_id = code.split("/")[-1] if "/" in code else "" + if item_id in LAB_ITEM_IDS: + admissions[visit_id]['labs'].add(code) + + # Sort admissions by time + sorted_visits = sorted( + [(vid, data) for vid, data in admissions.items()], + key=lambda x: x[1]['time'] + ) + + # Create samples - predicting mortality at NEXT visit + for i in range(len(sorted_visits) - 1): + visit_id, current = sorted_visits[i] + _, next_visit = sorted_visits[i + 1] + + conditions = list(current['conditions']) + procedures = list(current['procedures']) + labs = list(current['labs']) + + # Target: mortality at next visit + mortality_label = next_visit['discharge_status'] + + # Match PyHealth's filtering: require conditions and labs + if len(conditions) == 0 or len(labs) == 0: + continue + + samples.append({ + "visit_id": visit_id, + "patient_id": subject.subject_id, + "conditions": conditions, + "procedures": procedures, + "labs": labs, + "label": mortality_label, + }) + + return samples + + +# ============================================================================= +# Benchmark Infrastructure +# ============================================================================= + +@dataclass +class RunResult: + num_threads: int + repeat_index: int + meds_etl_s: float # Time for MIMIC-IV -> MEDS conversion + meds_reader_convert_s: float # Time for MEDS -> meds_reader conversion + task_process_s: float # Time to run the ML task + total_s: float # Total time (conversion + task) + peak_rss_bytes: int + num_samples: int + conversion_cached: bool # True if conversion was skipped + + +def format_size(size_bytes: int) -> str: + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} PB" + + +def set_memory_limit(max_memory_gb: float) -> None: + if not HAS_RESOURCE: + print("Warning: resource module not available. Memory limit not enforced.") + return + max_memory_bytes = int(max_memory_gb * 1024**3) + try: + resource.setrlimit(resource.RLIMIT_AS, (max_memory_bytes, max_memory_bytes)) + print(f"✓ Memory limit set to {max_memory_gb} GB") + except Exception as e: + print(f"Warning: Failed to set memory limit: {e}") + + +class PeakMemoryTracker: + def __init__(self, poll_interval_s: float = 0.1) -> None: + self._proc = psutil.Process(os.getpid()) + self._poll_interval_s = poll_interval_s + self._stop = threading.Event() + self._lock = threading.Lock() + self._peak = 0 + self._thread = threading.Thread(target=self._run, daemon=True) + + def start(self) -> None: + self._thread.start() + + def reset(self) -> None: + with self._lock: + self._peak = 0 + + def stop(self) -> None: + self._stop.set() + + def peak_bytes(self) -> int: + with self._lock: + return self._peak + + def _total_rss_bytes(self) -> int: + total = 0 + try: + total += self._proc.memory_info().rss + for child in self._proc.children(recursive=True): + try: + total += child.memory_info().rss + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + return total + + def _run(self) -> None: + while not self._stop.is_set(): + rss = self._total_rss_bytes() + with self._lock: + if rss > self._peak: + self._peak = rss + time.sleep(self._poll_interval_s) + + +def parse_threads(value: str) -> list[int]: + parts = [p.strip() for p in value.split(",") if p.strip()] + threads = [] + for p in parts: + t = int(p) + if t <= 0: + raise argparse.ArgumentTypeError("All thread counts must be > 0") + threads.append(t) + if not threads: + raise argparse.ArgumentTypeError("No threads provided") + return threads + + +def median(values: Iterable[float]) -> float: + xs = sorted(values) + if not xs: + return 0.0 + mid = len(xs) // 2 + if len(xs) % 2 == 1: + return xs[mid] + return (xs[mid - 1] + xs[mid]) / 2.0 + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Benchmark meds_reader for MIMIC-IV mortality prediction" + ) + parser.add_argument( + "--threads", type=parse_threads, default=[1, 4, 8, 12, 16], + help="Comma-separated list of num_threads values (default: 1,4,8,12,16)", + ) + parser.add_argument( + "--repeats", type=int, default=1, + help="Number of repeats per thread setting (default: 1)", + ) + parser.add_argument( + "--mimic-root", type=str, + default="/srv/local/data/physionet.org/files/mimiciv", + help="Path to MIMIC-IV root directory (containing 2.2/ subdirectory)", + ) + parser.add_argument( + "--cache-dir", type=str, default="/srv/local/data/johnwu3/meds_reader", + help="Directory for MEDS cache", + ) + parser.add_argument( + "--num-shards", type=int, default=100, + help="Number of shards for meds_etl_mimic (default: 100)", + ) + parser.add_argument( + "--num-proc", type=int, default=8, + help="Number of processes for meds_etl_mimic (default: 8)", + ) + parser.add_argument( + "--backend", type=str, default="polars", choices=["polars", "cpp"], + help="Backend for meds_etl_mimic (default: polars)", + ) + parser.add_argument( + "--force-reconvert", action="store_true", + help="Force reconversion even if MEDS database exists (recommended for benchmarking)", + ) + parser.add_argument( + "--skip-conversion", action="store_true", + help="Skip conversion entirely (for debugging only - NOT fair benchmarking)", + ) + parser.add_argument( + "--enable-memory-limit", action="store_true", + help="Enforce a hard memory limit via resource.setrlimit (Unix only)", + ) + parser.add_argument( + "--max-memory-gb", type=float, default=None, + help="Hard memory limit in GB (only used if --enable-memory-limit is set)", + ) + parser.add_argument( + "--output-csv", type=str, + default="benchmark_meds_reader_mortality_threads_sweep.csv", + help="Where to write per-run results as CSV", + ) + args = parser.parse_args() + + if args.repeats <= 0: + raise SystemExit("--repeats must be > 0") + + if args.enable_memory_limit: + if args.max_memory_gb is None: + raise SystemExit( + "When using --enable-memory-limit, you must also pass --max-memory-gb" + ) + set_memory_limit(args.max_memory_gb) + + # MEDS paths + # Use task-specific cache directories to avoid interference between tasks + meds_dir = f"{args.cache_dir}/mimic4_meds_mortality" + meds_reader_dir = f"{args.cache_dir}/mimic4_meds_reader_mortality" + + print("=" * 80) + print("BENCHMARK: meds_reader - Mortality Prediction (Thread Sweep)") + print(f"threads={args.threads} repeats={args.repeats}") + print(f"mimic_root: {args.mimic_root}") + print(f"backend: {args.backend}, num_proc: {args.num_proc}, num_shards: {args.num_shards}") + if args.skip_conversion: + print("WARNING: --skip-conversion is set. Conversion time will NOT be included.") + print(" This is NOT a fair comparison with PyHealth!") + print("=" * 80) + + tracker = PeakMemoryTracker(poll_interval_s=0.1) + tracker.start() + + total_start = time.time() + results: list[RunResult] = [] + + print(f"\n{'='*60}") + print("Running benchmark...") + print(f"{'='*60}") + + for t in args.threads: + for r in range(args.repeats): + tracker.reset() + run_start = time.time() + + # Step 0: Convert MIMIC-IV to MEDS format (part of total time) + # For fair comparison with PyHealth, we must include this conversion time + # since PyHealth's dataset loading includes parsing raw MIMIC-IV CSVs. + conversion = run_meds_conversion( + mimic_root=args.mimic_root, + meds_dir=meds_dir, + meds_reader_dir=meds_reader_dir, + num_shards=args.num_shards, + num_proc=args.num_proc, + backend=args.backend, + force_reconvert=args.force_reconvert and r == 0, # Only reconvert on first repeat + skip_conversion=args.skip_conversion or r > 0, # Reuse on subsequent repeats + ) + + print(f"\n threads={t} repeat={r + 1}/{args.repeats}: Processing task...") + task_start = time.time() + + # Step 1: Extract samples using meds_reader (parallel) + samples = [] + with meds_reader.SubjectDatabase(meds_reader_dir, num_threads=t) as database: + for s in database.map(get_mortality_samples): + samples.extend(s) + + # Step 2: Build vocabularies (matching PyHealth's processor.fit()) + conditions_processor = SequenceProcessor() + procedures_processor = SequenceProcessor() + labs_processor = SequenceProcessor() + label_processor = BinaryLabelProcessor() + + conditions_processor.fit(samples, "conditions") + procedures_processor.fit(samples, "procedures") + labs_processor.fit(samples, "labs") + label_processor.fit(samples, "label") + + # Step 3: Tokenize samples (matching PyHealth's processor.process()) + processed_samples = [] + for sample in samples: + processed_sample = { + "visit_id": sample["visit_id"], + "patient_id": sample["patient_id"], + "conditions": conditions_processor.process(sample["conditions"]), + "procedures": procedures_processor.process(sample["procedures"]), + "labs": labs_processor.process(sample["labs"]), + "label": label_processor.process(sample["label"]), + } + processed_samples.append(processed_sample) + + # Step 4: Wrap in PyTorch Dataset for model training compatibility + dataset = MedsReaderSampleDataset( + samples=processed_samples, + input_schema={ + "conditions": "sequence", + "procedures": "sequence", + "labs": "sequence", + }, + output_schema={"label": "binary"}, + input_processors={ + "conditions": conditions_processor, + "procedures": procedures_processor, + "labs": labs_processor, + }, + output_processors={"label": label_processor}, + dataset_name="MIMIC-IV", + task_name="MortalityPrediction", + ) + + task_process_s = time.time() - task_start + total_s = time.time() - run_start + peak_rss_bytes = tracker.peak_bytes() + num_samples = len(dataset) + + results.append( + RunResult( + num_threads=t, + repeat_index=r, + meds_etl_s=conversion.meds_etl_s, + meds_reader_convert_s=conversion.meds_reader_convert_s, + task_process_s=task_process_s, + total_s=total_s, + peak_rss_bytes=peak_rss_bytes, + num_samples=num_samples, + conversion_cached=conversion.was_cached, + ) + ) + + # Build output message + timing_str = f"task={task_process_s:.2f}s" + if not conversion.was_cached: + timing_str = ( + f"meds_etl={conversion.meds_etl_s:.2f}s " + f"convert={conversion.meds_reader_convert_s:.2f}s " + + timing_str + f" total={total_s:.2f}s" + ) + + print( + f" ✓ threads={t:>2} repeat={r + 1:>2}/{args.repeats} " + f"samples={num_samples} " + f"{timing_str} " + f"peak_rss={format_size(peak_rss_bytes)} " + f"vocab_sizes=({conditions_processor.size()},{procedures_processor.size()},{labs_processor.size()})" + ) + + total_sweep_s = time.time() - total_start + + # Write CSV + out_csv = Path(args.output_csv) + out_csv.parent.mkdir(parents=True, exist_ok=True) + with out_csv.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=list(asdict(results[0]).keys())) + writer.writeheader() + for rr in results: + writer.writerow(asdict(rr)) + + # Print summary + print("\n" + "=" * 80) + print("SUMMARY (median across repeats)") + print("=" * 80) + + # Check if any results have conversion times + has_conversion = any(not rr.conversion_cached for rr in results) + + if has_conversion: + print("\n NOTE: Conversion time included for fair comparison with PyHealth.") + print(" PyHealth's dataset_load_s ≈ meds_etl_s + meds_reader_convert_s") + else: + print("\n WARNING: Conversion was cached. For fair benchmarking, use --force-reconvert") + + print() + for t in args.threads: + trs = [rr for rr in results if rr.num_threads == t] + med_task = median([rr.task_process_s for rr in trs]) + med_total = median([rr.total_s for rr in trs]) + med_peak = median([float(rr.peak_rss_bytes) for rr in trs]) + + # Get conversion times (from first repeat which has them if --force-reconvert) + first_run = [rr for rr in trs if rr.repeat_index == 0][0] + + if not first_run.conversion_cached: + print( + f"threads={t:>2} " + f"meds_etl={first_run.meds_etl_s:>7.2f}s " + f"convert={first_run.meds_reader_convert_s:>7.2f}s " + f"task_med={med_task:>7.2f}s " + f"total={med_total:>7.2f}s " + f"peak_rss={format_size(int(med_peak)):>10}" + ) + else: + print( + f"threads={t:>2} " + f"task_med={med_task:>8.2f}s " + f"(conversion cached) " + f"peak_rss_med={format_size(int(med_peak)):>10}" + ) + + print("\nArtifacts:") + print(f" - CSV: {out_csv}") + print(f" - MEDS database: {meds_reader_dir}") + print("\nTotals:") + print(f" - Sweep wall time: {total_sweep_s:.2f}s") + + # Print comparison note + print("\nFor comparison with PyHealth:") + print(" PyHealth total_s = dataset_load_s + task_process_s") + print(" meds_reader total_s = meds_etl_s + meds_reader_convert_s + task_process_s") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/benchmark_perf/meds_reader_ver/benchmark_meds_reader_mortality_pyhealth_etl.py b/examples/benchmark_perf/meds_reader_ver/benchmark_meds_reader_mortality_pyhealth_etl.py new file mode 100644 index 000000000..89e9c9fe5 --- /dev/null +++ b/examples/benchmark_perf/meds_reader_ver/benchmark_meds_reader_mortality_pyhealth_etl.py @@ -0,0 +1,746 @@ +"""Benchmark script for MIMIC-IV mortality prediction using meds_reader. + +This is a FALLBACK version that uses PyHealth 1.1.6 for ETL instead of meds_etl_mimic. +Use this if meds_etl_mimic fails to run properly. + +Pipeline: +1. Load MIMIC-IV data using PyHealth 1.1.6 (MIMIC4Dataset) +2. Convert PyHealth data structures to MEDS format (parquet files) +3. Run meds_reader_convert to create meds_reader database +4. Process the task using meds_reader +5. Return samples in a PyTorch-compatible Dataset + +IMPORTANT: For fair comparison with PyHealth, conversion time MUST be included. + +Typical usage: + # First install dependencies: + pip install pyhealth==1.1.6 meds_reader pyarrow + + # Run benchmark: + python benchmark_meds_reader_mortality_pyhealth_etl.py + python benchmark_meds_reader_mortality_pyhealth_etl.py --threads 1,4,8,12,16 +""" + +from __future__ import annotations + +import argparse +import collections +import csv +import datetime +import os +import shutil +import subprocess +import threading +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, Iterator, List + +import numpy as np +import psutil +import pyarrow as pa +import pyarrow.parquet as pq +import torch +from torch.utils.data import Dataset + +try: + import meds_reader +except ImportError: + raise ImportError( + "meds_reader not found. Install with: pip install meds_reader\n" + "Or from source: pip install -e /path/to/meds_reader" + ) + +# Import PyHealth 1.1.6 +try: + from pyhealth.datasets import MIMIC4Dataset +except ImportError: + raise ImportError( + "PyHealth not found. Install with: pip install pyhealth==1.1.6" + ) + + +# ============================================================================= +# PyTorch Dataset Wrapper +# ============================================================================= + +class MedsReaderSampleDataset(Dataset): + """PyTorch Dataset wrapper for meds_reader samples.""" + + def __init__( + self, + samples: List[Dict[str, Any]], + input_schema: Dict[str, str], + output_schema: Dict[str, str], + input_processors: Dict[str, Any], + output_processors: Dict[str, Any], + dataset_name: str = "", + task_name: str = "", + ): + self.samples = samples + self.input_schema = input_schema + self.output_schema = output_schema + self.input_processors = input_processors + self.output_processors = output_processors + self.dataset_name = dataset_name + self.task_name = task_name + + self.patient_to_index: Dict[Any, List[int]] = collections.defaultdict(list) + self.record_to_index: Dict[Any, List[int]] = collections.defaultdict(list) + + for idx, sample in enumerate(samples): + if "patient_id" in sample: + self.patient_to_index[sample["patient_id"]].append(idx) + if "visit_id" in sample: + self.record_to_index[sample["visit_id"]].append(idx) + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, index: int) -> Dict[str, Any]: + return self.samples[index] + + def __repr__(self) -> str: + return f"MedsReaderSampleDataset({self.dataset_name}, {self.task_name}, n={len(self)})" + + +# ============================================================================= +# Processor Classes +# ============================================================================= + +class SequenceProcessor: + """Matches PyHealth's SequenceProcessor for vocabulary building.""" + + def __init__(self): + self.code_vocab = {"": 0} + self._next_index = 1 + + def fit(self, samples, field): + for sample in samples: + if field not in sample: + continue + for token in sample[field]: + if token is None: + continue + if token not in self.code_vocab: + self.code_vocab[token] = self._next_index + self._next_index += 1 + self.code_vocab[""] = len(self.code_vocab) + + def process(self, value): + indices = [] + for token in value: + if token in self.code_vocab: + indices.append(self.code_vocab[token]) + else: + indices.append(self.code_vocab[""]) + return torch.tensor(indices, dtype=torch.long) + + def size(self): + return len(self.code_vocab) + + +class BinaryLabelProcessor: + """Processor for binary labels.""" + + def __init__(self): + self.label_vocab = {0: 0, 1: 1} + + def fit(self, samples, field): + for sample in samples: + if field in sample: + val = sample[field] + if val not in self.label_vocab: + self.label_vocab[val] = len(self.label_vocab) + + def process(self, value): + return torch.tensor([self.label_vocab.get(value, 0)], dtype=torch.float32) + + def size(self): + return 1 + + +# Lab item IDs for StageNet (matching PyHealth's implementation) +LAB_ITEM_IDS = { + "50824", "52455", "50983", "52623", # Sodium + "50822", "52452", "50971", "52610", # Potassium + "50806", "52434", "50902", "52535", # Chloride + "50803", "50804", # Bicarbonate + "50809", "52027", "50931", "52569", # Glucose + "50808", "51624", # Calcium + "50960", # Magnesium + "50868", "52500", # Anion Gap + "52031", "50964", "51701", # Osmolality + "50970", # Phosphate +} + + +# ============================================================================= +# Data Conversion (PyHealth 1.1.6 -> MEDS -> meds_reader) +# ============================================================================= + +def pyhealth_to_meds( + pyhealth_root: str, + output_dir: str, + tables: List[str], + dev: bool = False, + num_shards: int = 100, +) -> float: + """Convert MIMIC-IV data via PyHealth 1.1.6 to MEDS format.""" + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + print(" Loading MIMIC-IV via PyHealth 1.1.6...") + print(f" Root: {pyhealth_root}") + print(f" Tables: {tables}") + print(f" Dev mode: {dev}") + + start = time.time() + + dataset = MIMIC4Dataset( + root=pyhealth_root, + tables=tables, + dev=dev, + refresh_cache=True, + ) + + pyhealth_load_time = time.time() - start + print(f" PyHealth load completed in {pyhealth_load_time:.2f}s") + + print(" Converting to MEDS format...") + convert_start = time.time() + + results = collections.defaultdict(list) + + for patient_id, patient in dataset.patients.items(): + subject_id = int(patient_id) + + # Birth event + if patient.birth_datetime is not None: + birth_obj = { + 'subject_id': subject_id, + 'code': 'meds/birth', + 'time': patient.birth_datetime, + } + if hasattr(patient, 'gender') and patient.gender: + birth_obj['gender'] = patient.gender + if hasattr(patient, 'ethnicity') and patient.ethnicity: + birth_obj['ethnicity'] = patient.ethnicity + results[subject_id].append(birth_obj) + + # Death event + if patient.death_datetime is not None: + results[subject_id].append({ + 'subject_id': subject_id, + 'code': 'meds/death', + 'time': patient.death_datetime, + }) + + # Process visits + for visit_id, visit in patient.visits.items(): + visit_id_int = int(visit_id) + + visit_event = { + 'subject_id': subject_id, + 'code': 'MIMIC_IV_Admission/unknown', + 'time': visit.encounter_time, + 'visit_id': visit_id_int, + } + if visit.discharge_time: + visit_event['end'] = visit.discharge_time + if hasattr(visit, 'discharge_status'): + visit_event['discharge_status'] = visit.discharge_status + + results[subject_id].append(visit_event) + + for table in visit.available_tables: + for event in visit.get_event_list(table): + event_obj = { + 'subject_id': subject_id, + 'visit_id': visit_id_int, + 'code': f'{event.vocabulary}/{event.code}', + 'time': event.timestamp or visit.discharge_time, + } + + if hasattr(event, 'attr_dict') and event.attr_dict: + for k, v in event.attr_dict.items(): + if v == v: # Skip NaN + event_obj[k] = v + + results[subject_id].append(event_obj) + + results[subject_id].sort( + key=lambda a: a['time'] if a['time'] else datetime.datetime.min + ) + + # Write to parquet shards + os.makedirs(output_dir, exist_ok=True) + os.makedirs(f"{output_dir}/metadata", exist_ok=True) + os.makedirs(f"{output_dir}/data", exist_ok=True) + + all_subjects = list(results.keys()) + subject_ids_per_shard = np.array_split(all_subjects, num_shards) + + attr_map = { + str: pa.string(), + int: pa.int64(), + np.int64: pa.int64(), + float: pa.float64(), + datetime.datetime: pa.timestamp('us'), + } + + attr_schema = {} + for subject_values in results.values(): + for row in subject_values: + for k, v in row.items(): + if k not in {'subject_id', 'time'} and v is not None: + pa_type = attr_map.get(type(v), pa.string()) + if k not in attr_schema: + attr_schema[k] = pa_type + + schema = pa.schema([ + ('subject_id', pa.int64()), + ('time', pa.timestamp('us')), + ] + [(k, v) for k, v in sorted(attr_schema.items())]) + + for i, subject_ids in enumerate(subject_ids_per_shard): + if len(subject_ids) == 0: + continue + rows = [v for subject_id in subject_ids for v in results[subject_id]] + if rows: + table = pa.Table.from_pylist(rows, schema=schema) + pq.write_table(table, f"{output_dir}/data/{i}.parquet") + + convert_time = time.time() - convert_start + total_time = time.time() - start + + print(f" MEDS conversion completed in {convert_time:.2f}s") + print(f" Total PyHealth ETL time: {total_time:.2f}s") + + return total_time + + +def run_meds_reader_convert( + input_dir: str, output_dir: str, num_threads: int = 10 +) -> float: + """Run meds_reader_convert CLI tool.""" + print(f" Running meds_reader_convert (threads={num_threads})...") + print(f" {input_dir} -> {output_dir}") + + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + start = time.time() + try: + subprocess.run( + ["meds_reader_convert", input_dir, output_dir, + "--num_threads", str(num_threads)], + capture_output=True, + text=True, + check=True, + ) + elapsed = time.time() - start + print(f" meds_reader_convert completed in {elapsed:.2f}s") + return elapsed + except subprocess.CalledProcessError as e: + print(" ERROR: meds_reader_convert failed:") + print(f" stdout: {e.stdout}") + print(f" stderr: {e.stderr}") + raise + except FileNotFoundError: + print(" ERROR: meds_reader_convert not found in PATH") + raise + + +@dataclass +class ConversionResult: + """Holds timing information for the conversion process.""" + pyhealth_etl_s: float + meds_reader_convert_s: float + total_conversion_s: float + was_cached: bool + + +def run_pyhealth_meds_conversion( + pyhealth_root: str, + meds_dir: str, + meds_reader_dir: str, + tables: List[str], + dev: bool, + num_shards: int, + num_threads: int, + force_reconvert: bool, + skip_conversion: bool, +) -> ConversionResult: + """Run PyHealth-based MEDS conversion.""" + + if skip_conversion: + if not Path(meds_reader_dir).exists(): + raise SystemExit( + f"Cannot skip conversion: MEDS database does not exist at " + f"{meds_reader_dir}\nRun without --skip-conversion first." + ) + print("✓ Skipping conversion (using cached MEDS database)") + return ConversionResult(0.0, 0.0, 0.0, True) + + if Path(meds_reader_dir).exists() and not force_reconvert: + print(f"✓ MEDS database exists: {meds_reader_dir}") + return ConversionResult(0.0, 0.0, 0.0, True) + + print("\n" + "=" * 60) + print("Converting MIMIC-IV to MEDS format via PyHealth 1.1.6") + print("=" * 60) + + if Path(meds_dir).exists(): + print(f" Clearing existing MEDS cache: {meds_dir}") + shutil.rmtree(meds_dir) + if Path(meds_reader_dir).exists(): + print(f" Clearing existing meds_reader cache: {meds_reader_dir}") + shutil.rmtree(meds_reader_dir) + + print("\n[Step 1/2] Loading via PyHealth and converting to MEDS...") + pyhealth_etl_s = pyhealth_to_meds( + pyhealth_root=pyhealth_root, + output_dir=meds_dir, + tables=tables, + dev=dev, + num_shards=num_shards, + ) + + print("\n[Step 2/2] Running meds_reader_convert...") + meds_reader_convert_s = run_meds_reader_convert( + meds_dir, meds_reader_dir, num_threads=num_threads + ) + + total = pyhealth_etl_s + meds_reader_convert_s + print(f"\n✓ MEDS database ready. Total conversion: {total:.2f}s") + + return ConversionResult(pyhealth_etl_s, meds_reader_convert_s, total, False) + + +# ============================================================================= +# Task Function - Mortality Prediction +# ============================================================================= + +def get_mortality_samples(subjects: Iterator[meds_reader.Subject]): + """Process subjects for mortality prediction with lab events.""" + samples = [] + + for subject in subjects: + admissions = {} + death_time = None + + for event in subject.events: + if event.code == "meds/death": + death_time = event.time + break + + for event in subject.events: + if event.code.startswith("MIMIC_IV_Admission/"): + visit_id = getattr(event, 'visit_id', None) + end_time = getattr(event, 'end', None) + if visit_id is not None and event.time is not None: + discharge_status = 0 + if death_time is not None and end_time is not None: + if death_time <= end_time: + discharge_status = 1 + + admissions[visit_id] = { + 'time': event.time, + 'end': end_time, + 'conditions': set(), + 'procedures': set(), + 'labs': set(), + 'discharge_status': discharge_status, + } + + for event in subject.events: + visit_id = getattr(event, 'visit_id', None) + if visit_id is None or visit_id not in admissions: + continue + + code = event.code + if code.startswith("ICD"): + if "CM" in code: + admissions[visit_id]['conditions'].add(code) + else: + admissions[visit_id]['procedures'].add(code) + elif "LABITEM" in code or code.startswith("MIMIC_IV_LABITEM/"): + item_id = code.split("/")[-1] if "/" in code else "" + if item_id in LAB_ITEM_IDS: + admissions[visit_id]['labs'].add(code) + + sorted_visits = sorted( + [(vid, data) for vid, data in admissions.items()], + key=lambda x: x[1]['time'] + ) + + for i in range(len(sorted_visits) - 1): + visit_id, current = sorted_visits[i] + _, next_visit = sorted_visits[i + 1] + + conditions = list(current['conditions']) + procedures = list(current['procedures']) + labs = list(current['labs']) + mortality_label = next_visit['discharge_status'] + + if len(conditions) == 0 or len(labs) == 0: + continue + + samples.append({ + "visit_id": visit_id, + "patient_id": subject.subject_id, + "conditions": conditions, + "procedures": procedures, + "labs": labs, + "label": mortality_label, + }) + + return samples + + +# ============================================================================= +# Benchmark Infrastructure +# ============================================================================= + +@dataclass +class RunResult: + num_threads: int + repeat_index: int + pyhealth_etl_s: float + meds_reader_convert_s: float + task_process_s: float + total_s: float + peak_rss_bytes: int + num_samples: int + conversion_cached: bool + + +def format_size(size_bytes: int) -> str: + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} PB" + + +class PeakMemoryTracker: + def __init__(self, poll_interval_s: float = 0.1): + self._proc = psutil.Process(os.getpid()) + self._poll_interval_s = poll_interval_s + self._stop = threading.Event() + self._lock = threading.Lock() + self._peak = 0 + self._thread = threading.Thread(target=self._run, daemon=True) + + def start(self): + self._thread.start() + + def reset(self): + with self._lock: + self._peak = 0 + + def peak_bytes(self) -> int: + with self._lock: + return self._peak + + def _total_rss_bytes(self) -> int: + total = 0 + try: + total += self._proc.memory_info().rss + for child in self._proc.children(recursive=True): + try: + total += child.memory_info().rss + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + return total + + def _run(self): + while not self._stop.is_set(): + rss = self._total_rss_bytes() + with self._lock: + if rss > self._peak: + self._peak = rss + time.sleep(self._poll_interval_s) + + +def parse_threads(value: str) -> List[int]: + parts = [p.strip() for p in value.split(",") if p.strip()] + return [int(p) for p in parts if int(p) > 0] + + +def median(values: Iterable[float]) -> float: + xs = sorted(values) + if not xs: + return 0.0 + mid = len(xs) // 2 + return xs[mid] if len(xs) % 2 == 1 else (xs[mid - 1] + xs[mid]) / 2.0 + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark meds_reader Mortality using PyHealth 1.1.6 ETL" + ) + parser.add_argument( + "--threads", type=parse_threads, default=[1, 4, 8, 12, 16], + help="Comma-separated list of thread counts", + ) + parser.add_argument("--repeats", type=int, default=1) + parser.add_argument( + "--pyhealth-root", type=str, + default="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp", + help="Path to MIMIC-IV hosp directory (for PyHealth 1.1.6)", + ) + parser.add_argument("--cache-dir", type=str, default="/srv/local/data/johnwu3/meds_reader") + parser.add_argument("--num-shards", type=int, default=100) + parser.add_argument("--num-threads", type=int, default=8) + parser.add_argument("--dev", action="store_true") + parser.add_argument("--force-reconvert", action="store_true") + parser.add_argument("--skip-conversion", action="store_true") + parser.add_argument( + "--output-csv", type=str, + default="benchmark_meds_reader_mortality_pyhealth_etl.csv", + ) + args = parser.parse_args() + + meds_dir = f"{args.cache_dir}/mimic4_meds_mortality_pyhealth" + meds_reader_dir = f"{args.cache_dir}/mimic4_meds_reader_mortality_pyhealth" + + print("=" * 80) + print("BENCHMARK: meds_reader Mortality (PyHealth 1.1.6 ETL - Fallback)") + print(f"threads={args.threads} repeats={args.repeats} dev={args.dev}") + print(f"pyhealth_root: {args.pyhealth_root}") + print("=" * 80) + + tracker = PeakMemoryTracker() + tracker.start() + + total_start = time.time() + results: List[RunResult] = [] + + # Tables needed for mortality task + tables = ["diagnoses_icd", "procedures_icd", "labevents"] + + for t in args.threads: + for r in range(args.repeats): + tracker.reset() + run_start = time.time() + + conversion = run_pyhealth_meds_conversion( + pyhealth_root=args.pyhealth_root, + meds_dir=meds_dir, + meds_reader_dir=meds_reader_dir, + tables=tables, + dev=args.dev, + num_shards=args.num_shards, + num_threads=args.num_threads, + force_reconvert=args.force_reconvert and r == 0, + skip_conversion=args.skip_conversion or r > 0, + ) + + print(f"\n threads={t} repeat={r + 1}/{args.repeats}: Processing...") + task_start = time.time() + + samples = [] + with meds_reader.SubjectDatabase(meds_reader_dir, num_threads=t) as db: + for s in db.map(get_mortality_samples): + samples.extend(s) + + conditions_proc = SequenceProcessor() + procedures_proc = SequenceProcessor() + labs_proc = SequenceProcessor() + label_proc = BinaryLabelProcessor() + + conditions_proc.fit(samples, "conditions") + procedures_proc.fit(samples, "procedures") + labs_proc.fit(samples, "labs") + label_proc.fit(samples, "label") + + processed = [] + for sample in samples: + processed.append({ + "visit_id": sample["visit_id"], + "patient_id": sample["patient_id"], + "conditions": conditions_proc.process(sample["conditions"]), + "procedures": procedures_proc.process(sample["procedures"]), + "labs": labs_proc.process(sample["labs"]), + "label": label_proc.process(sample["label"]), + }) + + dataset = MedsReaderSampleDataset( + samples=processed, + input_schema={ + "conditions": "sequence", + "procedures": "sequence", + "labs": "sequence", + }, + output_schema={"label": "binary"}, + input_processors={ + "conditions": conditions_proc, + "procedures": procedures_proc, + "labs": labs_proc, + }, + output_processors={"label": label_proc}, + dataset_name="MIMIC-IV", + task_name="MortalityPrediction", + ) + + task_process_s = time.time() - task_start + total_s = time.time() - run_start + peak_rss = tracker.peak_bytes() + + results.append(RunResult( + num_threads=t, + repeat_index=r, + pyhealth_etl_s=conversion.pyhealth_etl_s, + meds_reader_convert_s=conversion.meds_reader_convert_s, + task_process_s=task_process_s, + total_s=total_s, + peak_rss_bytes=peak_rss, + num_samples=len(dataset), + conversion_cached=conversion.was_cached, + )) + + timing = f"task={task_process_s:.2f}s" + if not conversion.was_cached: + timing = (f"pyhealth_etl={conversion.pyhealth_etl_s:.2f}s " + f"convert={conversion.meds_reader_convert_s:.2f}s " + f"{timing} total={total_s:.2f}s") + + print(f" ✓ threads={t:>2} samples={len(dataset)} {timing} " + f"peak_rss={format_size(peak_rss)}") + + total_sweep_s = time.time() - total_start + + out_csv = Path(args.output_csv) + out_csv.parent.mkdir(parents=True, exist_ok=True) + with out_csv.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=list(asdict(results[0]).keys())) + writer.writeheader() + for rr in results: + writer.writerow(asdict(rr)) + + print("\n" + "=" * 80) + print("SUMMARY") + print("=" * 80) + for t in args.threads: + trs = [rr for rr in results if rr.num_threads == t] + med_task = median([rr.task_process_s for rr in trs]) + first = [rr for rr in trs if rr.repeat_index == 0][0] + if not first.conversion_cached: + print(f"threads={t:>2} pyhealth_etl={first.pyhealth_etl_s:.2f}s " + f"convert={first.meds_reader_convert_s:.2f}s " + f"task_med={med_task:.2f}s") + else: + print(f"threads={t:>2} task_med={med_task:.2f}s (cached)") + + print(f"\nSweep time: {total_sweep_s:.2f}s") + print(f"CSV: {out_csv}") + print("=" * 80) + + +if __name__ == "__main__": + main() + diff --git a/examples/benchmark_perf/patient_exploration/benchmark_patient_access_legacy.py b/examples/benchmark_perf/patient_exploration/benchmark_patient_access_legacy.py new file mode 100644 index 000000000..50a280797 --- /dev/null +++ b/examples/benchmark_perf/patient_exploration/benchmark_patient_access_legacy.py @@ -0,0 +1,316 @@ +"""Benchmark: PyHealth 1.1.6 (Legacy) - Data Loading & Single Patient Access + +Measures: +1. Time to load/initialize the dataset from raw MIMIC-IV data +2. Time to access a single patient after loading +3. Total time (load + access) + +Usage: + # Activate legacy environment first + pip install pyhealth==1.1.6 + + # Run benchmark + python benchmark_patient_access_legacy.py + python benchmark_patient_access_legacy.py --patient-id 10014729 --workers 8 + python benchmark_patient_access_legacy.py --dev +""" + +from __future__ import annotations + +import argparse +import csv +import os +import threading +import time +from dataclasses import asdict, dataclass +from pathlib import Path + +import psutil + +# Import pandarallel - will be initialized before each run +from pandarallel import pandarallel + +# Global variable to store desired worker count for monkey-patching +_DESIRED_NB_WORKERS: int = 8 + +# Store the original pandarallel.initialize function +_original_pandarallel_initialize = pandarallel.initialize + + +def _patched_pandarallel_initialize(*args, **kwargs): + """Patched pandarallel.initialize that enforces our worker count.""" + kwargs['nb_workers'] = _DESIRED_NB_WORKERS + return _original_pandarallel_initialize(*args, **kwargs) + + +# Apply the monkey-patch +pandarallel.initialize = _patched_pandarallel_initialize + +# Legacy PyHealth 1.1.6 imports (AFTER monkey-patching) +from pyhealth.datasets import MIMIC4Dataset +from pyhealth.datasets.utils import MODULE_CACHE_PATH + + +# ============================================================================= +# Benchmark Result +# ============================================================================= + +@dataclass +class BenchmarkResult: + approach: str + data_load_s: float + patient_access_1st_s: float # First access (cold cache) + patient_access_2nd_s: float # Second access (warm cache) + total_s: float + peak_rss_bytes: int + patient_found: bool + num_events: int + num_visits: int + + +def format_size(size_bytes: int) -> str: + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} PB" + + +class PeakMemoryTracker: + """Tracks peak RSS for current process + children.""" + + def __init__(self, poll_interval_s: float = 0.05) -> None: + self._proc = psutil.Process(os.getpid()) + self._poll_interval_s = poll_interval_s + self._stop = threading.Event() + self._lock = threading.Lock() + self._peak = 0 + self._thread = threading.Thread(target=self._run, daemon=True) + + def start(self) -> None: + self._thread.start() + + def reset(self) -> None: + with self._lock: + self._peak = 0 + + def stop(self) -> None: + self._stop.set() + + def peak_bytes(self) -> int: + with self._lock: + return self._peak + + def _total_rss_bytes(self) -> int: + total = 0 + try: + total += self._proc.memory_info().rss + for child in self._proc.children(recursive=True): + try: + total += child.memory_info().rss + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + return total + + def _run(self) -> None: + while not self._stop.is_set(): + rss = self._total_rss_bytes() + with self._lock: + if rss > self._peak: + self._peak = rss + time.sleep(self._poll_interval_s) + + +def clear_pyhealth_cache(verbose: bool = True) -> int: + """Clear all PyHealth cache files.""" + cache_path = Path(MODULE_CACHE_PATH) + if not cache_path.exists(): + return 0 + + deleted_count = 0 + total_size = 0 + + for cache_file in cache_path.glob("*.pkl"): + try: + file_size = cache_file.stat().st_size + cache_file.unlink() + deleted_count += 1 + total_size += file_size + except OSError as e: + if verbose: + print(f" Warning: Could not delete {cache_file}: {e}") + + if verbose and deleted_count > 0: + print(f" Cleared {deleted_count} cache files ({format_size(total_size)})") + + return deleted_count + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Benchmark PyHealth 1.1.6 data loading and single patient access" + ) + parser.add_argument( + "--patient-id", + type=str, + default="10014729", + help="Patient ID to access (default: 10014729)", + ) + parser.add_argument( + "--workers", + type=int, + default=8, + help="Number of workers (default: 8)", + ) + parser.add_argument( + "--dev", + action="store_true", + help="Use dev mode (smaller subset)", + ) + parser.add_argument( + "--root", + type=str, + default="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp", + help="Path to MIMIC-IV hosp directory", + ) + parser.add_argument( + "--no-clear-cache", + action="store_true", + help="Do not clear cache before benchmark", + ) + parser.add_argument( + "--output-csv", + type=str, + default="benchmark_patient_access_legacy.csv", + help="Output CSV file", + ) + args = parser.parse_args() + + # Set worker count + global _DESIRED_NB_WORKERS + _DESIRED_NB_WORKERS = args.workers + + print("=" * 80) + print("BENCHMARK: PyHealth 1.1.6 (Legacy) - Data Loading & Patient Access") + print("=" * 80) + print(f"Patient ID: {args.patient_id}") + print(f"Workers: {args.workers}") + print(f"Dev mode: {args.dev}") + print(f"Root: {args.root}") + print(f"Cache path: {MODULE_CACHE_PATH}") + print("=" * 80) + + # Clear cache + if not args.no_clear_cache: + print("\nClearing PyHealth cache...") + clear_pyhealth_cache(verbose=True) + + tracker = PeakMemoryTracker(poll_interval_s=0.05) + tracker.start() + tracker.reset() + + # Step 1: Load dataset + print("\n[Step 1] Loading dataset...") + load_start = time.time() + + dataset = MIMIC4Dataset( + root=args.root, + tables=["diagnoses_icd", "procedures_icd", "labevents"], + dev=args.dev, + refresh_cache=True, + ) + + data_load_s = time.time() - load_start + print(f" Dataset loaded in {data_load_s:.2f}s") + print(f" Number of patients: {len(dataset.patients)}") + + # Step 2: First patient access (cold cache) + print(f"\n[Step 2] First access to patient {args.patient_id} (cold cache)...") + access_1_start = time.time() + + patient_dict = dataset.patients + patient_found = args.patient_id in patient_dict + + if patient_found: + patient = patient_dict[args.patient_id] + # Count events across all visits + num_events = 0 + num_visits = len(patient.visits) + for visit in patient.visits.values(): + for table in visit.available_tables: + num_events += len(visit.get_event_list(table)) + print(f" Patient found!") + print(f" Number of visits: {num_visits}") + print(f" Number of events: {num_events}") + else: + num_events = 0 + num_visits = 0 + print(f" Patient NOT found!") + # List available patient IDs (first 10) + available_ids = list(patient_dict.keys())[:10] + print(f" Available patient IDs (first 10): {available_ids}") + + patient_access_1st_s = time.time() - access_1_start + + # Step 3: Second patient access (warm cache) + print(f"\n[Step 3] Second access to patient {args.patient_id} (warm cache)...") + access_2_start = time.time() + + if patient_found: + patient = patient_dict[args.patient_id] + # Re-count events to ensure we're actually accessing the data + count = 0 + for visit in patient.visits.values(): + for table in visit.available_tables: + count += len(visit.get_event_list(table)) + + patient_access_2nd_s = time.time() - access_2_start + + total_s = data_load_s + patient_access_1st_s + patient_access_2nd_s + peak_rss = tracker.peak_bytes() + + tracker.stop() + + result = BenchmarkResult( + approach="pyhealth_1.1.6", + data_load_s=data_load_s, + patient_access_1st_s=patient_access_1st_s, + patient_access_2nd_s=patient_access_2nd_s, + total_s=total_s, + peak_rss_bytes=peak_rss, + patient_found=patient_found, + num_events=num_events, + num_visits=num_visits, + ) + + # Summary + print("\n" + "=" * 80) + print("SUMMARY: PyHealth 1.1.6 (Legacy)") + print("=" * 80) + access_1_str = f"{patient_access_1st_s*1000:.2f}ms" if patient_access_1st_s < 1 else f"{patient_access_1st_s:.2f}s" + access_2_str = f"{patient_access_2nd_s*1000:.2f}ms" if patient_access_2nd_s < 1 else f"{patient_access_2nd_s:.2f}s" + print(f" Data load time: {data_load_s:.2f}s") + print(f" Patient access (1st/cold): {access_1_str}") + print(f" Patient access (2nd/warm): {access_2_str}") + print(f" Total time: {total_s:.2f}s") + print(f" Peak RSS: {format_size(peak_rss)}") + print(f" Patient found: {patient_found}") + print(f" Visits: {num_visits}") + print(f" Events: {num_events}") + + # Write CSV + out_csv = Path(args.output_csv) + out_csv.parent.mkdir(parents=True, exist_ok=True) + with out_csv.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=list(asdict(result).keys())) + writer.writeheader() + writer.writerow(asdict(result)) + print(f"\nResults saved to: {out_csv}") + print("=" * 80) + + +if __name__ == "__main__": + main() + diff --git a/examples/benchmark_perf/patient_exploration/benchmark_patient_access_meds_reader.py b/examples/benchmark_perf/patient_exploration/benchmark_patient_access_meds_reader.py new file mode 100644 index 000000000..961615ddf --- /dev/null +++ b/examples/benchmark_perf/patient_exploration/benchmark_patient_access_meds_reader.py @@ -0,0 +1,404 @@ +"""Benchmark: meds_reader - Data Loading & Single Patient Access + +Measures: +1. Time to convert MIMIC-IV to MEDS format using meds_etl +2. Time to convert MEDS to meds_reader database +3. Time to access a single patient after loading +4. Total time (load + access) + +For meds_reader, "data loading" includes: +- meds_etl_mimic: Convert MIMIC-IV directly to MEDS format +- meds_reader_convert: Convert MEDS to meds_reader database + +Usage: + # Activate meds_reader environment (with meds_etl installed) + pip install meds_etl meds_reader + + # Run benchmark (uses existing DB if available) + python benchmark_patient_access_meds_reader.py + + # Force reconversion of database + python benchmark_patient_access_meds_reader.py --force-reconvert + + # Custom settings + python benchmark_patient_access_meds_reader.py --patient-id 10014729 --threads 8 +""" + +from __future__ import annotations + +import argparse +import csv +import os +import shutil +import subprocess +import threading +import time +from dataclasses import asdict, dataclass +from pathlib import Path + +import psutil + +try: + import meds_reader +except ImportError: + raise ImportError( + "meds_reader not found. Install with: pip install meds_reader\n" + "Or from source: pip install -e /path/to/meds_reader" + ) + + +# ============================================================================= +# Benchmark Result +# ============================================================================= + +@dataclass +class BenchmarkResult: + approach: str + data_load_s: float # Full conversion time (or 0 if using cached DB) + meds_etl_s: float # meds_etl_mimic conversion time + meds_reader_convert_s: float # meds_reader_convert time + patient_access_1st_s: float # First access (cold cache) + patient_access_2nd_s: float # Second access (warm cache) + total_s: float + peak_rss_bytes: int + patient_found: bool + num_events: int + used_cached_db: bool + + +def format_size(size_bytes: int) -> str: + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} PB" + + +class PeakMemoryTracker: + """Tracks peak RSS for current process + children.""" + + def __init__(self, poll_interval_s: float = 0.05) -> None: + self._proc = psutil.Process(os.getpid()) + self._poll_interval_s = poll_interval_s + self._stop = threading.Event() + self._lock = threading.Lock() + self._peak = 0 + self._thread = threading.Thread(target=self._run, daemon=True) + + def start(self) -> None: + self._thread.start() + + def reset(self) -> None: + with self._lock: + self._peak = 0 + + def stop(self) -> None: + self._stop.set() + + def peak_bytes(self) -> int: + with self._lock: + return self._peak + + def _total_rss_bytes(self) -> int: + total = 0 + try: + total += self._proc.memory_info().rss + for child in self._proc.children(recursive=True): + try: + total += child.memory_info().rss + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + return total + + def _run(self) -> None: + while not self._stop.is_set(): + rss = self._total_rss_bytes() + with self._lock: + if rss > self._peak: + self._peak = rss + time.sleep(self._poll_interval_s) + + +# ============================================================================= +# Data Conversion Functions +# ============================================================================= + +def run_meds_etl_mimic( + src_mimic: str, + output_dir: str, + num_shards: int = 100, + num_proc: int = 1, + backend: str = "polars", +) -> float: + """Run meds_etl_mimic to convert MIMIC-IV to MEDS format. + + Args: + src_mimic: Path to MIMIC-IV root (containing 2.2/ subdirectory) + output_dir: Path to output MEDS dataset + num_shards: Number of shards for processing + num_proc: Number of processes to use + backend: Backend to use (polars or cpp) + + Returns: + Time taken in seconds + """ + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + print(f" Running meds_etl_mimic (shards={num_shards}, proc={num_proc}, backend={backend})...") + print(f" Source: {src_mimic}") + print(f" Destination: {output_dir}") + + start = time.time() + result = subprocess.run( + [ + "meds_etl_mimic", + src_mimic, + output_dir, + "--num_shards", str(num_shards), + "--num_proc", str(num_proc), + "--backend", backend, + ], + capture_output=True, + text=True, + ) + elapsed = time.time() - start + + if result.returncode != 0: + print(f" STDOUT: {result.stdout}") + print(f" STDERR: {result.stderr}") + raise RuntimeError(f"meds_etl_mimic failed with code {result.returncode}") + + print(f" meds_etl_mimic completed in {elapsed:.2f}s") + return elapsed + + +def run_meds_reader_convert(input_dir: str, output_dir: str, num_threads: int = 10) -> float: + """Run meds_reader_convert. Returns time taken.""" + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + print(f" Running meds_reader_convert (threads={num_threads})...") + start = time.time() + result = subprocess.run( + ["meds_reader_convert", input_dir, output_dir, "--num_threads", str(num_threads)], + capture_output=True, text=True, + ) + elapsed = time.time() - start + + if result.returncode != 0: + print(f" ERROR: {result.stderr}") + raise RuntimeError(f"meds_reader_convert failed: {result.stderr}") + + print(f" meds_reader_convert completed in {elapsed:.2f}s") + return elapsed + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Benchmark meds_reader data loading and single patient access" + ) + parser.add_argument( + "--patient-id", + type=str, + default="10014729", + help="Patient ID to access (default: 10014729)", + ) + parser.add_argument( + "--threads", + type=int, + default=8, + help="Number of threads for meds_reader (default: 8)", + ) + parser.add_argument( + "--num-proc", + type=int, + default=8, + help="Number of processes for meds_etl_mimic (default: 8)", + ) + parser.add_argument( + "--num-shards", + type=int, + default=100, + help="Number of shards for meds_etl_mimic (default: 100)", + ) + parser.add_argument( + "--backend", + type=str, + default="polars", + choices=["polars", "cpp"], + help="Backend for meds_etl_mimic (default: polars)", + ) + parser.add_argument( + "--mimic-root", + type=str, + default="/srv/local/data/physionet.org/files/mimiciv", + help="Path to MIMIC-IV root directory (containing 2.2/ subdirectory)", + ) + parser.add_argument( + "--cache-dir", + type=str, + default="/shared/eng/pyhealth", + help="Cache directory for MEDS databases", + ) + parser.add_argument( + "--force-reconvert", + action="store_true", + help="Force reconversion even if database exists", + ) + parser.add_argument( + "--output-csv", + type=str, + default="benchmark_patient_access_meds_reader.csv", + help="Output CSV file", + ) + args = parser.parse_args() + + meds_dir = f"{args.cache_dir}/mimic4_meds" + meds_reader_dir = f"{args.cache_dir}/mimic4_meds_reader" + + print("=" * 80) + print("BENCHMARK: meds_reader - Data Loading & Patient Access") + print("=" * 80) + print(f"Patient ID: {args.patient_id}") + print(f"Threads: {args.threads}") + print(f"Num proc: {args.num_proc}") + print(f"Num shards: {args.num_shards}") + print(f"Backend: {args.backend}") + print(f"MIMIC root: {args.mimic_root}") + print(f"MEDS dir: {meds_dir}") + print(f"meds_reader dir: {meds_reader_dir}") + print("=" * 80) + + # Verify MIMIC-IV structure + mimic_version_path = os.path.join(args.mimic_root, "2.2") + if not os.path.exists(mimic_version_path): + print(f"\nWARNING: Expected MIMIC-IV version directory not found: {mimic_version_path}") + print("meds_etl_mimic expects the MIMIC-IV data to be in {mimic_root}/2.2/") + print("Please ensure the directory structure is correct.") + + tracker = PeakMemoryTracker(poll_interval_s=0.05) + tracker.start() + tracker.reset() + + # Step 1: Data loading (conversion if needed) + need_convert = args.force_reconvert or not Path(meds_reader_dir).exists() + used_cached_db = not need_convert + + meds_etl_s = 0.0 + meds_reader_convert_s = 0.0 + + if need_convert: + print("\n[Step 1] Converting MIMIC-IV -> MEDS -> meds_reader database...") + load_start = time.time() + + # Step 1a: meds_etl_mimic + meds_etl_s = run_meds_etl_mimic( + src_mimic=args.mimic_root, + output_dir=meds_dir, + num_shards=args.num_shards, + num_proc=args.num_proc, + backend=args.backend, + ) + + # Step 1b: meds_reader_convert + meds_reader_convert_s = run_meds_reader_convert( + meds_dir, meds_reader_dir, num_threads=args.threads + ) + + data_load_s = time.time() - load_start + else: + print("\n[Step 1] Using existing meds_reader database") + print(f" (use --force-reconvert to rebuild)") + data_load_s = 0.0 + + # Convert patient_id to integer for meds_reader + subject_id = int(args.patient_id) if args.patient_id.isdigit() else hash(args.patient_id) % (10**9) + + with meds_reader.SubjectDatabase(meds_reader_dir, num_threads=args.threads) as database: + print(f" Database opened with {len(database)} subjects") + + # Step 2: First patient access (cold cache) + print(f"\n[Step 2] First access to patient {args.patient_id} (cold cache)...") + access_1_start = time.time() + + try: + subject = database[subject_id] + patient_found = True + num_events = len(subject.events) + print(f" Patient found!") + print(f" Subject ID: {subject.subject_id}") + print(f" Number of events: {num_events}") + except KeyError: + patient_found = False + num_events = 0 + print(f" Patient NOT found with subject_id={subject_id}") + # List some available subject IDs + available_ids = list(database)[:10] + print(f" Available subject IDs (first 10): {available_ids}") + + patient_access_1st_s = time.time() - access_1_start + + # Step 3: Second patient access (warm cache) + print(f"\n[Step 3] Second access to patient {args.patient_id} (warm cache)...") + access_2_start = time.time() + + if patient_found: + subject = database[subject_id] + # Re-count events to ensure we're actually accessing the data + count = len(subject.events) + + patient_access_2nd_s = time.time() - access_2_start + + total_s = data_load_s + patient_access_1st_s + patient_access_2nd_s + peak_rss = tracker.peak_bytes() + + tracker.stop() + + result = BenchmarkResult( + approach="meds_reader", + data_load_s=data_load_s, + meds_etl_s=meds_etl_s, + meds_reader_convert_s=meds_reader_convert_s, + patient_access_1st_s=patient_access_1st_s, + patient_access_2nd_s=patient_access_2nd_s, + total_s=total_s, + peak_rss_bytes=peak_rss, + patient_found=patient_found, + num_events=num_events, + used_cached_db=used_cached_db, + ) + + # Summary + print("\n" + "=" * 80) + print("SUMMARY: meds_reader") + print("=" * 80) + access_1_str = f"{patient_access_1st_s*1000:.2f}ms" if patient_access_1st_s < 1 else f"{patient_access_1st_s:.2f}s" + access_2_str = f"{patient_access_2nd_s*1000:.2f}ms" if patient_access_2nd_s < 1 else f"{patient_access_2nd_s:.2f}s" + print(f" Used cached DB: {used_cached_db}") + if not used_cached_db: + print(f" meds_etl_mimic: {meds_etl_s:.2f}s") + print(f" meds_reader_convert: {meds_reader_convert_s:.2f}s") + print(f" Total data load: {data_load_s:.2f}s") + print(f" Patient access (1st/cold): {access_1_str}") + print(f" Patient access (2nd/warm): {access_2_str}") + print(f" Total time: {total_s:.2f}s") + print(f" Peak RSS: {format_size(peak_rss)}") + print(f" Patient found: {patient_found}") + print(f" Events: {num_events}") + + # Write CSV + out_csv = Path(args.output_csv) + out_csv.parent.mkdir(parents=True, exist_ok=True) + with out_csv.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=list(asdict(result).keys())) + writer.writeheader() + writer.writerow(asdict(result)) + print(f"\nResults saved to: {out_csv}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/benchmark_perf/patient_exploration/benchmark_patient_access_pandas.py b/examples/benchmark_perf/patient_exploration/benchmark_patient_access_pandas.py new file mode 100644 index 000000000..47fd434bf --- /dev/null +++ b/examples/benchmark_perf/patient_exploration/benchmark_patient_access_pandas.py @@ -0,0 +1,583 @@ +"""Benchmark: Pure Pandas - Data Loading & Single Patient Access (with Parquet Caching) + +Measures: +1. Time to load raw MIMIC-IV CSV tables with pandas +2. Time to cache tables as parquet files +3. Time to reload from parquet cache +4. Time to join tables and access a single patient's events +5. Total time + +This benchmark mimics a realistic workflow where: +- Raw CSV data is loaded once +- Data is cached as parquet for faster subsequent access +- Patient queries are performed on the cached data + +Usage: + python benchmark_patient_access_pandas.py + python benchmark_patient_access_pandas.py --patient-id 10014729 + python benchmark_patient_access_pandas.py --data-root /path/to/mimiciv/hosp + python benchmark_patient_access_pandas.py --skip-parquet # Skip parquet caching step + python benchmark_patient_access_pandas.py --use-temp-dir # Use temp dir (auto-cleaned) +""" + +from __future__ import annotations + +import argparse +import csv +import os +import shutil +import tempfile +import threading +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Dict, List, Any, Optional + +import pandas as pd +import psutil + + +# ============================================================================= +# Benchmark Result +# ============================================================================= + +@dataclass +class BenchmarkResult: + approach: str + csv_load_s: float # Time to load raw CSVs + parquet_write_s: float # Time to write parquet cache + parquet_read_s: float # Time to reload from parquet + patient_access_1st_s: float # First access (includes joins) + patient_access_2nd_s: float # Second access (warm cache) + total_s: float + peak_rss_bytes: int + patient_found: bool + num_events: int + num_visits: int + num_tables_loaded: int + parquet_cache_bytes: int # Size of parquet cache + + +def format_size(size_bytes: int) -> str: + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} PB" + + +class PeakMemoryTracker: + """Tracks peak RSS for current process + children.""" + + def __init__(self, poll_interval_s: float = 0.05) -> None: + self._proc = psutil.Process(os.getpid()) + self._poll_interval_s = poll_interval_s + self._stop = threading.Event() + self._lock = threading.Lock() + self._peak = 0 + self._thread = threading.Thread(target=self._run, daemon=True) + + def start(self) -> None: + self._thread.start() + + def reset(self) -> None: + with self._lock: + self._peak = 0 + + def stop(self) -> None: + self._stop.set() + + def peak_bytes(self) -> int: + with self._lock: + return self._peak + + def _total_rss_bytes(self) -> int: + total = 0 + try: + total += self._proc.memory_info().rss + for child in self._proc.children(recursive=True): + try: + total += child.memory_info().rss + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + return total + + def _run(self) -> None: + while not self._stop.is_set(): + rss = self._total_rss_bytes() + with self._lock: + if rss > self._peak: + self._peak = rss + time.sleep(self._poll_interval_s) + + +# ============================================================================= +# Data Loading and Parquet Caching +# ============================================================================= + +def get_directory_size(path: str | Path) -> int: + """Calculate total size of a directory.""" + total = 0 + p = Path(path) + if not p.exists(): + return 0 + for entry in p.rglob("*"): + if entry.is_file(): + try: + total += entry.stat().st_size + except FileNotFoundError: + pass + return total + + +def write_tables_to_parquet( + tables: Dict[str, pd.DataFrame], + cache_dir: str, +) -> float: + """Write DataFrames to parquet files for caching. + + Args: + tables: Dictionary mapping table name to DataFrame + cache_dir: Directory to write parquet files + + Returns: + Time taken in seconds + """ + start = time.time() + + cache_path = Path(cache_dir) + cache_path.mkdir(parents=True, exist_ok=True) + + for table_name, df in tables.items(): + parquet_path = cache_path / f"{table_name}.parquet" + df.to_parquet(parquet_path, index=False, engine="pyarrow") + + return time.time() - start + + +def load_tables_from_parquet( + cache_dir: str, + tables: List[str], +) -> Dict[str, pd.DataFrame]: + """Load tables from parquet cache. + + Args: + cache_dir: Directory containing parquet files + tables: List of table names to load + + Returns: + Dictionary mapping table name to DataFrame + """ + loaded = {} + cache_path = Path(cache_dir) + + for table in tables: + parquet_path = cache_path / f"{table}.parquet" + if parquet_path.exists(): + df = pd.read_parquet(parquet_path, engine="pyarrow") + loaded[table] = df + + return loaded + + +def load_mimic_tables( + data_root: str, + tables: List[str], +) -> Dict[str, pd.DataFrame]: + """Load MIMIC-IV tables from CSV files. + + Args: + data_root: Path to MIMIC-IV hosp directory + tables: List of table names to load + + Returns: + Dictionary mapping table name to DataFrame + """ + loaded = {} + + for table in tables: + # Try both .csv and .csv.gz extensions + csv_path = os.path.join(data_root, f"{table}.csv") + csv_gz_path = os.path.join(data_root, f"{table}.csv.gz") + + if os.path.exists(csv_gz_path): + path = csv_gz_path + elif os.path.exists(csv_path): + path = csv_path + else: + print(f" WARNING: Table {table} not found at {csv_path} or {csv_gz_path}") + continue + + print(f" Loading {table}...") + start = time.time() + + # Use low_memory=False for tables that might have mixed types + df = pd.read_csv(path, low_memory=False) + + elapsed = time.time() - start + print(f" -> {len(df):,} rows in {elapsed:.2f}s") + + loaded[table] = df + + return loaded + + +def get_patient_events_with_joins( + patient_id: str, + tables: Dict[str, pd.DataFrame], +) -> Dict[str, Any]: + """Get all events for a patient, joining tables to build visit hierarchy. + + This mimics what PyHealth does: building a patient -> visits -> events structure + by joining clinical tables with admissions to get visit context. + + Args: + patient_id: Patient ID (subject_id in MIMIC-IV) + tables: Dictionary of loaded DataFrames + + Returns: + Dictionary with patient data organized by visits + """ + subject_id = int(patient_id) + + patient_data = { + "subject_id": subject_id, + "visits": {}, # hadm_id -> visit data with events + "demographics": None, + "total_events": 0, + } + + # Get patient demographics + if "patients" in tables: + patients_df = tables["patients"] + patient_demo = patients_df[patients_df["subject_id"] == subject_id] + if len(patient_demo) > 0: + patient_data["demographics"] = patient_demo.iloc[0].to_dict() + + # Get admissions (visits) for this patient + if "admissions" not in tables: + return patient_data + + admissions_df = tables["admissions"] + patient_admissions = admissions_df[admissions_df["subject_id"] == subject_id].copy() + + if len(patient_admissions) == 0: + return patient_data + + # Parse datetime columns for admissions + patient_admissions["admittime"] = pd.to_datetime(patient_admissions["admittime"]) + patient_admissions["dischtime"] = pd.to_datetime(patient_admissions["dischtime"]) + + # Initialize visit structure + for _, admission in patient_admissions.iterrows(): + hadm_id = admission["hadm_id"] + patient_data["visits"][hadm_id] = { + "hadm_id": hadm_id, + "admittime": admission["admittime"], + "dischtime": admission["dischtime"], + "events": {}, + } + + hadm_ids = set(patient_admissions["hadm_id"].tolist()) + + # Join diagnoses_icd with admissions context + if "diagnoses_icd" in tables: + diagnoses_df = tables["diagnoses_icd"] + + # Filter to patient first, then join with admissions + patient_diagnoses = diagnoses_df[diagnoses_df["subject_id"] == subject_id].copy() + + # Join to get admission context (admittime, dischtime) + patient_diagnoses = patient_diagnoses.merge( + patient_admissions[["hadm_id", "admittime", "dischtime"]], + on="hadm_id", + how="inner" + ) + + # Organize by visit + for hadm_id in hadm_ids: + visit_diagnoses = patient_diagnoses[patient_diagnoses["hadm_id"] == hadm_id] + patient_data["visits"][hadm_id]["events"]["diagnoses_icd"] = visit_diagnoses + patient_data["total_events"] += len(visit_diagnoses) + + # Join procedures_icd with admissions context + if "procedures_icd" in tables: + procedures_df = tables["procedures_icd"] + + patient_procedures = procedures_df[procedures_df["subject_id"] == subject_id].copy() + + patient_procedures = patient_procedures.merge( + patient_admissions[["hadm_id", "admittime", "dischtime"]], + on="hadm_id", + how="inner" + ) + + for hadm_id in hadm_ids: + visit_procedures = patient_procedures[patient_procedures["hadm_id"] == hadm_id] + patient_data["visits"][hadm_id]["events"]["procedures_icd"] = visit_procedures + patient_data["total_events"] += len(visit_procedures) + + # Join labevents with admissions context + if "labevents" in tables: + labevents_df = tables["labevents"] + + # Filter to patient (labevents is large, so filter first) + patient_labs = labevents_df[labevents_df["subject_id"] == subject_id].copy() + + # Join with admissions to get visit context + # Note: Some lab events may not have hadm_id (outpatient) + patient_labs = patient_labs.merge( + patient_admissions[["hadm_id", "admittime", "dischtime"]], + on="hadm_id", + how="inner" # Only keep labs with valid admission + ) + + for hadm_id in hadm_ids: + visit_labs = patient_labs[patient_labs["hadm_id"] == hadm_id] + patient_data["visits"][hadm_id]["events"]["labevents"] = visit_labs + patient_data["total_events"] += len(visit_labs) + + # Join prescriptions with admissions context + if "prescriptions" in tables: + prescriptions_df = tables["prescriptions"] + + patient_prescriptions = prescriptions_df[prescriptions_df["subject_id"] == subject_id].copy() + + patient_prescriptions = patient_prescriptions.merge( + patient_admissions[["hadm_id", "admittime", "dischtime"]], + on="hadm_id", + how="inner" + ) + + for hadm_id in hadm_ids: + visit_prescriptions = patient_prescriptions[patient_prescriptions["hadm_id"] == hadm_id] + patient_data["visits"][hadm_id]["events"]["prescriptions"] = visit_prescriptions + patient_data["total_events"] += len(visit_prescriptions) + + return patient_data + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Benchmark pure pandas data loading and single patient access (with parquet caching)" + ) + parser.add_argument( + "--patient-id", + type=str, + default="10014729", + help="Patient ID to access (default: 10014729)", + ) + parser.add_argument( + "--data-root", + type=str, + default="/srv/local/data/physionet.org/files/mimiciv/2.2/hosp", + help="Path to MIMIC-IV hosp directory", + ) + parser.add_argument( + "--output-csv", + type=str, + default="benchmark_patient_access_pandas.csv", + help="Output CSV file", + ) + parser.add_argument( + "--skip-parquet", + action="store_true", + help="Skip the parquet caching step", + ) + parser.add_argument( + "--cache-dir", + type=str, + default="/shared/eng/pyhealth/pandas_parquet_cache", + help="Directory for parquet cache (default: /shared/eng/pyhealth/pandas_parquet_cache)", + ) + parser.add_argument( + "--use-temp-dir", + action="store_true", + help="Use a temporary directory for parquet cache (auto-cleaned after benchmark)", + ) + args = parser.parse_args() + + # Tables to load (matching what PyHealth loads for mortality/LOS tasks) + tables_to_load = [ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ] + + # Set up parquet cache directory + use_temp_dir = args.use_temp_dir + if use_temp_dir: + temp_dir = tempfile.mkdtemp(prefix="mimic_parquet_cache_") + cache_dir = temp_dir + else: + cache_dir = args.cache_dir + temp_dir = None + + print("=" * 80) + print("BENCHMARK: Pure Pandas - Data Loading & Patient Access (with Parquet Caching)") + print("=" * 80) + print(f"Patient ID: {args.patient_id}") + print(f"Data root: {args.data_root}") + print(f"Tables: {tables_to_load}") + print(f"Skip parquet: {args.skip_parquet}") + print(f"Parquet cache: {cache_dir} {'(temp, will be deleted)' if use_temp_dir else ''}") + print("=" * 80) + + tracker = PeakMemoryTracker(poll_interval_s=0.05) + tracker.start() + tracker.reset() + + try: + # Step 1: Load raw CSV tables + print("\n[Step 1] Loading MIMIC-IV tables from CSV...") + csv_load_start = time.time() + + tables = load_mimic_tables(args.data_root, tables_to_load) + + csv_load_s = time.time() - csv_load_start + num_tables_loaded = len(tables) + + total_rows = sum(len(df) for df in tables.values()) + print(f"\n Loaded {num_tables_loaded} tables ({total_rows:,} total rows) in {csv_load_s:.2f}s") + + # Step 2: Write to parquet cache (if not skipping) + parquet_write_s = 0.0 + parquet_read_s = 0.0 + parquet_cache_bytes = 0 + + if not args.skip_parquet: + print("\n[Step 2] Writing tables to parquet cache...") + parquet_write_start = time.time() + + write_tables_to_parquet(tables, cache_dir) + + parquet_write_s = time.time() - parquet_write_start + parquet_cache_bytes = get_directory_size(cache_dir) + + print(f" Wrote {num_tables_loaded} parquet files in {parquet_write_s:.2f}s") + print(f" Cache size: {format_size(parquet_cache_bytes)}") + + # Step 3: Reload from parquet cache (simulating future access) + print("\n[Step 3] Reloading tables from parquet cache...") + + # Clear the in-memory tables to simulate a fresh load + del tables + + parquet_read_start = time.time() + + tables = load_tables_from_parquet(cache_dir, tables_to_load) + + parquet_read_s = time.time() - parquet_read_start + print(f" Reloaded {len(tables)} tables from parquet in {parquet_read_s:.2f}s") + + next_step = 4 + else: + print("\n[Step 2] Skipping parquet caching (--skip-parquet)") + next_step = 3 + + # Step N: First patient access (cold - includes joining tables) + print(f"\n[Step {next_step}] First access to patient {args.patient_id} (with table joins)...") + access_1_start = time.time() + + patient_data = get_patient_events_with_joins(args.patient_id, tables) + + patient_found = patient_data["total_events"] > 0 or len(patient_data["visits"]) > 0 + num_events = patient_data["total_events"] + num_visits = len(patient_data["visits"]) + + if patient_found: + print(f" Patient found!") + print(f" Number of visits: {num_visits}") + print(f" Total events: {num_events}") + + # Show events per visit + for hadm_id, visit in patient_data["visits"].items(): + visit_events = sum(len(df) for df in visit["events"].values()) + print(f" Visit {hadm_id}: {visit_events} events") + for table_name, events_df in visit["events"].items(): + if len(events_df) > 0: + print(f" - {table_name}: {len(events_df)} rows") + else: + print(f" Patient NOT found!") + if "patients" in tables: + available_ids = tables["patients"]["subject_id"].head(10).tolist() + print(f" Available patient IDs (first 10): {available_ids}") + + patient_access_1st_s = time.time() - access_1_start + + # Step N+1: Second patient access (warm) + print(f"\n[Step {next_step + 1}] Second access to patient {args.patient_id} (repeat with joins)...") + access_2_start = time.time() + + if patient_found: + patient_data_2 = get_patient_events_with_joins(args.patient_id, tables) + count = patient_data_2["total_events"] + print(f" Verified {count} events") + + patient_access_2nd_s = time.time() - access_2_start + + # Calculate total time + total_s = csv_load_s + parquet_write_s + parquet_read_s + patient_access_1st_s + patient_access_2nd_s + peak_rss = tracker.peak_bytes() + + finally: + tracker.stop() + + # Clean up temporary parquet cache + if use_temp_dir and temp_dir and os.path.exists(temp_dir): + print(f"\n[Cleanup] Removing temporary parquet cache: {temp_dir}") + shutil.rmtree(temp_dir) + + result = BenchmarkResult( + approach="pandas_with_parquet_cache", + csv_load_s=csv_load_s, + parquet_write_s=parquet_write_s, + parquet_read_s=parquet_read_s, + patient_access_1st_s=patient_access_1st_s, + patient_access_2nd_s=patient_access_2nd_s, + total_s=total_s, + peak_rss_bytes=peak_rss, + patient_found=patient_found, + num_events=num_events, + num_visits=num_visits, + num_tables_loaded=num_tables_loaded, + parquet_cache_bytes=parquet_cache_bytes, + ) + + # Summary + print("\n" + "=" * 80) + print("SUMMARY: Pure Pandas (with Parquet Caching)") + print("=" * 80) + access_1_str = f"{patient_access_1st_s*1000:.2f}ms" if patient_access_1st_s < 1 else f"{patient_access_1st_s:.2f}s" + access_2_str = f"{patient_access_2nd_s*1000:.2f}ms" if patient_access_2nd_s < 1 else f"{patient_access_2nd_s:.2f}s" + print(f" CSV load time: {csv_load_s:.2f}s") + if not args.skip_parquet: + print(f" Parquet write time: {parquet_write_s:.2f}s") + print(f" Parquet read time: {parquet_read_s:.2f}s") + print(f" Parquet cache size: {format_size(parquet_cache_bytes)}") + print(f" Patient access (1st/cold): {access_1_str}") + print(f" Patient access (2nd/warm): {access_2_str}") + print(f" Total time: {total_s:.2f}s") + print(f" Peak RSS: {format_size(peak_rss)}") + print(f" Patient found: {patient_found}") + print(f" Visits: {num_visits}") + print(f" Events: {num_events}") + print(f" Tables loaded: {num_tables_loaded}") + + # Write CSV + out_csv = Path(args.output_csv) + out_csv.parent.mkdir(parents=True, exist_ok=True) + with out_csv.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=list(asdict(result).keys())) + writer.writeheader() + writer.writerow(asdict(result)) + print(f"\nResults saved to: {out_csv}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/benchmark_perf/patient_exploration/benchmark_patient_access_pyhealth2.py b/examples/benchmark_perf/patient_exploration/benchmark_patient_access_pyhealth2.py new file mode 100644 index 000000000..736ae3023 --- /dev/null +++ b/examples/benchmark_perf/patient_exploration/benchmark_patient_access_pyhealth2.py @@ -0,0 +1,268 @@ +"""Benchmark: PyHealth 2.0 - Data Loading & Single Patient Access + +Measures: +1. Time to load/initialize the dataset from raw MIMIC-IV data +2. Time to access a single patient after loading +3. Total time (load + access) + +Usage: + # Activate PyHealth 2.0 environment first + source activate pyhealth312 + + # Run benchmark + python benchmark_patient_access_pyhealth2.py + python benchmark_patient_access_pyhealth2.py --patient-id 10014729 --workers 8 + python benchmark_patient_access_pyhealth2.py --dev +""" + +from __future__ import annotations + +import argparse +import csv +import os +import shutil +import threading +import time +from dataclasses import asdict, dataclass +from pathlib import Path + +import psutil + +from pyhealth.datasets import MIMIC4Dataset + + +# ============================================================================= +# Benchmark Result +# ============================================================================= + +@dataclass +class BenchmarkResult: + approach: str + data_load_s: float + patient_access_1st_s: float # First access (cold cache) + patient_access_2nd_s: float # Second access (warm cache) + total_s: float + peak_rss_bytes: int + patient_found: bool + num_events: int + + +def format_size(size_bytes: int) -> str: + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} PB" + + +class PeakMemoryTracker: + """Tracks peak RSS for current process + children.""" + + def __init__(self, poll_interval_s: float = 0.05) -> None: + self._proc = psutil.Process(os.getpid()) + self._poll_interval_s = poll_interval_s + self._stop = threading.Event() + self._lock = threading.Lock() + self._peak = 0 + self._thread = threading.Thread(target=self._run, daemon=True) + + def start(self) -> None: + self._thread.start() + + def reset(self) -> None: + with self._lock: + self._peak = 0 + + def stop(self) -> None: + self._stop.set() + + def peak_bytes(self) -> int: + with self._lock: + return self._peak + + def _total_rss_bytes(self) -> int: + total = 0 + try: + total += self._proc.memory_info().rss + for child in self._proc.children(recursive=True): + try: + total += child.memory_info().rss + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + return total + + def _run(self) -> None: + while not self._stop.is_set(): + rss = self._total_rss_bytes() + with self._lock: + if rss > self._peak: + self._peak = rss + time.sleep(self._poll_interval_s) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Benchmark PyHealth 2.0 data loading and single patient access" + ) + parser.add_argument( + "--patient-id", + type=str, + default="10014729", + help="Patient ID to access (default: 10014729)", + ) + parser.add_argument( + "--workers", + type=int, + default=8, + help="Number of workers (default: 8)", + ) + parser.add_argument( + "--dev", + action="store_true", + help="Use dev mode (smaller subset)", + ) + parser.add_argument( + "--ehr-root", + type=str, + default="/srv/local/data/physionet.org/files/mimiciv/2.2/", + help="Path to MIMIC-IV root directory", + ) + parser.add_argument( + "--cache-dir", + type=str, + default="/shared/eng/pyhealth/benchmark_patient_access_2.0", + help="Cache directory", + ) + parser.add_argument( + "--no-clear-cache", + action="store_true", + help="Do not clear cache before benchmark", + ) + parser.add_argument( + "--output-csv", + type=str, + default="benchmark_patient_access_pyhealth2.csv", + help="Output CSV file", + ) + args = parser.parse_args() + + print("=" * 80) + print("BENCHMARK: PyHealth 2.0 - Data Loading & Patient Access") + print("=" * 80) + print(f"Patient ID: {args.patient_id}") + print(f"Workers: {args.workers}") + print(f"Dev mode: {args.dev}") + print(f"EHR Root: {args.ehr_root}") + print(f"Cache dir: {args.cache_dir}") + print("=" * 80) + + # Clear cache for accurate timing + cache_path = Path(args.cache_dir) + if not args.no_clear_cache: + if cache_path.exists(): + print("\nClearing cache...") + shutil.rmtree(cache_path) + cache_path.mkdir(parents=True, exist_ok=True) + + tracker = PeakMemoryTracker(poll_interval_s=0.05) + tracker.start() + tracker.reset() + + # Step 1: Load dataset + print("\n[Step 1] Loading dataset...") + load_start = time.time() + + dataset = MIMIC4Dataset( + ehr_root=args.ehr_root, + ehr_tables=["patients", "admissions", "diagnoses_icd", "procedures_icd", "labevents"], + dev=args.dev, + cache_dir=str(cache_path), + num_workers=args.workers, + ) + + data_load_s = time.time() - load_start + print(f" Dataset loaded in {data_load_s:.2f}s") + + # Step 2: First patient access (cold cache) + print(f"\n[Step 2] First access to patient {args.patient_id} (cold cache)...") + access_1_start = time.time() + + patient_ids = dataset.unique_patient_ids + print(f" Number of patients: {len(patient_ids)}") + patient_found = args.patient_id in patient_ids + + if patient_found: + patient = dataset.get_patient(args.patient_id) + num_events = len(patient.get_events()) + print(f" Patient found!") + print(f" Number of events: {num_events}") + else: + num_events = 0 + print(f" Patient NOT found!") + # List available patient IDs (first 10) + print(f" Available patient IDs (first 10): {patient_ids[:10]}") + + patient_access_1st_s = time.time() - access_1_start + + # Step 3: Second patient access (warm cache) + print(f"\n[Step 3] Second access to patient {args.patient_id} (warm cache)...") + access_2_start = time.time() + + if patient_found: + patient = dataset.get_patient(args.patient_id) + # Re-count events to ensure we're actually accessing the data + count = len(patient.get_events()) + + patient_access_2nd_s = time.time() - access_2_start + + total_s = data_load_s + patient_access_1st_s + patient_access_2nd_s + peak_rss = tracker.peak_bytes() + + tracker.stop() + + # Cleanup + del dataset + if not args.no_clear_cache and cache_path.exists(): + shutil.rmtree(cache_path) + + result = BenchmarkResult( + approach="pyhealth_2.0", + data_load_s=data_load_s, + patient_access_1st_s=patient_access_1st_s, + patient_access_2nd_s=patient_access_2nd_s, + total_s=total_s, + peak_rss_bytes=peak_rss, + patient_found=patient_found, + num_events=num_events, + ) + + # Summary + print("\n" + "=" * 80) + print("SUMMARY: PyHealth 2.0") + print("=" * 80) + access_1_str = f"{patient_access_1st_s*1000:.2f}ms" if patient_access_1st_s < 1 else f"{patient_access_1st_s:.2f}s" + access_2_str = f"{patient_access_2nd_s*1000:.2f}ms" if patient_access_2nd_s < 1 else f"{patient_access_2nd_s:.2f}s" + print(f" Data load time: {data_load_s:.2f}s") + print(f" Patient access (1st/cold): {access_1_str}") + print(f" Patient access (2nd/warm): {access_2_str}") + print(f" Total time: {total_s:.2f}s") + print(f" Peak RSS: {format_size(peak_rss)}") + print(f" Patient found: {patient_found}") + print(f" Events: {num_events}") + + # Write CSV + out_csv = Path(args.output_csv) + out_csv.parent.mkdir(parents=True, exist_ok=True) + with out_csv.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=list(asdict(result).keys())) + writer.writeheader() + writer.writerow(asdict(result)) + print(f"\nResults saved to: {out_csv}") + print("=" * 80) + + +if __name__ == "__main__": + main() + diff --git a/examples/benchmark_perf/patient_exploration/clean_patient_access_legacy.py b/examples/benchmark_perf/patient_exploration/clean_patient_access_legacy.py new file mode 100644 index 000000000..4114ad120 --- /dev/null +++ b/examples/benchmark_perf/patient_exploration/clean_patient_access_legacy.py @@ -0,0 +1,14 @@ +from pyhealth.datasets import MIMIC4Dataset +MIMIC_ROOT = "/srv/local/data/physionet.org/files/mimiciv/2.0/hosp" +PATIENT_ID = "10014729" +if __name__ == "__main__": + dataset = MIMIC4Dataset( + root=MIMIC_ROOT, + tables=["diagnoses_icd", "procedures_icd", "labevents"], + refresh_cache=True, + ) + patient = dataset.patients[PATIENT_ID] + events = [] + for visit in patient.visits.values(): + for table in visit.available_tables: + events.extend(visit.get_event_list(table)) diff --git a/examples/benchmark_perf/patient_exploration/clean_patient_access_meds_etl.py b/examples/benchmark_perf/patient_exploration/clean_patient_access_meds_etl.py new file mode 100644 index 000000000..9e209f251 --- /dev/null +++ b/examples/benchmark_perf/patient_exploration/clean_patient_access_meds_etl.py @@ -0,0 +1,12 @@ +import subprocess +import meds_reader +MIMIC_ROOT = "/srv/local/data/physionet.org/files/mimiciv" +MEDS_DIR = "/tmp/mimic4_meds" +MEDS_READER_DIR = "/tmp/mimic4_meds_reader" +PATIENT_ID = 10014729 +if __name__ == "__main__": + subprocess.run(["meds_etl_mimic", MIMIC_ROOT, MEDS_DIR], check=True) + subprocess.run(["meds_reader_convert", MEDS_DIR, MEDS_READER_DIR], check=True) + with meds_reader.SubjectDatabase(MEDS_READER_DIR) as db: + patient = db[PATIENT_ID] + events = list(patient.events) \ No newline at end of file diff --git a/examples/benchmark_perf/patient_exploration/clean_patient_access_meds_reader.py b/examples/benchmark_perf/patient_exploration/clean_patient_access_meds_reader.py new file mode 100644 index 000000000..d5772ec4b --- /dev/null +++ b/examples/benchmark_perf/patient_exploration/clean_patient_access_meds_reader.py @@ -0,0 +1,82 @@ +import os +import shutil +import subprocess +import datetime +import collections +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq +from pyhealth.datasets import MIMIC4Dataset +import meds_reader +MIMIC_ROOT = "/srv/local/data/physionet.org/files/mimiciv/2.0/hosp" +MEDS_DIR = "/tmp/pyhealth_meds" +MEDS_READER_DIR = "/tmp/pyhealth_meds_reader" +PATIENT_ID = 10014729 +if __name__ == "__main__": + dataset = MIMIC4Dataset( + root=MIMIC_ROOT, + tables=["diagnoses_icd", "procedures_icd", "prescriptions", "labevents"], + refresh_cache=True, + ) + results = collections.defaultdict(list) + for patient_id, patient in dataset.patients.items(): + subject_id = int(patient_id) + birth_obj = {"subject_id": subject_id, "code": "Birth", "time": patient.birth_datetime} + birth_obj["gender"] = patient.gender + birth_obj["ethnicity"] = patient.ethnicity + for k, v in patient.attr_dict.items(): + if v != v: # Skip NaN + continue + birth_obj[k] = v + results[subject_id].append(birth_obj) + if patient.death_datetime is not None: + results[subject_id].append({"subject_id": subject_id, "code": "Death", "time": patient.death_datetime}) + for visit_id, visit in patient.visits.items(): + vid = int(visit_id) + visit_event = { + "subject_id": subject_id, + "code": "Visit", + "time": visit.encounter_time, + "visit_id": vid, + "discharge_time": visit.discharge_time, + "discharge_status": visit.discharge_status, + } + for k, v in visit.attr_dict.items(): + if v != v: + continue + visit_event[k] = v + results[subject_id].append(visit_event) + for table in visit.available_tables: + for event in visit.get_event_list(table): + event_obj = { + "subject_id": subject_id, + "visit_id": vid, + "code": f"{event.vocabulary}/{event.code}", + "time": event.timestamp or visit.discharge_time, + } + for k, v in event.attr_dict.items(): + if v != v: + continue + event_obj[k] = v + results[subject_id].append(event_obj) + results[subject_id].sort(key=lambda a: a["time"]) + attr_map = {str: pa.string(), int: pa.int64(), np.int64: pa.int64(), float: pa.float64(), np.float64: pa.float64(), datetime.datetime: pa.timestamp("us")} + attr_schema = set() + for subject_values in results.values(): + for row in subject_values: + for k, v in row.items(): + if k not in {"subject_id", "time", "numeric_value"} and type(v) in attr_map: + attr_schema.add((k, attr_map[type(v)])) + schema = pa.schema([("subject_id", pa.int64()), ("time", pa.timestamp("us"))] + sorted(list(attr_schema))) + shutil.rmtree(MEDS_DIR, ignore_errors=True) + os.makedirs(f"{MEDS_DIR}/data") + os.makedirs(f"{MEDS_DIR}/metadata") + all_subjects = list(results.keys()) + for i, subject_ids in enumerate(np.array_split(all_subjects, 100)): + rows = [v for sid in subject_ids for v in results[sid]] + pq.write_table(pa.Table.from_pylist(rows, schema=schema), f"{MEDS_DIR}/data/{i}.parquet") + shutil.rmtree(MEDS_READER_DIR, ignore_errors=True) + subprocess.run(["meds_reader_convert", MEDS_DIR, MEDS_READER_DIR], check=True) + with meds_reader.SubjectDatabase(MEDS_READER_DIR) as db: + patient = db[PATIENT_ID] + events = list(patient.events) diff --git a/examples/benchmark_perf/patient_exploration/clean_patient_access_pandas.py b/examples/benchmark_perf/patient_exploration/clean_patient_access_pandas.py new file mode 100644 index 000000000..ccce05961 --- /dev/null +++ b/examples/benchmark_perf/patient_exploration/clean_patient_access_pandas.py @@ -0,0 +1,16 @@ +import pandas as pd +MIMIC_ROOT = "/srv/local/data/physionet.org/files/mimiciv/2.2/hosp" +PATIENT_ID = 10014729 +if __name__ == "__main__": + patients = pd.read_csv(f"{MIMIC_ROOT}/patients.csv.gz") + admissions = pd.read_csv(f"{MIMIC_ROOT}/admissions.csv.gz") + diagnoses = pd.read_csv(f"{MIMIC_ROOT}/diagnoses_icd.csv.gz") + procedures = pd.read_csv(f"{MIMIC_ROOT}/procedures_icd.csv.gz") + labevents = pd.read_csv(f"{MIMIC_ROOT}/labevents.csv.gz", low_memory=False) + patient_info = patients[patients["subject_id"] == PATIENT_ID] + patient_hadm_ids = admissions[admissions["subject_id"] == PATIENT_ID]["hadm_id"] + events = pd.concat([ + diagnoses[diagnoses["hadm_id"].isin(patient_hadm_ids)], + procedures[procedures["hadm_id"].isin(patient_hadm_ids)], + labevents[labevents["hadm_id"].isin(patient_hadm_ids)], + ]) \ No newline at end of file diff --git a/examples/benchmark_perf/patient_exploration/clean_patient_access_pyhealth2.py b/examples/benchmark_perf/patient_exploration/clean_patient_access_pyhealth2.py new file mode 100644 index 000000000..3981a656d --- /dev/null +++ b/examples/benchmark_perf/patient_exploration/clean_patient_access_pyhealth2.py @@ -0,0 +1,10 @@ +from pyhealth.datasets import MIMIC4Dataset +MIMIC_ROOT = "/srv/local/data/physionet.org/files/mimiciv/2.2/" +PATIENT_ID = "10014729" +if __name__ == "__main__": + dataset = MIMIC4Dataset( + ehr_root=MIMIC_ROOT, + ehr_tables=["patients", "admissions", "diagnoses_icd", "procedures_icd", "labevents"], + ) + patient = dataset.get_patient(PATIENT_ID) + events = patient.get_events() \ No newline at end of file diff --git a/examples/ChestXray-image-generation-GAN.ipynb b/examples/cxr/ChestXray-image-generation-GAN.ipynb similarity index 100% rename from examples/ChestXray-image-generation-GAN.ipynb rename to examples/cxr/ChestXray-image-generation-GAN.ipynb diff --git a/examples/ChestXrayClassificationWithSaliency.ipynb b/examples/cxr/ChestXrayClassificationWithSaliency.ipynb similarity index 100% rename from examples/ChestXrayClassificationWithSaliency.ipynb rename to examples/cxr/ChestXrayClassificationWithSaliency.ipynb diff --git a/examples/chestxray14_binary_classification.ipynb b/examples/cxr/chestxray14_binary_classification.ipynb similarity index 100% rename from examples/chestxray14_binary_classification.ipynb rename to examples/cxr/chestxray14_binary_classification.ipynb diff --git a/examples/chestxray14_multilabel_classification.ipynb b/examples/cxr/chestxray14_multilabel_classification.ipynb similarity index 100% rename from examples/chestxray14_multilabel_classification.ipynb rename to examples/cxr/chestxray14_multilabel_classification.ipynb diff --git a/examples/chextXray_image_generation_VAE.py b/examples/cxr/chextXray_image_generation_VAE.py similarity index 100% rename from examples/chextXray_image_generation_VAE.py rename to examples/cxr/chextXray_image_generation_VAE.py diff --git a/examples/cnn_cxr.ipynb b/examples/cxr/cnn_cxr.ipynb similarity index 100% rename from examples/cnn_cxr.ipynb rename to examples/cxr/cnn_cxr.ipynb diff --git a/examples/covid19cxr_conformal.py b/examples/cxr/covid19cxr_conformal.py similarity index 100% rename from examples/covid19cxr_conformal.py rename to examples/cxr/covid19cxr_conformal.py diff --git a/examples/cxr/covid19cxr_tutorial.ipynb b/examples/cxr/covid19cxr_tutorial.ipynb new file mode 100644 index 000000000..4eaf64aeb --- /dev/null +++ b/examples/cxr/covid19cxr_tutorial.ipynb @@ -0,0 +1,1397 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c5ad4b00", + "metadata": {}, + "source": [ + "#" + ] + }, + { + "cell_type": "markdown", + "id": "b3c5dddd", + "metadata": {}, + "source": [ + "## Covid19CXR Comprehensive Tutorial\n", + "\n", + "This tutorial takes you from start to finish with the COVID19CXR dataset and explores many post-hoc deployment aspects of PyHealth\n", + "\n", + "Namely, we go through:\n", + "- Loading the data\n", + "- Training a model\n", + "- Doing Conformal Prediction\n", + "- Running Interpretability on Predicted Samples" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "f135c3ac", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/johnwu3/projects/PyHealth_Branch_Testing/PyHealth/pyhealth/sampler/sage_sampler.py:3: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n", + " import pkg_resources\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "This is a warning of potentially slow compute. You could uncomment this line and use the Python implementation instead of Cython.\n" + ] + } + ], + "source": [ + "import os\n", + "import numpy as np\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from pyhealth.calib.predictionset import LABEL\n", + "from pyhealth.datasets import (\n", + " COVID19CXRDataset,\n", + " get_dataloader,\n", + " split_by_sample_conformal,\n", + ")\n", + "from pyhealth.models import TorchvisionModel\n", + "from pyhealth.trainer import Trainer, get_metrics_fn\n", + "from pyhealth.interpret.methods import CheferRelevance\n", + "from pyhealth.interpret.utils import visualize_image_attr\n", + "\n", + "torch.manual_seed(42)\n", + "np.random.seed(42)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b30e89a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No config path provided, using default config\n", + "Initializing covid19_cxr dataset from /home/johnwu3/projects/PyHealth_Branch_Testing/datasets/COVID-19_Radiography_Dataset (dev mode: False)\n", + "Setting task COVID19CXRClassification for covid19_cxr base dataset...\n", + "Fitting processors on the dataset...\n", + "Label disease vocab: {'COVID': 0, 'Lung Opacity': 1, 'Normal': 2, 'Viral Pneumonia': 3}\n", + "Processing samples and saving to /home/johnwu3/projects/covid19cxr_task_cache/samples_38b176c9-d393-4251-99cc-7de0b0557c39.ld...\n", + "Applying processors on data with 4 workers...\n", + "Detected Jupyter notebook environment, setting num_workers to 1\n", + "Single worker mode, processing sequentially\n", + "Worker 0 started processing 21165 samples. (0 to 21165)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/21165 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# ============================================================================\n", + "# STEP 5: Interpretability Visualization\n", + "# ============================================================================\n", + "print(\"\\n\" + \"=\" * 80)\n", + "print(\"STEP 5: Interpretability Visualization\")\n", + "print(\"=\" * 80)\n", + "\n", + "single_loader = get_dataloader(test_data, batch_size=1, shuffle=False)\n", + "n_viz = 3\n", + "print(f\"\\nGenerating Chefer attention attribution for {n_viz} samples...\")\n", + "\n", + "model.eval()\n", + "viz_samples = [batch for i, batch in enumerate(single_loader) if i < n_viz]\n", + "\n", + "fig, axes = plt.subplots(n_viz, 3, figsize=(15, 5 * n_viz))\n", + "\n", + "# Initialize Chefer interpreter (auto-detects ViT)\n", + "chefer_gen = CheferRelevance(model)\n", + "\n", + "for idx, batch in enumerate(viz_samples):\n", + " image = batch[\"image\"]\n", + " true_label = batch[\"disease\"].item()\n", + "\n", + " with torch.no_grad():\n", + " output = model(**batch)\n", + " pred_prob = output[\"y_prob\"][0]\n", + " pred_class = pred_prob.argmax().item()\n", + "\n", + " # Get attribution map using attribute()\n", + " # Returns dict keyed by feature key (e.g., {\"image\": tensor})\n", + " # Input size is inferred automatically from image dimensions\n", + " result = chefer_gen.attribute(\n", + " interpolate=True,\n", + " class_index=pred_class,\n", + " **batch\n", + " )\n", + " attr_map = result[\"image\"] # Keyed by task schema's feature key\n", + " \n", + " img_display, vit_attr_display, attention_overlay = visualize_image_attr(\n", + " image=image[0],\n", + " attribution=attr_map[0, 0],\n", + " interpolate=True,\n", + " )\n", + "\n", + " # Plot\n", + " ax1 = axes[idx, 0]\n", + " ax1.imshow(img_display, cmap='gray' if img_display.ndim == 2 else None)\n", + " ax1.set_title(f\"Original\\nTrue: {id2label.get(true_label)}\")\n", + " ax1.axis('off')\n", + "\n", + " ax2 = axes[idx, 1]\n", + " ax2.imshow(vit_attr_display, cmap='hot')\n", + " ax2.set_title(f\"Attribution\\nPred: {id2label.get(pred_class)}\")\n", + " ax2.axis('off')\n", + "\n", + " ax3 = axes[idx, 2]\n", + " ax3.imshow(attention_overlay)\n", + " ax3.set_title(f\"Overlay\\nConf: {pred_prob[pred_class]:.1%}\")\n", + " ax3.axis('off')\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(\"covid19_cxr_interpretability.png\", dpi=150, bbox_inches='tight')\n", + "print(\"✓ Saved to: covid19_cxr_interpretability.png\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pyhealth312", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/cxr/covid19cxr_tutorial.py b/examples/cxr/covid19cxr_tutorial.py new file mode 100644 index 000000000..f0b11457b --- /dev/null +++ b/examples/cxr/covid19cxr_tutorial.py @@ -0,0 +1,166 @@ +""" +COVID-19 CXR Tutorial: ViT, Conformal Prediction & Interpretability. + +Demonstrates: dataset loading, ViT training, LABEL conformal prediction, +and Chefer attention-based interpretability for uncertain samples. +""" + +import os +import numpy as np +import torch +import matplotlib.pyplot as plt + +from pyhealth.calib.predictionset import LABEL +from pyhealth.datasets import ( + COVID19CXRDataset, + get_dataloader, + split_by_sample_conformal, +) +from pyhealth.metrics.prediction_set import size, miscoverage_overall_ps +from pyhealth.models import TorchvisionModel +from pyhealth.trainer import Trainer +from pyhealth.interpret.methods import CheferRelevance +from pyhealth.interpret.utils import visualize_image_attr + +# Configuration +DATA_ROOT = "/home/johnwu3/projects/PyHealth_Branch_Testing/datasets" +ROOT = f"{DATA_ROOT}/COVID-19_Radiography_Dataset" +CACHE = "/home/johnwu3/projects/covid19cxr_base_cache" +TASK_CACHE = "/home/johnwu3/projects/covid19cxr_task_cache" +CKPT = "/home/johnwu3/projects/covid19cxr_vit_model.ckpt" +SEED = 42 + +if __name__ == "__main__": + # Set seeds for reproducibility + torch.manual_seed(SEED) + np.random.seed(SEED) + + # ========================================================================= + # Cell 1: Data Loading & Model Training + # ========================================================================= + + # Load dataset and create train/val/calibration/test splits + dataset = COVID19CXRDataset(ROOT, cache_dir=CACHE, num_workers=8) + sample_dataset = dataset.set_task(cache_dir=TASK_CACHE, num_workers=8) + train_data, val_data, cal_data, test_data = split_by_sample_conformal( + sample_dataset, ratios=[0.6, 0.1, 0.15, 0.15] + ) + + # Create dataloaders + train_loader = get_dataloader(train_data, batch_size=64, shuffle=True) + val_loader = get_dataloader(val_data, batch_size=64, shuffle=False) + test_loader = get_dataloader(test_data, batch_size=64, shuffle=False) + + # Initialize ViT model + model = TorchvisionModel( + dataset=sample_dataset, + model_name="vit_b_16", + model_config={"weights": "DEFAULT"}, + ) + device = "cuda:4" if torch.cuda.is_available() else "cpu" + + # Load or train model + trainer = Trainer(model=model, device=device, enable_logging=False) + if os.path.exists(CKPT): + trainer.load_ckpt(CKPT) + else: + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=20, + monitor="accuracy", + ) + trainer.save_ckpt(CKPT) + + # Evaluate test performance + test_metrics = trainer.evaluate(test_loader) + print(f"Test Performance: {test_metrics}") + + # ========================================================================= + # Cell 2: Conformal Prediction with LABEL + # ========================================================================= + + # Label mapping for visualization + label_vocab = sample_dataset.output_processors["disease"].label_vocab + id2label = {v: k for k, v in label_vocab.items()} + + # Calibrate LABEL predictor (90% coverage target) + label_predictor = LABEL(model=model, alpha=0.01) + label_predictor.calibrate(cal_dataset=cal_data) + + # Run inference to get prediction sets + cal_trainer = Trainer(model=label_predictor, device=device) + results = cal_trainer.inference( + test_loader, additional_outputs=["y_predset"] + ) + y_true, y_predset = results[0], results[3]["y_predset"] + + # Compute and print coverage metrics + coverage = 1 - miscoverage_overall_ps(y_predset, y_true) + avg_set_size = size(y_predset) + print(f"Coverage: {coverage:.1%}, Avg set size: {avg_set_size:.2f}") + + # Use sample index 0 (known to be uncertain with SEED=42) + sample_idx = 0 + single_loader = get_dataloader(test_data, batch_size=1, shuffle=False) + batch = next(iter(single_loader)) + + # Get model prediction and prediction set for this sample + model.eval() + with torch.no_grad(): + pred_prob = model(**batch)["y_prob"][0] + pred_class = pred_prob.argmax().item() + true_label = batch["disease"].item() + sample_predset = y_predset[sample_idx] + predset_class_indices = [i for i, v in enumerate(sample_predset) if v] + predset_classes = [id2label[i] for i in predset_class_indices] + + # Print sample details + true_name = id2label[true_label] + pred_name = id2label[pred_class] + set_size = len(predset_classes) + print(f"Sample {sample_idx}: True={true_name}, Pred={pred_name}, " + f"Set={predset_classes} (size={set_size})") + + # ========================================================================= + # Cell 3: Interpretability (attribution for each class in prediction set) + # ========================================================================= + # Initialize Chefer/AttentionGrad interpreter + chefer = CheferRelevance(model) + n_classes = len(predset_class_indices) + + # Compute attribution for each class in the prediction set + overlays = [] + for class_idx in predset_class_indices: + attr_map = chefer.attribute(class_index=class_idx, **batch)["image"] + _, _, overlay = visualize_image_attr( + image=batch["image"][0], + attribution=attr_map[0, 0], + ) + overlays.append((class_idx, overlay)) + + # Create figure: ground truth + attribution for each class + figsize = (5 * (n_classes + 1), 5) + fig, axes = plt.subplots(1, n_classes + 1, figsize=figsize) + + # Ground truth image + img, _, _ = visualize_image_attr( + image=batch["image"][0], + attribution=torch.zeros_like(batch["image"][0, 0]), + ) + axes[0].imshow(img, cmap='gray') + axes[0].set_title(f"Ground Truth: {true_name}", fontsize=12) + axes[0].axis('off') + + # Plot attributions + for i, (class_idx, overlay) in enumerate(overlays): + axes[i + 1].imshow(overlay) + prob = pred_prob[class_idx].item() + class_name = id2label[class_idx] + axes[i + 1].set_title(f"{class_name} ({prob:.1%})", fontsize=12) + axes[i + 1].axis('off') + + plt.suptitle("Uncertain Prediction: Multiple Classes", fontsize=14, y=1.02) + plt.tight_layout() + plt.savefig("covid19_cxr_interpretability.png", dpi=150) + print("Saved visualization to covid19_cxr_interpretability.png") diff --git a/examples/cxr/covid19cxr_tutorial_display.py b/examples/cxr/covid19cxr_tutorial_display.py new file mode 100644 index 000000000..093ff286f --- /dev/null +++ b/examples/cxr/covid19cxr_tutorial_display.py @@ -0,0 +1,163 @@ +""" +COVID-19 CXR Tutorial: ViT, Conformal Prediction & Interpretability. + +Demonstrates: dataset loading, ViT training, LABEL conformal prediction, +and Chefer attention-based interpretability for uncertain samples. +""" + +import os +import numpy as np +import torch +import matplotlib.pyplot as plt + +from pyhealth.calib.predictionset import LABEL +from pyhealth.datasets import ( + COVID19CXRDataset, + get_dataloader, + split_by_sample_conformal, +) +from pyhealth.metrics.prediction_set import size, miscoverage_overall_ps +from pyhealth.models import TorchvisionModel +from pyhealth.trainer import Trainer +from pyhealth.interpret.methods import CheferRelevance +from pyhealth.interpret.utils import visualize_image_attr + +# Configuration +DATA_ROOT = "/home/johnwu3/projects/PyHealth_Branch_Testing/datasets" +ROOT = f"{DATA_ROOT}/COVID-19_Radiography_Dataset" +CACHE = "/home/johnwu3/projects/covid19cxr_base_cache" +TASK_CACHE = "/home/johnwu3/projects/covid19cxr_task_cache" +CKPT = "/home/johnwu3/projects/covid19cxr_vit_model.ckpt" +SEED = 42 + +if __name__ == "__main__": + # Set seeds for reproducibility + torch.manual_seed(SEED) + np.random.seed(SEED) + + # ========================================================================= + # Cell 1: Data Loading & Model Training + # ========================================================================= + + # Load dataset and create train/val/calibration/test splits + dataset = COVID19CXRDataset(ROOT, cache_dir=CACHE, num_workers=8) + sample_dataset = dataset.set_task(cache_dir=TASK_CACHE, num_workers=8) + train_data, val_data, cal_data, test_data = split_by_sample_conformal( + sample_dataset, ratios=[0.6, 0.1, 0.15, 0.15] + ) + + # Create dataloaders + train_loader = get_dataloader(train_data, batch_size=64, shuffle=True) + val_loader = get_dataloader(val_data, batch_size=64, shuffle=False) + test_loader = get_dataloader(test_data, batch_size=64, shuffle=False) + + # Initialize ViT model + model = TorchvisionModel( + dataset=sample_dataset, + model_name="vit_b_16", + model_config={"weights": "DEFAULT"}, + ) + device = "cuda:4" if torch.cuda.is_available() else "cpu" + + # Train model + trainer = Trainer(model=model, device=device, enable_logging=False) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=20, + monitor="accuracy", + ) + trainer.save_ckpt(CKPT) + + # Evaluate test performance + test_metrics = trainer.evaluate(test_loader) + print(f"Test Performance: {test_metrics}") + + # ========================================================================= + # Cell 2: Conformal Prediction with LABEL + # ========================================================================= + + # Label mapping for visualization + label_vocab = sample_dataset.output_processors["disease"].label_vocab + id2label = {v: k for k, v in label_vocab.items()} + + # Calibrate LABEL predictor (90% coverage target) + label_predictor = LABEL(model=model, alpha=0.01) + label_predictor.calibrate(cal_dataset=cal_data) + + # Run inference to get prediction sets + cal_trainer = Trainer(model=label_predictor, device=device) + results = cal_trainer.inference( + test_loader, additional_outputs=["y_predset"] + ) + y_true, y_predset = results[0], results[3]["y_predset"] + + # Compute and print coverage metrics + coverage = 1 - miscoverage_overall_ps(y_predset, y_true) + avg_set_size = size(y_predset) + print(f"Coverage: {coverage:.1%}, Avg set size: {avg_set_size:.2f}") + + # Use sample index 0 (known to be uncertain with SEED=42) + sample_idx = 0 + single_loader = get_dataloader(test_data, batch_size=1, shuffle=False) + batch = next(iter(single_loader)) + + # Get model prediction and prediction set for this sample + model.eval() + with torch.no_grad(): + pred_prob = model(**batch)["y_prob"][0] + pred_class = pred_prob.argmax().item() + true_label = batch["disease"].item() + sample_predset = y_predset[sample_idx] + predset_class_indices = [i for i, v in enumerate(sample_predset) if v] + predset_classes = [id2label[i] for i in predset_class_indices] + + # Print sample details + true_name = id2label[true_label] + pred_name = id2label[pred_class] + set_size = len(predset_classes) + print(f"Sample {sample_idx}: True={true_name}, Pred={pred_name}, " + f"Set={predset_classes} (size={set_size})") + + # ========================================================================= + # Cell 3: Interpretability (attribution for each class in prediction set) + # ========================================================================= + # Initialize Chefer/AttentionGrad interpreter + chefer = CheferRelevance(model) + n_classes = len(predset_class_indices) + + # Compute attribution for each class in the prediction set + overlays = [] + for class_idx in predset_class_indices: + attr_map = chefer.attribute(class_index=class_idx, **batch)["image"] + _, _, overlay = visualize_image_attr( + image=batch["image"][0], + attribution=attr_map[0, 0], + ) + overlays.append((class_idx, overlay)) + + # Create figure: ground truth + attribution for each class + figsize = (5 * (n_classes + 1), 5) + fig, axes = plt.subplots(1, n_classes + 1, figsize=figsize) + + # Ground truth image + img, _, _ = visualize_image_attr( + image=batch["image"][0], + attribution=torch.zeros_like(batch["image"][0, 0]), + ) + axes[0].imshow(img, cmap='gray') + axes[0].set_title(f"Ground Truth: {true_name}", fontsize=12) + axes[0].axis('off') + + # Plot attributions + for i, (class_idx, overlay) in enumerate(overlays): + axes[i + 1].imshow(overlay) + prob = pred_prob[class_idx].item() + class_name = id2label[class_idx] + axes[i + 1].set_title(f"{class_name} ({prob:.1%})", fontsize=12) + axes[i + 1].axis('off') + + plt.suptitle("Uncertain Prediction: Multiple Classes", fontsize=14, y=1.02) + plt.tight_layout() + plt.savefig("covid19_cxr_interpretability.png", dpi=150) + print("Saved visualization to covid19_cxr_interpretability.png") diff --git a/examples/drug_recommendation_eICU_transformer.py b/examples/drug_recommendation/drug_recommendation_eICU_transformer.py similarity index 100% rename from examples/drug_recommendation_eICU_transformer.py rename to examples/drug_recommendation/drug_recommendation_eICU_transformer.py diff --git a/examples/drug_recommendation_mimic3_gamenet.py b/examples/drug_recommendation/drug_recommendation_mimic3_gamenet.py similarity index 100% rename from examples/drug_recommendation_mimic3_gamenet.py rename to examples/drug_recommendation/drug_recommendation_mimic3_gamenet.py diff --git a/examples/drug_recommendation_mimic3_micron.ipynb b/examples/drug_recommendation/drug_recommendation_mimic3_micron.ipynb similarity index 100% rename from examples/drug_recommendation_mimic3_micron.ipynb rename to examples/drug_recommendation/drug_recommendation_mimic3_micron.ipynb diff --git a/examples/drug_recommendation_mimic3_micron.py b/examples/drug_recommendation/drug_recommendation_mimic3_micron.py similarity index 100% rename from examples/drug_recommendation_mimic3_micron.py rename to examples/drug_recommendation/drug_recommendation_mimic3_micron.py diff --git a/examples/drug_recommendation_mimic3_molerec.py b/examples/drug_recommendation/drug_recommendation_mimic3_molerec.py similarity index 100% rename from examples/drug_recommendation_mimic3_molerec.py rename to examples/drug_recommendation/drug_recommendation_mimic3_molerec.py diff --git a/examples/drug_recommendation_mimic3_safedrug.py b/examples/drug_recommendation/drug_recommendation_mimic3_safedrug.py similarity index 100% rename from examples/drug_recommendation_mimic3_safedrug.py rename to examples/drug_recommendation/drug_recommendation_mimic3_safedrug.py diff --git a/examples/drug_recommendation_mimic3_transformer.py b/examples/drug_recommendation/drug_recommendation_mimic3_transformer.py similarity index 100% rename from examples/drug_recommendation_mimic3_transformer.py rename to examples/drug_recommendation/drug_recommendation_mimic3_transformer.py diff --git a/examples/drug_recommendation_mimic4_gamenet.py b/examples/drug_recommendation/drug_recommendation_mimic4_gamenet.py similarity index 100% rename from examples/drug_recommendation_mimic4_gamenet.py rename to examples/drug_recommendation/drug_recommendation_mimic4_gamenet.py diff --git a/examples/drug_recommendation_mimic4_retain.py b/examples/drug_recommendation/drug_recommendation_mimic4_retain.py similarity index 100% rename from examples/drug_recommendation_mimic4_retain.py rename to examples/drug_recommendation/drug_recommendation_mimic4_retain.py diff --git a/examples/EEG_events_SparcNet.py b/examples/eeg/seizure_detection/EEG_events_SparcNet.py similarity index 100% rename from examples/EEG_events_SparcNet.py rename to examples/eeg/seizure_detection/EEG_events_SparcNet.py diff --git a/examples/EEG_isAbnormal_SparcNet.py b/examples/eeg/seizure_detection/EEG_isAbnormal_SparcNet.py similarity index 100% rename from examples/EEG_isAbnormal_SparcNet.py rename to examples/eeg/seizure_detection/EEG_isAbnormal_SparcNet.py diff --git a/examples/contrawr_sleepedf.ipynb b/examples/eeg/sleep_staging/contrawr_sleepedf.ipynb similarity index 100% rename from examples/contrawr_sleepedf.ipynb rename to examples/eeg/sleep_staging/contrawr_sleepedf.ipynb diff --git a/examples/sleep_staging_ISRUC_SparcNet.py b/examples/eeg/sleep_staging/sleep_staging_ISRUC_SparcNet.py similarity index 100% rename from examples/sleep_staging_ISRUC_SparcNet.py rename to examples/eeg/sleep_staging/sleep_staging_ISRUC_SparcNet.py diff --git a/examples/sleep_staging_shhs_contrawr.py b/examples/eeg/sleep_staging/sleep_staging_shhs_contrawr.py similarity index 100% rename from examples/sleep_staging_shhs_contrawr.py rename to examples/eeg/sleep_staging/sleep_staging_shhs_contrawr.py diff --git a/examples/sleep_staging_sleepEDF_contrawr.py b/examples/eeg/sleep_staging/sleep_staging_sleepEDF_contrawr.py similarity index 100% rename from examples/sleep_staging_sleepEDF_contrawr.py rename to examples/eeg/sleep_staging/sleep_staging_sleepEDF_contrawr.py diff --git a/examples/deeplift_stagenet_mimic4.py b/examples/interpretability/deeplift_stagenet_mimic4.py similarity index 100% rename from examples/deeplift_stagenet_mimic4.py rename to examples/interpretability/deeplift_stagenet_mimic4.py diff --git a/examples/gim_stagenet_mimic4.py b/examples/interpretability/gim_stagenet_mimic4.py similarity index 100% rename from examples/gim_stagenet_mimic4.py rename to examples/interpretability/gim_stagenet_mimic4.py diff --git a/examples/gim_transformer_mimic4.py b/examples/interpretability/gim_transformer_mimic4.py similarity index 100% rename from examples/gim_transformer_mimic4.py rename to examples/interpretability/gim_transformer_mimic4.py diff --git a/examples/integrated_gradients_mortality_mimic4_stagenet.py b/examples/interpretability/integrated_gradients_mortality_mimic4_stagenet.py similarity index 100% rename from examples/integrated_gradients_mortality_mimic4_stagenet.py rename to examples/interpretability/integrated_gradients_mortality_mimic4_stagenet.py diff --git a/examples/interpret_demo.ipynb b/examples/interpretability/interpret_demo.ipynb similarity index 100% rename from examples/interpret_demo.ipynb rename to examples/interpretability/interpret_demo.ipynb diff --git a/examples/interpretability_metrics.py b/examples/interpretability/interpretability_metrics.py similarity index 100% rename from examples/interpretability_metrics.py rename to examples/interpretability/interpretability_metrics.py diff --git a/examples/shap_stagenet_mimic4.ipynb b/examples/interpretability/shap_stagenet_mimic4.ipynb similarity index 100% rename from examples/shap_stagenet_mimic4.ipynb rename to examples/interpretability/shap_stagenet_mimic4.ipynb diff --git a/examples/shap_stagenet_mimic4.py b/examples/interpretability/shap_stagenet_mimic4.py similarity index 100% rename from examples/shap_stagenet_mimic4.py rename to examples/interpretability/shap_stagenet_mimic4.py diff --git a/examples/length_of_stay_mimic3_rnn.py b/examples/length_of_stay/length_of_stay_mimic3_rnn.py similarity index 100% rename from examples/length_of_stay_mimic3_rnn.py rename to examples/length_of_stay/length_of_stay_mimic3_rnn.py diff --git a/examples/length_of_stay_mimic4_rnn.py b/examples/length_of_stay/length_of_stay_mimic4_rnn.py similarity index 100% rename from examples/length_of_stay_mimic4_rnn.py rename to examples/length_of_stay/length_of_stay_mimic4_rnn.py diff --git a/examples/mortality_mimic3_adacare.ipynb b/examples/mortality_prediction/mortality_mimic3_adacare.ipynb similarity index 100% rename from examples/mortality_mimic3_adacare.ipynb rename to examples/mortality_prediction/mortality_mimic3_adacare.ipynb diff --git a/examples/mortality_mimic3_agent.py b/examples/mortality_prediction/mortality_mimic3_agent.py similarity index 100% rename from examples/mortality_mimic3_agent.py rename to examples/mortality_prediction/mortality_mimic3_agent.py diff --git a/examples/mortality_mimic3_concare.py b/examples/mortality_prediction/mortality_mimic3_concare.py similarity index 100% rename from examples/mortality_mimic3_concare.py rename to examples/mortality_prediction/mortality_mimic3_concare.py diff --git a/examples/mortality_mimic3_grasp.py b/examples/mortality_prediction/mortality_mimic3_grasp.py similarity index 100% rename from examples/mortality_mimic3_grasp.py rename to examples/mortality_prediction/mortality_mimic3_grasp.py diff --git a/examples/mortality_mimic3_rnn.py b/examples/mortality_prediction/mortality_mimic3_rnn.py similarity index 100% rename from examples/mortality_mimic3_rnn.py rename to examples/mortality_prediction/mortality_mimic3_rnn.py diff --git a/examples/mortality_mimic3_stagenet.py b/examples/mortality_prediction/mortality_mimic3_stagenet.py similarity index 100% rename from examples/mortality_mimic3_stagenet.py rename to examples/mortality_prediction/mortality_mimic3_stagenet.py diff --git a/examples/mortality_mimic3_tcn.py b/examples/mortality_prediction/mortality_mimic3_tcn.py similarity index 100% rename from examples/mortality_mimic3_tcn.py rename to examples/mortality_prediction/mortality_mimic3_tcn.py diff --git a/examples/mortality_mimic4_stagenet_v2.py b/examples/mortality_prediction/mortality_mimic4_stagenet_v2.py similarity index 100% rename from examples/mortality_mimic4_stagenet_v2.py rename to examples/mortality_prediction/mortality_mimic4_stagenet_v2.py diff --git a/examples/mortality_prediction/multimodal_mimic4_demo.py b/examples/mortality_prediction/multimodal_mimic4_demo.py new file mode 100644 index 000000000..9a2d2e3dd --- /dev/null +++ b/examples/mortality_prediction/multimodal_mimic4_demo.py @@ -0,0 +1,589 @@ +""" +PyHealth Multimodal MIMIC-IV Demo: Benchmark + Showcase + +This script demonstrates PyHealth's capability to load and process +multimodal medical data from MIMIC-IV, including: +- EHR codes (ICD-10 diagnoses and procedures) +- Clinical notes (discharge summaries, radiology reports) +- Lab events (time-series lab values) +- Chest X-ray images + +It also benchmarks memory usage and processing time with 16 workers. + +Usage: + python multimodal_mimic4_demo.py + python multimodal_mimic4_demo.py --ehr-root /path/to/mimic-iv + python multimodal_mimic4_demo.py --dev +""" + +from __future__ import annotations + +import argparse +import os +import shutil +import textwrap +import threading +import time +from pathlib import Path +from typing import Dict, List, Any, Optional + +import numpy as np +import psutil + + +# ============================================================================= +# Utility Functions +# ============================================================================= + +def format_size(size_bytes: int) -> str: + """Format bytes to human-readable string.""" + for unit in ["B", "KB", "MB", "GB", "TB"]: + if size_bytes < 1024.0: + return f"{size_bytes:.2f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.2f} PB" + + +def get_directory_size(path: str | Path) -> int: + """Calculate total size of a directory.""" + total = 0 + p = Path(path) + if not p.exists(): + return 0 + try: + for entry in p.rglob("*"): + if entry.is_file(): + try: + total += entry.stat().st_size + except FileNotFoundError: + pass + except Exception as e: + print(f"Error calculating size for {p}: {e}") + return total + + +def ensure_empty_dir(path: str | Path) -> None: + """Ensure directory exists and is empty.""" + p = Path(path) + if p.exists(): + shutil.rmtree(p) + p.mkdir(parents=True, exist_ok=True) + + +def remove_dir(path: str | Path, retries: int = 3, delay: float = 1.0) -> None: + """Remove a directory with retry logic.""" + p = Path(path) + if not p.exists(): + return + for attempt in range(retries): + try: + shutil.rmtree(p) + return + except OSError as e: + if attempt < retries - 1: + time.sleep(delay) + else: + print(f"Warning: Failed to delete {p}: {e}") + + +def truncate_text(text: str, max_words: int = 100) -> str: + """Truncate text to max_words with '...' suffix.""" + if not text: + return "[No text available]" + words = text.split() + if len(words) <= max_words: + return text + return " ".join(words[:max_words]) + "..." + + +def print_section(title: str, width: int = 80) -> None: + """Print a section header.""" + print("\n" + "=" * width) + print(f" {title}") + print("=" * width) + + +def print_subsection(title: str) -> None: + """Print a subsection header.""" + print(f"\n--- {title} ---") + + +# ============================================================================= +# Memory Tracking +# ============================================================================= + +class PeakMemoryTracker: + """Tracks peak RSS for current process + children.""" + + def __init__(self, poll_interval_s: float = 0.1) -> None: + self._proc = psutil.Process(os.getpid()) + self._poll_interval_s = poll_interval_s + self._stop = threading.Event() + self._lock = threading.Lock() + self._peak = 0 + self._thread = threading.Thread(target=self._run, daemon=True) + + def start(self) -> None: + self._thread.start() + + def reset(self) -> None: + with self._lock: + self._peak = 0 + + def stop(self) -> None: + self._stop.set() + + def peak_bytes(self) -> int: + with self._lock: + return self._peak + + def _total_rss_bytes(self) -> int: + total = 0 + try: + total += self._proc.memory_info().rss + for child in self._proc.children(recursive=True): + try: + total += child.memory_info().rss + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + return total + + def _run(self) -> None: + while not self._stop.is_set(): + rss = self._total_rss_bytes() + with self._lock: + if rss > self._peak: + self._peak = rss + time.sleep(self._poll_interval_s) + + +# ============================================================================= +# MedCode Lookup +# ============================================================================= + +def lookup_icd_codes( + codes: List[str], + code_system: str = "ICD10CM", + max_display: int = 10, +) -> List[Dict[str, str]]: + """Look up ICD code names using PyHealth medcode. + + Args: + codes: List of ICD codes + code_system: Either "ICD10CM" or "ICD9CM" + max_display: Maximum number of codes to look up + + Returns: + List of dicts with code and name + """ + try: + from pyhealth.medcode import InnerMap + + icd_map = InnerMap.load(code_system) + + results = [] + for code in codes[:max_display]: + try: + name = icd_map.lookup(code) + results.append({"code": code, "name": name}) + except (KeyError, Exception): + try: + clean_code = code.replace(".", "") + name = icd_map.lookup(clean_code) + results.append({"code": code, "name": name}) + except (KeyError, Exception): + results.append({"code": code, "name": "[Unknown code]"}) + + return results + + except ImportError: + return [{"code": c, "name": "[medcode not available]"} + for c in codes[:max_display]] + except Exception as e: + print(f"Warning: Could not load {code_system}: {e}") + return [{"code": c, "name": f"[{code_system} unavailable]"} + for c in codes[:max_display]] + + +# ============================================================================= +# Display Functions +# ============================================================================= + +def display_lab_stats(labs_data: tuple) -> None: + """Display lab event statistics.""" + lab_times, lab_values = labs_data + + lab_categories = [ + "Sodium", "Potassium", "Chloride", "Bicarbonate", "Glucose", + "Calcium", "Magnesium", "Anion Gap", "Osmolality", "Phosphate" + ] + + print(f" Total lab measurements: {len(lab_times)}") + print(f" Time span: {min(lab_times):.1f}h to {max(lab_times):.1f}h") + + print("\n Lab Category Statistics:") + print(" " + "-" * 56) + print(f" {'Category':<15} {'Count':>8} {'Mean':>10} {'Min':>10} {'Max':>10}") + print(" " + "-" * 56) + + for idx, category in enumerate(lab_categories): + values = [v[idx] for v in lab_values if v[idx] is not None] + if values: + arr = np.array(values) + print( + f" {category:<15} {len(values):>8} " + f"{np.mean(arr):>10.1f} {np.min(arr):>10.1f} {np.max(arr):>10.1f}" + ) + else: + print(f" {category:<15} {'N/A':>8} " + f"{'N/A':>10} {'N/A':>10} {'N/A':>10}") + + +def display_image(image_path: str, output_path: Optional[str] = None) -> bool: + """Display chest X-ray image if available.""" + try: + import matplotlib.pyplot as plt + from PIL import Image + + if not os.path.exists(image_path): + print(f" [Image not found: {image_path}]") + return False + + img = Image.open(image_path) + + fig, ax = plt.subplots(1, 1, figsize=(8, 8)) + ax.imshow(img, cmap="gray" if img.mode == "L" else None) + ax.set_title("Chest X-Ray", fontsize=14) + ax.axis("off") + + if output_path: + plt.savefig(output_path, dpi=150, bbox_inches="tight") + print(f" ✓ Image saved to: {output_path}") + else: + plt.show() + + plt.close() + return True + + except ImportError: + print(" [matplotlib/PIL not available for image display]") + return False + except Exception as e: + print(f" [Error displaying image: {e}]") + return False + + +def showcase_sample(sample: Dict[str, Any], cxr_root: Optional[str] = None, + save_image: Optional[str] = None) -> None: + """Display a multimodal sample with all its components.""" + + # ========================================================================= + # EHR Codes with MedCode Lookup + # ========================================================================= + print_section("EHR Codes (ICD-10)") + + # Diagnoses + conditions = sample.get("conditions", []) + print_subsection(f"Diagnosis Codes ({len(conditions)} total)") + + if conditions: + diagnosis_info = lookup_icd_codes(conditions, "ICD10CM", max_display=10) + if all(info["name"] == "[Unknown code]" for info in diagnosis_info): + diagnosis_info = lookup_icd_codes(conditions, "ICD9CM", max_display=10) + + for info in diagnosis_info: + print(f" • {info['code']}: {info['name']}") + + if len(conditions) > 10: + print(f" ... and {len(conditions) - 10} more codes") + + # Procedures + procedures = sample.get("procedures", []) + print_subsection(f"Procedure Codes ({len(procedures)} total)") + + if procedures: + procedure_info = lookup_icd_codes(procedures, "ICD10PROC", max_display=5) + if all(info["name"] in ["[Unknown code]", "[ICD10PROC unavailable]"] + for info in procedure_info): + procedure_info = lookup_icd_codes(procedures, "ICD9PROC", max_display=5) + + for info in procedure_info: + print(f" • {info['code']}: {info['name']}") + + if len(procedures) > 5: + print(f" ... and {len(procedures) - 5} more codes") + + # Drugs + drugs = sample.get("drugs", []) + print_subsection(f"Drug Prescriptions ({len(drugs)} total)") + if drugs: + for drug in drugs[:5]: + print(f" • {drug}") + if len(drugs) > 5: + print(f" ... and {len(drugs) - 5} more drugs") + + # ========================================================================= + # Clinical Notes + # ========================================================================= + print_section("Clinical Notes") + + # Radiology report (truncated to <100 words) + print_subsection("Radiology Report Summary (<100 words)") + radiology_text = sample.get("radiology", "") + truncated_radiology = truncate_text(radiology_text, max_words=100) + wrapped = textwrap.fill(truncated_radiology, width=70, + initial_indent=" ", subsequent_indent=" ") + print(wrapped) + + # Discharge summary (truncated) + print_subsection("Discharge Summary Excerpt (<100 words)") + discharge_text = sample.get("discharge", "") + truncated_discharge = truncate_text(discharge_text, max_words=100) + wrapped = textwrap.fill(truncated_discharge, width=70, + initial_indent=" ", subsequent_indent=" ") + print(wrapped) + + # ========================================================================= + # Lab Events + # ========================================================================= + print_section("Lab Events (Time-Series)") + + labs_data = sample.get("labs") + if labs_data and isinstance(labs_data, tuple) and len(labs_data) == 2: + display_lab_stats(labs_data) + else: + print(" [No lab data available]") + + # ========================================================================= + # X-Ray Data + # ========================================================================= + print_section("Chest X-Ray Data") + + # X-ray NegBio findings + xray_findings = sample.get("xrays_negbio", []) + print_subsection(f"NegBio Findings ({len(xray_findings)} detected)") + if xray_findings: + unique_findings = list(set(xray_findings)) + for finding in unique_findings[:10]: + count = xray_findings.count(finding) + print(f" • {finding.title()} (×{count})") + else: + print(" [No X-ray findings detected]") + + # X-ray image + image_path = sample.get("image") + print_subsection("X-Ray Image") + if image_path: + print(f" Image path: {image_path}") + full_path = image_path + if cxr_root and not os.path.isabs(image_path): + full_path = os.path.join(cxr_root, image_path) + + if os.path.exists(full_path): + display_image(full_path, save_image) + else: + print(" [Image file not accessible - set --cxr-root to view]") + else: + print(" [No X-ray image available for this sample]") + + # ========================================================================= + # Summary + # ========================================================================= + print_section("Sample Modality Summary") + + print(f"\n ✓ EHR Codes: {len(conditions)} diagnoses, " + f"{len(procedures)} procedures, {len(drugs)} drugs") + print(f" ✓ Clinical Notes: Discharge ({len(discharge_text.split())} words), " + f"Radiology ({len(radiology_text.split())} words)") + if labs_data: + print(f" ✓ Lab Events: {len(labs_data[0])} measurements, 10 dimensions") + print(f" ✓ X-Ray: {len(xray_findings)} NegBio findings, " + f"Image: {'Available' if image_path else 'N/A'}") + + +# ============================================================================= +# Main +# ============================================================================= + +def main(): + parser = argparse.ArgumentParser( + description="PyHealth Multimodal MIMIC-IV Demo: Benchmark + Showcase" + ) + parser.add_argument( + "--dev", + action="store_true", + help="Use dev mode (smaller subset)", + ) + parser.add_argument( + "--ehr-root", + type=str, + default="/srv/local/data/physionet.org/files/mimiciv/2.2/", + help="Path to MIMIC-IV EHR root", + ) + parser.add_argument( + "--cxr-root", + type=str, + default=None, + help="Path to MIMIC-CXR root for images", + ) + parser.add_argument( + "--cache-dir", + type=str, + default="/tmp/pyhealth_multimodal_demo/", + help="Cache directory for processed data", + ) + parser.add_argument( + "--save-image", + type=str, + default=None, + help="Path to save X-ray visualization", + ) + parser.add_argument( + "--sample-idx", + type=int, + default=0, + help="Index of sample to display (default: 0)", + ) + parser.add_argument( + "--skip-benchmark", + action="store_true", + help="Skip benchmarking, only show sample", + ) + args = parser.parse_args() + + num_workers = 16 + + print_section("PyHealth: Multimodal Medical Data Loading Demo") + print("\nThis demo showcases PyHealth's ability to load and process") + print("multimodal medical data from MIMIC-IV dataset.") + print(f"\nConfiguration:") + print(f" EHR root: {args.ehr_root}") + print(f" CXR root: {args.cxr_root}") + print(f" Dev mode: {args.dev}") + print(f" Workers: {num_workers}") + + # ========================================================================= + # Benchmark: Load Dataset + # ========================================================================= + print_section("Step 1: Loading MIMIC-IV Multimodal Dataset") + + from pyhealth.datasets import MIMIC4Dataset + from pyhealth.tasks import MultimodalMortalityPredictionMIMIC4 + + cache_root = Path(args.cache_dir) + base_cache_dir = cache_root / ("base_dataset_dev" if args.dev else "base_dataset") + task_cache_dir = cache_root / "task_samples" + + # Initialize memory tracker + tracker = PeakMemoryTracker(poll_interval_s=0.1) + tracker.start() + tracker.reset() + + run_start = time.time() + + # Load base dataset + print("\n[1/2] Loading base dataset...") + dataset_start = time.time() + + base_dataset = MIMIC4Dataset( + ehr_root=args.ehr_root, + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "prescriptions", + "labevents", + "discharge", + "radiology", + ], + dev=args.dev, + cache_dir=str(base_cache_dir), + ) + + dataset_load_s = time.time() - dataset_start + base_cache_bytes = get_directory_size(base_cache_dir) + print(f" ✓ Base dataset loaded in {dataset_load_s:.2f}s") + print(f" ✓ Base cache size: {format_size(base_cache_bytes)}") + + # Apply multimodal task + print("\n[2/2] Applying MultimodalMortalityPredictionMIMIC4 task...") + task_start = time.time() + + task = MultimodalMortalityPredictionMIMIC4(cxr_root=args.cxr_root) + sample_dataset = base_dataset.set_task( + task, + num_workers=num_workers, + cache_dir=str(task_cache_dir), + ) + + task_process_s = time.time() - task_start + total_s = time.time() - run_start + peak_rss_bytes = tracker.peak_bytes() + task_cache_bytes = get_directory_size(task_cache_dir) + num_samples = len(sample_dataset) + + print(f" ✓ Task completed in {task_process_s:.2f}s") + print(f" ✓ Task cache size: {format_size(task_cache_bytes)}") + print(f" ✓ Total samples: {num_samples}") + + # ========================================================================= + # Benchmark Results + # ========================================================================= + if not args.skip_benchmark: + print_section("Benchmark Results") + + print(f"\n Dataset load time: {dataset_load_s:>10.2f}s") + print(f" Task processing time: {task_process_s:>10.2f}s") + print(f" Total time: {total_s:>10.2f}s") + print(f" Peak memory (RSS): {format_size(peak_rss_bytes):>10}") + print(f" Base cache size: {format_size(base_cache_bytes):>10}") + print(f" Task cache size: {format_size(task_cache_bytes):>10}") + print(f" Number of samples: {num_samples:>10}") + + print("\n Dataset Schema:") + print(f" Input: {list(sample_dataset.input_schema.keys())}") + print(f" Output: {list(sample_dataset.output_schema.keys())}") + + tracker.stop() + + # ========================================================================= + # Showcase: Display First Sample + # ========================================================================= + if num_samples == 0: + print("\n⚠ No samples generated. This may be because:") + print(" - The dataset is too small (try without --dev)") + print(" - Missing required modalities (all must be present)") + return + + print_section("Step 2: Showcasing First Multimodal Sample") + + sample_idx = min(args.sample_idx, num_samples - 1) + sample = sample_dataset.samples[sample_idx] + + print(f"\n Sample index: {sample_idx}") + print(f" Patient ID: {sample.get('patient_id', 'N/A')}") + print(f" Visit ID: {sample.get('visit_id', 'N/A')}") + print(f" Mortality label: {sample.get('mortality', 'N/A')}") + + showcase_sample(sample, cxr_root=args.cxr_root, save_image=args.save_image) + + # ========================================================================= + # Final Summary + # ========================================================================= + print_section("Demo Complete") + print("\n PyHealth: Your one-stop solution for multimodal medical data!") + print("\n Key features demonstrated:") + print(" • EHR code loading with medcode lookup") + print(" • Clinical note extraction (discharge, radiology)") + print(" • Time-series lab event processing") + print(" • Chest X-ray image integration") + print(" • Memory-efficient parallel processing (16 workers)") + print("=" * 80) + + +if __name__ == "__main__": + main() + diff --git a/examples/timeseries_mimic4.ipynb b/examples/mortality_prediction/timeseries_mimic4.ipynb similarity index 100% rename from examples/timeseries_mimic4.ipynb rename to examples/mortality_prediction/timeseries_mimic4.ipynb diff --git a/examples/timeseries_mimic4.py b/examples/mortality_prediction/timeseries_mimic4.py similarity index 100% rename from examples/timeseries_mimic4.py rename to examples/mortality_prediction/timeseries_mimic4.py diff --git a/examples/readmission_mimic3_fairness.py b/examples/readmission/readmission_mimic3_fairness.py similarity index 100% rename from examples/readmission_mimic3_fairness.py rename to examples/readmission/readmission_mimic3_fairness.py diff --git a/examples/readmission_mimic3_rnn.py b/examples/readmission/readmission_mimic3_rnn.py similarity index 100% rename from examples/readmission_mimic3_rnn.py rename to examples/readmission/readmission_mimic3_rnn.py diff --git a/pyhealth/calib/predictionset/label.py b/pyhealth/calib/predictionset/label.py index c88eaeefd..435a0e9d1 100644 --- a/pyhealth/calib/predictionset/label.py +++ b/pyhealth/calib/predictionset/label.py @@ -130,7 +130,6 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: pred["y_predset"] = pred["y_prob"] > self.t return pred - if __name__ == "__main__": from pyhealth.datasets import ISRUCDataset, split_by_patient, get_dataloader @@ -168,3 +167,4 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: y_predset=extra_output["y_predset"], ) ) + diff --git a/pyhealth/datasets/covid19_cxr.py b/pyhealth/datasets/covid19_cxr.py index 42f5671b5..17704b22d 100644 --- a/pyhealth/datasets/covid19_cxr.py +++ b/pyhealth/datasets/covid19_cxr.py @@ -67,6 +67,9 @@ class COVID19CXRDataset(BaseDataset): dataset_name: Optional name of the dataset. Defaults to "covid19_cxr". config_path: Optional path to the configuration file. If not provided, uses the default config in the configs directory. + cache_dir: Optional directory for caching processed data. + num_workers: Number of parallel workers for data processing. Defaults to 1. + dev: If True, only loads a small subset of data for development/testing. Attributes: root: Root directory of the raw data. @@ -88,6 +91,9 @@ def __init__( root: str, dataset_name: Optional[str] = None, config_path: Optional[str] = None, + cache_dir: Optional[str] = None, + num_workers: int = 1, + dev: bool = False, ) -> None: if config_path is None: logger.info("No config path provided, using default config") @@ -100,6 +106,9 @@ def __init__( tables=default_tables, dataset_name=dataset_name or "covid19_cxr", config_path=config_path, + cache_dir=cache_dir, + num_workers=num_workers, + dev=dev, ) return @@ -149,33 +158,6 @@ def prepare_metadata(self, root: str) -> None: df.to_csv(os.path.join(root, "covid19_cxr-metadata-pyhealth.csv"), index=False) return - def set_task( - self, - task: BaseTask | None = None, - num_workers: int = 1, - cache_dir: str | None = None, - cache_format: str = "parquet", - input_processors: Dict[str, FeatureProcessor] | None = None, - output_processors: Dict[str, FeatureProcessor] | None = None, - ) -> SampleDataset: - if input_processors is None or "image" not in input_processors: - image_processor = ImageProcessor( - image_size=299, # The image size for COVID-19 CXR dataset - mode="L", # Grayscale images - ) - if input_processors is None: - input_processors = {} - input_processors["image"] = image_processor - - return super().set_task( - task, - num_workers, - cache_dir, - cache_format, - input_processors, - output_processors, - ) - @property def default_task(self) -> COVID19CXRClassification: """Returns the default task for this dataset. diff --git a/pyhealth/interpret/__init__.py b/pyhealth/interpret/__init__.py index 29a6e5666..06828b47e 100644 --- a/pyhealth/interpret/__init__.py +++ b/pyhealth/interpret/__init__.py @@ -1 +1,9 @@ -from pyhealth.interpret import methods \ No newline at end of file +from pyhealth.interpret import methods +from pyhealth.interpret import utils +from pyhealth.interpret.utils import ( + # Core visualization functions + show_cam_on_image, + interpolate_attribution_map, + normalize_attribution, + visualize_image_attr, +) diff --git a/pyhealth/interpret/methods/base_interpreter.py b/pyhealth/interpret/methods/base_interpreter.py index 9d7fce065..18d0ee088 100644 --- a/pyhealth/interpret/methods/base_interpreter.py +++ b/pyhealth/interpret/methods/base_interpreter.py @@ -3,10 +3,14 @@ This module defines the interface that all interpretability/attribution methods must implement. It ensures consistency across different methods and makes it easy to swap between different attribution techniques. + +The key API contract is that ``attribute()`` returns a dictionary keyed by +the model's feature keys (as defined by the task schema), making it easy to +map attributions back to specific input modalities. """ from abc import ABC, abstractmethod -from typing import Dict, Optional +from typing import Dict import torch import torch.nn as nn @@ -22,53 +26,78 @@ class BaseInterpreter(ABC): features, explaining which features contributed most to the model's prediction. + **API Contract:** + All interpretability methods should: + 1. Take a trained model in their constructor - 2. Implement the `attribute()` method - 3. Return attributions as a dictionary matching input shapes - 4. Work with any PyHealth model (or at least clearly document - compatibility requirements) - - The `attribute()` method is the core interface that: - - Takes model inputs (as would be passed to model.forward()) - - Computes attribution scores for each input feature - - Returns a dictionary mapping feature keys to attribution tensors - - Attribution tensors have the same shape as input tensors - - Higher absolute values indicate more important features + 2. Implement the ``attribute()`` method + 3. Return attributions as a dictionary **keyed by the model's feature keys** + (as defined by the task's ``input_schema``) + 4. Work with any PyHealth model (or clearly document compatibility) + + The ``attribute()`` method returns a dictionary that mirrors the task schema: + + - For EHR tasks with ``input_schema={"conditions": "sequence", "procedures": "sequence"}``, + returns ``{"conditions": attr_tensor, "procedures": attr_tensor}`` + - For image tasks with ``input_schema={"image": "image"}``, + returns ``{"image": attr_tensor}`` + + This design ensures attributions are dynamically tied to dataset feature keys, + making the API consistent across CXR datasets, EHR datasets, or any custom + task schema. Subclasses should implement: - - __init__(self, model, **kwargs): Initialize with model and + - ``__init__(self, model, **kwargs)``: Initialize with model and method-specific parameters - - attribute(self, **data): Compute attributions for given inputs + - ``attribute(self, **data)``: Compute attributions for given inputs Args: - model (BaseModel or nn.Module): A trained PyHealth model to - interpret. Should be in evaluation mode during attribution - computation. - - Examples: - >>> # Example of implementing a new attribution method - >>> class MyAttributionMethod(BaseInterpreter): - ... def __init__(self, model, some_param=1.0): - ... super().__init__(model) - ... self.some_param = some_param - ... - ... def attribute(self, **data): - ... # Implement attribution computation - ... attributions = {} - ... for key in self.model.feature_keys: - ... # Compute importance scores - ... attributions[key] = compute_scores(data[key]) - ... return attributions + model (BaseModel or nn.Module): A trained PyHealth model to interpret. + The model must have ``feature_keys`` (list of input feature names) + derived from the dataset's task schema. Should be in evaluation + mode during attribution computation. + + Example: + >>> # Example 1: EHR data with multiple feature keys + >>> from pyhealth.datasets import create_sample_dataset, get_dataloader + >>> from pyhealth.models import Transformer + >>> + >>> samples = [ + ... { + ... "patient_id": "p0", + ... "visit_id": "v0", + ... "conditions": ["A05B", "A05C", "A06A"], + ... "procedures": ["P01", "P02"], + ... "label": 1, + ... }, + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={"conditions": "sequence", "procedures": "sequence"}, + ... output_schema={"label": "binary"}, + ... dataset_name="ehr_example", + ... ) + >>> model = Transformer(dataset=dataset) + >>> # model.feature_keys == ["conditions", "procedures"] + >>> + >>> interpreter = CheferRelevance(model) + >>> batch = next(iter(get_dataloader(dataset, batch_size=1))) + >>> attributions = interpreter.attribute(**batch) + >>> # Returns: {"conditions": tensor(...), "procedures": tensor(...)} + >>> print(attributions.keys()) # dict_keys(['conditions', 'procedures']) + >>> + >>> # Example 2: Image data (CXR) with single feature key + >>> # Given task schema: input_schema={"image": "image"} + >>> # model.feature_keys == ["image"] (or model.feature_key == "image") >>> - >>> # Using the attribution method - >>> model = StageNet(dataset=dataset) - >>> interpreter = MyAttributionMethod(model) - >>> batch = next(iter(dataloader)) + >>> interpreter = CheferRelevance(vit_model) >>> attributions = interpreter.attribute(**batch) + >>> # Returns: {"image": tensor(...)} - keyed by the task's feature key + >>> print(attributions["image"].shape) # [batch, 1, H, W] """ - def __init__(self, model: nn.Module): + def __init__(self, model: BaseModel): """Initialize the base interpreter. Args: @@ -95,28 +124,46 @@ def attribute( indicating which features were most important for the model's prediction. + **Important:** The returned dictionary must be keyed by the model's + feature keys (from ``model.feature_keys`` or ``model.feature_key``), + which are derived from the task's ``input_schema``. This ensures + attributions map directly to the input modalities defined in the task. + Args: **data: Input data dictionary from a dataloader batch. Should contain at minimum: - - Feature keys (e.g., 'conditions', 'procedures', 'icd_codes'): - Input tensors or sequences for each modality. The exact - keys depend on the model's feature_keys. - - 'label' (optional): Ground truth labels, may be needed by - some methods but not used in attribution computation. - - Additional method-specific parameters can be passed here - (e.g., target_class_idx, baseline, steps). + + - Feature keys (e.g., ``"conditions"``, ``"procedures"``, + ``"image"``): Input tensors for each modality as defined + by the task's ``input_schema``. + - Label key (optional): Ground truth labels, may be needed + by some methods for loss computation. + - ``class_index`` (optional): Target class for attribution. + If not provided, uses the predicted class. + - Additional method-specific parameters (e.g., ``baseline``, + ``steps``, ``interpolate``). The data dictionary should match what would be passed to - the model's forward() method. + the model's ``forward()`` method. Returns: - Dict[str, torch.Tensor]: Dictionary mapping each feature key to - its attribution tensor. Each attribution tensor: - - Has the same shape as the corresponding input tensor + Dict[str, torch.Tensor]: Dictionary mapping **each feature key** + to its attribution tensor. The keys must match the model's + feature keys from the task schema. + + For EHR tasks:: + + {"conditions": tensor, "procedures": tensor, ...} + + For image tasks:: + + {"image": tensor} # Shape: [batch, 1, H, W] for spatial + + Each attribution tensor: + - Contains real-valued importance scores - Higher absolute values = more important features - - Can be positive (increases prediction) or negative - (decreases prediction) depending on the method + - Can be positive or negative depending on the method - Should be on the same device as the input Raises: @@ -125,51 +172,38 @@ def attribute( Note: **Attribution Properties:** - Different attribution methods may produce scores with different + Different attribution methods produce scores with different properties: 1. **Sign**: Some methods produce only positive scores (e.g., - attention weights), while others can produce both positive - and negative scores (e.g., Integrated Gradients). + attention weights), while others produce both positive and + negative scores (e.g., Integrated Gradients, DeepLift). - 2. **Magnitude**: Scores may be: - - Normalized to sum to 1 (probability-like) - - Unnormalized gradients or relevance scores - - Relative importance within each feature modality + 2. **Magnitude**: Scores may be normalized (sum to 1) or + unnormalized (raw gradients/relevance). - 3. **Interpretation**: Higher absolute values generally mean - more important, but the exact interpretation depends on the - method. + 3. **Shape**: For sequential data, shape matches input tokens. + For images, shape is typically ``[batch, 1, H, W]`` for + spatial attribution maps. **Common Patterns:** - - Gradient-based methods (IG, Saliency): Can be positive or - negative, represent contribution to output change - - Attention-based methods (Chefer): Usually positive, represent - relevance or importance - - Perturbation-based methods (LIME, SHAP): Can be positive or - negative, represent feature contribution + - Gradient-based (IG, Saliency): +/- scores, contribution to output + - Attention-based (Chefer): Usually positive, relevance/importance + - Perturbation-based (LIME, SHAP): +/- scores, feature contribution - Examples: - >>> # Basic usage - >>> interpreter = IntegratedGradients(model) + Example: + >>> # EHR model with multiple feature keys + >>> interpreter = DeepLift(model) >>> batch = next(iter(test_loader)) >>> attributions = interpreter.attribute(**batch) - >>> print(attributions.keys()) # Feature keys - >>> print(attributions['conditions'].shape) # Same as input + >>> print(attributions.keys()) # ['conditions', 'procedures'] >>> - >>> # With method-specific parameters - >>> attributions = interpreter.attribute( - ... **batch, - ... target_class_idx=1, # Attribute to specific class - ... steps=50 # Method-specific parameter - ... ) - >>> - >>> # Analyze most important features - >>> cond_attr = attributions['conditions'][0] # First sample - >>> top_features = torch.topk(torch.abs(cond_attr), k=5) - >>> print(f"Top 5 features: {top_features.indices}") - >>> print(f"Importance scores: {top_features.values}") + >>> # Image model (CXR) with single feature key + >>> interpreter = CheferRelevance(vit_model) + >>> attributions = interpreter.attribute(**batch) + >>> print(attributions.keys()) # ['image'] + >>> print(attributions['image'].shape) # [1, 1, 224, 224] """ pass diff --git a/pyhealth/interpret/methods/chefer.py b/pyhealth/interpret/methods/chefer.py index 4cba0bb83..5bef2fc45 100644 --- a/pyhealth/interpret/methods/chefer.py +++ b/pyhealth/interpret/methods/chefer.py @@ -1,310 +1,241 @@ +from typing import Dict + import torch import torch.nn.functional as F from pyhealth.models import Transformer +from pyhealth.models.base_model import BaseModel from .base_interpreter import BaseInterpreter +# Import TorchvisionModel conditionally to avoid circular imports +try: + from pyhealth.models import TorchvisionModel + HAS_TORCHVISION_MODEL = True +except ImportError: + HAS_TORCHVISION_MODEL = False + TorchvisionModel = None + def apply_self_attention_rules(R_ss, cam_ss): """Apply Chefer's self-attention rules for relevance propagation. - This function propagates relevance scores through an attention layer by - multiplying the current relevance matrix with the attention weights. - Args: - R_ss (torch.Tensor): Relevance matrix of shape ``[batch, seq_len, seq_len]`` - representing token-to-token relevance from previous layers. - cam_ss (torch.Tensor): Attention weight matrix of shape ``[batch, seq_len, seq_len]`` - representing the current layer's attention scores. + R_ss: Relevance matrix [batch, seq_len, seq_len]. + cam_ss: Attention weight matrix [batch, seq_len, seq_len]. Returns: - torch.Tensor: Updated relevance matrix of shape ``[batch, seq_len, seq_len]`` - after propagating through the attention layer. + Updated relevance matrix after propagating through attention layer. """ return torch.matmul(cam_ss, R_ss) def avg_heads(cam, grad): - """Average attention scores weighted by their gradients across multiple heads. - - This function computes gradient-weighted attention scores and averages them - across attention heads. The gradients indicate how much each attention weight - contributed to the final prediction, providing a measure of importance. + """Average attention scores weighted by gradients across heads. Args: - cam (torch.Tensor): Attention weights. Shape ``[batch, heads, seq_len, seq_len]`` - for multi-head attention or ``[batch, seq_len, seq_len]`` for single-head. - grad (torch.Tensor): Gradients of the loss with respect to attention weights. - Same shape as ``cam``. + cam: Attention weights [batch, heads, seq_len, seq_len] or [batch, seq_len, seq_len]. + grad: Gradients w.r.t. attention weights. Same shape as cam. Returns: - torch.Tensor: Gradient-weighted attention scores, averaged across heads. - Shape ``[batch, seq_len, seq_len]``. Negative values are clamped to zero. - - Note: - If input tensors have fewer than 4 dimensions (single-head case), no - averaging is performed and the element-wise product is returned directly. + Gradient-weighted attention [batch, seq_len, seq_len]. """ - # force shapes of cam and grad to be the same order - if ( - len(cam.size()) < 4 and len(grad.size()) < 4 - ): # check if no averaging needed. i.e single head + if len(cam.size()) < 4 and len(grad.size()) < 4: return (grad * cam).clamp(min=0) - cam = grad * cam # elementwise mult - cam = cam.clamp(min=0).mean(dim=1) # average across heads + cam = grad * cam + cam = cam.clamp(min=0).mean(dim=1) return cam.clone() class CheferRelevance(BaseInterpreter): - """Transformer Self Attention Token Relevance Computation using Chefer's Method. + """Chefer's gradient-weighted attention method for transformer interpretability. - This class computes the relevance of each token in the input sequence for a given - class prediction. The relevance is computed using Chefer's Self Attention Rules, - which provide interpretability for transformer models by propagating relevance - scores through attention layers. + This class implements the relevance propagation method from Chefer et al. for + explaining transformer model predictions. It computes relevance scores for each + input token (for text/EHR transformers) or patch (for Vision Transformers) by + combining attention weights with their gradients. - The method is based on the paper: - Generic Attention-model Explainability for Interpreting Bi-Modal and - Encoder-Decoder Transformers - Hila Chefer, Shir Gur, Lior Wolf - https://arxiv.org/abs/2103.15679 - Implementation based on https://github.com/hila-chefer/Transformer-Explainability + The method works by: + 1. Performing a forward pass to capture attention maps from each layer + 2. Computing gradients of the target class w.r.t. attention weights + 3. Combining attention and gradients using element-wise multiplication + 4. Propagating relevance through layers using attention rollout rules + + This approach provides more faithful explanations than raw attention weights + alone, as it accounts for how attention contributes to the final prediction. + + Paper: + Chefer, Hila, Shir Gur, and Lior Wolf. + "Generic Attention-model Explainability for Interpreting Bi-Modal and + Encoder-Decoder Transformers." + Proceedings of the IEEE/CVF International Conference on Computer Vision + (ICCV), 2021. + + Supported Models: + - PyHealth Transformer: For sequential/EHR data with multiple feature keys + - TorchvisionModel (ViT variants): vit_b_16, vit_b_32, vit_l_16, vit_l_32, vit_h_14 Args: - model (Transformer): A trained PyHealth Transformer model to interpret. + model (BaseModel): A trained PyHealth model to interpret. Must be either: + - A ``Transformer`` model for sequential/EHR data + - A ``TorchvisionModel`` with a ViT architecture for image data - Examples: - >>> import torch - >>> from pyhealth.datasets import SampleDataset, split_by_patient, get_dataloader + Example: + >>> # Example 1: PyHealth Transformer for EHR data + >>> from pyhealth.datasets import create_sample_dataset, get_dataloader >>> from pyhealth.models import Transformer >>> from pyhealth.interpret.methods import CheferRelevance - >>> from pyhealth.trainer import Trainer >>> - >>> # Define sample data >>> samples = [ ... { - ... "patient_id": "patient-0", - ... "visit_id": "visit-0", - ... "conditions": ["D001", "D002", "D003"], - ... "procedures": ["P001", "P002"], - ... "drugs": ["M001", "M002"], + ... "patient_id": "p0", + ... "visit_id": "v0", + ... "conditions": ["A05B", "A05C", "A06A"], + ... "procedures": ["P01", "P02"], ... "label": 1, ... }, ... { - ... "patient_id": "patient-1", - ... "visit_id": "visit-1", - ... "conditions": ["D004", "D005"], - ... "procedures": ["P003"], - ... "drugs": ["M003"], + ... "patient_id": "p0", + ... "visit_id": "v1", + ... "conditions": ["A05B"], + ... "procedures": ["P01"], ... "label": 0, ... }, - ... # ... more samples ... ] - >>> - >>> # Create dataset with schema - >>> input_schema = { - ... "conditions": "sequence", - ... "procedures": "sequence", - ... "drugs": "sequence" - ... } - >>> output_schema = {"label": "binary"} - >>> - >>> dataset = SampleDataset( + >>> dataset = create_sample_dataset( ... samples=samples, - ... input_schema=input_schema, - ... output_schema=output_schema, - ... dataset_name="example" + ... input_schema={"conditions": "sequence", "procedures": "sequence"}, + ... output_schema={"label": "binary"}, + ... dataset_name="ehr_example", ... ) + >>> model = Transformer(dataset=dataset) + >>> # ... train the model ... >>> - >>> # Initialize Transformer model - >>> model = Transformer( - ... dataset=dataset, - ... embedding_dim=128, - ... heads=2, - ... dropout=0.3, - ... num_layers=2 - ... ) + >>> # Create interpreter and compute attribution + >>> interpreter = CheferRelevance(model) + >>> batch = next(iter(get_dataloader(dataset, batch_size=2))) + >>> + >>> # Default: attribute to predicted class + >>> attributions = interpreter.attribute(**batch) + >>> # Returns dict: {"conditions": tensor, "procedures": tensor} + >>> print(attributions["conditions"].shape) # [batch, num_tokens] + >>> + >>> # Optional: attribute to a specific class (e.g., class 1) + >>> attributions = interpreter.attribute(class_index=1, **batch) >>> - >>> # Split data and create dataloaders - >>> train_data, val_data, test_data = split_by_patient(dataset, [0.7, 0.15, 0.15]) - >>> train_loader = get_dataloader(train_data, batch_size=32, shuffle=True) - >>> val_loader = get_dataloader(val_data, batch_size=32, shuffle=False) - >>> test_loader = get_dataloader(test_data, batch_size=1, shuffle=False) + >>> # Example 2: TorchvisionModel ViT for image data + >>> from pyhealth.datasets import COVID19CXRDataset + >>> from pyhealth.models import TorchvisionModel + >>> from pyhealth.interpret.utils import visualize_image_attr >>> - >>> # Train model - >>> trainer = Trainer(model=model, device="cuda:0") - >>> trainer.train( - ... train_dataloader=train_loader, - ... val_dataloader=val_loader, - ... epochs=10, - ... monitor="roc_auc" + >>> base_dataset = COVID19CXRDataset(root="/path/to/data") + >>> sample_dataset = base_dataset.set_task() + >>> model = TorchvisionModel( + ... dataset=sample_dataset, + ... model_name="vit_b_16", + ... model_config={"weights": "DEFAULT"}, ... ) + >>> # ... train the model ... >>> - >>> # Compute relevance scores for test samples - >>> relevance = CheferRelevance(model) - >>> data_batch = next(iter(test_loader)) + >>> # Create interpreter and compute attribution + >>> # Task schema: input_schema={"image": "image"}, so feature_key="image" + >>> interpreter = CheferRelevance(model) >>> - >>> # Option 1: Specify target class explicitly - >>> data_batch['class_index'] = 0 - >>> scores = relevance.get_relevance_matrix(**data_batch) - >>> print(scores) - {'conditions': tensor([[1.2210]], device='cuda:0'), - 'procedures': tensor([[1.0865]], device='cuda:0'), - 'drugs': tensor([[1.0000]], device='cuda:0')} + >>> # Default: attribute to predicted class + >>> result = interpreter.attribute(**batch) + >>> # Returns dict keyed by feature_key: {"image": tensor} + >>> attr_map = result["image"] # Shape: [batch, 1, H, W] >>> - >>> # Option 2: Use predicted class (omit class_index) - >>> scores = relevance.get_relevance_matrix( - ... conditions=data_batch['conditions'], - ... procedures=data_batch['procedures'], - ... drugs=data_batch['drugs'], - ... label=data_batch['label'] + >>> # Optional: attribute to a specific class (e.g., predicted class) + >>> pred_class = model(**batch)["y_prob"].argmax().item() + >>> result = interpreter.attribute(class_index=pred_class, **batch) + >>> + >>> # Visualize + >>> img, attr, overlay = visualize_image_attr( + ... image=batch["image"][0], + ... attribution=result["image"][0, 0], ... ) """ - def __init__(self, model: Transformer): - """Initialize Chefer relevance interpreter. - - Args: - model: A trained PyHealth Transformer model to interpret. - Must be an instance of pyhealth.models.Transformer. - - Raises: - AssertionError: If model is not a Transformer instance. - """ + def __init__(self, model: BaseModel): super().__init__(model) - assert isinstance(model, Transformer), ( - f"CheferRelevance only works with Transformer models, " - f"got {type(model).__name__}" - ) - - def attribute(self, **data): - """Compute relevance scores for each token in the input features. - - This method performs a forward pass through the model and computes - gradient-based relevance scores for each input token across all feature - modalities (e.g., conditions, procedures, drugs). The relevance scores - indicate the importance of each token for the predicted class. Higher - relevance scores suggest that the token contributed more to the model's - prediction. + + # Determine model type + self._is_transformer = isinstance(model, Transformer) + self._is_vit = False + + if HAS_TORCHVISION_MODEL and TorchvisionModel is not None: + if isinstance(model, TorchvisionModel): + self._is_vit = model.is_vit_model() + + if not self._is_transformer and not self._is_vit: + raise ValueError( + f"CheferRelevance requires a Transformer or TorchvisionModel (ViT), " + f"got {type(model).__name__}. For TorchvisionModel, only ViT variants " + f"(vit_b_16, vit_b_32, etc.) are supported." + ) + + def attribute( + self, + interpolate: bool = True, + class_index: int = None, + **data, + ) -> Dict[str, torch.Tensor]: + """Compute relevance scores for each token/patch. + + This is the primary method for computing attributions. Returns a + dictionary keyed by the model's feature keys (from the task schema). Args: - **data: Input data dictionary from a dataloader batch containing: - - Feature keys (e.g., 'conditions', 'procedures', 'drugs'): - Input tensors or sequences for each modality - - 'label': Ground truth label tensor - - 'class_index' (optional): Integer specifying target class for - relevance computation. If not provided, uses the predicted - class (argmax of model output). + interpolate: For ViT models, if True interpolate attribution to image size. + class_index: Target class index to compute attribution for. If None + (default), uses the model's predicted class. This is useful when + you want to explain why a specific class was predicted or to + compare attributions across different classes. + **data: Input data from dataloader batch containing: + - For Transformer: feature keys (conditions, procedures, etc.) + label + - For ViT: image feature key (e.g., "image") + label Returns: - Dict[str, torch.Tensor]: Dictionary mapping each feature key to its - relevance score tensor. Each tensor has shape ``[batch_size, - num_tokens]`` where higher values indicate greater relevance for - the prediction. Scores are non-negative due to the clamping - operation in the relevance propagation algorithm. - - Note: - - This method requires gradients, so it should not be called within - a ``torch.no_grad()`` context. - - The method modifies model state temporarily (registers hooks) but - restores it after computation. - - For batch processing, it's recommended to use batch_size=1 to get - per-sample interpretability. - - Examples: - >>> from pyhealth.interpret.methods import CheferRelevance - >>> - >>> # Assuming you have a trained transformer model and test data - >>> relevance = CheferRelevance(trained_model) - >>> test_batch = next(iter(test_loader)) - >>> - >>> # Compute relevance for predicted class - >>> scores = relevance.attribute(**test_batch) - >>> print(f"Feature relevance: {scores.keys()}") - >>> print(f"Condition relevance shape: {scores['conditions'].shape}") - >>> - >>> # Compute relevance for specific class (e.g., positive class) - >>> test_batch['class_index'] = 1 - >>> scores_positive = relevance.attribute(**test_batch) - >>> - >>> # Analyze which tokens are most relevant - >>> condition_scores = scores['conditions'][0] # First sample - >>> top_k_indices = torch.topk(condition_scores, k=5).indices - >>> print(f"Most relevant condition tokens: {top_k_indices}") - """ - return self.get_relevance_matrix(**data) - - def get_relevance_matrix(self, **data): - """Compute relevance scores for each token in the input features. - - This method performs a forward pass through the model and computes gradient-based - relevance scores for each input token across all feature modalities (e.g., - conditions, procedures, drugs). The relevance scores indicate the importance - of each token for the predicted class. Higher relevance scores suggest that - the token contributed more to the model's prediction. + Dict[str, torch.Tensor]: Dictionary keyed by feature keys from the task schema. + - For Transformer: ``{"conditions": tensor, "procedures": tensor, ...}`` + where each tensor has shape ``[batch, num_tokens]``. + - For ViT: ``{"image": tensor}`` (or whatever the task's image key is) + where tensor has shape ``[batch, 1, H, W]``. + """ + if self._is_vit: + return self._attribute_vit( + interpolate=interpolate, + class_index=class_index, + **data + ) + return self._attribute_transformer(class_index=class_index, **data) + + def _attribute_transformer( + self, + class_index: int = None, + **data + ) -> Dict[str, torch.Tensor]: + """Compute relevance for PyHealth Transformer models. + Args: - **data: Input data dictionary from a dataloader batch containing: - - Feature keys (e.g., 'conditions', 'procedures', 'drugs'): - Input tensors or sequences for each modality - - 'label': Ground truth label tensor - - 'class_index' (optional): Integer specifying target class for - relevance computation. If not provided, uses the predicted class - (argmax of model output). - - Returns: - dict: Dictionary mapping each feature key to its relevance score tensor. - Each tensor has shape ``[batch_size, num_tokens]`` where higher values - indicate greater relevance for the prediction. Scores are non-negative - due to the clamping operation in the relevance propagation algorithm. - - Note: - - This method requires gradients, so it should not be called within a - ``torch.no_grad()`` context. - - The method modifies model state temporarily (registers hooks) but - restores it after computation. - - For batch processing, it's recommended to use batch_size=1 to get - per-sample interpretability. - - Examples: - >>> from pyhealth.interpret.methods import CheferRelevance - >>> - >>> # Assuming you have a trained transformer model and test data - >>> relevance = CheferRelevance(trained_model) - >>> test_batch = next(iter(test_loader)) - >>> - >>> # Compute relevance for predicted class - >>> scores = relevance.get_relevance_matrix(**test_batch) - >>> print(f"Feature relevance: {scores.keys()}") - >>> print(f"Condition relevance shape: {scores['conditions'].shape}") - >>> - >>> # Compute relevance for specific class (e.g., positive class in binary) - >>> test_batch['class_index'] = 1 - >>> scores_positive = relevance.get_relevance_matrix(**test_batch) - >>> - >>> # Analyze which tokens are most relevant - >>> condition_scores = scores['conditions'][0] # First sample - >>> top_k_indices = torch.topk(condition_scores, k=5).indices - >>> print(f"Most relevant condition tokens: {top_k_indices}") + class_index: Target class for attribution. If None, uses predicted class. + **data: Input data from dataloader batch. """ - input = data - input["register_hook"] = True - index = data.get("class_index") + data["register_hook"] = True - logits = self.model(**input)["logit"] - if index == None: - index = torch.argmax(logits, dim=-1) + logits = self.model(**data)["logit"] + if class_index is None: + class_index = torch.argmax(logits, dim=-1) - # create one_hot matrix of n x c, one_hot vecs, for graph computation - one_hot = F.one_hot(torch.tensor(index), logits.size()[1]).float() + one_hot = F.one_hot(torch.tensor(class_index), logits.size()[1]).float() one_hot = one_hot.requires_grad_(True) one_hot = torch.sum(one_hot.to(logits.device) * logits) self.model.zero_grad() one_hot.backward(retain_graph=True) - feature_keys = self.model.feature_keys - # get how many tokens we see per modality + feature_keys = self.model.feature_keys num_tokens = {} for key in feature_keys: feature_transformer = self.model.transformer[key].transformer @@ -316,16 +247,121 @@ def get_relevance_matrix(self, **data): R = ( torch.eye(num_tokens[key]) .unsqueeze(0) - .repeat(len(input[key]), 1, 1) + .repeat(len(data[key]), 1, 1) .to(logits.device) - ) # initialize identity matrix, but batched + ) for blk in self.model.transformer[key].transformer: grad = blk.attention.get_attn_grad() cam = blk.attention.get_attn_map() cam = avg_heads(cam, grad) R += apply_self_attention_rules(R, cam).detach() - - attn[key] = R[:, 0] # get CLS Token - - # return Rs for each feature_key - return attn # Assume CLS token is first row of attention score matrix + attn[key] = R[:, 0] + + return attn + + def _attribute_vit( + self, + interpolate: bool = True, + class_index: int = None, + **data, + ) -> Dict[str, torch.Tensor]: + """Compute ViT attribution and return spatial attribution map. + + Args: + interpolate: If True, interpolate to full image size. + class_index: Target class for attribution. If None, uses predicted class. + **data: Must contain the image feature key. + + Returns: + Dict keyed by the model's feature_key (e.g., "image") with spatial + attribution map of shape [batch, 1, H, W]. + """ + # Get the feature key (first element of feature_keys list) + feature_key = self.model.feature_keys[0] + x = data.get(feature_key) + if x is None: + raise ValueError( + f"Expected feature key '{feature_key}' in data. " + f"Available keys: {list(data.keys())}" + ) + + x = x.to(self.model.device) + + # Infer input size from image dimensions (assumes square images) + input_size = x.shape[-1] + + # Forward pass with attention capture + self.model.zero_grad() + logits, attention_maps = self.model.forward_with_attention(x, register_hook=True) + + # Use predicted class if not specified + target_class = class_index + if target_class is None: + target_class = logits.argmax(dim=-1) + + # Backward pass + one_hot = torch.zeros_like(logits) + if isinstance(target_class, int): + one_hot[:, target_class] = 1 + else: + if target_class.dim() == 0: + target_class = target_class.unsqueeze(0) + one_hot.scatter_(1, target_class.unsqueeze(1), 1) + + one_hot = one_hot.requires_grad_(True) + (logits * one_hot).sum().backward(retain_graph=True) + + # Compute gradient-weighted attention + attention_gradients = self.model.get_attention_gradients() + batch_size = attention_maps[0].shape[0] + num_tokens = attention_maps[0].shape[-1] + device = attention_maps[0].device + + R = torch.eye(num_tokens, device=device) + R = R.unsqueeze(0).expand(batch_size, -1, -1).clone() + + for attn, grad in zip(attention_maps, attention_gradients): + cam = avg_heads(attn, grad) + R = R + apply_self_attention_rules(R.detach(), cam.detach()) + + # CLS token's relevance to patches (excluding CLS itself) + patches_attr = R[:, 0, 1:] + + # Reshape to spatial layout + h_patches, w_patches = self.model.get_num_patches(input_size) + attr_map = patches_attr.reshape(batch_size, 1, h_patches, w_patches) + + if interpolate: + attr_map = F.interpolate( + attr_map, + size=(input_size, input_size), + mode="bilinear", + align_corners=False, + ) + + # Return keyed by the model's feature key (e.g., "image") + return {feature_key: attr_map} + + # Backwards compatibility aliases + def get_relevance_matrix(self, **data): + """Alias for _attribute_transformer. Use attribute() instead.""" + return self._attribute_transformer(**data) + + def get_vit_attribution_map( + self, + interpolate: bool = True, + class_index: int = None, + **data + ): + """Alias for attribute() for ViT. Use attribute() instead. + + Returns the attribution tensor directly (not wrapped in a dict). + """ + result = self._attribute_vit( + interpolate=interpolate, + class_index=class_index, + **data + ) + # Return the attribution tensor directly (get the first/only value) + feature_key = self.model.feature_keys[0] + return result[feature_key] diff --git a/pyhealth/interpret/utils.py b/pyhealth/interpret/utils.py new file mode 100644 index 000000000..348898bb4 --- /dev/null +++ b/pyhealth/interpret/utils.py @@ -0,0 +1,366 @@ +"""Visualization utilities for interpretability methods. + +This module provides visualization functions for interpretability in PyHealth, +particularly useful for medical imaging applications. It includes utilities for: + +- **Overlay visualizations**: Show attribution/saliency maps on top of images +- **Attribution normalization**: Prepare raw attributions for visualization +- **Interpolation**: Resize patch-level attributions (e.g., from ViT) to image size + +Example Usage +------------- + +Basic attribution overlay: + +>>> from pyhealth.interpret.utils import show_cam_on_image, normalize_attribution +>>> # Assume we have an image and attribution from an interpreter +>>> attr_normalized = normalize_attribution(attribution) +>>> overlay = show_cam_on_image(image, attr_normalized) + +Image attribution visualization: + +>>> from pyhealth.interpret.methods import CheferRelevance +>>> from pyhealth.interpret.utils import visualize_image_attr +>>> interpreter = CheferRelevance(model) +>>> attribution = interpreter.get_vit_attribution_map(**batch) +>>> image, attr_map, overlay = visualize_image_attr( +... image=batch["image"][0], +... attribution=attribution[0, 0], +... interpolate=True, # Resize attribution to match image +... ) + +See Also +-------- +pyhealth.interpret.methods : Attribution methods (DeepLift, IntegratedGradients, etc.) +""" + +import numpy as np +from typing import Tuple, Union, TYPE_CHECKING + +if TYPE_CHECKING: + import torch + +try: + import cv2 + HAS_CV2 = True +except ImportError: + HAS_CV2 = False + + +def show_cam_on_image( + img: np.ndarray, + mask: np.ndarray, + use_rgb: bool = True, + colormap: int = None, + image_weight: float = 0.5, +) -> np.ndarray: + """Overlay a Class Activation Map (CAM) or attribution map on an image. + + This function creates a visualization by blending an attribution/saliency + map with the original image using a colormap (typically 'jet' for heatmap + visualization). + + Args: + img: Input image as numpy array with shape (H, W, 3) for RGB or (H, W) + for grayscale. Values should be in range [0, 1]. + mask: Attribution/saliency map with shape (H, W). Values should be + in range [0, 1] where higher values indicate more importance. + use_rgb: If True, return RGB format. If False, return BGR format. + Default is True. + colormap: OpenCV colormap constant. If None, uses cv2.COLORMAP_JET. + Common options: cv2.COLORMAP_JET, cv2.COLORMAP_HOT, cv2.COLORMAP_VIRIDIS + image_weight: Weight of the original image in the blend (0 to 1). + Default is 0.5 for equal blend. + + Returns: + Blended visualization as uint8 numpy array with shape (H, W, 3) in + range [0, 255]. + + Raises: + ValueError: If inputs are invalid or cv2 is not available. + + Examples: + >>> import numpy as np + >>> from pyhealth.interpret.utils import show_cam_on_image + >>> + >>> # Create sample image and attribution + >>> image = np.random.rand(224, 224, 3) # RGB image + >>> attribution = np.random.rand(224, 224) # Saliency map + >>> + >>> # Create visualization + >>> overlay = show_cam_on_image(image, attribution) + >>> overlay.shape + (224, 224, 3) + """ + if not HAS_CV2: + # Fallback implementation without cv2 + return _show_cam_fallback(img, mask, image_weight) + + if colormap is None: + colormap = cv2.COLORMAP_JET + + # Ensure image is RGB format with 3 channels + if img.ndim == 2: + img = np.stack([img] * 3, axis=-1) + elif img.shape[-1] == 1: + img = np.concatenate([img] * 3, axis=-1) + + # Validate inputs + if img.max() > 1.0 + 1e-6: + raise ValueError( + f"Image values should be in [0, 1], got max={img.max():.4f}. " + "Normalize with: img = (img - img.min()) / (img.max() - img.min())" + ) + + # Normalize mask to [0, 1] + mask = mask.astype(np.float32) + if mask.max() > mask.min(): + mask = (mask - mask.min()) / (mask.max() - mask.min()) + + # Apply colormap to mask + heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) + if use_rgb: + heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) + heatmap = np.float32(heatmap) / 255 + + # Blend image and heatmap + cam = (1 - image_weight) * heatmap + image_weight * img + cam = cam / cam.max() # Normalize + + return np.uint8(255 * cam) + + +def _show_cam_fallback( + img: np.ndarray, + mask: np.ndarray, + image_weight: float = 0.5, +) -> np.ndarray: + """Fallback implementation of show_cam_on_image without OpenCV. + + Uses matplotlib colormaps instead of cv2.applyColorMap. + """ + try: + from matplotlib import cm + except ImportError: + raise ImportError( + "Either cv2 (opencv-python) or matplotlib is required for " + "visualization. Install with: pip install opencv-python matplotlib" + ) + + # Ensure image is RGB format with 3 channels + if img.ndim == 2: + img = np.stack([img] * 3, axis=-1) + elif img.shape[-1] == 1: + img = np.concatenate([img] * 3, axis=-1) + + # Normalize mask to [0, 1] + mask = mask.astype(np.float32) + if mask.max() > mask.min(): + mask = (mask - mask.min()) / (mask.max() - mask.min()) + + # Apply jet colormap + heatmap = cm.jet(mask)[:, :, :3] # Remove alpha channel + + # Blend image and heatmap + cam = (1 - image_weight) * heatmap + image_weight * img + cam = cam / cam.max() # Normalize + + return np.uint8(255 * cam) + + +def interpolate_attribution_map( + attribution: np.ndarray, + target_size: Tuple[int, int], + mode: str = "bilinear", +) -> np.ndarray: + """Interpolate attribution map to target size. + + This is useful for models where the attribution is computed at a lower + resolution (e.g., 14x14 patch grid for ViT-B/16) and needs to be + upsampled to the original image resolution (e.g., 224x224). + + Args: + attribution: Attribution map as numpy array or torch tensor. + Shape can be (H, W) or (B, H, W) or (1, 1, H, W). + target_size: Target (height, width) for interpolation. + mode: Interpolation mode. Options: "bilinear", "nearest". + Default is "bilinear" for smooth gradients. + + Returns: + Interpolated attribution map with shape (target_h, target_w). + + Examples: + >>> # For ViT-B/16 with 14x14 patch grid + >>> attr_patches = np.random.rand(14, 14) + >>> attr_full = interpolate_attribution_map(attr_patches, (224, 224)) + >>> attr_full.shape + (224, 224) + """ + import torch + import torch.nn.functional as F + + # Convert to tensor if needed + is_numpy = isinstance(attribution, np.ndarray) + if is_numpy: + attribution = torch.from_numpy(attribution).float() + + # Ensure 4D tensor: (B, C, H, W) + while attribution.dim() < 4: + attribution = attribution.unsqueeze(0) + + # Interpolate + interpolated = F.interpolate( + attribution, + size=target_size, + mode=mode, + align_corners=False if mode == "bilinear" else None, + ) + + # Remove batch and channel dims, convert back to numpy + result = interpolated.squeeze() + if is_numpy: + result = result.numpy() + + return result + + +def normalize_attribution( + attribution: Union[np.ndarray, "torch.Tensor"], + method: str = "minmax", +) -> np.ndarray: + """Normalize attribution values for visualization. + + Args: + attribution: Raw attribution values. + method: Normalization method. Options: + - "minmax": Scale to [0, 1] using min-max normalization + - "abs_max": Scale by absolute maximum, keeping sign + - "percentile": Clip to [5, 95] percentile then normalize + + Returns: + Normalized attribution as numpy array in [0, 1]. + """ + import torch + + if isinstance(attribution, torch.Tensor): + attribution = attribution.detach().cpu().numpy() + + attr = attribution.astype(np.float32) + + if method == "minmax": + if attr.max() > attr.min(): + return (attr - attr.min()) / (attr.max() - attr.min()) + return np.zeros_like(attr) + + elif method == "abs_max": + abs_max = np.abs(attr).max() + if abs_max > 0: + return (attr / abs_max + 1) / 2 # Map [-1, 1] to [0, 1] + return np.zeros_like(attr) + 0.5 + + elif method == "percentile": + p5, p95 = np.percentile(attr, [5, 95]) + attr = np.clip(attr, p5, p95) + if p95 > p5: + return (attr - p5) / (p95 - p5) + return np.zeros_like(attr) + + else: + raise ValueError(f"Unknown normalization method: {method}") + + +def visualize_image_attr( + image: Union[np.ndarray, "torch.Tensor"], + attribution: Union[np.ndarray, "torch.Tensor"], + normalize: bool = True, + interpolate: bool = True, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Generate visualization components from an image and attribution map. + + This is a convenience function that prepares image and attribution for + visualization, handling common format conversions, interpolation, and + creating an overlay. Works with any image-based model (CNN, ViT, etc.). + + Args: + image: Input image as numpy array or torch tensor. + Accepted shapes: [H, W], [H, W, C], [C, H, W]. + Values can be in any range (will be normalized to [0, 1]). + attribution: Attribution map as numpy array or torch tensor. + Shape should be [H, W]. If different from image size, will be + interpolated to match when interpolate=True. + normalize: If True, normalize attribution to [0, 1] range. + Default is True. + interpolate: If True, interpolate attribution map to match image + dimensions if they differ. Default is True. + + Returns: + Tuple of (image, attribution_map, overlay) where: + - image: Normalized image as numpy array [H, W] or [H, W, C] in [0, 1] + - attribution_map: Attribution as numpy array [H, W] in [0, 1] + - overlay: Attribution overlay on image as numpy array [H, W, 3] + in [0, 255] + + Examples: + >>> from pyhealth.interpret.methods import CheferRelevance + >>> from pyhealth.interpret.utils import visualize_image_attr + >>> + >>> # Compute attribution with interpreter + >>> interpreter = CheferRelevance(model) + >>> attr_map = interpreter.get_vit_attribution_map(**batch) + >>> + >>> # Generate visualization (auto-interpolates to image size) + >>> image, attr_display, overlay = visualize_image_attr( + ... image=batch["image"][0], + ... attribution=attr_map[0, 0], + ... interpolate=True, + ... ) + >>> + >>> # Display + >>> import matplotlib.pyplot as plt + >>> plt.imshow(overlay) + >>> plt.savefig("attribution.png") + """ + import torch + + # Convert image to numpy + if isinstance(image, torch.Tensor): + image = image.detach().cpu().numpy() + + # Handle channel dimension - convert [C, H, W] to [H, W, C] + if image.ndim == 3 and image.shape[0] in [1, 3]: + image = np.transpose(image, (1, 2, 0)) + + # Handle single-channel images + if image.ndim == 3 and image.shape[-1] == 1: + image = image.squeeze(-1) + + # Normalize image to [0, 1] + image = image.astype(np.float32) + image = (image - image.min()) / (image.max() - image.min() + 1e-8) + + # Get image spatial dimensions + img_h, img_w = image.shape[:2] + + # Convert attribution to numpy + if isinstance(attribution, torch.Tensor): + attribution = attribution.detach().cpu().numpy() + + # Ensure attribution is 2D + attribution = np.squeeze(attribution) + + # Interpolate attribution to match image size if needed + if interpolate and attribution.shape != (img_h, img_w): + attribution = interpolate_attribution_map(attribution, (img_h, img_w)) + + # Normalize attribution if requested + if normalize: + attribution = normalize_attribution(attribution) + + # Create overlay + if image.ndim == 2: + image_rgb = np.stack([image] * 3, axis=-1) + else: + image_rgb = image + overlay = show_cam_on_image(image_rgb, attribution) + + return image, attribution, overlay diff --git a/pyhealth/models/torchvision_model.py b/pyhealth/models/torchvision_model.py index 2016a9e73..07041b6b7 100644 --- a/pyhealth/models/torchvision_model.py +++ b/pyhealth/models/torchvision_model.py @@ -1,7 +1,8 @@ -from typing import Any, Dict +from typing import Any, Dict, List, Tuple import torch import torch.nn as nn +import torch.nn.functional as F import torchvision from ..datasets import SampleDataset @@ -42,43 +43,83 @@ class TorchvisionModel(BaseModel): - """Models from PyTorch's torchvision package. + """Models from PyTorch's torchvision package for image classification. - This class is a wrapper for models from torchvision. It will automatically load - the corresponding model and weights from torchvision. The final layer will be - replaced with a linear layer with the correct output size. + This class is a wrapper for pretrained models from torchvision. It will + automatically load the corresponding model and weights from torchvision. + The final classification layer is replaced with a linear layer matching + the dataset's output size, enabling transfer learning on custom datasets. + + The model supports: + - Standard forward pass for training/inference + - Embedding extraction for interpretability methods + - Attention map capture for ViT models (used by CheferRelevance) Supported Models: - ---------------- - ResNet: + ----------------- + ResNet (resnet18, resnet34, resnet50, resnet101, resnet152): Paper: Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. - Deep Residual Learning for Image Recognition. CVPR 2016. + "Deep Residual Learning for Image Recognition." + IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2016. - DenseNet: + DenseNet (densenet121, densenet161, densenet169, densenet201): Paper: Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger. - Densely Connected Convolutional Networks. CVPR 2017. + "Densely Connected Convolutional Networks." + IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2017. - Vision Transformer (ViT): + Vision Transformer (vit_b_16, vit_b_32, vit_l_16, vit_l_32, vit_h_14): Paper: Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, et al. - An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. - ICLR 2021. + "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale." + International Conference on Learning Representations (ICLR), 2021. - Swin Transformer: + Swin Transformer (swin_t, swin_s, swin_b): Paper: Ze Liu, Yutong Lin, Yue Cao, et al. - Swin Transformer: Hierarchical Vision Transformer Using Shifted Windows. - ICCV 2021. + "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows." + IEEE/CVF International Conference on Computer Vision (ICCV), 2021. Paper: Ze Liu, Han Hu, Yutong Lin, et al. - Swin Transformer V2: Scaling Up Capacity and Resolution. CVPR 2022. + "Swin Transformer V2: Scaling Up Capacity and Resolution." + IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2022. Args: - dataset: The dataset to train the model. Used to query information - such as the set of all tokens. - model_name: Name of the model to use (e.g., "resnet18"). - See SUPPORTED_MODELS in the source code for the full list. - model_config: Dictionary of kwargs to pass to the model constructor. - Example: {"weights": "DEFAULT"}. See torchvision documentation for - supported kwargs for each model. + dataset (SampleDataset): The dataset to train the model. Used to query + information such as the number of output classes. Must have exactly + one feature key (the image) and one label key. + model_name (str): Name of the model to use. Must be one of: + resnet18, resnet34, resnet50, resnet101, resnet152, + densenet121, densenet161, densenet169, densenet201, + vit_b_16, vit_b_32, vit_l_16, vit_l_32, vit_h_14, + swin_t, swin_s, swin_b. + model_config (Dict[str, Any]): Dictionary of kwargs to pass to the model + constructor. Common options include: + - ``{"weights": "DEFAULT"}``: Use pretrained ImageNet weights + - ``{"weights": None}``: Random initialization + See torchvision documentation for all supported kwargs. + + Example: + >>> from pyhealth.datasets import COVID19CXRDataset, get_dataloader + >>> from pyhealth.models import TorchvisionModel + >>> from pyhealth.trainer import Trainer + >>> + >>> # Load a medical imaging dataset + >>> base_dataset = COVID19CXRDataset(root="/path/to/COVID-19_Radiography_Dataset") + >>> sample_dataset = base_dataset.set_task() + >>> + >>> # Create a ViT model with pretrained weights + >>> model = TorchvisionModel( + ... dataset=sample_dataset, + ... model_name="vit_b_16", + ... model_config={"weights": "DEFAULT"}, + ... ) + >>> + >>> # Train the model + >>> train_loader = get_dataloader(train_data, batch_size=32, shuffle=True) + >>> trainer = Trainer(model=model, device="cuda:0") + >>> trainer.train(train_dataloader=train_loader, epochs=10) + >>> + >>> # Inference + >>> test_loader = get_dataloader(test_data, batch_size=32, shuffle=False) + >>> y_true, y_prob, _ = trainer.inference(test_loader) """ def __init__( @@ -117,6 +158,15 @@ def __init__( output_size = self.get_output_size() layer_name = final_layer_name.split(".")[0] setattr(self.model, layer_name, nn.Linear(hidden_dim, output_size)) + + # Initialize attention hooks storage for ViT interpretability + self._attention_maps: List[torch.Tensor] = [] + self._attention_gradients: List[torch.Tensor] = [] + self._hooks: List[Any] = [] + + # Setup attention hooks for ViT models + if "vit" in model_name: + self._setup_vit_attention_hooks() def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation. @@ -220,3 +270,287 @@ def _extract_embeddings(self, x: torch.Tensor) -> torch.Tensor: ) return embeddings + + def forward_from_embedding( + self, + embeddings: torch.Tensor, + **kwargs, + ) -> Dict[str, torch.Tensor]: + """Forward pass from pre-computed embeddings. + + This method allows running the classification head on embeddings that + were computed externally, useful for interpretability methods like + DeepLift and Integrated Gradients. + + Args: + embeddings: Pre-computed embeddings tensor of shape (batch_size, hidden_dim). + **kwargs: Must contain label_key for loss computation. + + Returns: + Dictionary with: + - loss: classification loss + - y_prob: predicted probabilities + - y_true: true labels + - logit: raw logits + """ + embeddings = embeddings.to(self.device) + + # Get the final classification layer + final_layer_name = SUPPORTED_MODELS_FINAL_LAYER[self.model_name] + layer_name = final_layer_name.split(".")[0] + fc_layer = getattr(self.model, layer_name) + + # Apply classification head + logits = fc_layer(embeddings) + + y_true = kwargs[self.label_key].to(self.device) + loss = self.get_loss_function()(logits, y_true) + y_prob = self.prepare_y_prob(logits) + + return { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + "logit": logits, + } + + # ========================================================================= + # ViT Attention Hooks for Interpretability (used by CheferRelevance) + # ========================================================================= + + def _setup_vit_attention_hooks(self) -> None: + """Setup attention hooks for ViT models to capture attention maps. + + This enables Chefer-style interpretability by storing attention maps + and their gradients during forward and backward passes. + """ + if "vit" not in self.model_name: + return + + # Access the encoder blocks (different paths for different torchvision versions) + try: + encoder = self.model.encoder + if hasattr(encoder, 'layers'): + blocks = encoder.layers + else: + blocks = list(encoder.children()) + except AttributeError: + print(f"Warning: Could not setup attention hooks for {self.model_name}") + return + + self._vit_blocks = blocks + + def clear_attention_storage(self) -> None: + """Clear stored attention maps and gradients.""" + self._attention_maps = [] + self._attention_gradients = [] + + def get_attention_maps(self) -> List[torch.Tensor]: + """Get stored attention maps from last forward pass. + + Returns: + List of attention tensors, one per encoder block. + """ + return self._attention_maps + + def get_attention_gradients(self) -> List[torch.Tensor]: + """Get stored attention gradients from last backward pass. + + Returns: + List of attention gradient tensors, one per encoder block. + """ + return self._attention_gradients + + def is_vit_model(self) -> bool: + """Check if this is a Vision Transformer model. + + Returns: + True if model is ViT, False otherwise. + """ + return "vit" in self.model_name + + def _compute_manual_attention( + self, + mha: nn.MultiheadAttention, + x: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute attention manually to enable gradient flow through attention weights. + + This method replaces the black-box nn.MultiheadAttention call with explicit + QKV computation, ensuring that attention weights are part of the computational + graph and gradients can flow through them for interpretability methods. + + Args: + mha: The nn.MultiheadAttention module to extract weights from. + x: Input tensor of shape [batch, seq_len, embed_dim]. + + Returns: + Tuple of (attn_output, attn_weights) where: + - attn_output: [batch, seq_len, embed_dim] - the attention output + - attn_weights: [batch, num_heads, seq_len, seq_len] - attention weights + that ARE in the computation graph (gradients will flow through them) + """ + batch_size, seq_len, embed_dim = x.shape + num_heads = mha.num_heads + head_dim = embed_dim // num_heads + scale = head_dim ** -0.5 + + # Extract projection weights from MHA module + # in_proj_weight: [3*embed_dim, embed_dim] contains Q, K, V projections + W_qkv = mha.in_proj_weight # [3*embed_dim, embed_dim] + b_qkv = mha.in_proj_bias # [3*embed_dim] or None + + # Project input to Q, K, V using the MHA's learned weights + qkv = F.linear(x, W_qkv, b_qkv) # [batch, seq_len, 3*embed_dim] + + # Split into Q, K, V + q, k, v = qkv.chunk(3, dim=-1) # Each: [batch, seq_len, embed_dim] + + # Reshape for multi-head attention + # [batch, seq_len, embed_dim] -> [batch, seq_len, num_heads, head_dim] + # -> [batch, num_heads, seq_len, head_dim] + q = q.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2) + k = k.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2) + v = v.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2) + + # Compute attention scores: [batch, heads, seq, seq] + # THIS is the key: attn_weights is computed inline and stays in the graph + attn_weights = (q @ k.transpose(-2, -1)) * scale + attn_weights = attn_weights.softmax(dim=-1) + + # Apply attention to values: [batch, heads, seq, head_dim] + attn_output = attn_weights @ v + + # Reshape back: [batch, heads, seq, head_dim] -> [batch, seq, embed_dim] + attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, embed_dim) + + # Apply output projection + attn_output = F.linear(attn_output, mha.out_proj.weight, mha.out_proj.bias) + + return attn_output, attn_weights + + def forward_with_attention( + self, + x: torch.Tensor, + register_hook: bool = True, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Forward pass with attention map capture for ViT models. + + This method provides explicit access to attention maps for interpretability, + used by CheferRelevance interpreter. + + Args: + x: Input tensor of shape (batch_size, channels, height, width). + register_hook: If True, register hooks to capture attention gradients. + + Returns: + Tuple of (logits, attention_maps) where attention_maps is a list of + attention tensors from each encoder block. + + Raises: + ValueError: If model is not a ViT model. + """ + if not self.is_vit_model(): + raise ValueError("forward_with_attention only works with ViT models") + + self.clear_attention_storage() + + # Move input to device (consistent with forward method) + x = x.to(self.device) + + # Handle channel dimension + if x.shape[1] == 1: + x = x.repeat((1, 3, 1, 1)) + + # Ensure input requires grad for gradient-based attribution + if register_hook and not x.requires_grad: + x = x.requires_grad_(True) + + # Process input (conv projection + position embeddings) + x = self.model._process_input(x) + + # Add CLS token + batch_size = x.shape[0] + batch_class_token = self.model.class_token.expand(batch_size, -1, -1) + x = torch.cat([batch_class_token, x], dim=1) + + # Forward through encoder blocks with attention capture + attention_maps = [] + + # Access encoder layers + if hasattr(self.model.encoder, 'layers'): + encoder_layers = self.model.encoder.layers + else: + encoder_layers = list(self.model.encoder.children()) + + for idx, block in enumerate(encoder_layers): + # Each block is an EncoderBlock with self_attention + if hasattr(block, 'self_attention'): + # Apply layer norm + ln_x = block.ln_1(x) + + # Use manual attention computation for gradient flow + # This computes Q, K, V inline so attention weights stay in the graph + attn_output, attn_weights = self._compute_manual_attention( + block.self_attention, ln_x + ) + + # Store attention weights (now in computation graph!) + attention_maps.append(attn_weights) + if register_hook: + # Register hook to capture gradients during backprop + # Gradients will now flow through attn_weights! + attn_weights.register_hook( + lambda grad, i=idx: self._attention_gradients.insert(i, grad) + ) + + # Continue with residual connections + x = x + block.dropout(attn_output) + x = x + block.mlp(block.ln_2(x)) + else: + # Fallback: just pass through the block + x = block(x) + + # Apply layer norm + x = self.model.encoder.ln(x) + + # Get CLS token embedding and classify + cls_embedding = x[:, 0] + logits = self.model.heads(cls_embedding) + + self._attention_maps = attention_maps + return logits, attention_maps + + def get_patch_size(self) -> int: + """Get the patch size for ViT models. + + Returns: + Patch size (e.g., 16 for vit_b_16). + + Raises: + ValueError: If model is not a ViT model. + """ + if not self.is_vit_model(): + raise ValueError("get_patch_size only works with ViT models") + + # Extract from model name + parts = self.model_name.split("_") + for part in parts: + if part.isdigit(): + return int(part) + + # Default fallback + return 16 + + def get_num_patches(self, input_size: int = 224) -> Tuple[int, int]: + """Get the number of patches for ViT models. + + Args: + input_size: Input image size (default 224). + + Returns: + Tuple of (height_patches, width_patches). For standard 224x224 input + with patch_size=16, this is (14, 14). + """ + patch_size = self.get_patch_size() + return (input_size // patch_size, input_size // patch_size) diff --git a/pyhealth/processors/label_processor.py b/pyhealth/processors/label_processor.py index ff32dabf8..ae8d1f8aa 100644 --- a/pyhealth/processors/label_processor.py +++ b/pyhealth/processors/label_processor.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Iterable +from typing import Any, Dict, Iterable import torch @@ -68,7 +68,7 @@ def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: def process(self, value: Any) -> torch.Tensor: index = self.label_vocab[value] return torch.tensor(index, dtype=torch.long) - + def size(self): return len(self.label_vocab)