Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ We implement the following models for supporting multiple healthcare predictive
models/pyhealth.models.GRASP
models/pyhealth.models.MedLink
models/pyhealth.models.TCN
models/pyhealth.models.TFMTokenizer
models/pyhealth.models.GAN
models/pyhealth.models.VAE
models/pyhealth.models.SDOH
25 changes: 25 additions & 0 deletions docs/api/models/pyhealth.models.TFMTokenizer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
pyhealth.models.TFMTokenizer
===================================

TFM-Tokenizer model for EEG signal tokenization using VQ-VAE.

.. autoclass:: pyhealth.models.TFMTokenizer
:members:
:undoc-members:
:show-inheritance:

.. autoclass:: pyhealth.models.TFM_VQVAE2_deep
:members:
:undoc-members:
:show-inheritance:

.. autoclass:: pyhealth.models.TFM_TOKEN_Classifier
:members:
:undoc-members:
:show-inheritance:

.. autofunction:: pyhealth.models.get_tfm_tokenizer_2x2x8

.. autofunction:: pyhealth.models.get_tfm_token_classifier_64x4

.. autofunction:: pyhealth.models.load_embedding_weights
351 changes: 351 additions & 0 deletions examples/conformal_eeg/tfm_tokenizer_quickstart.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "intro",
"metadata": {},
"source": [
"# TFM-Tokenizer for EEG Signal Tokenization\n",
"\n",
"This notebook demonstrates the TFM-Tokenizer model for tokenizing EEG signals into discrete tokens and continuous embeddings.\n",
"\n",
"**Note**: This example uses dummy data. The EEG-specific processor for generating STFT features from raw signals is under development."
]
},
{
"cell_type": "markdown",
"id": "setup",
"metadata": {},
"source": [
"## 1. Environment Setup"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "imports",
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'litdata'",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 5\u001b[39m\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorch\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m5\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpyhealth\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mdatasets\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m create_sample_dataset, get_dataloader\n\u001b[32m 6\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpyhealth\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mmodels\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m TFMTokenizer, get_tfm_tokenizer_2x2x8\n\u001b[32m 8\u001b[39m SEED = \u001b[32m42\u001b[39m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/PyHealth/pyhealth/datasets/__init__.py:49\u001b[39m\n\u001b[32m 41\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mwarnings\u001b[39;00m\n\u001b[32m 43\u001b[39m warnings.warn(\n\u001b[32m 44\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mThe SampleSignalDataset class is deprecated and will be removed in a future version.\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 45\u001b[39m \u001b[38;5;167;01mDeprecationWarning\u001b[39;00m,\n\u001b[32m 46\u001b[39m )\n\u001b[32m---> \u001b[39m\u001b[32m49\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[34;01mbase_dataset\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m BaseDataset\n\u001b[32m 50\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[34;01mcardiology\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m CardiologyDataset\n\u001b[32m 51\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[34;01mchestxray14\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m ChestXray14Dataset\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/PyHealth/pyhealth/datasets/base_dataset.py:18\u001b[39m\n\u001b[32m 15\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmultiprocessing\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mqueues\u001b[39;00m\n\u001b[32m 16\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mshutil\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m18\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mlitdata\u001b[39;00m\n\u001b[32m 19\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mlitdata\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mstreaming\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mitem_loader\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m ParquetLoader\n\u001b[32m 20\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mlitdata\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mprocessing\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mdata_processor\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m in_notebook\n",
"\u001b[31mModuleNotFoundError\u001b[39m: No module named 'litdata'"
]
}
],
"source": [
"import random\n",
"import numpy as np\n",
"import torch\n",
"\n",
"from pyhealth.datasets import create_sample_dataset, get_dataloader\n",
"from pyhealth.models import TFMTokenizer, get_tfm_tokenizer_2x2x8\n",
"\n",
"SEED = 42\n",
"random.seed(SEED)\n",
"np.random.seed(SEED)\n",
"torch.manual_seed(SEED)\n",
"if torch.cuda.is_available():\n",
" torch.cuda.manual_seed_all(SEED)\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Running on device: {device}\")"
]
},
{
"cell_type": "markdown",
"id": "data_prep",
"metadata": {},
"source": [
"## 2. Create Sample Dataset\n",
"\n",
"TFM-Tokenizer expects two inputs:\n",
"- `stft`: STFT spectrogram of shape (n_freq, n_time), e.g., (100, 60)\n",
"- `signal`: Raw temporal signal of shape (n_samples,), e.g., (1280,)\n",
"\n",
"For demonstration, we'll use dummy data:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "create_data",
"metadata": {},
"outputs": [],
"source": [
"# Create dummy samples (in practice, these would come from EEG preprocessing)\n",
"samples = [\n",
" {\n",
" \"patient_id\": f\"patient-{i}\",\n",
" \"visit_id\": \"visit-0\",\n",
" \"stft\": torch.randn(100, 60).numpy().tolist(), # STFT spectrogram\n",
" \"signal\": torch.randn(1280).numpy().tolist(), # Raw signal\n",
" \"label\": i % 6, # 6 classes for TUEV events\n",
" }\n",
" for i in range(50)\n",
"]\n",
"\n",
"input_schema = {\n",
" \"stft\": \"tensor\",\n",
" \"signal\": \"tensor\",\n",
"}\n",
"output_schema = {\"label\": \"multiclass\"}\n",
"\n",
"dataset = create_sample_dataset(\n",
" samples=samples,\n",
" input_schema=input_schema,\n",
" output_schema=output_schema,\n",
" dataset_name=\"tfm_demo\",\n",
")\n",
"\n",
"print(f\"Created dataset with {len(dataset)} samples\")\n",
"print(f\"Input schema: {dataset.input_schema}\")\n",
"print(f\"Output schema: {dataset.output_schema}\")"
]
},
{
"cell_type": "markdown",
"id": "split_data",
"metadata": {},
"source": [
"## 3. Split Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "split",
"metadata": {},
"outputs": [],
"source": [
"from pyhealth.datasets.splitter import split_by_sample\n",
"\n",
"train_data, val_data, test_data = split_by_sample(dataset, [0.7, 0.15, 0.15], seed=SEED)\n",
"\n",
"print(f\"Train: {len(train_data)} samples\")\n",
"print(f\"Val: {len(val_data)} samples\")\n",
"print(f\"Test: {len(test_data)} samples\")\n",
"\n",
"train_loader = get_dataloader(train_data, batch_size=8, shuffle=True)\n",
"val_loader = get_dataloader(val_data, batch_size=8, shuffle=False)\n",
"test_loader = get_dataloader(test_data, batch_size=8, shuffle=False)"
]
},
{
"cell_type": "markdown",
"id": "model_init",
"metadata": {},
"source": [
"## 4. Initialize TFM-Tokenizer Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "create_model",
"metadata": {},
"outputs": [],
"source": [
"model = TFMTokenizer(\n",
" dataset=dataset,\n",
" emb_size=64,\n",
" code_book_size=8192,\n",
" trans_freq_encoder_depth=2,\n",
" trans_temporal_encoder_depth=2,\n",
" trans_decoder_depth=8,\n",
" use_classifier=True,\n",
" classifier_depth=4,\n",
")\n",
"\n",
"model = model.to(device)\n",
"print(f\"Model created with {sum(p.numel() for p in model.parameters())} parameters\")"
]
},
{
"cell_type": "markdown",
"id": "forward_pass",
"metadata": {},
"source": [
"## 5. Test Forward Pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "test_forward",
"metadata": {},
"outputs": [],
"source": [
"batch = next(iter(train_loader))\n",
"\n",
"with torch.no_grad():\n",
" outputs = model(**batch)\n",
"\n",
"print(\"Output keys:\", outputs.keys())\n",
"print(f\"Loss: {outputs['loss'].item():.4f}\")\n",
"print(f\"Logits shape: {outputs['logit'].shape}\")\n",
"print(f\"Tokens shape: {outputs['tokens'].shape}\")\n",
"print(f\"Embeddings shape: {outputs['embeddings'].shape}\")"
]
},
{
"cell_type": "markdown",
"id": "training",
"metadata": {},
"source": [
"## 6. Train Model (Optional)\n",
"\n",
"Train the model using PyHealth's Trainer:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "train",
"metadata": {},
"outputs": [],
"source": [
"from pyhealth.trainer import Trainer\n",
"\n",
"trainer = Trainer(model=model, device=device)\n",
"trainer.train(\n",
" train_dataloader=train_loader,\n",
" val_dataloader=val_loader,\n",
" epochs=3,\n",
" monitor=\"accuracy\",\n",
")"
]
},
{
"cell_type": "markdown",
"id": "embeddings",
"metadata": {},
"source": [
"## 7. Extract Embeddings for Analysis\n",
"\n",
"Extract patient embeddings for downstream tasks like clustering or conformal prediction:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "extract_embeddings",
"metadata": {},
"outputs": [],
"source": [
"# Extract embeddings from test set\n",
"test_embeddings = model.get_embeddings(test_loader)\n",
"print(f\"Test embeddings shape: {test_embeddings.shape}\")\n",
"\n",
"# Get patient-level representation (mean pooling)\n",
"patient_embeddings = test_embeddings.mean(dim=1)\n",
"print(f\"Patient-level embeddings shape: {patient_embeddings.shape}\")"
]
},
{
"cell_type": "markdown",
"id": "tokens",
"metadata": {},
"source": [
"## 8. Extract Discrete Tokens"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "extract_tokens",
"metadata": {},
"outputs": [],
"source": [
"# Extract tokens from test set\n",
"test_tokens = model.get_tokens(test_loader)\n",
"print(f\"Test tokens shape: {test_tokens.shape}\")\n",
"\n",
"# Analyze token vocabulary usage\n",
"unique_tokens = torch.unique(test_tokens)\n",
"print(f\"Active tokens: {len(unique_tokens)} / {model.code_book_size}\")\n",
"print(f\"Token usage: {len(unique_tokens) / model.code_book_size * 100:.2f}%\")"
]
},
{
"cell_type": "markdown",
"id": "clustering",
"metadata": {},
"source": [
"## 9. Patient Clustering (Example)\n",
"\n",
"Use embeddings for k-means clustering:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "clustering_example",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.cluster import KMeans\n",
"\n",
"# Cluster patients based on embeddings\n",
"kmeans = KMeans(n_clusters=3, random_state=SEED)\n",
"clusters = kmeans.fit_predict(patient_embeddings.cpu().numpy())\n",
"\n",
"print(\"Cluster distribution:\")\n",
"unique, counts = np.unique(clusters, return_counts=True)\n",
"for cluster_id, count in zip(unique, counts):\n",
" print(f\" Cluster {cluster_id}: {count} patients ({count/len(clusters)*100:.1f}%)\")"
]
},
{
"cell_type": "markdown",
"id": "pretrained",
"metadata": {},
"source": [
"## 10. Loading Pre-trained Weights\n",
"\n",
"Load pre-trained weights from checkpoint:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "load_weights",
"metadata": {},
"outputs": [],
"source": [
"# Uncomment and set the path to load pre-trained weights\n",
"# model.load_pretrained_weights(\"path/to/tfm_encoder_best_model.pth\")\n",
"print(\"To load pre-trained weights:\")\n",
"print(\"model.load_pretrained_weights('tfm_encoder_best_model.pth')\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading