Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 172 additions & 0 deletions examples/conformal_eeg/tuev_conventional_conformal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""Conventional Conformal Prediction (LABEL) on TUEV EEG Events using ContraWR.

This script:
1) Loads the TUEV dataset and applies the EEGEventsTUEV task.
2) Splits into train/val/cal/test using split conformal protocol.
3) Trains a ContraWR model.
4) Calibrates a LABEL prediction-set predictor on the calibration split.
5) Evaluates prediction-set coverage/miscoverage and efficiency on the test split.

Example (from repo root):
python examples/conformal_eeg/tuev_conventional_conformal.py --root downloads/tuev/v2.0.1/edf
"""

from __future__ import annotations

import argparse
import random
from pathlib import Path

import numpy as np
import torch

from pyhealth.calib.predictionset import LABEL
from pyhealth.datasets import TUEVDataset, get_dataloader, split_by_sample_conformal
from pyhealth.models import ContraWR
from pyhealth.tasks import EEGEventsTUEV
from pyhealth.trainer import Trainer, get_metrics_fn


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Conventional conformal prediction (LABEL) on TUEV EEG events using ContraWR."
)
parser.add_argument(
"--root",
type=str,
default="downloads/tuev/v2.0.1/edf",
help="Path to TUEV edf/ folder.",
)
parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"])
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--epochs", type=int, default=2)
parser.add_argument("--alpha", type=float, default=0.1, help="Miscoverage rate (e.g., 0.1 => 90% target coverage).")
parser.add_argument(
"--ratios",
type=float,
nargs=4,
default=(0.6, 0.1, 0.15, 0.15),
metavar=("TRAIN", "VAL", "CAL", "TEST"),
help="Split ratios for train/val/cal/test. Must sum to 1.0.",
)
parser.add_argument("--n-fft", type=int, default=128, help="STFT FFT size used by ContraWR.")
parser.add_argument(
"--device",
type=str,
default=None,
help="Device string, e.g. 'cuda:0' or 'cpu'. Defaults to auto-detect.",
)
return parser.parse_args()


def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)


def main() -> None:
args = parse_args()
set_seed(args.seed)

device = args.device or ("cuda:0" if torch.cuda.is_available() else "cpu")
root = Path(args.root)
if not root.exists():
raise FileNotFoundError(
f"TUEV root not found: {root}. "
"Pass --root to point to your downloaded TUEV edf/ directory."
)

print("=" * 80)
print("STEP 1: Load TUEV + build task dataset")
print("=" * 80)
dataset = TUEVDataset(root=str(root), subset=args.subset)
sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir="examples/conformal_eeg/cache")

print(f"Task samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

if len(sample_dataset) == 0:
raise RuntimeError("No samples produced. Verify TUEV root/subset/task.")

print("\n" + "=" * 80)
print("STEP 2: Split train/val/cal/test")
print("=" * 80)
train_ds, val_ds, cal_ds, test_ds = split_by_sample_conformal(
dataset=sample_dataset, ratios=list(args.ratios), seed=args.seed
)
print(f"Train: {len(train_ds)}")
print(f"Val: {len(val_ds)}")
print(f"Cal: {len(cal_ds)}")
print(f"Test: {len(test_ds)}")

train_loader = get_dataloader(train_ds, batch_size=args.batch_size, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=args.batch_size, shuffle=False) if len(val_ds) else None
test_loader = get_dataloader(test_ds, batch_size=args.batch_size, shuffle=False)

print("\n" + "=" * 80)
print("STEP 3: Train ContraWR")
print("=" * 80)
model = ContraWR(dataset=sample_dataset, n_fft=args.n_fft).to(device)
trainer = Trainer(model=model, device=device, enable_logging=False)

trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
epochs=args.epochs,
monitor="accuracy" if val_loader is not None else None,
)

print("\nBase model performance on test set:")
y_true_base, y_prob_base, _loss_base = trainer.inference(test_loader)
base_metrics = get_metrics_fn("multiclass")(y_true_base, y_prob_base, metrics=["accuracy", "f1_weighted"])
for metric, value in base_metrics.items():
print(f" {metric}: {value:.4f}")

print("\n" + "=" * 80)
print("STEP 4: Conventional Conformal Prediction (LABEL)")
print("=" * 80)
print(f"Target miscoverage alpha: {args.alpha} (target coverage {1 - args.alpha:.0%})")

label_predictor = LABEL(model=model, alpha=float(args.alpha))
print("Calibrating LABEL predictor...")
label_predictor.calibrate(cal_dataset=cal_ds)

print("Evaluating LABEL predictor on test set...")
y_true, y_prob, _loss, extra = Trainer(model=label_predictor).inference(
test_loader, additional_outputs=["y_predset"]
)

label_metrics = get_metrics_fn("multiclass")(
y_true,
y_prob,
metrics=["accuracy", "miscoverage_ps"],
y_predset=extra["y_predset"],
)

predset = extra["y_predset"]
if isinstance(predset, np.ndarray):
predset_t = torch.tensor(predset)
else:
predset_t = predset
avg_set_size = predset_t.float().sum(dim=1).mean().item()

miscoverage = label_metrics["miscoverage_ps"]
if isinstance(miscoverage, np.ndarray):
miscoverage = float(miscoverage.item() if miscoverage.size == 1 else miscoverage.mean())
else:
miscoverage = float(miscoverage)

print("\nLABEL Results:")
print(f" Accuracy: {label_metrics['accuracy']:.4f}")
print(f" Empirical miscoverage: {miscoverage:.4f}")
print(f" Empirical coverage: {1 - miscoverage:.4f}")
print(f" Average set size: {avg_set_size:.2f}")


if __name__ == "__main__":
main()
189 changes: 189 additions & 0 deletions examples/conformal_eeg/tuev_covariate_shift_conformal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
"""Covariate-Shift Adaptive Conformal Prediction (CovariateLabel) on TUEV EEG Events using ContraWR.

