-
Notifications
You must be signed in to change notification settings - Fork 557
Add BHCToAVS model for patient-friendly summaries #730
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
c9afeb1
66f37d9
d7cf144
c0747cd
ad3842a
18fbe2c
ded3e2c
24d77be
6afb92f
02e33e2
982dfa1
4ad294b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| pyhealth.models.bhc_to_avs | ||
| ========================== | ||
|
|
||
| BHCToAVS | ||
| ------------------------------ | ||
|
|
||
| .. autoclass:: pyhealth.models.bhc_to_avs.BHCToAVS | ||
| :members: | ||
| :inherited-members: | ||
| :show-inheritance: | ||
| :undoc-members: | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| from pyhealth.models.bhc_to_avs import BHCToAVS | ||
|
|
||
| # Initialize the model | ||
| model = BHCToAVS() | ||
|
|
||
| # Example Brief Hospital Course (BHC) text with common clinical abbreviations generated synthetically via ChatGPT 5.1 | ||
| bhc = ( | ||
| "Pt admitted with acute onset severe epigastric pain and hypotension. " | ||
| "Labs notable for elevated lactate, WBC 18K, mild AST/ALT elevation, and Cr 1.4 (baseline 0.9). " | ||
| "CT A/P w/ contrast demonstrated peripancreatic fat stranding c/w acute pancreatitis; " | ||
| "no necrosis or peripancreatic fluid collection. " | ||
| "Pt received aggressive IVFs, electrolyte repletion, IV analgesia, and NPO status initially. " | ||
| "Serial abd exams remained benign with no rebound or guarding. " | ||
| "BP stabilized, lactate downtrended, and pt tolerated ADAT to low-fat diet without recurrence of sx. " | ||
| "Discharged in stable condition w/ instructions for GI f/u and outpatient CMP in 1 week." | ||
| ) | ||
|
|
||
| # Generate a patient-friendly After-Visit Summary | ||
| print(model.predict(bhc)) | ||
|
|
||
| # Expected output: A simplified, patient-friendly summary explaining the hospital stay without medical jargon. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| # Author: Charan Williams | ||
| # NetID: charanw2 | ||
| # Description: Converts clinical brief hospital course (BHC) data to after visit summaries using a fine-tuned Mistral 7B model. | ||
charanw marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| from typing import Dict, Any | ||
charanw marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| from dataclasses import dataclass, field | ||
| import torch | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | ||
| from peft import PeftModelForCausalLM | ||
| from pyhealth.models.base_model import BaseModel | ||
|
|
||
| _PROMPT = """Summarize for the patient what happened during the hospital stay: | ||
|
|
||
| ### Brief Hospital Course: | ||
| {bhc} | ||
|
|
||
| ### Patient Summary: | ||
| """ | ||
|
|
||
| # System prompt used during inference | ||
| _SYSTEM_PROMPT = ( | ||
| "You are a clinical summarization model. Produce accurate, patient-friendly summaries " | ||
| "using only information from the doctor's note. Do not add new details.\n\n" | ||
| ) | ||
|
|
||
| # Prompt used during fine-tuning | ||
| _PROMPT = ( | ||
| "Summarize for the patient what happened during the hospital stay based on this doctor's note:\n" | ||
| "{bhc}\n\n" | ||
| "Summary for the patient:\n" | ||
| ) | ||
charanw marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Comment on lines
+36
to
+41
|
||
|
|
||
| @dataclass | ||
| class BHCToAVS(BaseModel): | ||
charanw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| base_model_id: str = field(default="mistralai/Mistral-7B-Instruct") | ||
| """HuggingFace repo containing the base Mistral 7B model.""" | ||
|
|
||
| adapter_model_id: str = field(default="williach31/mistral-7b-bhc-to-avs-lora") | ||
| """HuggingFace repo containing only LoRA adapter weights.""" | ||
|
|
||
charanw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| def _get_pipeline(self): | ||
| """Create and cache the text-generation pipeline.""" | ||
| if not hasattr(self, "_pipeline"): | ||
| # Load base model | ||
| base = AutoModelForCausalLM.from_pretrained( | ||
| self.base_model_id, | ||
| torch_dtype=torch.bfloat16, | ||
| device_map="auto" | ||
| ) | ||
|
|
||
| # Load LoRA adapter | ||
| model = PeftModelForCausalLM.from_pretrained( | ||
| base, | ||
| self.adapter_model_id, | ||
| torch_dtype=torch.bfloat16 | ||
| ) | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(self.base_model_id) | ||
|
|
||
| # Create HF pipeline | ||
| self._pipeline = pipeline( | ||
| "text-generation", | ||
| model=model, | ||
| tokenizer=tokenizer, | ||
| device_map="auto", | ||
charanw marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| model_kwargs={"torch_dtype": torch.bfloat16} | ||
| ) | ||
|
|
||
| return self._pipeline | ||
|
|
||
| def predict(self, bhc_text: str) -> str: | ||
| """ | ||
| Generate an After-Visit Summary (AVS) from a Brief Hospital Course (BHC) note. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| bhc_text : str | ||
| Raw BHC text. | ||
|
|
||
| Returns | ||
| ------- | ||
| str | ||
| Patient-friendly summary. | ||
| """ | ||
|
|
||
charanw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| prompt = _SYSTEM_PROMPT + _PROMPT.format(bhc=bhc_text) | ||
|
|
||
| pipe = self._get_pipeline() | ||
| outputs = pipe( | ||
| prompt, | ||
| max_new_tokens=512, | ||
| temperature=0.0, | ||
| eos_token_id=[pipe.tokenizer.eos_token_id], | ||
| pad_token_id=pipe.tokenizer.eos_token_id, | ||
charanw marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) | ||
|
|
||
| # Output is a single text string | ||
| return outputs[0]["generated_text"].strip() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| from tests.base import BaseTestCase | ||
| from pyhealth.models.bhc_to_avs import BHCToAVS | ||
|
|
||
|
|
||
| class TestBHCToAVS(BaseTestCase): | ||
| """Unit tests for the BHCToAVS model.""" | ||
|
|
||
| def setUp(self): | ||
| self.set_random_seed() | ||
|
|
||
| def test_predict(self): | ||
| """Test the predict method of BHCToAVS.""" | ||
| bhc_text = ( | ||
| "Patient admitted with abdominal pain. Imaging showed no acute findings. " | ||
| "Pain improved with supportive care and the patient was discharged in stable condition." | ||
| ) | ||
| model = BHCToAVS() | ||
| try: | ||
|
|
||
| summary = model.predict(bhc_text) | ||
|
|
||
| # Output must be type str | ||
| self.assertIsInstance(summary, str) | ||
|
|
||
| # Output should not be empty | ||
| self.assertGreater(len(summary.strip()), 0) | ||
|
|
||
| # Output should be different from input | ||
| self.assertNotIn(bhc_text[:40], summary) | ||
|
|
||
| except OSError as e: | ||
| # Allow test to pass if model download fails on e.g. on GitHub workflows | ||
Logiquo marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if "gated repo" in str(e).lower() or "404" in str(e): | ||
| pass | ||
| else: | ||
| raise e | ||
Uh oh!
There was an error while loading. Please reload this page.