Skip to content

Commit 3193230

Browse files
Validation dataloader so that train_size gets used (#39)
* Validation dataloader so that `train_size` gets used * use XLocator * filter query and layer_name * use random_split * set seed for the validation indices * seed works; no need to print validation indices anymore * fix linter issues * abstract dataloader function shared by train and val
1 parent 2d60468 commit 3193230

File tree

1 file changed

+100
-8
lines changed

1 file changed

+100
-8
lines changed

src/tiledbsoma_ml/scvi.py

Lines changed: 100 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import os
4+
from enum import Enum
45
from typing import Any, Sequence
56

67
import pandas as pd
@@ -12,6 +13,8 @@
1213

1314
from tiledbsoma_ml import ExperimentDataset, experiment_dataloader
1415
from tiledbsoma_ml._common import MiniBatch
16+
from tiledbsoma_ml._query_ids import QueryIDs
17+
from tiledbsoma_ml.x_locator import XLocator
1518

1619
DEFAULT_DATALOADER_KWARGS: dict[str, Any] = {
1720
"pin_memory": torch.cuda.is_available(),
@@ -20,6 +23,13 @@
2023
}
2124

2225

26+
class DatasetSplit(Enum):
27+
"""Enum for dataset splits."""
28+
29+
TRAIN = "train"
30+
VAL = "val"
31+
32+
2333
class SCVIDataModule(LightningDataModule): # type: ignore[misc]
2434
"""PyTorch Lightning DataModule for training scVI models from SOMA data.
2535
@@ -38,6 +48,8 @@ def __init__(
3848
batch_column_names: Sequence[str] | None = None,
3949
batch_labels: Sequence[str] | None = None,
4050
dataloader_kwargs: dict[str, Any] | None = None,
51+
train_size: float = 1.0,
52+
seed: int = 42,
4153
**kwargs: Any,
4254
):
4355
"""Args:
@@ -63,6 +75,13 @@ def __init__(
6375
6476
dataloader_kwargs: dict, optional
6577
Keyword arguments passed to `tiledbsoma_ml.experiment_dataloader()`, e.g. `num_workers`.
78+
79+
train_size: float, optional
80+
Fraction of data to use for training (between 0 and 1). Default is 1.0 (use all data for training).
81+
If less than 1.0, the remaining data will be used for validation.
82+
83+
seed: int, optional
84+
Random seed for deterministic train/validation split. Default is 42.
6685
"""
6786
super().__init__()
6887
self.query = query
@@ -93,22 +112,95 @@ def __init__(
93112
batch_labels = obs_df[self.batch_colname].unique()
94113
self.batch_labels = batch_labels
95114
self.batch_encoder = LabelEncoder().fit(self.batch_labels)
115+
self.train_size = train_size
116+
self.seed = seed
117+
self.train_query_ids: QueryIDs | None = None
118+
self.val_query_ids: QueryIDs | None = None
119+
self.x_locator: XLocator | None = None
120+
self.layer_name = kwargs.get("layer_name", "raw")
96121

97122
def setup(self, stage: str | None = None) -> None:
98-
# Instantiate the ExperimentDataset with the provided args and kwargs.
99-
self.train_dataset = ExperimentDataset(
100-
self.query,
101-
*self.dataset_args,
102-
obs_column_names=self.batch_column_names, # type: ignore[arg-type]
103-
**self.dataset_kwargs, # type: ignore[misc]
123+
# Create QueryIDs and XLocator from the query
124+
query_ids = QueryIDs.create(self.query)
125+
self.x_locator = XLocator.create(
126+
self.query.experiment,
127+
measurement_name=self.query.measurement_name,
128+
layer_name=self.layer_name,
104129
)
105130

106-
def train_dataloader(self) -> DataLoader:
131+
# Split data into train and validation sets if train_size < 1.0
132+
if self.train_size < 1.0:
133+
# Use QueryIDs.random_split() for efficient splitting
134+
val_size = 1.0 - self.train_size
135+
train_ids, val_ids = query_ids.random_split(
136+
self.train_size, val_size, seed=self.seed
137+
)
138+
self.train_query_ids = train_ids
139+
self.val_query_ids = val_ids
140+
else:
141+
# Use all data for training
142+
self.train_query_ids = query_ids
143+
self.val_query_ids = None
144+
145+
def _create_dataloader(self, split: DatasetSplit) -> DataLoader | None:
146+
"""Create a dataloader for the specified dataset split.
147+
148+
Args:
149+
split: The dataset split (TRAIN or VAL)
150+
151+
Returns:
152+
DataLoader for the specified split, or None if the split doesn't exist
153+
"""
154+
# Get the appropriate query_ids based on split
155+
query_ids_map = {
156+
DatasetSplit.TRAIN: self.train_query_ids,
157+
DatasetSplit.VAL: self.val_query_ids,
158+
}
159+
160+
query_ids = query_ids_map.get(split)
161+
if query_ids is None or self.x_locator is None:
162+
return None
163+
164+
# Filter out query and layer_name from dataset_kwargs since we're using x_locator and query_ids
165+
filtered_kwargs = {
166+
k: v
167+
for k, v in self.dataset_kwargs.items()
168+
if k not in ("query", "layer_name")
169+
}
170+
171+
# Create dataset with appropriate query_ids
172+
dataset = ExperimentDataset(
173+
x_locator=self.x_locator,
174+
query_ids=query_ids,
175+
obs_column_names=list(self.batch_column_names),
176+
**filtered_kwargs,
177+
)
107178
return experiment_dataloader(
108-
self.train_dataset,
179+
dataset,
109180
**self.dataloader_kwargs,
110181
)
111182

183+
def train_dataloader(self) -> DataLoader:
184+
"""Create the training dataloader.
185+
186+
Returns:
187+
DataLoader for training data
188+
189+
Raises:
190+
AssertionError: If setup() hasn't been called
191+
"""
192+
loader = self._create_dataloader(DatasetSplit.TRAIN)
193+
assert loader is not None, "setup() must be called before train_dataloader()"
194+
return loader
195+
196+
def val_dataloader(self) -> DataLoader | None:
197+
"""Create the validation dataloader.
198+
199+
Returns:
200+
DataLoader for validation data, or None if no validation split exists
201+
"""
202+
return self._create_dataloader(DatasetSplit.VAL)
203+
112204
def _add_batch_col(
113205
self, obs_df: pd.DataFrame, inplace: bool = False
114206
) -> pd.DataFrame:

0 commit comments

Comments
 (0)