This script:
1) Loads the TUEV dataset and applies the EEGEventsTUEV task.
2) Splits into train/val/cal/test using split conformal protocol.
3) Trains a ContraWR model.
4) Extracts embeddings for calibration and test splits using embed=True.
5) Calibrates a CovariateLabel prediction-set predictor (KDE-based shift correction).
6) Evaluates prediction-set coverage/miscoverage and efficiency on the test split.

Example (from repo root):
python examples/conformal_eeg/tuev_covariate_shift_conformal.py --root downloads/tuev/v2.0.1/edf

Notes:
- CovariateLabel requires access to test embeddings/features to estimate density ratios.
"""

from __future__ import annotations

import argparse
import random
from pathlib import Path

import numpy as np
import torch

from pyhealth.calib.predictionset.covariate import CovariateLabel
from pyhealth.calib.utils import extract_embeddings
from pyhealth.datasets import TUEVDataset, get_dataloader, split_by_sample_conformal
from pyhealth.models import ContraWR
from pyhealth.tasks import EEGEventsTUEV
from pyhealth.trainer import Trainer, get_metrics_fn


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Covariate-shift adaptive conformal prediction (CovariateLabel) on TUEV EEG events using ContraWR."
)
parser.add_argument(
"--root",
type=str,
default="downloads/tuev/v2.0.1/edf",
help="Path to TUEV edf/ folder.",
)
parser.add_argument("--subset", type=str, default="both", choices=["train", "eval", "both"])
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--alpha", type=float, default=0.1, help="Miscoverage rate (e.g., 0.1 => 90% target coverage).")
parser.add_argument(
"--ratios",
type=float,
nargs=4,
default=(0.6, 0.1, 0.15, 0.15),
metavar=("TRAIN", "VAL", "CAL", "TEST"),
help="Split ratios for train/val/cal/test. Must sum to 1.0.",
)
parser.add_argument("--n-fft", type=int, default=128, help="STFT FFT size used by ContraWR.")
parser.add_argument(
"--device",
type=str,
default=None,
help="Device string, e.g. 'cuda:0' or 'cpu'. Defaults to auto-detect.",
)
return parser.parse_args()


def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)


def main() -> None:
args = parse_args()
set_seed(args.seed)

device = args.device or ("cuda:0" if torch.cuda.is_available() else "cpu")
root = Path(args.root)
if not root.exists():
raise FileNotFoundError(
f"TUEV root not found: {root}. "
"Pass --root to point to your downloaded TUEV edf/ directory."
)

print("=" * 80)
print("STEP 1: Load TUEV + build task dataset")
print("=" * 80)
dataset = TUEVDataset(root=str(root), subset=args.subset)
sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir="examples/conformal_eeg/cache")

print(f"Task samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

if len(sample_dataset) == 0:
raise RuntimeError("No samples produced. Verify TUEV root/subset/task.")

print("\n" + "=" * 80)
print("STEP 2: Split train/val/cal/test")
print("=" * 80)
train_ds, val_ds, cal_ds, test_ds = split_by_sample_conformal(
dataset=sample_dataset, ratios=list(args.ratios), seed=args.seed
)
print(f"Train: {len(train_ds)}")
print(f"Val: {len(val_ds)}")
print(f"Cal: {len(cal_ds)}")
print(f"Test: {len(test_ds)}")

train_loader = get_dataloader(train_ds, batch_size=args.batch_size, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=args.batch_size, shuffle=False) if len(val_ds) else None
test_loader = get_dataloader(test_ds, batch_size=args.batch_size, shuffle=False)

print("\n" + "=" * 80)
print("STEP 3: Train ContraWR")
print("=" * 80)
model = ContraWR(dataset=sample_dataset, n_fft=args.n_fft).to(device)
trainer = Trainer(model=model, device=device, enable_logging=False)

trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
epochs=args.epochs,
monitor="accuracy" if val_loader is not None else None,
)

print("\nBase model performance on test set:")
y_true_base, y_prob_base, _loss_base = trainer.inference(test_loader)
base_metrics = get_metrics_fn("multiclass")(y_true_base, y_prob_base, metrics=["accuracy", "f1_weighted"])
for metric, value in base_metrics.items():
print(f" {metric}: {value:.4f}")

print("\n" + "=" * 80)
print("STEP 4: Covariate Shift Adaptive Conformal Prediction (CovariateLabel)")
print("=" * 80)
print(f"Target miscoverage alpha: {args.alpha} (target coverage {1 - args.alpha:.0%})")

print("Extracting embeddings for calibration split...")
cal_embeddings = extract_embeddings(model, cal_ds, batch_size=args.batch_size, device=device)
print(f" cal_embeddings shape: {cal_embeddings.shape}")

print("Extracting embeddings for test split...")
test_embeddings = extract_embeddings(model, test_ds, batch_size=args.batch_size, device=device)
print(f" test_embeddings shape: {test_embeddings.shape}")

cov_predictor = CovariateLabel(model=model, alpha=float(args.alpha))
print("Calibrating CovariateLabel predictor (fits KDEs internally)...")
cov_predictor.calibrate(
cal_dataset=cal_ds,
cal_embeddings=cal_embeddings,
test_embeddings=test_embeddings,
)

print("Evaluating CovariateLabel predictor on test set...")
y_true, y_prob, _loss, extra = Trainer(model=cov_predictor).inference(
test_loader, additional_outputs=["y_predset"]
)

cov_metrics = get_metrics_fn("multiclass")(
y_true,
y_prob,
metrics=["accuracy", "miscoverage_ps"],
y_predset=extra["y_predset"],
)

predset = extra["y_predset"]
if isinstance(predset, np.ndarray):
predset_t = torch.tensor(predset)
else:
predset_t = predset
avg_set_size = predset_t.float().sum(dim=1).mean().item()

miscoverage = cov_metrics["miscoverage_ps"]
if isinstance(miscoverage, np.ndarray):
miscoverage = float(miscoverage.item() if miscoverage.size == 1 else miscoverage.mean())
else:
miscoverage = float(miscoverage)

print("\nCovariateLabel Results:")
print(f" Accuracy: {cov_metrics['accuracy']:.4f}")
print(f" Empirical miscoverage: {miscoverage:.4f}")
print(f" Empirical coverage: {1 - miscoverage:.4f}")
print(f" Average set size: {avg_set_size:.2f}")


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions pyhealth/models/contrawr.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def torch_stft(self, X):
def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
"""Forward propagation."""
# concat the info within one batch (batch, channel, length)
x = kwargs[self.feature_keys[0]]
x = kwargs[self.feature_keys[0]].to(self.device)
# obtain the stft spectrogram (batch, channel, freq, time step)
x_spectrogram = self.torch_stft(x)
# final layer embedding (batch, embedding)
Expand All @@ -291,7 +291,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
# (patient, label_size)
logits = self.fc(emb)
# obtain y_true, loss, y_prob
y_true = kwargs[self.label_keys[0]]
y_true = kwargs[self.label_keys[0]].to(self.device)
loss = self.get_loss_function()(logits, y_true)
y_prob = self.prepare_y_prob(logits)

Expand Down