diff --git a/docs/api/models.rst b/docs/api/models.rst index eb69c50f..1983ac07 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -31,6 +31,7 @@ We implement the following models for supporting multiple healthcare predictive models/pyhealth.models.GRASP models/pyhealth.models.MedLink models/pyhealth.models.TCN + models/pyhealth.models.TFMTokenizer models/pyhealth.models.GAN models/pyhealth.models.VAE models/pyhealth.models.SDOH \ No newline at end of file diff --git a/docs/api/models/pyhealth.models.TFMTokenizer.rst b/docs/api/models/pyhealth.models.TFMTokenizer.rst new file mode 100644 index 00000000..719615c5 --- /dev/null +++ b/docs/api/models/pyhealth.models.TFMTokenizer.rst @@ -0,0 +1,25 @@ +pyhealth.models.TFMTokenizer +=================================== + +TFM-Tokenizer model for EEG signal tokenization using VQ-VAE. + +.. autoclass:: pyhealth.models.TFMTokenizer + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.TFM_VQVAE2_deep + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.TFM_TOKEN_Classifier + :members: + :undoc-members: + :show-inheritance: + +.. autofunction:: pyhealth.models.get_tfm_tokenizer_2x2x8 + +.. autofunction:: pyhealth.models.get_tfm_token_classifier_64x4 + +.. autofunction:: pyhealth.models.load_embedding_weights diff --git a/examples/conformal_eeg/tfm_tokenizer_quickstart.ipynb b/examples/conformal_eeg/tfm_tokenizer_quickstart.ipynb new file mode 100644 index 00000000..506aad93 --- /dev/null +++ b/examples/conformal_eeg/tfm_tokenizer_quickstart.ipynb @@ -0,0 +1,351 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "intro", + "metadata": {}, + "source": [ + "# TFM-Tokenizer for EEG Signal Tokenization\n", + "\n", + "This notebook demonstrates the TFM-Tokenizer model for tokenizing EEG signals into discrete tokens and continuous embeddings.\n", + "\n", + "**Note**: This example uses dummy data. The EEG-specific processor for generating STFT features from raw signals is under development." + ] + }, + { + "cell_type": "markdown", + "id": "setup", + "metadata": {}, + "source": [ + "## 1. Environment Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "imports", + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'litdata'", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 5\u001b[39m\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorch\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m5\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpyhealth\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mdatasets\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m create_sample_dataset, get_dataloader\n\u001b[32m 6\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpyhealth\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mmodels\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m TFMTokenizer, get_tfm_tokenizer_2x2x8\n\u001b[32m 8\u001b[39m SEED = \u001b[32m42\u001b[39m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/PyHealth/pyhealth/datasets/__init__.py:49\u001b[39m\n\u001b[32m 41\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mwarnings\u001b[39;00m\n\u001b[32m 43\u001b[39m warnings.warn(\n\u001b[32m 44\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mThe SampleSignalDataset class is deprecated and will be removed in a future version.\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 45\u001b[39m \u001b[38;5;167;01mDeprecationWarning\u001b[39;00m,\n\u001b[32m 46\u001b[39m )\n\u001b[32m---> \u001b[39m\u001b[32m49\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[34;01mbase_dataset\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m BaseDataset\n\u001b[32m 50\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[34;01mcardiology\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m CardiologyDataset\n\u001b[32m 51\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[34;01mchestxray14\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m ChestXray14Dataset\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/PyHealth/pyhealth/datasets/base_dataset.py:18\u001b[39m\n\u001b[32m 15\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmultiprocessing\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mqueues\u001b[39;00m\n\u001b[32m 16\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mshutil\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m18\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mlitdata\u001b[39;00m\n\u001b[32m 19\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mlitdata\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mstreaming\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mitem_loader\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m ParquetLoader\n\u001b[32m 20\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mlitdata\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mprocessing\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mdata_processor\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m in_notebook\n", + "\u001b[31mModuleNotFoundError\u001b[39m: No module named 'litdata'" + ] + } + ], + "source": [ + "import random\n", + "import numpy as np\n", + "import torch\n", + "\n", + "from pyhealth.datasets import create_sample_dataset, get_dataloader\n", + "from pyhealth.models import TFMTokenizer, get_tfm_tokenizer_2x2x8\n", + "\n", + "SEED = 42\n", + "random.seed(SEED)\n", + "np.random.seed(SEED)\n", + "torch.manual_seed(SEED)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(SEED)\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Running on device: {device}\")" + ] + }, + { + "cell_type": "markdown", + "id": "data_prep", + "metadata": {}, + "source": [ + "## 2. Create Sample Dataset\n", + "\n", + "TFM-Tokenizer expects two inputs:\n", + "- `stft`: STFT spectrogram of shape (n_freq, n_time), e.g., (100, 60)\n", + "- `signal`: Raw temporal signal of shape (n_samples,), e.g., (1280,)\n", + "\n", + "For demonstration, we'll use dummy data:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "create_data", + "metadata": {}, + "outputs": [], + "source": [ + "# Create dummy samples (in practice, these would come from EEG preprocessing)\n", + "samples = [\n", + " {\n", + " \"patient_id\": f\"patient-{i}\",\n", + " \"visit_id\": \"visit-0\",\n", + " \"stft\": torch.randn(100, 60).numpy().tolist(), # STFT spectrogram\n", + " \"signal\": torch.randn(1280).numpy().tolist(), # Raw signal\n", + " \"label\": i % 6, # 6 classes for TUEV events\n", + " }\n", + " for i in range(50)\n", + "]\n", + "\n", + "input_schema = {\n", + " \"stft\": \"tensor\",\n", + " \"signal\": \"tensor\",\n", + "}\n", + "output_schema = {\"label\": \"multiclass\"}\n", + "\n", + "dataset = create_sample_dataset(\n", + " samples=samples,\n", + " input_schema=input_schema,\n", + " output_schema=output_schema,\n", + " dataset_name=\"tfm_demo\",\n", + ")\n", + "\n", + "print(f\"Created dataset with {len(dataset)} samples\")\n", + "print(f\"Input schema: {dataset.input_schema}\")\n", + "print(f\"Output schema: {dataset.output_schema}\")" + ] + }, + { + "cell_type": "markdown", + "id": "split_data", + "metadata": {}, + "source": [ + "## 3. Split Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "split", + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.datasets.splitter import split_by_sample\n", + "\n", + "train_data, val_data, test_data = split_by_sample(dataset, [0.7, 0.15, 0.15], seed=SEED)\n", + "\n", + "print(f\"Train: {len(train_data)} samples\")\n", + "print(f\"Val: {len(val_data)} samples\")\n", + "print(f\"Test: {len(test_data)} samples\")\n", + "\n", + "train_loader = get_dataloader(train_data, batch_size=8, shuffle=True)\n", + "val_loader = get_dataloader(val_data, batch_size=8, shuffle=False)\n", + "test_loader = get_dataloader(test_data, batch_size=8, shuffle=False)" + ] + }, + { + "cell_type": "markdown", + "id": "model_init", + "metadata": {}, + "source": [ + "## 4. Initialize TFM-Tokenizer Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "create_model", + "metadata": {}, + "outputs": [], + "source": [ + "model = TFMTokenizer(\n", + " dataset=dataset,\n", + " emb_size=64,\n", + " code_book_size=8192,\n", + " trans_freq_encoder_depth=2,\n", + " trans_temporal_encoder_depth=2,\n", + " trans_decoder_depth=8,\n", + " use_classifier=True,\n", + " classifier_depth=4,\n", + ")\n", + "\n", + "model = model.to(device)\n", + "print(f\"Model created with {sum(p.numel() for p in model.parameters())} parameters\")" + ] + }, + { + "cell_type": "markdown", + "id": "forward_pass", + "metadata": {}, + "source": [ + "## 5. Test Forward Pass" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "test_forward", + "metadata": {}, + "outputs": [], + "source": [ + "batch = next(iter(train_loader))\n", + "\n", + "with torch.no_grad():\n", + " outputs = model(**batch)\n", + "\n", + "print(\"Output keys:\", outputs.keys())\n", + "print(f\"Loss: {outputs['loss'].item():.4f}\")\n", + "print(f\"Logits shape: {outputs['logit'].shape}\")\n", + "print(f\"Tokens shape: {outputs['tokens'].shape}\")\n", + "print(f\"Embeddings shape: {outputs['embeddings'].shape}\")" + ] + }, + { + "cell_type": "markdown", + "id": "training", + "metadata": {}, + "source": [ + "## 6. Train Model (Optional)\n", + "\n", + "Train the model using PyHealth's Trainer:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "train", + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.trainer import Trainer\n", + "\n", + "trainer = Trainer(model=model, device=device)\n", + "trainer.train(\n", + " train_dataloader=train_loader,\n", + " val_dataloader=val_loader,\n", + " epochs=3,\n", + " monitor=\"accuracy\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "embeddings", + "metadata": {}, + "source": [ + "## 7. Extract Embeddings for Analysis\n", + "\n", + "Extract patient embeddings for downstream tasks like clustering or conformal prediction:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "extract_embeddings", + "metadata": {}, + "outputs": [], + "source": [ + "# Extract embeddings from test set\n", + "test_embeddings = model.get_embeddings(test_loader)\n", + "print(f\"Test embeddings shape: {test_embeddings.shape}\")\n", + "\n", + "# Get patient-level representation (mean pooling)\n", + "patient_embeddings = test_embeddings.mean(dim=1)\n", + "print(f\"Patient-level embeddings shape: {patient_embeddings.shape}\")" + ] + }, + { + "cell_type": "markdown", + "id": "tokens", + "metadata": {}, + "source": [ + "## 8. Extract Discrete Tokens" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "extract_tokens", + "metadata": {}, + "outputs": [], + "source": [ + "# Extract tokens from test set\n", + "test_tokens = model.get_tokens(test_loader)\n", + "print(f\"Test tokens shape: {test_tokens.shape}\")\n", + "\n", + "# Analyze token vocabulary usage\n", + "unique_tokens = torch.unique(test_tokens)\n", + "print(f\"Active tokens: {len(unique_tokens)} / {model.code_book_size}\")\n", + "print(f\"Token usage: {len(unique_tokens) / model.code_book_size * 100:.2f}%\")" + ] + }, + { + "cell_type": "markdown", + "id": "clustering", + "metadata": {}, + "source": [ + "## 9. Patient Clustering (Example)\n", + "\n", + "Use embeddings for k-means clustering:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "clustering_example", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.cluster import KMeans\n", + "\n", + "# Cluster patients based on embeddings\n", + "kmeans = KMeans(n_clusters=3, random_state=SEED)\n", + "clusters = kmeans.fit_predict(patient_embeddings.cpu().numpy())\n", + "\n", + "print(\"Cluster distribution:\")\n", + "unique, counts = np.unique(clusters, return_counts=True)\n", + "for cluster_id, count in zip(unique, counts):\n", + " print(f\" Cluster {cluster_id}: {count} patients ({count/len(clusters)*100:.1f}%)\")" + ] + }, + { + "cell_type": "markdown", + "id": "pretrained", + "metadata": {}, + "source": [ + "## 10. Loading Pre-trained Weights\n", + "\n", + "Load pre-trained weights from checkpoint:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "load_weights", + "metadata": {}, + "outputs": [], + "source": [ + "# Uncomment and set the path to load pre-trained weights\n", + "# model.load_pretrained_weights(\"path/to/tfm_encoder_best_model.pth\")\n", + "print(\"To load pre-trained weights:\")\n", + "print(\"model.load_pretrained_weights('tfm_encoder_best_model.pth')\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index b822dd57..2cb78494 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -23,6 +23,14 @@ from .stagenet import StageNet, StageNetLayer from .stagenet_mha import StageAttentionNet, StageNetAttentionLayer from .tcn import TCN, TCNLayer +from .tfm_tokenizer import ( + TFMTokenizer, + TFM_VQVAE2_deep, + TFM_TOKEN_Classifier, + get_tfm_tokenizer_2x2x8, + get_tfm_token_classifier_64x4, + load_embedding_weights, +) from .torchvision_model import TorchvisionModel from .transformer import Transformer, TransformerLayer from .transformers_model import TransformersModel diff --git a/pyhealth/models/tfm_tokenizer.py b/pyhealth/models/tfm_tokenizer.py new file mode 100644 index 00000000..8fe216f3 --- /dev/null +++ b/pyhealth/models/tfm_tokenizer.py @@ -0,0 +1,880 @@ +import math +from typing import Dict, Optional, Tuple, Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from linear_attention_transformer import LinearAttentionTransformer + +from pyhealth.datasets import SampleDataset +from pyhealth.models import BaseModel + +class PositionalEncoding(nn.Module): + """Positional encoding for transformer models. + + Args: + d_model: dimension of the model embedding. + dropout: dropout probability. Default is 0.1. + max_len: maximum sequence length. Default is 1000. + """ + + def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len).unsqueeze(1).float() + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + """Forward propagation. + + Args: + x: input embeddings of shape (batch, max_len, d_model). + + Returns: + output tensor of shape (batch, max_len, d_model). + """ + x = x + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class TransformerEncoder(nn.Module): + """Transformer encoder using linear attention. + + Args: + emb_size: embedding size. Default is 64. + num_heads: number of attention heads. Default is 8. + depth: number of transformer layers. Default is 4. + max_seq_len: maximum sequence length. Default is 1024. + """ + + def __init__( + self, + emb_size: int = 64, + num_heads: int = 8, + depth: int = 4, + max_seq_len: int = 1024, + ): + super().__init__() + + self.transformer = LinearAttentionTransformer( + dim=emb_size, + heads=num_heads, + depth=depth, + max_seq_len=max_seq_len, + attn_layer_dropout=0.2, + attn_dropout=0.2, + ) + + def forward(self, x): + """Forward propagation. + + Args: + x: input tensor of shape (batch, seq_len, emb_size). + + Returns: + output tensor of shape (batch, seq_len, emb_size). + """ + x = self.transformer(x) + return x + + +def l2norm(t): + return F.normalize(t, p=2, dim=-1) + + +class EMAVectorQuantizer(nn.Module): + """Exponential Moving Average Vector Quantizer. + + Args: + emb_size: dimensionality of embeddings. + code_book_size: number of codebook entries. + decay: exponential moving average decay factor. Default is 0.99. + eps: small constant for numerical stability. Default is 1e-5. + """ + + def __init__( + self, emb_size: int, code_book_size: int, decay: float = 0.99, eps: float = 1e-5 + ): + super().__init__() + self.emb_size = emb_size + self.code_book_size = code_book_size + self.decay = decay + self.eps = eps + + self.embedding = nn.Embedding(code_book_size, emb_size) + self.embedding.weight.data.uniform_(-1 / code_book_size, 1 / code_book_size) + + self.register_buffer("cluster_size", torch.zeros(code_book_size)) + self.register_buffer("ema_w", self.embedding.weight.data.clone()) + + def forward(self, x): + """Forward propagation. + + Args: + x: input tensor of shape (B, T, emb_size). + + Returns: + quantized: quantized vectors of shape (B, T, emb_size). + encoding_indices: indices of selected codebook entries of shape (B, T). + """ + flat_x = x.reshape(-1, self.emb_size) + + dist = ( + flat_x.pow(2).sum(dim=1, keepdim=True) + - 2 * flat_x @ self.embedding.weight.t() + + self.embedding.weight.pow(2).sum(dim=1, keepdim=True).t() + ) + + encoding_indices = torch.argmin(dist, dim=1) + quantized = self.embedding(encoding_indices).view_as(x) + + if self.training: + encodings_one_hot = F.one_hot(encoding_indices, self.code_book_size).type_as( + flat_x + ) + + new_cluster_size = encodings_one_hot.sum(dim=0) + self.cluster_size.data.mul_(self.decay).add_( + new_cluster_size, alpha=1 - self.decay + ) + + dw = encodings_one_hot.t() @ flat_x + self.ema_w.data.mul_(self.decay).add_(dw, alpha=1 - self.decay) + + n = self.cluster_size.sum() + cluster_size = ( + (self.cluster_size + self.eps) + / (n + self.code_book_size * self.eps) + * n + ) + + embed_normalized = self.ema_w / cluster_size.unsqueeze(1) + self.embedding.weight.data.copy_(embed_normalized) + + encoding_indices = encoding_indices.reshape(x.size(0), x.size(1)) + return quantized, encoding_indices + + +def freq_bin_temporal_masking( + X, + freq_mask_ratio: float = 0.5, + freq_bin_size: int = 5, + time_mask_ratio: float = 0.5, + time_bin_size: int = 10, +): + """Apply frequency-bin and temporal masking to spectrograms. + + Args: + X: input spectrogram of shape (B, F, T). + freq_mask_ratio: ratio of frequency bins to mask. Default is 0.5. + freq_bin_size: size of frequency bins. Default is 5. + time_mask_ratio: ratio of time bins to mask. Default is 0.5. + time_bin_size: size of time bins. Default is 10. + + Returns: + X_masked: masked spectrogram (unmasked regions). + X_masked_sym: inverse masked spectrogram (masked regions). + full_mask: boolean mask for unmasked regions. + full_mask_sym: boolean mask for masked regions. + """ + B, F, T = X.shape + + num_freq_bins = F // freq_bin_size + X_freq_binned = X.view(B, num_freq_bins, freq_bin_size, T) + freq_mask = torch.ones_like(X_freq_binned) + num_freq_bins_to_mask = int(num_freq_bins * freq_mask_ratio) + freq_bins_to_mask = torch.randperm(num_freq_bins)[:num_freq_bins_to_mask] + freq_mask[:, freq_bins_to_mask, ...] = 0 + freq_mask = freq_mask.view(B, F, T) + + num_time_bins = T // time_bin_size + X_time_binned = X.view(B, F, num_time_bins, time_bin_size) + time_mask = torch.ones_like(X_time_binned) + num_time_bins_to_mask = int(num_time_bins * time_mask_ratio) + time_bins_to_mask = torch.randperm(num_time_bins)[:num_time_bins_to_mask] + time_mask[:, :, time_bins_to_mask, :] = 0 + time_mask = time_mask.view(B, F, T) + + full_mask = freq_mask * time_mask + full_mask_sym = 1 - full_mask + full_mask = full_mask.to(torch.bool) + full_mask_sym = full_mask_sym.to(torch.bool) + X_masked = X * full_mask + X_masked_sym = X * full_mask_sym + + return X_masked, X_masked_sym, full_mask, full_mask_sym + + +class TFM_VQVAE2_deep(nn.Module): + """TFM-Tokenizer module with raw EEG and STFT as input. + + Args: + in_channels: number of input channels. Default is 1. + n_freq: number of frequency bins in STFT. Default is 100. + n_freq_patch: frequency patch size. Default is 5. + emb_size: embedding dimension. Default is 64. + code_book_size: size of the VQ codebook. Default is 8192. + trans_freq_encoder_depth: depth of frequency encoder. Default is 4. + trans_temporal_encoder_depth: depth of temporal encoder. Default is 4. + trans_decoder_depth: depth of decoder. Default is 4. + beta: weight for commitment loss. Default is 1.0. + """ + + def __init__( + self, + in_channels: int = 1, + n_freq: int = 100, + n_freq_patch: int = 5, + emb_size: int = 64, + code_book_size: int = 8192, + trans_freq_encoder_depth: int = 4, + trans_temporal_encoder_depth: int = 4, + trans_decoder_depth: int = 4, + beta: float = 1.0, + ): + super().__init__() + self.n_freq_patch = n_freq_patch + self.emb_size = emb_size + self.code_book_size = code_book_size + + # bin wise frequency embedding + self.freq_patch_embedding = nn.Sequential( + nn.Conv1d(in_channels, emb_size, kernel_size=n_freq_patch, stride=n_freq_patch), + nn.GELU(), + nn.GroupNorm(emb_size // 4, emb_size), + nn.Conv1d(emb_size, emb_size, kernel_size=1, stride=1), + nn.GELU(), + nn.GroupNorm(emb_size // 4, emb_size), + nn.Conv1d(emb_size, emb_size, kernel_size=1, stride=1), + nn.GELU(), + nn.GroupNorm(emb_size // 4, emb_size), + ) + + # Freq Encoder + self.trans_freq_encoder = TransformerEncoder( + emb_size=emb_size, + num_heads=8, + depth=trans_freq_encoder_depth, + max_seq_len=n_freq // n_freq_patch, + ) + + # Temporal embedding + self.temporal_patch_embedding = nn.Sequential( + nn.Conv1d(in_channels, emb_size, kernel_size=200, stride=100), + nn.GELU(), + nn.GroupNorm(emb_size // 4, emb_size), + nn.Conv1d(emb_size, emb_size, kernel_size=1, stride=1), + nn.GELU(), + nn.GroupNorm(emb_size // 4, emb_size), + nn.Conv1d(emb_size, emb_size // 2, kernel_size=1, stride=1), + nn.GELU(), + nn.GroupNorm(emb_size // 4, emb_size // 2), + ) + + # attention based aggregation + global_freq_divider = n_freq // (n_freq_patch * n_freq_patch) + self.freq_patch_embedding_2_atten = nn.Sequential( + nn.Conv1d( + emb_size, + emb_size // (global_freq_divider * 2), + kernel_size=n_freq_patch, + stride=n_freq_patch, + ), + nn.Sigmoid(), + ) + self.freq_patch_embedding_2 = nn.Sequential( + nn.Conv1d( + emb_size, + emb_size // (global_freq_divider * 2), + kernel_size=n_freq_patch, + stride=n_freq_patch, + ), + ) + + # Temporal Encoder + self.trans_temporal_encoder = TransformerEncoder( + emb_size=emb_size, num_heads=8, depth=trans_temporal_encoder_depth + ) + + # Vector quantization bottleneck + self.quantizer = EMAVectorQuantizer(emb_size, code_book_size) + self.beta = beta + + # Decoder + self.trans_decoder = TransformerEncoder( + emb_size=emb_size, num_heads=8, depth=trans_decoder_depth + ) + + # self.decoder = nn.Linear(emb_size, n_freq) + self.decoder = nn.Sequential( + nn.Linear(emb_size, emb_size), nn.Tanh(), nn.Linear(emb_size, n_freq) + ) + + @torch.jit.ignore + def no_weight_decay(self): + return {"quantizer.embedding.weight"} + + def tokenize(self, x, x_temporal): + """Tokenize EEG signals into discrete tokens. + + Args: + x: STFT spectrogram of shape (B, F, T). + x_temporal: raw temporal signal of shape (B, n_samples). + + Returns: + quant_out: quantized output. + indices: discrete token indices. + quant_in: input to quantizer (before quantization). + """ + B, F, T = x.shape + x = x.permute(0, 2, 1).reshape(-1, 1, F) + x = self.freq_patch_embedding(x) + x = x.permute(0, 2, 1) + + x = self.trans_freq_encoder(x) + + x = x.permute(0, 2, 1) + atten = self.freq_patch_embedding_2_atten(x) + x = self.freq_patch_embedding_2(x) * atten + x = x.reshape(-1, x.size(1) * x.size(2)) + + x = rearrange(x, "(B T) E -> B T E", T=T) + + x_temporal = x_temporal.unsqueeze(1) + x_temporal = self.temporal_patch_embedding(x_temporal) + x_temporal = rearrange(x_temporal, "B E T -> B T E") + + x = torch.cat((x, x_temporal), dim=-1) + + x = self.trans_temporal_encoder(x) + + quant_in = l2norm(x) + quant_out, indices = self.quantizer(quant_in) + + return quant_out, indices, quant_in + + def forward(self, x, x_temporal): + """Forward propagation. + + Args: + x: STFT spectrogram of shape (B, F, T). + x_temporal: raw temporal signal of shape (B, n_samples). + + Returns: + x: reconstructed STFT spectrogram. + indices: discrete token indices. + quant_out: quantized output. + quant_in: input to quantizer. + """ + quant_out, indices, quant_in = self.tokenize(x, x_temporal) + quant_out = quant_in + (quant_out - quant_in).detach() + x = self.trans_decoder(quant_out) + x = self.decoder(x) + x = x.permute(0, 2, 1) + + return x, indices, quant_out, quant_in + + def vec_quantizer_loss(self, quant_in, quant_out): + """Compute vector quantizer losses. + + Args: + quant_in: input to quantizer. + quant_out: output from quantizer. + + Returns: + loss: total VQ loss. + code_book_loss: codebook loss component. + commitment_loss: commitment loss component. + """ + commitment_loss = torch.mean((quant_out.detach() - quant_in) ** 2) + code_book_loss = torch.mean((quant_out - quant_in.detach()) ** 2) + loss = code_book_loss + self.beta * commitment_loss + return loss, code_book_loss, commitment_loss + + @torch.no_grad() + def forward_ana(self, x, x_temporal): + """Forward propagation with intermediate outputs for analysis. + + Returns: + x_dec: reconstructed output. + indices: quantizer indices. + quant_out: quantized representation. + quant_in: input to quantizer. + freq_encoded: frequency encoder tokens. + temporal_encoded: temporal encoder tokens. + """ + B, F, T = x.shape + + x_freq = x.permute(0, 2, 1).reshape(-1, 1, F) + x_freq = self.freq_patch_embedding(x_freq) + x_freq = x_freq.permute(0, 2, 1) + + freq_encoded = self.trans_freq_encoder(x_freq) + + x_freq_agg = freq_encoded.permute(0, 2, 1) + atten = self.freq_patch_embedding_2_atten(x_freq_agg) + x_freq_agg = self.freq_patch_embedding_2(x_freq_agg) * atten + x_freq_agg = x_freq_agg.reshape(-1, x_freq_agg.size(1) * x_freq_agg.size(2)) + x_freq_agg = rearrange(x_freq_agg, "(B T) E -> B T E", T=T) + + x_temporal_branch = x_temporal.unsqueeze(1) + x_temporal_branch = self.temporal_patch_embedding(x_temporal_branch) + x_temporal_branch = rearrange(x_temporal_branch, "B E T -> B T E") + + x_combined = torch.cat((x_freq_agg, x_temporal_branch), dim=-1) + + temporal_encoded = self.trans_temporal_encoder(x_combined) + + quant_in = l2norm(temporal_encoded) + quant_out, indices = self.quantizer(quant_in) + quant_out = quant_in + (quant_out - quant_in).detach() + + x_dec = self.trans_decoder(quant_out) + x_dec = self.decoder(x_dec) + x_dec = x_dec.permute(0, 2, 1) + + return x_dec, indices, quant_out, quant_in, freq_encoded, temporal_encoded + + +class TFM_TOKEN_Classifier(nn.Module): + """Downstream classifier using TFM tokens. + + Args: + emb_size: embedding dimension. Default is 256. + code_book_size: size of the VQ codebook. Default is 8192. + num_heads: number of attention heads. Default is 8. + depth: number of transformer layers. Default is 12. + max_seq_len: maximum sequence length. Default is 61. + n_classes: number of output classes. Default is 5. + """ + + def __init__( + self, + emb_size: int = 256, + code_book_size: int = 8192, + num_heads: int = 8, + depth: int = 12, + max_seq_len: int = 61, + n_classes: int = 5, + ): + super().__init__() + + self.eeg_token_embedding = nn.Embedding(code_book_size + 1, emb_size) + self.channel_embed = nn.Embedding(16, emb_size) + self.index = nn.Parameter(torch.LongTensor(range(16)), requires_grad=False) + self.temporal_pos_embed = PositionalEncoding(emb_size) + self.pos_drop = nn.Dropout(p=0.1) + self.cls_token = nn.Parameter(torch.zeros(1, 1, emb_size)) + + self.LAT = LinearAttentionTransformer( + dim=emb_size, + heads=num_heads, + depth=depth, + max_seq_len=max_seq_len, + attn_layer_dropout=0.2, + attn_dropout=0.2, + ) + + self.classification_head = nn.Linear(emb_size, n_classes) + + def forward(self, x, num_ch: int = 16): + """Forward propagation. + + Args: + x: token indices of shape (B, C, T). + num_ch: number of channels. Default is 16. + + Returns: + pred: class predictions of shape (B, n_classes). + """ + x = self.eeg_token_embedding(x) + + for i in range(x.shape[1]): + used_channel_embed = ( + self.channel_embed(self.index[i]) + .unsqueeze(0) + .unsqueeze(0) + .expand(x.size(0), -1, -1) + ) + x[:, i] = self.temporal_pos_embed(x[:, i] + used_channel_embed) + + x = rearrange(x, "B C T E -> B (C T) E") + + cls_tokens = self.cls_token.expand(x.size(0), -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + x = self.LAT(x) + pred = self.classification_head(x[:, 0]) + return pred + + def masked_prediction(self, x, num_ch: int = 16): + """Forward propagation with masked prediction (all tokens). + + Args: + x: token indices of shape (B, C, T). + num_ch: number of channels. Default is 16. + + Returns: + pred: predictions for all tokens (excluding CLS). + """ + x = self.eeg_token_embedding(x) + + for i in range(x.shape[1]): + used_channel_embed = ( + self.channel_embed(self.index[i]) + .unsqueeze(0) + .unsqueeze(0) + .expand(x.size(0), -1, -1) + ) + x[:, i] = self.temporal_pos_embed(x[:, i] + used_channel_embed) + + x = rearrange(x, "B C T E -> B (C T) E") + + cls_tokens = self.cls_token.expand(x.size(0), -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + x = self.LAT(x) + pred = self.classification_head(x[:, 1:]) + return pred + + @torch.jit.ignore + def no_weight_decay(self): + return {"temporal_pos_embed", "cls_token"} + + +def get_tfm_tokenizer_2x2x8(code_book_size: int = 8192, emb_size: int = 64): + """Create TFM-Tokenizer with 2x2x8 architecture. + + Args: + code_book_size: size of the VQ codebook. Default is 8192. + emb_size: embedding dimension. Default is 64. + + Returns: + TFM_VQVAE2_deep model instance. + """ + vqvae = TFM_VQVAE2_deep( + in_channels=1, + n_freq=100, + n_freq_patch=5, + emb_size=emb_size, + code_book_size=code_book_size, + trans_freq_encoder_depth=2, + trans_temporal_encoder_depth=2, + trans_decoder_depth=8, + beta=1.0, + ) + return vqvae + + +def get_tfm_token_classifier_64x4( + n_classes: int = 5, code_book_size: int = 8192, emb_size: int = 64 +): + """Create TFM-Token classifier with 64x4 architecture. + + Args: + n_classes: number of output classes. Default is 5. + code_book_size: size of the VQ codebook. Default is 8192. + emb_size: embedding dimension. Default is 64. + + Returns: + TFM_TOKEN_Classifier model instance. + """ + classifier = TFM_TOKEN_Classifier( + emb_size=emb_size, + code_book_size=code_book_size, + num_heads=8, + depth=4, + max_seq_len=2048, + n_classes=n_classes, + ) + return classifier + + +def load_embedding_weights(source_model, target_model): + """Load embedding weights from tokenizer to classifier. + + Args: + source_model: the tokenizer model (TFM_VQVAE2_deep). + target_model: the classifier model (TFM_TOKEN_Classifier). + """ + source_weights = source_model.quantizer.embedding.weight.data + target_weights = target_model.eeg_token_embedding.weight.data + + src_vocab_size, src_emb_dim = source_weights.shape + tgt_vocab_size, tgt_emb_dim = target_weights.shape + + print(f"Source Embedding Shape: {source_weights.shape}") + print(f"Target Embedding Shape: {target_weights.shape}") + + if src_emb_dim != tgt_emb_dim: + raise ValueError( + f"Embedding size mismatch: {src_emb_dim} (source) vs {tgt_emb_dim} (target)" + ) + + if src_vocab_size > tgt_vocab_size: + adapted_weights = source_weights[:tgt_vocab_size, :] + print(f"Trimming source embeddings from {src_vocab_size} to {tgt_vocab_size}") + elif src_vocab_size < tgt_vocab_size: + adapted_weights = torch.zeros( + (tgt_vocab_size, tgt_emb_dim), dtype=source_weights.dtype + ) + adapted_weights[:src_vocab_size, :] = source_weights + print(f"Padding source embeddings from {src_vocab_size} to {tgt_vocab_size}") + else: + adapted_weights = source_weights + + target_model.eeg_token_embedding.weight.data.copy_(adapted_weights) + print("Successfully loaded embedding weights!") + + +class TFMTokenizer(BaseModel): + """TFM-Tokenizer model. + + This model uses VQ-VAE with transformers to tokenize EEG signals. It can + extract discrete tokens and continuous embeddings for downstream tasks. + + The model expects two inputs: + - STFT spectrogram: shape (batch, n_freq, n_time) + - Raw temporal signal: shape (batch, n_samples) + + Args: + dataset: the dataset to train the model. + emb_size: embedding dimension. Default is 64. + code_book_size: size of the VQ codebook. Default is 8192. + n_freq: number of frequency bins in STFT. Default is 100. + n_freq_patch: frequency patch size. Default is 5. + trans_freq_encoder_depth: depth of frequency encoder. Default is 2. + trans_temporal_encoder_depth: depth of temporal encoder. Default is 2. + trans_decoder_depth: depth of decoder. Default is 8. + use_classifier: whether to use the classifier head. Default is True. + classifier_depth: depth of classifier transformer. Default is 4. + classifier_heads: number of attention heads in classifier. Default is 8. + + Examples: + >>> from pyhealth.datasets import TUEVDataset + >>> from pyhealth.models import TFMTokenizer + >>> dataset = TUEVDataset(root="/path/to/tuev") + >>> sample_dataset = dataset.set_task() + >>> model = TFMTokenizer(dataset=sample_dataset) + >>> model.load_pretrained_weights("tfm_encoder_best_model.pth") + """ + + def __init__( + self, + dataset: SampleDataset, + emb_size: int = 64, + code_book_size: int = 8192, + n_freq: int = 100, + n_freq_patch: int = 5, + trans_freq_encoder_depth: int = 2, + trans_temporal_encoder_depth: int = 2, + trans_decoder_depth: int = 8, + use_classifier: bool = True, + classifier_depth: int = 4, + classifier_heads: int = 8, + **kwargs, + ): + super().__init__(dataset=dataset) + + self.emb_size = emb_size + self.code_book_size = code_book_size + self.use_classifier = use_classifier + + self.tokenizer = TFM_VQVAE2_deep( + in_channels=1, + n_freq=n_freq, + n_freq_patch=n_freq_patch, + emb_size=emb_size, + code_book_size=code_book_size, + trans_freq_encoder_depth=trans_freq_encoder_depth, + trans_temporal_encoder_depth=trans_temporal_encoder_depth, + trans_decoder_depth=trans_decoder_depth, + beta=1.0, + ) + + if use_classifier: + output_size = self.get_output_size() + self.classifier = TFM_TOKEN_Classifier( + emb_size=emb_size, + code_book_size=code_book_size, + num_heads=classifier_heads, + depth=classifier_depth, + max_seq_len=2048, + n_classes=output_size, + ) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward propagation. + + Args: + **kwargs: keyword arguments containing 'stft', 'signal', and label key. + + Returns: + a dictionary containing loss, y_prob, y_true, logit, tokens, embeddings. + """ + stft = kwargs.get("stft") + signal = kwargs.get("signal") + + if stft is None or signal is None: + raise ValueError("Both 'stft' and 'signal' must be provided in inputs") + + stft = stft.to(self.device) + signal = signal.to(self.device) + + reconstructed, tokens, quant_out, quant_in = self.tokenizer(stft, signal) + + recon_loss = F.mse_loss(reconstructed, stft) + vq_loss, _, _ = self.tokenizer.vec_quantizer_loss(quant_in, quant_out) + + results = { + "recon_loss": recon_loss, + "vq_loss": vq_loss, + "tokens": tokens, + "embeddings": quant_out, + } + + if self.use_classifier and len(self.label_keys) > 0: + label_key = self.label_keys[0] + y_true = kwargs[label_key].to(self.device) + + # Reshape tokens to (B, C, T) for multi-channel classifier + # tokens shape: (B, T) -> (B, 1, T) + tokens_reshaped = tokens.unsqueeze(1) + logits = self.classifier(tokens_reshaped) + loss_fn = self.get_loss_function() + cls_loss = loss_fn(logits, y_true) + total_loss = recon_loss + vq_loss + cls_loss + y_prob = self.prepare_y_prob(logits) + + results.update( + { + "loss": total_loss, + "cls_loss": cls_loss, + "y_prob": y_prob, + "y_true": y_true, + "logit": logits, + } + ) + else: + results["loss"] = recon_loss + vq_loss + + return results + + def get_embeddings(self, dataloader) -> torch.Tensor: + """Extract continuous embeddings for all samples in a dataloader. + + Args: + dataloader: PyHealth dataloader. + + Returns: + tensor of shape (n_samples, seq_len, emb_size). + """ + self.eval() + all_embeddings = [] + + with torch.no_grad(): + for batch in dataloader: + stft = batch.get("stft").to(self.device) + signal = batch.get("signal").to(self.device) + _, _, quant_out, _ = self.tokenizer(stft, signal) + all_embeddings.append(quant_out.cpu()) + + return torch.cat(all_embeddings, dim=0) + + def get_tokens(self, dataloader) -> torch.Tensor: + """Extract discrete tokens for all samples in a dataloader. + + Args: + dataloader: PyHealth dataloader. + + Returns: + tensor of shape (n_samples, seq_len). + """ + self.eval() + all_tokens = [] + + with torch.no_grad(): + for batch in dataloader: + stft = batch.get("stft").to(self.device) + signal = batch.get("signal").to(self.device) + _, tokens, _, _ = self.tokenizer(stft, signal) + all_tokens.append(tokens.cpu()) + + return torch.cat(all_tokens, dim=0) + + def load_pretrained_weights( + self, checkpoint_path: str, strict: bool = True, map_location: str = None + ): + """Load pre-trained weights from checkpoint. + + Args: + checkpoint_path: path to the checkpoint file. + strict: whether to strictly enforce key matching. Default is True. + map_location: device to map the loaded tensors. Default is None. + """ + if map_location is None: + map_location = str(self.device) + + checkpoint = torch.load(checkpoint_path, map_location=map_location) + + if "model_state_dict" in checkpoint: + state_dict = checkpoint["model_state_dict"] + elif "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + + try: + self.tokenizer.load_state_dict(state_dict, strict=strict) + print(f"✓ Successfully loaded weights from {checkpoint_path}") + except RuntimeError as e: + print(f"Warning: Could not load weights with strict={strict}: {e}") + if strict: + print("Retrying with strict=False...") + self.tokenizer.load_state_dict(state_dict, strict=False) + print("✓ Loaded weights with strict=False") + + if self.use_classifier and hasattr(self, "classifier"): + try: + load_embedding_weights(self.tokenizer, self.classifier) + except Exception as e: + print(f"Note: Could not transfer embeddings to classifier: {e}") + + +if __name__ == "__main__": + print("Testing TFM-Tokenizer components...") + + tokenizer = get_tfm_tokenizer_2x2x8() + print(f"✓ Created tokenizer: {tokenizer.__class__.__name__}") + + classifier = get_tfm_token_classifier_64x4(n_classes=6) + print(f"✓ Created classifier: {classifier.__class__.__name__}") + + batch_size = 2 + n_freq = 100 + n_time = 60 + n_samples = 1280 + + dummy_stft = torch.randn(batch_size, n_freq, n_time) + dummy_signal = torch.randn(batch_size, n_samples) + + recon, tokens, quant_out, quant_in = tokenizer(dummy_stft, dummy_signal) + print(f"✓ Tokenizer forward pass:") + print(f" Reconstructed shape: {recon.shape}") + print(f" Tokens shape: {tokens.shape}") + print(f" Embeddings shape: {quant_out.shape}") + + preds = classifier(tokens) + print(f"✓ Classifier forward pass:") + print(f" Predictions shape: {preds.shape}") + + print("\n✓ All tests passed!") diff --git a/pyproject.toml b/pyproject.toml index c96ca785..e0fe4c8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,8 @@ dependencies = [ "pyarrow~=22.0.0", "narwhals~=2.13.0", "more-itertools~=10.8.0", + "einops>=0.8.0", + "linear-attention-transformer>=0.19.1", ] license = "BSD-3-Clause" license-files = ["LICENSE.md"] diff --git a/tests/core/test_tfm_tokenizer.py b/tests/core/test_tfm_tokenizer.py new file mode 100644 index 00000000..e52c87d4 --- /dev/null +++ b/tests/core/test_tfm_tokenizer.py @@ -0,0 +1,158 @@ +import unittest +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import TFMTokenizer, get_tfm_tokenizer_2x2x8, get_tfm_token_classifier_64x4 + + +class TestTFMTokenizer(unittest.TestCase): + """Test cases for the TFMTokenizer model.""" + + def setUp(self): + """Set up test data and model.""" + # Create dummy EEG-style samples with STFT and signal inputs + self.samples = [ + { + "patient_id": "patient-0", + "visit_id": "visit-0", + "stft": torch.randn(100, 60).numpy().tolist(), # (n_freq, n_time) + "signal": torch.randn(6100).numpy().tolist(), # (n_samples,) + "label": 1, + }, + { + "patient_id": "patient-1", + "visit_id": "visit-0", + "stft": torch.randn(100, 60).numpy().tolist(), + "signal": torch.randn(6100).numpy().tolist(), + "label": 0, + }, + ] + + self.input_schema = { + "stft": "tensor", + "signal": "tensor", + } + self.output_schema = {"label": "binary"} + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="test", + ) + + self.model = TFMTokenizer( + dataset=self.dataset, + emb_size=64, + code_book_size=128, # Small for testing + use_classifier=True, + ) + + def test_model_initialization(self): + """Test that the TFMTokenizer model initializes correctly.""" + self.assertIsInstance(self.model, TFMTokenizer) + self.assertEqual(self.model.emb_size, 64) + self.assertEqual(self.model.code_book_size, 128) + self.assertTrue(self.model.use_classifier) + self.assertEqual(len(self.model.feature_keys), 2) + self.assertIn("stft", self.model.feature_keys) + self.assertIn("signal", self.model.feature_keys) + + def test_model_forward(self): + """Test that the TFMTokenizer forward pass works correctly.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + self.assertIn("logit", ret) + self.assertIn("tokens", ret) + self.assertIn("embeddings", ret) + + self.assertEqual(ret["y_prob"].shape[0], 2) + self.assertEqual(ret["y_true"].shape[0], 2) + self.assertEqual(ret["logit"].shape[0], 2) + self.assertEqual(ret["loss"].dim(), 0) + self.assertEqual(ret["tokens"].shape[0], 2) + self.assertEqual(ret["embeddings"].shape[0], 2) + + def test_model_backward(self): + """Test that the TFMTokenizer backward pass works correctly.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + ret = self.model(**data_batch) + ret["loss"].backward() + + has_gradient = any( + param.requires_grad and param.grad is not None + for param in self.model.parameters() + ) + self.assertTrue(has_gradient, "No parameters have gradients after backward pass") + + def test_tokenizer_only(self): + """Test TFMTokenizer without classifier.""" + model_no_classifier = TFMTokenizer( + dataset=self.dataset, + emb_size=64, + code_book_size=128, + use_classifier=False, + ) + + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=True) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = model_no_classifier(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("tokens", ret) + self.assertIn("embeddings", ret) + self.assertNotIn("y_prob", ret) + self.assertNotIn("logit", ret) + + def test_get_embeddings(self): + """Test extraction of embeddings from dataloader.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + + embeddings = self.model.get_embeddings(train_loader) + + self.assertEqual(embeddings.shape[0], 2) # 2 samples + self.assertEqual(embeddings.shape[2], 64) # emb_size + + def test_get_tokens(self): + """Test extraction of tokens from dataloader.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + + tokens = self.model.get_tokens(train_loader) + + self.assertEqual(tokens.shape[0], 2) # 2 samples + self.assertTrue(torch.all(tokens >= 0)) + self.assertTrue(torch.all(tokens < 128)) # Within codebook size + + +class TestTFMTokenizerFactories(unittest.TestCase): + """Test factory functions for TFM-Tokenizer.""" + + def test_get_tfm_tokenizer_2x2x8(self): + """Test factory function for tokenizer.""" + tokenizer = get_tfm_tokenizer_2x2x8(code_book_size=512, emb_size=64) + + self.assertIsNotNone(tokenizer) + self.assertEqual(tokenizer.code_book_size, 512) + self.assertEqual(tokenizer.emb_size, 64) + + def test_get_tfm_token_classifier_64x4(self): + """Test factory function for classifier.""" + classifier = get_tfm_token_classifier_64x4(n_classes=5, code_book_size=512) + + self.assertIsNotNone(classifier) + self.assertEqual(classifier.classification_head.out_features, 5) + + +if __name__ == "__main__": + unittest.main()