Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
7ae3bd1
[PromptEHR] Setup: Foundation structure with placeholder files (Phase 1)
jalengg Nov 23, 2025
9919d75
[PromptEHR] Data: Tokenization using PyHealth infrastructure (Phase 2.1)
jalengg Nov 23, 2025
34907d9
[PromptEHR] Data: Dataset and collator implementation (Phase 2.2)
jalengg Nov 23, 2025
515f55c
[PromptEHR] Model: Conditional prompt encoder (Phase 3.1)
jalengg Nov 24, 2025
b1b8230
[PromptEHR] Model: BART encoder with prompt injection (Phase 3.2)
jalengg Nov 24, 2025
6630e61
[PromptEHR] Model: BART decoder with prompt injection (Phase 3.3)
jalengg Nov 24, 2025
4c2959a
[PromptEHR] Model: Main PromptEHR model with PyHealth integration (Ph…
jalengg Nov 24, 2025
e2ff4e2
[PromptEHR] Fix: Encoder and decoder bug fixes for robustness
jalengg Nov 24, 2025
2d3b84a
[PromptEHR] Generation: Visit structure sampling and generation funct…
jalengg Nov 26, 2025
191be9d
[PromptEHR] Training: PyHealth Trainer integration and checkpoint loa…
jalengg Nov 26, 2025
ef07281
[PromptEHR] Examples: Training and generation scripts for MIMIC-III
jalengg Nov 26, 2025
0b52196
Fix: Update AdamW import and venv path in training scripts
jalengg Nov 26, 2025
74fdf86
Fix: Add logger parameter to PromptEHRDataset initialization
jalengg Nov 26, 2025
c78cba6
Fix: Add logger parameter to EHRDataCollator
jalengg Nov 26, 2025
40ae97c
Fix: Correct PromptEHR API usage in example script
jalengg Nov 26, 2025
b3cf335
Fix: Add DeviceAwareCollatorWrapper to resolve CUDA device mismatch
jalengg Dec 5, 2025
bc72f88
Fix: Correct import path for VisitStructureSampler
jalengg Dec 5, 2025
9e4be9b
Fix: Add weights_only=False for PyTorch 2.6+ checkpoint loading
jalengg Dec 5, 2025
cd36122
Add: Local CPU generation script and tokenizer compatibility
jalengg Dec 7, 2025
fe706bc
Merge remote-tracking branch 'origin/master' into promptehr-port
jalengg Dec 8, 2025
3d191d4
Add: PromptEHR synthetic data generation with holdout training
jalengg Jan 14, 2026
3652499
Add: PromptEHR synthetic patient datasets
jalengg Jan 14, 2026
7fab7cd
add csv
jalengg Jan 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
400 changes: 400 additions & 0 deletions docs/promptehr/IMPLEMENTATION_REPORT.md

Large diffs are not rendered by default.

59 changes: 59 additions & 0 deletions examples/promptehr_generate_10k.slurm
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/bin/bash
#SBATCH --account=jalenj4-ic
#SBATCH --partition=IllinoisComputes-GPU
#SBATCH --gres=gpu:1
#SBATCH --mem=64G
#SBATCH --cpus-per-task=4
#SBATCH --time=04:00:00
#SBATCH --job-name=promptehr_gen_10k
#SBATCH --output=logs/promptehr_gen_10k_%j.out
#SBATCH --error=logs/promptehr_gen_10k_%j.err

# Exit on error
set -e
set -o pipefail

# Load modules
module purge
module load gcc/11.2.0 || true
module load cuda/12.6

# Environment setup
VENV_PATH="/u/jalenj4/pehr_scratch/venv"
if [ -d "$VENV_PATH" ]; then
source "$VENV_PATH/bin/activate"
echo "Activated environment: $VENV_PATH"
else
echo "ERROR: Virtual environment not found at $VENV_PATH"
exit 1
fi

# Change to project directory
cd /u/jalenj4/final/PyHealth

# Create logs directory
mkdir -p logs

# Print environment info
echo "==================== Environment Info ===================="
echo "Date: $(date)"
echo "Node: $(hostname)"
echo "Job ID: $SLURM_JOB_ID"
echo "Python: $(which python3)"
echo "PyTorch version: $(python3 -c 'import torch; print(torch.__version__)')"
echo "CUDA available: $(python3 -c 'import torch; print(torch.cuda.is_available())')"
echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader)"
echo "=========================================================="

# Run generation for 10,000 patients
python3 examples/promptehr_mimic3.py \
--mimic3_root /u/jalenj4/pehr_scratch/data_files \
--output_dir ./promptehr_outputs \
--checkpoint ./promptehr_outputs/checkpoints/final_model.pt \
--generate_only \
--num_synthetic 10000 \
--num_patients 1000 \
--temperature 0.7 \
--device cuda

