Skip to content

Commit 24e06fa

Browse files
authored
Merge pull request #6 from rvandewater/task-inspection
Task inspection
2 parents 6465dda + 275c31a commit 24e06fa

File tree

5 files changed

+88
-9
lines changed

5 files changed

+88
-9
lines changed

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ include src/MEDS_Inspect/assets/MIMIC-IV-DEMO-MEDS/data/train/*
22
include src/MEDS_Inspect/assets/MIMIC-IV-DEMO-MEDS/data/tuning/*
33
include src/MEDS_Inspect/assets/MIMIC-IV-DEMO-MEDS/data/held_out/*
44
include src/MEDS_Inspect/assets/MIMIC-IV-DEMO-MEDS/metadata/*
5+
include src/MEDS_Inspect/assets/MIMIC-IV-DEMO-MEDS/tasks/*

src/MEDS_Inspect/app.py

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import importlib.resources as pkg_resources
2+
import os
23

4+
import pandas as pd
35
import plotly.express as px
46
import polars as pl
57
from dash import Dash, Input, Output, State, dash_table, dcc, html
@@ -148,6 +150,7 @@ def render_content(tab, file_path):
148150
# codes = top_codes['code'].unique().to_list()
149151

150152
numerical_codes = numerical_code_data.select("code").unique().collect()["code"].to_list()
153+
151154
if tab == "tab-1":
152155
fig_code_count_years = px.histogram(
153156
code_count_years, x="Date", y="Amount of codes", nbins=len(code_count_years)
@@ -278,13 +281,22 @@ def render_content(tab, file_path):
278281
html.H2(children="Codes over time for a single patient", style={"textAlign": "center"}),
279282
dcc.Dropdown(
280283
id="patient-dropdown",
281-
options=[{"label": pid, "value": pid} for pid in subject_ids],
284+
options=[{"label": pid, "value": pid} for pid in subject_ids[:1000]],
282285
placeholder="Select a patient ID",
286+
value=None,
287+
multi=False,
288+
searchable=True,
289+
clearable=True,
290+
style={"width": "100%"},
291+
),
292+
dcc.Dropdown(
293+
id="task-dropdown",
294+
placeholder="Select a task",
283295
),
284296
dcc.Loading(
285297
id="loading-fig-patient-codes",
286298
type="default",
287-
children=dcc.Graph(id="fig_patient_codes", style={"width": "90hh", "height": "50vh"}),
299+
children=dcc.Graph(id="fig_patient_codes", style={"width": "90hh", "height": "90vh"}),
288300
),
289301
],
290302
style=card_style,
@@ -530,25 +542,84 @@ def update_top_codes(top_n, scale):
530542
)
531543
return fig_top_codes
532544

533-
@app.callback(Output("fig_patient_codes", "figure"), Input("patient-dropdown", "value"))
534-
def update_patient_codes(patient_id):
545+
@app.callback(
546+
Output("fig_patient_codes", "figure"),
547+
Output("task-dropdown", "options"),
548+
Input("patient-dropdown", "value"),
549+
Input("hidden-file-path", "value"),
550+
Input("task-dropdown", "value"),
551+
)
552+
def update_patient_codes_and_task_dropdown(patient_id, file_path, selected_task):
553+
if file_path:
554+
tasks_path = os.path.join(file_path, "tasks")
555+
detected_tasks = [
556+
f for f in os.listdir(tasks_path) if os.path.isfile(os.path.join(tasks_path, f))
557+
]
558+
task_options = [{"label": os.path.splitext(task)[0], "value": task} for task in detected_tasks]
559+
else:
560+
task_options = []
561+
535562
if patient_id is None:
536-
return {}
563+
return {}, task_options
537564

538565
patient_data = (
539566
pl.scan_parquet(return_data_path(file_path))
540567
.filter(pl.col("subject_id") == patient_id)
541-
.select(pl.col("time"), pl.col("code"))
568+
.select(pl.col("time"), pl.col("code"), pl.col("numeric_value"), pl.col("text_value"))
569+
.with_columns(pl.col("code").str.split("/").list.first().alias("coding_dict"))
542570
.collect()
543571
)
544572

545573
if patient_data.is_empty():
546-
return {}
574+
return {}, task_options
547575

576+
# Create the scatter plot with color based on the category
548577
fig_patient_codes = px.scatter(
549-
patient_data, x="time", y="code", title=f"Codes over time for patient {patient_id}"
578+
patient_data,
579+
x="time",
580+
y="code",
581+
color="coding_dict",
582+
title=f"Codes over time for patient {patient_id}",
583+
labels={"coding_dict": "Code Category"},
584+
hover_data={"code": True, "numeric_value": True, "text_value": True},
550585
)
551-
return fig_patient_codes
586+
587+
if selected_task:
588+
task_file_path = os.path.join(file_path, "tasks", selected_task)
589+
if os.path.isfile(task_file_path):
590+
task_data = pl.scan_parquet(task_file_path)
591+
task_label = task_data.filter(pl.col("subject_id") == patient_id).collect()
592+
# task_label.with_columns(pl.col("prediction_time").cast())
593+
if not task_label.is_empty():
594+
# Workaround for plotly that does not allow datetime values
595+
for row in task_label.iter_rows(named=True):
596+
prediction_time_timestamp = row["prediction_time"].timestamp() * 1000
597+
task_name = os.path.splitext(selected_task)[0]
598+
color = "red" if row.get("boolean_value", False) else "green"
599+
600+
hover_text = f"Task: {task_name}<br>Prediction Time: {row['prediction_time']}"
601+
if "boolean_value" in row and row["boolean_value"] is not None:
602+
hover_text += f"<br>Boolean Value: {row['boolean_value']}"
603+
if "integer_value" in row and row["integer_value"] is not None:
604+
hover_text += f"<br>Integer Value: {row['integer_value']}"
605+
if "float_value" in row and row["float_value"] is not None:
606+
hover_text += f"<br>Float Value: {row['float_value']}"
607+
if "categorical_value" in row and row["categorical_value"] is not None:
608+
hover_text += f"<br>Categorical Value: {row['categorical_value']}"
609+
610+
fig_patient_codes.add_scatter(
611+
x=[prediction_time_timestamp, prediction_time_timestamp],
612+
y=[0, 1],
613+
mode="lines",
614+
line=dict(color=color, dash="dash"),
615+
customdata=pd.Series(data=row),
616+
hovertemplate=hover_text,
617+
name=task_name + f" {row["prediction_time"]}",
618+
yaxis="y2",
619+
)
620+
fig_patient_codes.update_layout(yaxis2=dict(showticklabels=False))
621+
622+
return fig_patient_codes, task_options
552623

553624
@app.callback(
554625
Output("fig_code_distribution", "figure"),
Binary file not shown.

src/MEDS_Inspect/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,9 @@ def return_data_path(file_path):
4242
else:
4343
logging.error("No data found in the specified paths.")
4444
return None
45+
46+
47+
def get_detected_tasks(file_path):
48+
tasks_path = os.path.join(file_path, "tasks")
49+
detected_tasks = [f for f in os.listdir(tasks_path) if os.path.isfile(os.path.join(tasks_path, f))]
50+
return detected_tasks

task_extraction.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
/Users/robin/Documents/git/MEDS-DEV/src/MEDS_DEV/tasks/mortality/in_icu/first_24h.yaml

0 commit comments

Comments
 (0)