diff --git a/docs/api/models.rst b/docs/api/models.rst index a0df0d943..4ab9d6ed8 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -8,6 +8,7 @@ We implement the following models for supporting multiple healthcare predictive :maxdepth: 3 models/pyhealth.models.BaseModel + models/pyhealth.models.BHCToAVS models/pyhealth.models.LogisticRegression models/pyhealth.models.MLP models/pyhealth.models.CNN diff --git a/docs/api/models/pyhealth.models.BHCToAVS.rst b/docs/api/models/pyhealth.models.BHCToAVS.rst new file mode 100644 index 000000000..5c8d197b1 --- /dev/null +++ b/docs/api/models/pyhealth.models.BHCToAVS.rst @@ -0,0 +1,11 @@ +pyhealth.models.BHCToAVS +======================== + +BHCToAVS +------------------------------ + +.. autoclass:: pyhealth.models.bhc_to_avs.BHCToAVS + :members: + :inherited-members: + :show-inheritance: + :undoc-members: \ No newline at end of file diff --git a/examples/bhc_to_avs_example.py b/examples/bhc_to_avs_example.py new file mode 100644 index 000000000..ef438622f --- /dev/null +++ b/examples/bhc_to_avs_example.py @@ -0,0 +1,21 @@ +from pyhealth.models.bhc_to_avs import BHCToAVS + +# Initialize the model +model = BHCToAVS() + +# Example Brief Hospital Course (BHC) text with common clinical abbreviations generated synthetically via ChatGPT 5.1 +bhc = ( + "Pt admitted with acute onset severe epigastric pain and hypotension. " + "Labs notable for elevated lactate, WBC 18K, mild AST/ALT elevation, and Cr 1.4 (baseline 0.9). " + "CT A/P w/ contrast demonstrated peripancreatic fat stranding c/w acute pancreatitis; " + "no necrosis or peripancreatic fluid collection. " + "Pt received aggressive IVFs, electrolyte repletion, IV analgesia, and NPO status initially. " + "Serial abd exams remained benign with no rebound or guarding. " + "BP stabilized, lactate downtrended, and pt tolerated ADAT to low-fat diet without recurrence of sx. " + "Discharged in stable condition w/ instructions for GI f/u and outpatient CMP in 1 week." +) + +# Generate a patient-friendly After-Visit Summary +print(model.predict(bhc)) + +# Expected output: A simplified, patient-friendly summary explaining the hospital stay without medical jargon. \ No newline at end of file diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5c3683bc1..659bb7f88 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -1,6 +1,7 @@ from .adacare import AdaCare, AdaCareLayer from .agent import Agent, AgentLayer from .base_model import BaseModel +from .bhc_to_avs import BHCToAVS from .cnn import CNN, CNNLayer from .concare import ConCare, ConCareLayer from .contrawr import ContraWR, ResBlock2D @@ -26,4 +27,4 @@ from .transformer import Transformer, TransformerLayer from .transformers_model import TransformersModel from .vae import VAE -from .sdoh import SdohClassifier \ No newline at end of file +from .sdoh import SdohClassifier diff --git a/pyhealth/models/bhc_to_avs.py b/pyhealth/models/bhc_to_avs.py new file mode 100644 index 000000000..d12a2744a --- /dev/null +++ b/pyhealth/models/bhc_to_avs.py @@ -0,0 +1,155 @@ +""" +BHC to AVS Model + +Generates patient-friendly After Visit Summaries (AVS) from Brief Hospital Course (BHC) +notes using a fine-tuned Mistral 7B model with a LoRA adapter. + +This model requires access to a gated Hugging Face repository. Provide credentials +using one of the following methods: + +1. Set an environment variable: + export HF_TOKEN="hf_..." + +2. Pass the token explicitly when creating the model: + model = BHCToAVS(hf_token="hf_...") + +If no token is provided and the repository is gated, a RuntimeError will be raised. +""" + +# Author: Charan Williams +# NetID: charanw2 + + +from dataclasses import dataclass, field +import os +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline +from peft import PeftModelForCausalLM +from pyhealth.models.base_model import BaseModel + +# System prompt used during inference +_SYSTEM_PROMPT = ( + "You are a clinical summarization model. Produce accurate, patient-friendly summaries " + "using only information from the doctor's note. Do not add new details.\n\n" +) + +# Prompt used during fine-tuning +_PROMPT = ( + "Summarize for the patient what happened during the hospital stay based on this doctor's note:\n" + "{bhc}\n\n" + "Summary for the patient:\n" +) + + +@dataclass +class BHCToAVS(BaseModel): + """ + BHCToAVS is a model class designed to generate After-Visit Summaries (AVS) from + Brief Hospital Course (BHC) notes using a pre-trained base model and a LoRA adapter. + + Attributes: + base_model_id (str): The HuggingFace repository identifier for the base + Mistral 7B model. + adapter_model_id (str): The HuggingFace repository identifier for the LoRA + adapter weights. + hf_token (str | None): HuggingFace access token for gated repositories. + + Methods: + _get_pipeline(): Creates and caches a HuggingFace text-generation pipeline + using the base model and LoRA adapter. + predict(bhc_text: str) -> str: Generates a patient-friendly After-Visit + Summary (AVS) from a given Brief Hospital Course (BHC) note. + """ + + base_model_id: str = field(default="mistralai/Mistral-7B-Instruct-v0.3") + adapter_model_id: str = field(default="williach31/mistral-7b-bhc-to-avs-lora") + hf_token: str | None = None + + def __post_init__(self): + # Ensure nn.Module (via BaseModel) is initialized + super().__init__() + + def _resolve_token(self): + return self.hf_token or os.getenv("HF_TOKEN") + + def _get_pipeline(self): + """Create and cache the text-generation pipeline.""" + if not hasattr(self, "_pipeline"): + # Resolve HuggingFace token + token = self._resolve_token() + + # Throw RuntimeError if token is not found + if token is None: + raise RuntimeError( + "Hugging Face token not found. This model requires access to a gated repository.\n\n" + "Set the HF_TOKEN environment variable or pass hf_token=... when initializing BHCToAVS.\n\n" + "Example:\n" + " export HF_TOKEN='hf_...'\n" + " model = BHCToAVS()\n" + ) + + # Load base model + base = AutoModelForCausalLM.from_pretrained( + self.base_model_id, + torch_dtype=torch.bfloat16, + device_map="auto", + token=token, + ) + + # Load LoRA adapter + model = PeftModelForCausalLM.from_pretrained( + base, + self.adapter_model_id, + torch_dtype=torch.bfloat16, + token=token, + ) + + tokenizer = AutoTokenizer.from_pretrained(self.base_model_id, token=token) + + # Create HF pipeline + self._pipeline = pipeline( + "text-generation", + model=model, + tokenizer=tokenizer, + model_kwargs={"torch_dtype": torch.bfloat16}, + ) + + return self._pipeline + + def predict(self, bhc_text: str) -> str: + """ + Generate an After-Visit Summary (AVS) from a Brief Hospital Course (BHC) note. + + Parameters + ---------- + bhc_text : str + Raw BHC text. + + Returns + ------- + str + Patient-friendly summary. + """ + + # Validate input to provide clear error messages and avoid unexpected failures. + if not isinstance(bhc_text, str): + raise TypeError( + f"bhc_text must be a string, got {type(bhc_text).__name__}." + ) + if not bhc_text.strip(): + raise ValueError("bhc_text must be a non-empty string.") + prompt = _SYSTEM_PROMPT + _PROMPT.format(bhc=bhc_text) + + pipe = self._get_pipeline() + eos_id = pipe.tokenizer.eos_token_id + outputs = pipe( + prompt, + max_new_tokens=512, + temperature=0.0, + eos_token_id=eos_id, + pad_token_id=eos_id, + return_full_text=False, + ) + + # Output is a single text string + return outputs[0]["generated_text"].strip() diff --git a/tests/core/test_bhc_to_avs.py b/tests/core/test_bhc_to_avs.py new file mode 100644 index 000000000..e1951a57d --- /dev/null +++ b/tests/core/test_bhc_to_avs.py @@ -0,0 +1,96 @@ +""" +Unit tests for the BHCToAVS model. + +These tests validate both the unit-level behavior of the predict method +(using a mocked pipeline) and an optional integration path that runs +against the real Hugging Face model when credentials are provided. +""" + +import os +import unittest +from unittest.mock import patch + +from tests.base import BaseTestCase +from pyhealth.models.bhc_to_avs import BHCToAVS + + +class _DummyPipeline: + """ + Lightweight mock pipeline used to simulate Hugging Face text generation. + + This avoids downloading models or requiring authentication during unit tests. + """ + + def __call__(self, prompt, **kwargs): + """Return a fixed, deterministic generated response.""" + return [ + { + "generated_text": "Your pain improved with supportive care and you were discharged in good condition." + } + ] + + +class TestBHCToAVS(BaseTestCase): + """Unit and integration tests for the BHCToAVS model.""" + + def setUp(self): + """Set a deterministic random seed before each test.""" + self.set_random_seed() + + def test_predict_unit(self): + """ + Test the predict method using a mocked pipeline. + + This test verifies that: + - The model returns a string output + - The output is non-empty + - The output differs from the input text + """ + + bhc_text = ( + "Patient admitted with abdominal pain. Imaging showed no acute findings. " + "Pain improved with supportive care and the patient was discharged in stable condition." + ) + + with patch.object(BHCToAVS, "_get_pipeline", return_value=_DummyPipeline()): + model = BHCToAVS() + summary = model.predict(bhc_text) + + # Output must be type str + self.assertIsInstance(summary, str) + + # Output should not be empty + self.assertGreater(len(summary.strip()), 0) + + # Output should be different from input + self.assertNotIn(bhc_text[:40], summary) + + @unittest.skipUnless( + os.getenv("RUN_BHC_TO_AVS_INTEGRATION") == "1" and os.getenv("HF_TOKEN"), + "Integration test disabled. Set RUN_BHC_TO_AVS_INTEGRATION=1 and HF_TOKEN to enable.", + ) + def test_predict_integration(self): + """ + Integration test for the BHCToAVS model. + + This test runs the full inference pipeline using the real Hugging Face model. + It requires the HF_TOKEN environment variable to be set and is skipped by default. + """ + + # For Mistral weights, you will need HF_TOKEN set in the environment. + bhc_text = ( + "Patient admitted with abdominal pain. Imaging showed no acute findings. " + "Pain improved with supportive care and the patient was discharged in stable condition." + ) + + model = BHCToAVS() + summary = model.predict(bhc_text) + + # Output must be type str + self.assertIsInstance(summary, str) + + # Output should not be empty + self.assertGreater(len(summary.strip()), 0) + + # Output should be different from input + self.assertNotIn(bhc_text[:40], summary)