echo "Generation completed at $(date)"
157 changes: 157 additions & 0 deletions examples/promptehr_generate_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#!/usr/bin/env python3
"""Quick local generation test for PromptEHR (CPU-only).

This script demonstrates how to:
1. Load a trained PromptEHR checkpoint
2. Generate synthetic patients on CPU (no GPU required)
3. Display results in human-readable format

Usage:
python3 examples/promptehr_generate_local.py
"""

import sys
sys.path.insert(0, '/u/jalenj4/final/PyHealth')

import torch
import logging
from pathlib import Path

# PyHealth imports
from pyhealth.models import PromptEHR
from pyhealth.datasets.promptehr_dataset import load_mimic_data
from pyhealth.models.promptehr import (
VisitStructureSampler,
generate_patient_with_structure_constraints
)


def main():
"""Generate 10 synthetic patients locally on CPU."""

# Setup
device = torch.device("cpu") # Force CPU (no GPU required)
logging.basicConfig(
level=logging.WARNING, # Reduce noise, only show warnings/errors
format='%(message)s'
)
logger = logging.getLogger(__name__)

print("\n" + "="*80)
print("PromptEHR Local Generation Test (CPU mode)")
print("="*80)

# Load checkpoint
print("\n[1/4] Loading trained checkpoint...")
checkpoint_path = "./promptehr_outputs/checkpoints/final_model.pt"

if not Path(checkpoint_path).exists():
print(f"ERROR: Checkpoint not found at {checkpoint_path}")
print("Please ensure training has completed and checkpoint exists.")
return

checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
tokenizer = checkpoint['tokenizer']

# Add convenience properties and methods if not present
# (for compatibility with old checkpoints saved before these were added)
if not hasattr(tokenizer, 'bos_token_id'):
tokenizer.pad_token_id = tokenizer.vocabulary("<pad>") # ID 0
tokenizer.bos_token_id = tokenizer.vocabulary("<s>") # ID 1
tokenizer.eos_token_id = tokenizer.vocabulary("</s>") # ID 2
tokenizer.code_offset = 7 # First diagnosis code ID (after 7 special tokens)
if not hasattr(tokenizer, 'convert_tokens_to_ids'):
# Add method alias: pehr_scratch API uses convert_tokens_to_ids(token) → int
def convert_tokens_to_ids(token: str) -> int:
return tokenizer.convert_tokens_to_indices([token])[0]
tokenizer.convert_tokens_to_ids = convert_tokens_to_ids
if not hasattr(tokenizer, 'vocab'):
# Add vocab object for idx2code and code2idx mappings
class VocabCompat:
def __init__(self, tok):
self.idx2code = tok.vocabulary.idx2token
self.code2idx = tok.vocabulary.token2idx
def __len__(self):
return len(self.idx2code)
tokenizer.vocab = VocabCompat(tokenizer)

# Rebuild model
print("[2/4] Rebuilding model from checkpoint...")
config = checkpoint['config']
model = PromptEHR(**config)
model.bart_model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

print(f" Model vocabulary size: {config['_custom_vocab_size']}")
print(f" Hidden dimension: {config['d_hidden']}")
print(f" Prompt length: {config['prompt_length']}")

# Load MIMIC data for structure sampling
print("[3/4] Loading MIMIC-III data for structure sampling...")
print(" (Loading 1000 patients for realistic visit distributions)")

patient_records, _ = load_mimic_data(
patients_path="/u/jalenj4/pehr_scratch/data_files/PATIENTS.csv",
admissions_path="/u/jalenj4/pehr_scratch/data_files/ADMISSIONS.csv",
diagnoses_path="/u/jalenj4/pehr_scratch/data_files/DIAGNOSES_ICD.csv",
num_patients=1000,
logger=logger
)

# Initialize structure sampler
structure_sampler = VisitStructureSampler(patient_records, seed=42)
print(f" {structure_sampler}")

# Generate synthetic patients
n_patients = 10
print(f"\n[4/4] Generating {n_patients} synthetic patients...")
print(" (This will take ~10-15 seconds)")
print()

print("="*80)
print("SYNTHETIC PATIENTS")
print("="*80)
print()

for i in range(n_patients):
# Sample realistic visit structure
target_structure = structure_sampler.sample_structure()

# Generate patient
result = generate_patient_with_structure_constraints(
model=model,
tokenizer=tokenizer,
device=device,
target_structure=target_structure,
temperature=0.7,
top_k=40,
top_p=0.9,
max_codes_per_visit=25
)

# Display patient
demo = result['demographics']
print(f"Patient {i+1}:")
print(f" Age: {demo['age']} years")
print(f" Sex: {'Male' if demo['sex'] == 0 else 'Female'}")
print(f" Number of visits: {result['num_visits']}")
print(f" Diagnosis codes:")

for visit_idx, codes in enumerate(result['generated_visits'], 1):
if codes:
print(f" Visit {visit_idx}: {', '.join(codes)}")
else:
print(f" Visit {visit_idx}: (no diagnoses)")
print()

print("="*80)
print("Generation complete!")
print("="*80)
print()
print(f"Successfully generated {n_patients} synthetic patients on CPU.")
print()


if __name__ == "__main__":
main()
Loading