Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 65 additions & 68 deletions README.rst

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/api/tasks/pyhealth.tasks.readmission_prediction.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
pyhealth.tasks.readmission_prediction
=======================================

.. autofunction:: pyhealth.tasks.readmission_prediction.readmission_prediction_mimic3_fn
.. autofunction:: pyhealth.tasks.readmission_prediction.ReadmissionPredictionMIMIC3
.. autofunction:: pyhealth.tasks.readmission_prediction.readmission_prediction_mimic4_fn
.. autofunction:: pyhealth.tasks.readmission_prediction.readmission_prediction_eicu_fn
.. autofunction:: pyhealth.tasks.readmission_prediction.readmission_prediction_eicu_fn2
Expand Down
7 changes: 3 additions & 4 deletions examples/readmission_mimic3_fairness.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.tasks import readmission_prediction_mimic3_fn
from pyhealth.tasks import ReadmissionPredictionMIMIC3
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.metrics import fairness_metrics_fn
from pyhealth.models import Transformer
Expand All @@ -11,11 +11,10 @@
root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/",
tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
)
base_dataset.stat()
base_dataset.stats()

# STEP 2: set task
sample_dataset = base_dataset.set_task(readmission_prediction_mimic3_fn)
sample_dataset.stat()
sample_dataset = base_dataset.set_task(ReadmissionPredictionMIMIC3(exclude_minors=False)) # Must include minors to get any readmission samples on the synthetic dataset

train_dataset, val_dataset, test_dataset = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])
train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
Expand Down
17 changes: 5 additions & 12 deletions examples/readmission_mimic3_rnn.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.models import RNN
from pyhealth.tasks import readmission_prediction_mimic3_fn
from pyhealth.tasks import ReadmissionPredictionMIMIC3
from pyhealth.trainer import Trainer

# STEP 1: load data
base_dataset = MIMIC3Dataset(
root="/srv/local/data/physionet.org/files/mimiciii/1.4",
root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III",
tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC", "NDC": "ATC"},
dev=False,
refresh_cache=True,
)
base_dataset.stat()
base_dataset.stats()

# STEP 2: set task
sample_dataset = base_dataset.set_task(readmission_prediction_mimic3_fn)
sample_dataset.stat()
sample_dataset = base_dataset.set_task(ReadmissionPredictionMIMIC3(exclude_minors=False)) # Must include minors to get any readmission samples on the synthetic dataset

