11from __future__ import annotations
22
33import os
4+ from enum import Enum
45from typing import Any , Sequence
56
67import pandas as pd
1213
1314from tiledbsoma_ml import ExperimentDataset , experiment_dataloader
1415from tiledbsoma_ml ._common import MiniBatch
16+ from tiledbsoma_ml ._query_ids import QueryIDs
17+ from tiledbsoma_ml .x_locator import XLocator
1518
1619DEFAULT_DATALOADER_KWARGS : dict [str , Any ] = {
1720 "pin_memory" : torch .cuda .is_available (),
2023}
2124
2225
26+ class DatasetSplit (Enum ):
27+ """Enum for dataset splits."""
28+
29+ TRAIN = "train"
30+ VAL = "val"
31+
32+
2333class SCVIDataModule (LightningDataModule ): # type: ignore[misc]
2434 """PyTorch Lightning DataModule for training scVI models from SOMA data.
2535
@@ -38,6 +48,8 @@ def __init__(
3848 batch_column_names : Sequence [str ] | None = None ,
3949 batch_labels : Sequence [str ] | None = None ,
4050 dataloader_kwargs : dict [str , Any ] | None = None ,
51+ train_size : float = 1.0 ,
52+ seed : int = 42 ,
4153 ** kwargs : Any ,
4254 ):
4355 """Args:
@@ -63,6 +75,13 @@ def __init__(
6375
6476 dataloader_kwargs: dict, optional
6577 Keyword arguments passed to `tiledbsoma_ml.experiment_dataloader()`, e.g. `num_workers`.
78+
79+ train_size: float, optional
80+ Fraction of data to use for training (between 0 and 1). Default is 1.0 (use all data for training).
81+ If less than 1.0, the remaining data will be used for validation.
82+
83+ seed: int, optional
84+ Random seed for deterministic train/validation split. Default is 42.
6685 """
6786 super ().__init__ ()
6887 self .query = query
@@ -93,22 +112,95 @@ def __init__(
93112 batch_labels = obs_df [self .batch_colname ].unique ()
94113 self .batch_labels = batch_labels
95114 self .batch_encoder = LabelEncoder ().fit (self .batch_labels )
115+ self .train_size = train_size
116+ self .seed = seed
117+ self .train_query_ids : QueryIDs | None = None
118+ self .val_query_ids : QueryIDs | None = None
119+ self .x_locator : XLocator | None = None
120+ self .layer_name = kwargs .get ("layer_name" , "raw" )
96121
97122 def setup (self , stage : str | None = None ) -> None :
98- # Instantiate the ExperimentDataset with the provided args and kwargs.
99- self . train_dataset = ExperimentDataset (
100- self . query ,
101- * self .dataset_args ,
102- obs_column_names = self .batch_column_names , # type: ignore[arg-type]
103- ** self .dataset_kwargs , # type: ignore[misc]
123+ # Create QueryIDs and XLocator from the query
124+ query_ids = QueryIDs . create ( self . query )
125+ self . x_locator = XLocator . create (
126+ self .query . experiment ,
127+ measurement_name = self .query . measurement_name ,
128+ layer_name = self .layer_name ,
104129 )
105130
106- def train_dataloader (self ) -> DataLoader :
131+ # Split data into train and validation sets if train_size < 1.0
132+ if self .train_size < 1.0 :
133+ # Use QueryIDs.random_split() for efficient splitting
134+ val_size = 1.0 - self .train_size
135+ train_ids , val_ids = query_ids .random_split (
136+ self .train_size , val_size , seed = self .seed
137+ )
138+ self .train_query_ids = train_ids
139+ self .val_query_ids = val_ids
140+ else :
141+ # Use all data for training
142+ self .train_query_ids = query_ids
143+ self .val_query_ids = None
144+
145+ def _create_dataloader (self , split : DatasetSplit ) -> DataLoader | None :
146+ """Create a dataloader for the specified dataset split.
147+
148+ Args:
149+ split: The dataset split (TRAIN or VAL)
150+
151+ Returns:
152+ DataLoader for the specified split, or None if the split doesn't exist
153+ """
154+ # Get the appropriate query_ids based on split
155+ query_ids_map = {
156+ DatasetSplit .TRAIN : self .train_query_ids ,
157+ DatasetSplit .VAL : self .val_query_ids ,
158+ }
159+
160+ query_ids = query_ids_map .get (split )
161+ if query_ids is None or self .x_locator is None :
162+ return None
163+
164+ # Filter out query and layer_name from dataset_kwargs since we're using x_locator and query_ids
165+ filtered_kwargs = {
166+ k : v
167+ for k , v in self .dataset_kwargs .items ()
168+ if k not in ("query" , "layer_name" )
169+ }
170+
171+ # Create dataset with appropriate query_ids
172+ dataset = ExperimentDataset (
173+ x_locator = self .x_locator ,
174+ query_ids = query_ids ,
175+ obs_column_names = list (self .batch_column_names ),
176+ ** filtered_kwargs ,
177+ )
107178 return experiment_dataloader (
108- self . train_dataset ,
179+ dataset ,
109180 ** self .dataloader_kwargs ,
110181 )
111182
183+ def train_dataloader (self ) -> DataLoader :
184+ """Create the training dataloader.
185+
186+ Returns:
187+ DataLoader for training data
188+
189+ Raises:
190+ AssertionError: If setup() hasn't been called
191+ """
192+ loader = self ._create_dataloader (DatasetSplit .TRAIN )
193+ assert loader is not None , "setup() must be called before train_dataloader()"
194+ return loader
195+
196+ def val_dataloader (self ) -> DataLoader | None :
197+ """Create the validation dataloader.
198+
199+ Returns:
200+ DataLoader for validation data, or None if no validation split exists
201+ """
202+ return self ._create_dataloader (DatasetSplit .VAL )
203+
112204 def _add_batch_col (
113205 self , obs_df : pd .DataFrame , inplace : bool = False
114206 ) -> pd .DataFrame :
0 commit comments