diff --git a/pyhealth/tasks/mimic3_llm_diagnosis.py b/pyhealth/tasks/mimic3_llm_diagnosis.py new file mode 100644 index 00000000..8a8122d9 --- /dev/null +++ b/pyhealth/tasks/mimic3_llm_diagnosis.py @@ -0,0 +1,142 @@ +""" +Author: Stephen Moy +NetID: moy26 +Description: + This task aggregates MIMIC-III clinical notes by patient and applies a large language model + (e.g., FLAN-T5) to classify whether the patient has a specified diagnosis. It integrates + preprocessing, patient-level aggregation, and LLM-based classification into a single PyHealth + Task class for reproducibility and open-source contribution. +""" +import pandas as pd +import torch +import math +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM +from pyhealth.tasks import BaseTask + +class MIMIC3LLMDiagnosisTask(BaseTask): + """Aggregate MIMIC-III notes by patient and diagnose using an LLM.""" + + def __init__(self, notes_path: str, diagnosis: str, + model_name: str = "google/flan-t5-large", device: str = "cuda"): + """ + Initialize the MIMIC-III LLM diagnosis task. + + Args: + notes_path (str): Path to the NOTEEVENTS.csv file from MIMIC-III. + diagnosis (str): The diagnosis to classify (e.g., "heart failure"). + model_name (str): Hugging Face model name for the LLM. + device (str): Device to run the model on ("cuda" or "cpu"). + + Attributes: + notes_df (pd.DataFrame): Raw notes loaded from CSV. + patients_df (pd.DataFrame): Aggregated notes per patient. + tokenizer (AutoTokenizer): Hugging Face tokenizer. + model (AutoModelForSeq2SeqLM): Hugging Face seq2seq model. + """ + super().__init__() + self.notes_path = notes_path + self.diagnosis = diagnosis + self.device = device + + # Load data + self.notes_df = pd.read_csv(notes_path, low_memory=False) + self.patients_df = self.aggregate_by_patient() + + # Load LLM + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) + self.model.eval() + + def aggregate_by_patient(self, text_col: str = "TEXT", patient_col: str = "SUBJECT_ID") -> pd.DataFrame: + """ + Aggregate notes per patient into a single string. + + Args: + text_col (str): Column name containing note text. Defaults to "TEXT". + patient_col (str): Column name containing patient IDs. Defaults to "SUBJECT_ID". + + Returns: + pd.DataFrame: DataFrame with columns: + - SUBJECT_ID (int): Patient identifier. + - PATIENT_NOTES (str): Concatenated notes for the patient. + + Example: + >>> dataset = MIMIC3LLMDiagnosisTask("NOTEEVENTS.csv", "heart failure") + >>> patients = dataset.aggregate_by_patient() + >>> print(patients.head()) + """ + agg_df = ( + self.notes_df.groupby(patient_col)[text_col] + .apply(lambda x: " \n\n ".join(x.dropna())) + .reset_index() + .rename(columns={text_col: "PATIENT_NOTES"}) + ) + return agg_df + + def classify_patient(self, patient_notes: str) -> dict: + """ + Classify whether a patient has the specified diagnosis using the LLM. + + Args: + patient_notes (str): Concatenated notes for a single patient. + + Returns: + dict: Dictionary with keys: + - "diagnosis" (str): The diagnosis being tested. + - "verdict" (str): "YES" or "NO". + - "confidence" (float): Confidence score in [0,1]. + + Example: + >>> task = MIMIC3LLMDiagnosisTask("NOTEEVENTS.csv", "heart failure") + >>> result = task.classify_patient("Patient has CHF and hypertension.") + >>> print(result) + {'diagnosis': 'heart failure', 'verdict': 'YES', 'confidence': 0.92} + """ + prompt = ( + f"Read the following clinical notes:\n\n{patient_notes}\n\n" + f"Question: Does this patient have {self.diagnosis}?\n" + f"Answer with YES or NO." + ) + enc = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(self.device) + + def score_candidate(target: str): + labels = self.tokenizer(target, return_tensors="pt").input_ids.to(self.device) + out = self.model(**enc, labels=labels) + return -float(out.loss.item()) + + score_yes, score_no = score_candidate("YES"), score_candidate("NO") + e_yes, e_no = math.exp(score_yes), math.exp(score_no) + p_yes = e_yes / (e_yes + e_no + 1e-12) + p_no = e_no / (e_yes + e_no + 1e-12) + + verdict = "YES" if p_yes >= p_no else "NO" + confidence = max(p_yes, p_no) + return {"diagnosis": self.diagnosis, "verdict": verdict, "confidence": confidence} + + def run(self, sample_size: int = 10) -> list: + """ + Run diagnosis classification for a sample of patients. + + Args: + sample_size (int): Number of patients to sample. Defaults to 10. + + Returns: + list: List of dictionaries, each containing: + - "patient_id" (int): Patient identifier. + - "diagnosis" (str): Diagnosis tested. + - "verdict" (str): "YES" or "NO". + - "confidence" (float): Confidence score. + + Example: + >>> task = MIMIC3LLMDiagnosisTask("NOTEEVENTS.csv", "heart failure") + >>> results = task.run(sample_size=5) + >>> print(results[0]) + {'patient_id': 123, 'diagnosis': 'heart failure', 'verdict': 'YES', 'confidence': 0.87} + """ + sample_df = self.patients_df.sample(min(sample_size, len(self.patients_df))) + results = [] + for _, row in sample_df.iterrows(): + res = self.classify_patient(row["PATIENT_NOTES"]) + res["patient_id"] = int(row["SUBJECT_ID"]) + results.append(res) + return results diff --git a/tests/core/test_mimic3_llm_diagnosis.py b/tests/core/test_mimic3_llm_diagnosis.py new file mode 100644 index 00000000..5847a102 --- /dev/null +++ b/tests/core/test_mimic3_llm_diagnosis.py @@ -0,0 +1,76 @@ +# test_mimic3_llm_diagnosis.py +import pandas as pd +import pytest +from pathlib import Path + +from pyhealth.tasks.mimic3_llm_diagnosis import MIMIC3LLMDiagnosisTask + +class DummyLLM: + """A lightweight fake LLM for testing.""" + def classify(self, prompt: str): + # Always return YES with confidence 0.9 for simplicity + return "YES", 0.9 + +@pytest.fixture +def dummy_notes_csv(tmp_path: Path): + """Create a temporary NOTEEVENTS.csv with minimal content.""" + csv_path = tmp_path / "NOTEEVENTS.csv" + df = pd.DataFrame({ + "SUBJECT_ID": [1, 1, 2], + "TEXT": [ + "Patient has CHF and hypertension.", + "Follow-up note: stable condition.", + "Patient denies chest pain." + ] + }) + df.to_csv(csv_path, index=False) + return csv_path + +def test_aggregate_by_patient(dummy_notes_csv): + task = MIMIC3LLMDiagnosisTask( + notes_path=str(dummy_notes_csv), + diagnosis="heart failure", + model_name="google/flan-t5-large" + ) + # Replace heavy LLM with dummy + task.model = DummyLLM() + task.tokenizer = None # not used in dummy + + patients = task.aggregate_by_patient() + assert "PATIENT_NOTES" in patients.columns + assert patients.shape[0] == 2 # two patients aggregated + +def test_classify_patient(dummy_notes_csv): + task = MIMIC3LLMDiagnosisTask( + notes_path=str(dummy_notes_csv), + diagnosis="heart failure", + model_name="google/flan-t5-large" + ) + # Replace heavy LLM with dummy + task.model = DummyLLM() + task.tokenizer = None + + patient_notes = "Patient has CHF and hypertension." + result = task.classify_patient(patient_notes) + assert result["diagnosis"] == "heart failure" + assert result["verdict"] in ["YES", "NO"] + assert 0.0 <= result["confidence"] <= 1.0 + +def test_run_task(dummy_notes_csv): + task = MIMIC3LLMDiagnosisTask( + notes_path=str(dummy_notes_csv), + diagnosis="heart failure", + model_name="google/flan-t5-large" + ) + # Replace heavy LLM with dummy + task.model = DummyLLM() + task.tokenizer = None + + results = task.run(sample_size=2) + assert isinstance(results, list) + assert len(results) <= 2 + for res in results: + assert "patient_id" in res + assert "diagnosis" in res + assert "verdict" in res + assert "confidence" in res