Skip to content

Commit ce679ed

Browse files
authored
Merge branch 'sunlabuiuc:master' into feature/eeg-proj
2 parents af6948e + 470f89c commit ce679ed

File tree

9 files changed

+1345
-37
lines changed

9 files changed

+1345
-37
lines changed

docs/api/models.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ We implement the following models for supporting multiple healthcare predictive
2424
models/pyhealth.models.ContraWR
2525
models/pyhealth.models.SparcNet
2626
models/pyhealth.models.StageNet
27+
models/pyhealth.models.StageAttentionNet
2728
models/pyhealth.models.AdaCare
2829
models/pyhealth.models.ConCare
2930
models/pyhealth.models.Agent
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
pyhealth.models.StageAttentionNet
2+
===================================
3+
4+
The separate callable StageNetAttentionLayer and the complete StageAttentionNet model.
5+
6+
.. autoclass:: pyhealth.models.StageNetAttentionLayer
7+
:members:
8+
:undoc-members:
9+
:show-inheritance:
10+
11+
.. autoclass:: pyhealth.models.StageAttentionNet
12+
:members:
13+
:undoc-members:
14+
:show-inheritance:
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""Conventional Conformal Prediction (LABEL) on TUEV EEG Events using ContraWR.
2+
3+
This script:
4+
1) Loads the TUEV dataset and applies the EEGEventsTUEV task.
5+
2) Splits into train/val/cal/test using split conformal protocol.
6+
3) Trains a ContraWR model.
7+
4) Calibrates a LABEL prediction-set predictor on the calibration split.
8+
5) Evaluates prediction-set coverage/miscoverage and efficiency on the test split.
9+
10+
Example (from repo root):
11+
python examples/conformal_eeg/tuev_conventional_conformal.py --root downloads/tuev/v2.0.1/edf
12+
"""
13+
14+
from __future__ import annotations
15+
16+
import argparse
17+
import random
18+
from pathlib import Path
19+
20+
import numpy as np
21+
import torch
22+
23+
from pyhealth.calib.predictionset import LABEL
24+
from pyhealth.datasets import TUEVDataset, get_dataloader, split_by_sample_conformal
25+
from pyhealth.models import ContraWR
26+
from pyhealth.tasks import EEGEventsTUEV
27+
from pyhealth.trainer import Trainer, get_metrics_fn
28+
29+
30+
def parse_args() -> argparse.Namespace:
31+
parser = argparse.ArgumentParser(
32+
description="Conventional conformal prediction (LABEL) on TUEV EEG events using ContraWR."
33+
)
34+
parser.add_argument(
35+
"--root",
36+
type=str,
37+
default="downloads/tuev/v2.0.1/edf",
38+
help="Path to TUEV edf/ folder.",
39+
)
40+
parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"])
41+
parser.add_argument("--seed", type=int, default=42)
42+
parser.add_argument("--batch-size", type=int, default=32)
43+
parser.add_argument("--epochs", type=int, default=2)
44+
parser.add_argument("--alpha", type=float, default=0.1, help="Miscoverage rate (e.g., 0.1 => 90% target coverage).")
45+
parser.add_argument(
46+
"--ratios",
47+
type=float,
48+
nargs=4,
49+
default=(0.6, 0.1, 0.15, 0.15),
50+
metavar=("TRAIN", "VAL", "CAL", "TEST"),
51+
help="Split ratios for train/val/cal/test. Must sum to 1.0.",
52+
)
53+
parser.add_argument("--n-fft", type=int, default=128, help="STFT FFT size used by ContraWR.")
54+
parser.add_argument(
55+
"--device",
56+
type=str,
57+
default=None,
58+
help="Device string, e.g. 'cuda:0' or 'cpu'. Defaults to auto-detect.",
59+
)
60+
return parser.parse_args()
61+
62+
63+
def set_seed(seed: int) -> None:
64+
random.seed(seed)
65+
np.random.seed(seed)
66+
torch.manual_seed(seed)
67+
if torch.cuda.is_available():
68+
torch.cuda.manual_seed_all(seed)
69+
70+
71+
def main() -> None:
72+
args = parse_args()
73+
set_seed(args.seed)
74+
75+
device = args.device or ("cuda:0" if torch.cuda.is_available() else "cpu")
76+
root = Path(args.root)
77+
if not root.exists():
78+
raise FileNotFoundError(
79+
f"TUEV root not found: {root}. "
80+
"Pass --root to point to your downloaded TUEV edf/ directory."
81+
)
82+
83+
print("=" * 80)
84+
print("STEP 1: Load TUEV + build task dataset")
85+
print("=" * 80)
86+
dataset = TUEVDataset(root=str(root), subset=args.subset)
87+
sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir="examples/conformal_eeg/cache")
88+
89+
print(f"Task samples: {len(sample_dataset)}")
90+
print(f"Input schema: {sample_dataset.input_schema}")
91+
print(f"Output schema: {sample_dataset.output_schema}")
92+
93+
if len(sample_dataset) == 0:
94+
raise RuntimeError("No samples produced. Verify TUEV root/subset/task.")
95+
96+
print("\n" + "=" * 80)
97+
print("STEP 2: Split train/val/cal/test")
98+
print("=" * 80)
99+
train_ds, val_ds, cal_ds, test_ds = split_by_sample_conformal(
100+
dataset=sample_dataset, ratios=list(args.ratios), seed=args.seed
101+
)
102+
print(f"Train: {len(train_ds)}")
103+
print(f"Val: {len(val_ds)}")
104+
print(f"Cal: {len(cal_ds)}")
105+
print(f"Test: {len(test_ds)}")
106+
107+
train_loader = get_dataloader(train_ds, batch_size=args.batch_size, shuffle=True)
108+
val_loader = get_dataloader(val_ds, batch_size=args.batch_size, shuffle=False) if len(val_ds) else None
109+
test_loader = get_dataloader(test_ds, batch_size=args.batch_size, shuffle=False)
110+
111+
print("\n" + "=" * 80)
112+
print("STEP 3: Train ContraWR")
113+
print("=" * 80)
114+
model = ContraWR(dataset=sample_dataset, n_fft=args.n_fft).to(device)
115+
trainer = Trainer(model=model, device=device, enable_logging=False)
116+
117+
trainer.train(
118+
train_dataloader=train_loader,
119+
val_dataloader=val_loader,
120+
epochs=args.epochs,
121+
monitor="accuracy" if val_loader is not None else None,
122+
)
123+
124+
print("\nBase model performance on test set:")
125+
y_true_base, y_prob_base, _loss_base = trainer.inference(test_loader)
126+
base_metrics = get_metrics_fn("multiclass")(y_true_base, y_prob_base, metrics=["accuracy", "f1_weighted"])
127+
for metric, value in base_metrics.items():
128+
print(f" {metric}: {value:.4f}")
129+
130+
print("\n" + "=" * 80)
131+
print("STEP 4: Conventional Conformal Prediction (LABEL)")
132+
print("=" * 80)
133+
print(f"Target miscoverage alpha: {args.alpha} (target coverage {1 - args.alpha:.0%})")
134+
135+
label_predictor = LABEL(model=model, alpha=float(args.alpha))
136+
print("Calibrating LABEL predictor...")
137+
label_predictor.calibrate(cal_dataset=cal_ds)
138+
139+
print("Evaluating LABEL predictor on test set...")
140+
y_true, y_prob, _loss, extra = Trainer(model=label_predictor).inference(
141+
test_loader, additional_outputs=["y_predset"]
142+
)
143+
144+
label_metrics = get_metrics_fn("multiclass")(
145+
y_true,
146+
y_prob,
147+
metrics=["accuracy", "miscoverage_ps"],
148+
y_predset=extra["y_predset"],
149+
)
150+
151+
predset = extra["y_predset"]
152+
if isinstance(predset, np.ndarray):
153+
predset_t = torch.tensor(predset)
154+
else:
155+
predset_t = predset
156+
avg_set_size = predset_t.float().sum(dim=1).mean().item()
157+
158+
miscoverage = label_metrics["miscoverage_ps"]
159+
if isinstance(miscoverage, np.ndarray):
160+
miscoverage = float(miscoverage.item() if miscoverage.size == 1 else miscoverage.mean())
161+
else:
162+
miscoverage = float(miscoverage)
163+
164+
print("\nLABEL Results:")
165+
print(f" Accuracy: {label_metrics['accuracy']:.4f}")
166+
print(f" Empirical miscoverage: {miscoverage:.4f}")
167+
print(f" Empirical coverage: {1 - miscoverage:.4f}")
168+
print(f" Average set size: {avg_set_size:.2f}")
169+
170+
171+
if __name__ == "__main__":
172+
main()
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
"""Covariate-Shift Adaptive Conformal Prediction (CovariateLabel) on TUEV EEG Events using ContraWR.
2+
3+
This script:
4+
1) Loads the TUEV dataset and applies the EEGEventsTUEV task.
5+
2) Splits into train/val/cal/test using split conformal protocol.
6+
3) Trains a ContraWR model.
7+
4) Extracts embeddings for calibration and test splits using embed=True.
8+
5) Calibrates a CovariateLabel prediction-set predictor (KDE-based shift correction).
9+
6) Evaluates prediction-set coverage/miscoverage and efficiency on the test split.
10+
11+
Example (from repo root):
12+
python examples/conformal_eeg/tuev_covariate_shift_conformal.py --root downloads/tuev/v2.0.1/edf
13+
14+
Notes:
15+
- CovariateLabel requires access to test embeddings/features to estimate density ratios.
16+
"""
17+
18+
from __future__ import annotations
19+
20+
import argparse
21+
import random
22+
from pathlib import Path
23+
24+
import numpy as np
25+
import torch
26+
27+
from pyhealth.calib.predictionset.covariate import CovariateLabel
28+
from pyhealth.calib.utils import extract_embeddings
29+
from pyhealth.datasets import TUEVDataset, get_dataloader, split_by_sample_conformal
30+
from pyhealth.models import ContraWR
31+
from pyhealth.tasks import EEGEventsTUEV
32+
from pyhealth.trainer import Trainer, get_metrics_fn
33+
34+
35+
def parse_args() -> argparse.Namespace:
36+
parser = argparse.ArgumentParser(
37+
description="Covariate-shift adaptive conformal prediction (CovariateLabel) on TUEV EEG events using ContraWR."
38+
)
39+
parser.add_argument(
40+
"--root",
41+
type=str,
42+
default="downloads/tuev/v2.0.1/edf",
43+
help="Path to TUEV edf/ folder.",
44+
)
45+
parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"])
46+
parser.add_argument("--seed", type=int, default=42)
47+
parser.add_argument("--batch-size", type=int, default=32)
48+
parser.add_argument("--epochs", type=int, default=3)
49+
parser.add_argument("--alpha", type=float, default=0.1, help="Miscoverage rate (e.g., 0.1 => 90% target coverage).")
50+
parser.add_argument(
51+
"--ratios",
52+
type=float,
53+
nargs=4,
54+
default=(0.6, 0.1, 0.15, 0.15),
55+
metavar=("TRAIN", "VAL", "CAL", "TEST"),
56+
help="Split ratios for train/val/cal/test. Must sum to 1.0.",
57+
)
58+
parser.add_argument("--n-fft", type=int, default=128, help="STFT FFT size used by ContraWR.")
59+
parser.add_argument(
60+
"--device",
61+
type=str,
62+
default=None,
63+
help="Device string, e.g. 'cuda:0' or 'cpu'. Defaults to auto-detect.",
64+
)
65+
return parser.parse_args()
66+
67+
68+
def set_seed(seed: int) -> None:
69+
random.seed(seed)
70+
np.random.seed(seed)
71+
torch.manual_seed(seed)
72+
if torch.cuda.is_available():
73+
torch.cuda.manual_seed_all(seed)
74+
75+
76+
def main() -> None:
77+
args = parse_args()
78+
set_seed(args.seed)
79+
80+
device = args.device or ("cuda:0" if torch.cuda.is_available() else "cpu")
81+
root = Path(args.root)
82+
if not root.exists():
83+
raise FileNotFoundError(
84+
f"TUEV root not found: {root}. "
85+
"Pass --root to point to your downloaded TUEV edf/ directory."
86+
)
87+
88+
print("=" * 80)
89+
print("STEP 1: Load TUEV + build task dataset")
90+
print("=" * 80)
91+
dataset = TUEVDataset(root=str(root), subset=args.subset)
92+
sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir="examples/conformal_eeg/cache")
93+
94+
print(f"Task samples: {len(sample_dataset)}")
95+
print(f"Input schema: {sample_dataset.input_schema}")
96+
print(f"Output schema: {sample_dataset.output_schema}")
97+
98+
if len(sample_dataset) == 0:
99+
raise RuntimeError("No samples produced. Verify TUEV root/subset/task.")
100+
101+
print("\n" + "=" * 80)
102+
print("STEP 2: Split train/val/cal/test")
103+
print("=" * 80)
104+
train_ds, val_ds, cal_ds, test_ds = split_by_sample_conformal(
105+
dataset=sample_dataset, ratios=list(args.ratios), seed=args.seed
106+
)
107+
print(f"Train: {len(train_ds)}")
108+
print(f"Val: {len(val_ds)}")
109+
print(f"Cal: {len(cal_ds)}")
110+
print(f"Test: {len(test_ds)}")
111+
112+
train_loader = get_dataloader(train_ds, batch_size=args.batch_size, shuffle=True)
113+
val_loader = get_dataloader(val_ds, batch_size=args.batch_size, shuffle=False) if len(val_ds) else None
114+
test_loader = get_dataloader(test_ds, batch_size=args.batch_size, shuffle=False)
115+
116+
print("\n" + "=" * 80)
117+
print("STEP 3: Train ContraWR")
118+
print("=" * 80)
119+
model = ContraWR(dataset=sample_dataset, n_fft=args.n_fft).to(device)
120+
trainer = Trainer(model=model, device=device, enable_logging=False)
121+
122+
trainer.train(
123+
train_dataloader=train_loader,
124+
val_dataloader=val_loader,
125+
epochs=args.epochs,
126+
monitor="accuracy" if val_loader is not None else None,
127+
)
128+
129+
print("\nBase model performance on test set:")
130+
y_true_base, y_prob_base, _loss_base = trainer.inference(test_loader)
131+
base_metrics = get_metrics_fn("multiclass")(y_true_base, y_prob_base, metrics=["accuracy", "f1_weighted"])
132+
for metric, value in base_metrics.items():
133+
print(f" {metric}: {value:.4f}")
134+
135+
print("\n" + "=" * 80)
136+
print("STEP 4: Covariate Shift Adaptive Conformal Prediction (CovariateLabel)")
137+
print("=" * 80)
138+
print(f"Target miscoverage alpha: {args.alpha} (target coverage {1 - args.alpha:.0%})")
139+
140+
print("Extracting embeddings for calibration split...")
141+
cal_embeddings = extract_embeddings(model, cal_ds, batch_size=args.batch_size, device=device)
142+
print(f" cal_embeddings shape: {cal_embeddings.shape}")
143+
144+
print("Extracting embeddings for test split...")
145+
test_embeddings = extract_embeddings(model, test_ds, batch_size=args.batch_size, device=device)
146+
print(f" test_embeddings shape: {test_embeddings.shape}")
147+
148+
cov_predictor = CovariateLabel(model=model, alpha=float(args.alpha))
149+
print("Calibrating CovariateLabel predictor (fits KDEs internally)...")
150+
cov_predictor.calibrate(
151+
cal_dataset=cal_ds,
152+
cal_embeddings=cal_embeddings,
153+
test_embeddings=test_embeddings,
154+
)
155+
156+
print("Evaluating CovariateLabel predictor on test set...")
157+
y_true, y_prob, _loss, extra = Trainer(model=cov_predictor).inference(
158+
test_loader, additional_outputs=["y_predset"]
159+
)
160+
161+
cov_metrics = get_metrics_fn("multiclass")(
162+
y_true,
163+
y_prob,
164+
metrics=["accuracy", "miscoverage_ps"],
165+
y_predset=extra["y_predset"],
166+
)
167+
168+
predset = extra["y_predset"]
169+
if isinstance(predset, np.ndarray):
170+
predset_t = torch.tensor(predset)
171+
else:
172+
predset_t = predset
173+
avg_set_size = predset_t.float().sum(dim=1).mean().item()
174+
175+
miscoverage = cov_metrics["miscoverage_ps"]
176+
if isinstance(miscoverage, np.ndarray):
177+
miscoverage = float(miscoverage.item() if miscoverage.size == 1 else miscoverage.mean())
178+
else:
179+
miscoverage = float(miscoverage)
180+
181+
print("\nCovariateLabel Results:")
182+
print(f" Accuracy: {cov_metrics['accuracy']:.4f}")
183+
print(f" Empirical miscoverage: {miscoverage:.4f}")
184+
print(f" Empirical coverage: {1 - miscoverage:.4f}")
185+
print(f" Average set size: {avg_set_size:.2f}")
186+
187+
188+
if __name__ == "__main__":
189+
main()

0 commit comments

Comments
 (0)