-
Notifications
You must be signed in to change notification settings - Fork 556
Add BHCToAVS model for patient-friendly summaries #730
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
c9afeb1
66f37d9
d7cf144
c0747cd
ad3842a
18fbe2c
ded3e2c
24d77be
6afb92f
02e33e2
982dfa1
4ad294b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| pyhealth.models.BHCToAVS | ||
| ======================== | ||
|
|
||
| BHCToAVS | ||
| ------------------------------ | ||
|
|
||
| .. autoclass:: pyhealth.models.bhc_to_avs.BHCToAVS | ||
| :members: | ||
| :inherited-members: | ||
| :show-inheritance: | ||
| :undoc-members: |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| from pyhealth.models.bhc_to_avs import BHCToAVS | ||
|
|
||
| # Initialize the model | ||
| model = BHCToAVS() | ||
|
|
||
| # Example Brief Hospital Course (BHC) text with common clinical abbreviations generated synthetically via ChatGPT 5.1 | ||
| bhc = ( | ||
| "Pt admitted with acute onset severe epigastric pain and hypotension. " | ||
| "Labs notable for elevated lactate, WBC 18K, mild AST/ALT elevation, and Cr 1.4 (baseline 0.9). " | ||
| "CT A/P w/ contrast demonstrated peripancreatic fat stranding c/w acute pancreatitis; " | ||
| "no necrosis or peripancreatic fluid collection. " | ||
| "Pt received aggressive IVFs, electrolyte repletion, IV analgesia, and NPO status initially. " | ||
| "Serial abd exams remained benign with no rebound or guarding. " | ||
| "BP stabilized, lactate downtrended, and pt tolerated ADAT to low-fat diet without recurrence of sx. " | ||
| "Discharged in stable condition w/ instructions for GI f/u and outpatient CMP in 1 week." | ||
| ) | ||
|
|
||
| # Generate a patient-friendly After-Visit Summary | ||
| print(model.predict(bhc)) | ||
|
|
||
| # Expected output: A simplified, patient-friendly summary explaining the hospital stay without medical jargon. |
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,155 @@ | ||||||||
| """ | ||||||||
| BHC to AVS Model | ||||||||
|
|
||||||||
| Generates patient-friendly After Visit Summaries (AVS) from Brief Hospital Course (BHC) | ||||||||
| notes using a fine-tuned Mistral 7B model with a LoRA adapter. | ||||||||
|
|
||||||||
| This model requires access to a gated Hugging Face repository. Provide credentials | ||||||||
| using one of the following methods: | ||||||||
|
|
||||||||
| 1. Set an environment variable: | ||||||||
| export HF_TOKEN="hf_..." | ||||||||
|
|
||||||||
| 2. Pass the token explicitly when creating the model: | ||||||||
| model = BHCToAVS(hf_token="hf_...") | ||||||||
|
|
||||||||
| If no token is provided and the repository is gated, a RuntimeError will be raised. | ||||||||
| """ | ||||||||
|
|
||||||||
| # Author: Charan Williams | ||||||||
| # NetID: charanw2 | ||||||||
|
|
||||||||
|
|
||||||||
| from dataclasses import dataclass, field | ||||||||
| import os | ||||||||
| import torch | ||||||||
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | ||||||||
| from peft import PeftModelForCausalLM | ||||||||
| from pyhealth.models.base_model import BaseModel | ||||||||
|
|
||||||||
| # System prompt used during inference | ||||||||
| _SYSTEM_PROMPT = ( | ||||||||
| "You are a clinical summarization model. Produce accurate, patient-friendly summaries " | ||||||||
| "using only information from the doctor's note. Do not add new details.\n\n" | ||||||||
| ) | ||||||||
|
|
||||||||
| # Prompt used during fine-tuning | ||||||||
| _PROMPT = ( | ||||||||
| "Summarize for the patient what happened during the hospital stay based on this doctor's note:\n" | ||||||||
| "{bhc}\n\n" | ||||||||
| "Summary for the patient:\n" | ||||||||
| ) | ||||||||
|
|
||||||||
|
|
||||||||
| @dataclass | ||||||||
| class BHCToAVS(BaseModel): | ||||||||
charanw marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
| """ | ||||||||
| BHCToAVS is a model class designed to generate After-Visit Summaries (AVS) from | ||||||||
| Brief Hospital Course (BHC) notes using a pre-trained base model and a LoRA adapter. | ||||||||
|
|
||||||||
| Attributes: | ||||||||
| base_model_id (str): The HuggingFace repository identifier for the base | ||||||||
| Mistral 7B model. | ||||||||
| adapter_model_id (str): The HuggingFace repository identifier for the LoRA | ||||||||
| adapter weights. | ||||||||
| hf_token (str | None): HuggingFace access token for gated repositories. | ||||||||
|
|
||||||||
| Methods: | ||||||||
| _get_pipeline(): Creates and caches a HuggingFace text-generation pipeline | ||||||||
| using the base model and LoRA adapter. | ||||||||
| predict(bhc_text: str) -> str: Generates a patient-friendly After-Visit | ||||||||
| Summary (AVS) from a given Brief Hospital Course (BHC) note. | ||||||||
| """ | ||||||||
|
|
||||||||
| base_model_id: str = field(default="mistralai/Mistral-7B-Instruct-v0.3") | ||||||||
| adapter_model_id: str = field(default="williach31/mistral-7b-bhc-to-avs-lora") | ||||||||
| hf_token: str | None = None | ||||||||
|
Comment on lines
+64
to
+66
|
||||||||
|
|
||||||||
| def __post_init__(self): | ||||||||
| # Ensure nn.Module (via BaseModel) is initialized | ||||||||
| super().__init__() | ||||||||
|
|
||||||||
|
Comment on lines
+68
to
+71
|
||||||||
| def __post_init__(self): | |
| # Ensure nn.Module (via BaseModel) is initialized | |
| super().__init__() |
Copilot
AI
Dec 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message spans lines 87-92 and provides good guidance, but the message could be more specific about where to obtain a HuggingFace token. Consider adding a link to the HuggingFace token generation page (https://huggingface.co/settings/tokens) to help users quickly resolve this issue.
Copilot
AI
Dec 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _get_pipeline method loads a large 7B parameter model without any explicit guidance on resource requirements or expected load time. Consider adding documentation (either in the class docstring or method docstring) about: (1) expected memory requirements (GPU/CPU), (2) approximate model loading time, and (3) recommended hardware specifications. This would help users understand the resource implications before attempting to use the model.
charanw marked this conversation as resolved.
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,96 @@ | ||
| """ | ||
| Unit tests for the BHCToAVS model. | ||
|
|
||
| These tests validate both the unit-level behavior of the predict method | ||
| (using a mocked pipeline) and an optional integration path that runs | ||
| against the real Hugging Face model when credentials are provided. | ||
| """ | ||
|
|
||
| import os | ||
| import unittest | ||
| from unittest.mock import patch | ||
|
|
||
| from tests.base import BaseTestCase | ||
| from pyhealth.models.bhc_to_avs import BHCToAVS | ||
|
|
||
|
|
||
| class _DummyPipeline: | ||
| """ | ||
| Lightweight mock pipeline used to simulate Hugging Face text generation. | ||
|
|
||
| This avoids downloading models or requiring authentication during unit tests. | ||
| """ | ||
|
|
||
| def __call__(self, prompt, **kwargs): | ||
| """Return a fixed, deterministic generated response.""" | ||
| return [ | ||
| { | ||
| "generated_text": "Your pain improved with supportive care and you were discharged in good condition." | ||
| } | ||
| ] | ||
|
|
||
|
|
||
| class TestBHCToAVS(BaseTestCase): | ||
| """Unit and integration tests for the BHCToAVS model.""" | ||
|
|
||
| def setUp(self): | ||
| """Set a deterministic random seed before each test.""" | ||
| self.set_random_seed() | ||
|
|
||
| def test_predict_unit(self): | ||
| """ | ||
| Test the predict method using a mocked pipeline. | ||
|
|
||
| This test verifies that: | ||
| - The model returns a string output | ||
| - The output is non-empty | ||
| - The output differs from the input text | ||
| """ | ||
|
|
||
| bhc_text = ( | ||
| "Patient admitted with abdominal pain. Imaging showed no acute findings. " | ||
| "Pain improved with supportive care and the patient was discharged in stable condition." | ||
| ) | ||
|
|
||
| with patch.object(BHCToAVS, "_get_pipeline", return_value=_DummyPipeline()): | ||
| model = BHCToAVS() | ||
| summary = model.predict(bhc_text) | ||
|
|
||
| # Output must be type str | ||
| self.assertIsInstance(summary, str) | ||
|
|
||
| # Output should not be empty | ||
| self.assertGreater(len(summary.strip()), 0) | ||
|
|
||
| # Output should be different from input | ||
| self.assertNotIn(bhc_text[:40], summary) | ||
|
||
|
|
||
| @unittest.skipUnless( | ||
| os.getenv("RUN_BHC_TO_AVS_INTEGRATION") == "1" and os.getenv("HF_TOKEN"), | ||
| "Integration test disabled. Set RUN_BHC_TO_AVS_INTEGRATION=1 and HF_TOKEN to enable.", | ||
| ) | ||
| def test_predict_integration(self): | ||
| """ | ||
| Integration test for the BHCToAVS model. | ||
|
|
||
| This test runs the full inference pipeline using the real Hugging Face model. | ||
| It requires the HF_TOKEN environment variable to be set and is skipped by default. | ||
| """ | ||
|
|
||
| # For Mistral weights, you will need HF_TOKEN set in the environment. | ||
| bhc_text = ( | ||
| "Patient admitted with abdominal pain. Imaging showed no acute findings. " | ||
| "Pain improved with supportive care and the patient was discharged in stable condition." | ||
| ) | ||
|
|
||
| model = BHCToAVS() | ||
| summary = model.predict(bhc_text) | ||
|
|
||
| # Output must be type str | ||
| self.assertIsInstance(summary, str) | ||
|
|
||
| # Output should not be empty | ||
| self.assertGreater(len(summary.strip()), 0) | ||
|
|
||
| # Output should be different from input | ||
| self.assertNotIn(bhc_text[:40], summary) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment on line 36 states "Prompt used during fine-tuning" but this prompt is actually used during inference (as seen on line 146). If this prompt was indeed used during fine-tuning and is also being reused during inference, the comment should clarify this. If it's only used during inference, the comment is misleading and should be corrected to "Prompt template used during inference" or similar.