55from __future__ import annotations
66
77import logging
8- from typing import Iterator , List , Optional , Sequence , Tuple
8+ from enum import Enum
9+ from typing import Any , Iterator , List , Optional , Sequence , Tuple
910
1011import numpy as np
1112import torch
3233DEFAULT_IO_BATCH_SIZE = 2 ** 16
3334
3435
36+ class ShuffleMode (str , Enum ):
37+ """Shuffling backend selection."""
38+
39+ CPU = "cpu"
40+ GPU_IOBATCH = "gpu_iobatch" # Emulate CPU, shuffling at io batch level
41+ GPU_MINIBATCH = "gpu_minibatch" # Only shuffle the mini batch at the gpu
42+
43+
44+ def _shuffle_mode_converter (v : Any ) -> ShuffleMode :
45+ if isinstance (v , ShuffleMode ):
46+ return v
47+ if isinstance (v , str ):
48+ v = v .lower ()
49+ if v == "gpu" : # Simplicity alias
50+ return ShuffleMode .GPU_IOBATCH
51+ return ShuffleMode (v ) # "cpu" | "gpu_iobatch" | "gpu_minibatch"
52+ return ShuffleMode (v )
53+
54+
3555@define
3656class ExperimentDataset (IterableDataset [MiniBatch ]): # type: ignore[misc]
3757 r"""An |IterableDataset| implementation that reads from an |ExperimentAxisQuery|.
@@ -117,7 +137,7 @@ class ExperimentDataset(IterableDataset[MiniBatch]): # type: ignore[misc]
117137 """Names of ``obs`` columns to return."""
118138
119139 # Configuration fields with defaults
120- batch_size : int = field (default = 1 , validator = and_ (instance_of (int ), gt (0 )))
140+ batch_size : int = field (default = 1024 , validator = and_ (instance_of (int ), gt (0 )))
121141 """Number of rows of ``X`` and ``obs`` data to yield in each |MiniBatch|."""
122142 io_batch_size : int = field (
123143 default = DEFAULT_IO_BATCH_SIZE , validator = and_ (instance_of (int ), gt (0 ))
@@ -135,6 +155,17 @@ class ExperimentDataset(IterableDataset[MiniBatch]): # type: ignore[misc]
135155 use_eager_fetch : bool = field (default = True )
136156 """Pre-fetch one "IO batch" and one "mini batch"."""
137157
158+ # GPU Shuffle Config
159+ shuffle_mode : ShuffleMode = field (
160+ default = ShuffleMode .CPU , converter = _shuffle_mode_converter
161+ )
162+ """Whether to shuffle on cpu or gpu (and at what granularity).
163+
164+ Only read when shuffle=True
165+ """
166+ device : Optional [torch .device ] = field (default = None )
167+ """Device to move X to; set to torch.device('cuda', N) to enable GPU shuffle."""
168+
138169 # Internal state
139170 epoch : int = field (default = 0 , init = False )
140171 rank : int = field (init = False )
@@ -154,6 +185,8 @@ def __init__(
154185 seed : Optional [int ] = None ,
155186 return_sparse_X : bool = False ,
156187 use_eager_fetch : bool = True ,
188+ shuffle_mode : ShuffleMode = ShuffleMode .CPU ,
189+ device : Optional [torch .device ] = None ,
157190 ):
158191 r"""Construct a new |ExperimentDataset|.
159192
@@ -223,6 +256,7 @@ def __init__(
223256 In addition, when using shuffling in a distributed configuration (e.g., ``DDP``), you must provide a seed,
224257 ensuring that the same shuffle is used across all replicas.
225258 """
259+
226260 if query and layer_name :
227261 if x_locator or query_ids :
228262 raise ValueError (
@@ -255,21 +289,30 @@ def __init__(
255289 seed = seed ,
256290 return_sparse_X = return_sparse_X ,
257291 use_eager_fetch = use_eager_fetch ,
292+ shuffle_mode = shuffle_mode ,
293+ device = device ,
258294 )
259295
260296 def __attrs_post_init__ (self ) -> None :
261297 """Validate configuration and initialize distributed state."""
262298 obs_column_names = self .obs_column_names
263299 if not obs_column_names :
264300 raise ValueError ("Must specify at least one value in `obs_column_names`" )
265-
266301 if self .shuffle :
267302 # Verify `io_batch_size` is a multiple of `shuffle_chunk_size`
268303 if self .io_batch_size % self .shuffle_chunk_size :
269304 raise ValueError (
270305 f"{ self .io_batch_size = } is not a multiple of { self .shuffle_chunk_size = } "
271306 )
272307
308+ # Sanity Check for GPU Shuffle
309+ if self .shuffle and self .shuffle_mode != ShuffleMode .CPU :
310+ if self .device is None or getattr (self .device , "type" , None ) != "cuda" :
311+ logger .warning (
312+ "GPU shuffle requested but `device` is not CUDA; defaulting to CPU within-IO shuffle."
313+ )
314+ object .__setattr__ (self , "shuffle_mode" , ShuffleMode .CPU )
315+
273316 if self .seed is None :
274317 object .__setattr__ (
275318 self , "seed" , np .random .default_rng ().integers (0 , 2 ** 32 - 1 )
@@ -333,7 +376,6 @@ def __iter__(self) -> Iterator[MiniBatch]:
333376 experimental
334377 """
335378 self ._multiproc_check ()
336-
337379 worker_id , n_workers = get_worker_id_and_num ()
338380 partition = Partition (
339381 rank = self .rank ,
@@ -342,15 +384,25 @@ def __iter__(self) -> Iterator[MiniBatch]:
342384 n_workers = n_workers ,
343385 )
344386 query_ids = self .query_ids .partitioned (partition )
345- if self .shuffle :
346- chunks = query_ids .shuffle_chunks (
387+ use_gpu_shuffle = False
388+ gpu_shuffle_mode = "none"
389+ if self .shuffle and getattr (self .device , "type" , None ) == "cuda" :
390+ if self .shuffle_mode == ShuffleMode .GPU_IOBATCH :
391+ use_gpu_shuffle = True
392+ gpu_shuffle_mode = "iobatch"
393+ elif self .shuffle_mode == ShuffleMode .GPU_MINIBATCH :
394+ use_gpu_shuffle = True
395+ gpu_shuffle_mode = "minibatch"
396+
397+ if self .shuffle and self .shuffle_mode not in (ShuffleMode .GPU_MINIBATCH ,):
398+ chunks = query_ids .shuffle_chunks ( # provide shuffle chunk size of random chunks (upstream randomization)
347399 shuffle_chunk_size = self .shuffle_chunk_size ,
348400 seed = self .seed ,
349401 )
350402 else :
351- # In no-shuffle mode, all the `obs_joinids` can be treated as one "shuffle chunk",
352- # which IO-batches will stride over.
353- chunks = [ query_ids . obs_joinids ]
403+ chunks = [
404+ query_ids . obs_joinids
405+ ] # For no or just mini batch shuffling, provide sequential order of chunks
354406
355407 with self .x_locator .open () as (X , obs ):
356408 io_batch_iter = IOBatchIterable (
@@ -361,7 +413,8 @@ def __iter__(self) -> Iterator[MiniBatch]:
361413 X = X ,
362414 obs_column_names = self .obs_column_names ,
363415 seed = self .seed ,
364- shuffle = self .shuffle ,
416+ # disable internal shuffling if we are shuffling with GPU
417+ shuffle = (self .shuffle and not use_gpu_shuffle ),
365418 use_eager_fetch = self .use_eager_fetch ,
366419 )
367420
@@ -370,6 +423,12 @@ def __iter__(self) -> Iterator[MiniBatch]:
370423 batch_size = self .batch_size ,
371424 use_eager_fetch = self .use_eager_fetch ,
372425 return_sparse_X = self .return_sparse_X ,
426+ # gpu shuffle params
427+ gpu_shuffle = use_gpu_shuffle ,
428+ gpu_shuffle_mode = gpu_shuffle_mode , # "iobatch" | "minibatch"
429+ device = self .device ,
430+ seed = self .seed ,
431+ epoch = self .epoch ,
373432 )
374433
375434 self .epoch += 1
0 commit comments