train_dataset, val_dataset, test_dataset = split_by_patient(
sample_dataset, [0.8, 0.1, 0.1]
Expand All @@ -28,17 +24,14 @@
# STEP 3: define model
model = RNN(
dataset=sample_dataset,
feature_keys=["conditions", "procedures", "drugs"],
label_key="label",
mode="binary",
)

# STEP 4: define trainer
trainer = Trainer(model=model)
trainer.train(
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
epochs=50,
epochs=1,
monitor="roc_auc",
)

Expand Down
2 changes: 1 addition & 1 deletion pyhealth/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@
from .patient_linkage import patient_linkage_mimic3_fn
from .readmission_30days_mimic4 import Readmission30DaysMIMIC4
from .readmission_prediction import (
ReadmissionPredictionMIMIC3,
readmission_prediction_eicu_fn,
readmission_prediction_eicu_fn2,
readmission_prediction_mimic3_fn,
readmission_prediction_mimic4_fn,
readmission_prediction_omop_fn,
)
Expand Down
176 changes: 105 additions & 71 deletions pyhealth/tasks/readmission_prediction.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,114 @@
from pyhealth.data import Patient, Visit
from datetime import datetime, timedelta
from typing import Dict, List

from pyhealth.data import Event, Patient
from pyhealth.tasks import BaseTask

# TODO: time_window cannot be passed in to base_dataset
def readmission_prediction_mimic3_fn(patient: Patient, time_window=15):
"""Processes a single patient for the readmission prediction task.

Readmission prediction aims at predicting whether the patient will be readmitted
into hospital within time_window days based on the clinical information from
current visit (e.g., conditions and procedures).

Args:
patient: a Patient object
time_window: the time window threshold (gap < time_window means label=1 for
the task)

Returns:
samples: a list of samples, each sample is a dict with patient_id, visit_id,
and other task-specific attributes as key

Note that we define the task as a binary classification task.

Examples:
>>> from pyhealth.datasets import MIMIC3Dataset
>>> mimic3_base = MIMIC3Dataset(
... root="/srv/local/data/physionet.org/files/mimiciii/1.4",
... tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
... code_mapping={"ICD9CM": "CCSCM"},
... )
>>> from pyhealth.tasks import readmission_prediction_mimic3_fn
>>> mimic3_sample = mimic3_base.set_task(readmission_prediction_mimic3_fn)
>>> mimic3_sample.samples[0]
[{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '19', '122', '98', '663', '58', '51']], 'procedures': [['1']], 'label': 1}]
class ReadmissionPredictionMIMIC3(BaseTask):
"""
samples = []
Readmission prediction on the MIMIC3 dataset.

# we will drop the last visit
for i in range(len(patient) - 1):
visit: Visit = patient[i]
next_visit: Visit = patient[i + 1]
This task aims at predicting whether the patient will be readmitted into hospital within
a specified number of days based on clinical information from the current visit.

# get time difference between current visit and next visit
time_diff = (next_visit.encounter_time - visit.encounter_time).days
readmission_label = 1 if time_diff < time_window else 0
Attributes:
task_name (str): The name of the task.
input_schema (Dict[str, str]): The schema for the task input.
output_schema (Dict[str, str]): The schema for the task output.
"""
task_name: str = "ReadmissionPredictionMIMIC3"
input_schema: Dict[str, str] = {"conditions": "sequence", "procedures": "sequence", "drugs": "sequence"}
output_schema: Dict[str, str] = {"readmission": "binary"}

def __init__(self, window: timedelta=timedelta(days=15), exclude_minors: bool=True) -> None:
"""
Initializes the task object.

Args:
window (timedelta): If two admissions are closer than this window, it is considered a readmission. Defaults to 15 days.
exclude_minors (bool): Whether to exclude visits where the patient was under 18 years old. Defaults to True.
"""
self.window = window
self.exclude_minors = exclude_minors

def __call__(self, patient: Patient) -> List[Dict]:
"""
Generates binary classification data samples for a single patient.

Visits with no conditions OR no procedures OR no drugs are excluded from the output but are still used to calculate readmission for prior visits.

Args:
patient (Patient): A patient object.

Returns:
List[Dict]: A list containing a dictionary for each patient visit with:
- 'visit_id': MIMIC3 hadm_id.
- 'patient_id': MIMIC3 subject_id.
- 'conditions': MIMIC3 diagnoses_icd table ICD-9 codes.
- 'procedures': MIMIC3 procedures_icd table ICD-9 codes.
- 'drugs': MIMIC3 prescriptions table drug column entries.
- 'readmission': binary label.

Raises:
ValueError: If any `str` to `datetime` conversions fail.
"""
patients: List[Event] = patient.get_events(event_type="patients")
assert len(patients) == 1

if self.exclude_minors:
try:
dob = datetime.strptime(patients[0].dob, "%Y-%m-%d %H:%M:%S")
except ValueError:
dob = datetime.strptime(patients[0].dob, "%Y-%m-%d")

admissions: List[Event] = patient.get_events(event_type="admissions")
if len(admissions) < 2:
return []

samples = []
for i in range(len(admissions) - 1): # Skip the last admission since we need a "next" admission
if self.exclude_minors:
age = admissions[i].timestamp.year - dob.year
age = age-1 if ((admissions[i].timestamp.month, admissions[i].timestamp.day) < (dob.month, dob.day)) else age
if age < 18:
continue

filter = ("hadm_id", "==", admissions[i].hadm_id)

diagnoses = patient.get_events(event_type="diagnoses_icd", filters=[filter])
diagnoses = [event.icd9_code for event in diagnoses]
if len(diagnoses) == 0:
continue

procedures = patient.get_events(event_type="procedures_icd", filters=[filter])
procedures = [event.icd9_code for event in procedures]
if len(procedures) == 0:
continue

prescriptions = patient.get_events(event_type="prescriptions", filters=[filter])
prescriptions = [event.drug for event in prescriptions]
if len(prescriptions) == 0:
continue

try:
discharge_time = datetime.strptime(admissions[i].dischtime, "%Y-%m-%d %H:%M:%S")
except ValueError:
discharge_time = datetime.strptime(admissions[i].dischtime, "%Y-%m-%d")

readmission = int((admissions[i + 1].timestamp - discharge_time) < self.window)

samples.append(
{
"visit_id": admissions[i].hadm_id,
"patient_id": patient.patient_id,
"conditions": diagnoses,
"procedures": procedures,
"drugs": prescriptions,
"readmission": readmission,
}
)

conditions = visit.get_code_list(table="DIAGNOSES_ICD")
procedures = visit.get_code_list(table="PROCEDURES_ICD")
drugs = visit.get_code_list(table="PRESCRIPTIONS")
# exclude: visits without condition, procedure, or drug code
if len(conditions) * len(procedures) * len(drugs) == 0:
continue
# TODO: should also exclude visit with age < 18
samples.append(
{
"visit_id": visit.visit_id,
"patient_id": patient.patient_id,
"conditions": [conditions],
"procedures": [procedures],
"drugs": [drugs],
"label": readmission_label,
}
)
# no cohort selection
return samples
return samples


def readmission_prediction_mimic4_fn(patient: Patient, time_window=15):
Expand Down Expand Up @@ -328,19 +375,6 @@ def readmission_prediction_omop_fn(patient: Patient, time_window=15):


if __name__ == "__main__":
from pyhealth.datasets import MIMIC3Dataset

base_dataset = MIMIC3Dataset(
root="/srv/local/data/physionet.org/files/mimiciii/1.4",
tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
dev=True,
code_mapping={"ICD9CM": "CCSCM", "NDC": "ATC"},
refresh_cache=False,
)
sample_dataset = base_dataset.set_task(task_fn=readmission_prediction_mimic3_fn)
sample_dataset.stat()
print(sample_dataset.available_keys)

from pyhealth.datasets import MIMIC4Dataset

base_dataset = MIMIC4Dataset(
Expand Down
Binary file modified test-resources/core/mimic3demo/ADMISSIONS.csv.gz
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only change to this file was to change the admission and discharge times for subject_id 10088's first admission such that the admission was less than 18 years after the subject's date of birth.

Binary file not shown.
Loading