Skip to content

Commit 8ff15e5

Browse files
authored
Gpu Shuffling (#40)
* CI Run for GPU Shuffling * Linter * Func type * Test Infrastructure * Added Range * Fixed imports * Obtain Batches * DataLoader Fix * Cuda check before tests * Test Suite Ready to Gopre-commit run --all-files!
1 parent 3193230 commit 8ff15e5

File tree

5 files changed

+379
-11
lines changed

5 files changed

+379
-11
lines changed

src/tiledbsoma_ml/_mini_batch_iterable.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
from __future__ import annotations
55

66
import logging
7+
import os
78
from typing import Iterable, Iterator
89

910
import attrs
1011
import numpy as np
1112
import pandas as pd
13+
import torch
1214
from scipy import sparse
1315

1416
from tiledbsoma_ml._common import MiniBatch
@@ -27,6 +29,29 @@ class MiniBatchIterable(Iterable[MiniBatch]):
2729
use_eager_fetch: bool = True
2830
return_sparse_X: bool = False
2931

32+
gpu_shuffle: bool = False
33+
gpu_shuffle_mode: str = "iobatch"
34+
device: torch.device | None = None
35+
seed: int | None = None
36+
epoch: int = 0
37+
38+
def _gpu_perm(self, n: int) -> torch.Tensor:
39+
"""Deterministic permutation of range(n) seeded by (seed, epoch, pid)."""
40+
base = int(self.seed or 0)
41+
pid = os.getpid()
42+
mixed = (base * 1315423911 + self.epoch * 2654435761 + pid) & 0xFFFFFFFF
43+
44+
gen_device = (
45+
self.device
46+
if (
47+
self.device is not None and getattr(self.device, "type", None) == "cuda"
48+
)
49+
else "cpu"
50+
)
51+
g = torch.Generator(device=gen_device)
52+
g.manual_seed(mixed)
53+
return torch.randperm(n, generator=g, device=gen_device)
54+
3055
def _iter(self) -> Iterator[MiniBatch]:
3156
batch_size = self.batch_size
3257
result: MiniBatch | None = None
@@ -35,6 +60,39 @@ def _iter(self) -> Iterator[MiniBatch]:
3560
iob_idx = 0 # current offset into io batch
3661
iob_len = X_io_batch.shape[0]
3762

63+
# GPU within-IO-batch shuffle (dense only)
64+
if self.gpu_shuffle and self.gpu_shuffle_mode == "iobatch":
65+
if self.return_sparse_X:
66+
logger.warning(
67+
"GPU shuffle requested but return_sparse_X=True; leaving IO-batch order unchanged."
68+
)
69+
else:
70+
perm = self._gpu_perm(iob_len)
71+
perm_cpu = perm.to("cpu", non_blocking=False).numpy()
72+
73+
X_full = X_io_batch.slice_tonumpy(slice(0, iob_len))
74+
X_t = torch.from_numpy(X_full)
75+
if (
76+
self.device is not None
77+
and getattr(self.device, "type", None) == "cuda"
78+
):
79+
if not X_t.is_pinned():
80+
X_t = X_t.pin_memory() # faster H2D
81+
X_t = X_t.to(self.device, non_blocking=True)
82+
X_t = X_t.index_select(0, perm).contiguous()
83+
X_cpu = X_t.to("cpu", non_blocking=False).numpy()
84+
85+
obs_perm = obs_io_batch.iloc[perm_cpu].reset_index(drop=True)
86+
87+
# Emit mini-batches from the permuted IO-batch
88+
for start in range(0, iob_len, self.batch_size):
89+
stop = min(start + self.batch_size, iob_len)
90+
yield (
91+
X_cpu[start:stop],
92+
obs_perm.iloc[start:stop].reset_index(drop=True),
93+
)
94+
continue # done with this IO-batch
95+
3896
while iob_idx < iob_len:
3997
if result is None:
4098
# perform zero copy slice where possible
@@ -76,6 +134,28 @@ def _iter(self) -> Iterator[MiniBatch]:
76134
iob_idx += to_take
77135

78136
X, obs = result
137+
138+
if (
139+
self.gpu_shuffle
140+
and self.gpu_shuffle_mode == "minibatch"
141+
and not self.return_sparse_X
142+
):
143+
mb_n = X.shape[0]
144+
perm = self._gpu_perm(mb_n)
145+
perm_cpu = perm.to("cpu", non_blocking=False).numpy()
146+
147+
X_t = torch.from_numpy(X)
148+
if (
149+
self.device is not None
150+
and getattr(self.device, "type", None) == "cuda"
151+
):
152+
if not X_t.is_pinned():
153+
X_t = X_t.pin_memory()
154+
X_t = X_t.to(self.device, non_blocking=True)
155+
X_t = X_t.index_select(0, perm).contiguous()
156+
X = X_t.to("cpu", non_blocking=False).numpy()
157+
obs = obs.iloc[perm_cpu].reset_index(drop=True)
158+
79159
assert X.shape[0] == obs.shape[0]
80160
if X.shape[0] == batch_size:
81161
yield result

src/tiledbsoma_ml/dataset.py

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from __future__ import annotations
66

77
import logging
8-
from typing import Iterator, List, Optional, Sequence, Tuple
8+
from enum import Enum
9+
from typing import Any, Iterator, List, Optional, Sequence, Tuple
910

1011
import numpy as np
1112
import torch
@@ -32,6 +33,25 @@
3233
DEFAULT_IO_BATCH_SIZE = 2**16
3334

3435

36+
class ShuffleMode(str, Enum):
37+
"""Shuffling backend selection."""
38+
39+
CPU = "cpu"
40+
GPU_IOBATCH = "gpu_iobatch" # Emulate CPU, shuffling at io batch level
41+
GPU_MINIBATCH = "gpu_minibatch" # Only shuffle the mini batch at the gpu
42+
43+
44+
def _shuffle_mode_converter(v: Any) -> ShuffleMode:
45+
if isinstance(v, ShuffleMode):
46+
return v
47+
if isinstance(v, str):
48+
v = v.lower()
49+
if v == "gpu": # Simplicity alias
50+
return ShuffleMode.GPU_IOBATCH
51+
return ShuffleMode(v) # "cpu" | "gpu_iobatch" | "gpu_minibatch"
52+
return ShuffleMode(v)
53+
54+
3555
@define
3656
class ExperimentDataset(IterableDataset[MiniBatch]): # type: ignore[misc]
3757
r"""An |IterableDataset| implementation that reads from an |ExperimentAxisQuery|.
@@ -117,7 +137,7 @@ class ExperimentDataset(IterableDataset[MiniBatch]): # type: ignore[misc]
117137
"""Names of ``obs`` columns to return."""
118138

119139
# Configuration fields with defaults
120-
batch_size: int = field(default=1, validator=and_(instance_of(int), gt(0)))
140+
batch_size: int = field(default=1024, validator=and_(instance_of(int), gt(0)))
121141
"""Number of rows of ``X`` and ``obs`` data to yield in each |MiniBatch|."""
122142
io_batch_size: int = field(
123143
default=DEFAULT_IO_BATCH_SIZE, validator=and_(instance_of(int), gt(0))
@@ -135,6 +155,17 @@ class ExperimentDataset(IterableDataset[MiniBatch]): # type: ignore[misc]
135155
use_eager_fetch: bool = field(default=True)
136156
"""Pre-fetch one "IO batch" and one "mini batch"."""
137157

158+
# GPU Shuffle Config
159+
shuffle_mode: ShuffleMode = field(
160+
default=ShuffleMode.CPU, converter=_shuffle_mode_converter
161+
)
162+
"""Whether to shuffle on cpu or gpu (and at what granularity).
163+
164+
Only read when shuffle=True
165+
"""
166+
device: Optional[torch.device] = field(default=None)
167+
"""Device to move X to; set to torch.device('cuda', N) to enable GPU shuffle."""
168+
138169
# Internal state
139170
epoch: int = field(default=0, init=False)
140171
rank: int = field(init=False)
@@ -154,6 +185,8 @@ def __init__(
154185
seed: Optional[int] = None,
155186
return_sparse_X: bool = False,
156187
use_eager_fetch: bool = True,
188+
shuffle_mode: ShuffleMode = ShuffleMode.CPU,
189+
device: Optional[torch.device] = None,
157190
):
158191
r"""Construct a new |ExperimentDataset|.
159192
@@ -223,6 +256,7 @@ def __init__(
223256
In addition, when using shuffling in a distributed configuration (e.g., ``DDP``), you must provide a seed,
224257
ensuring that the same shuffle is used across all replicas.
225258
"""
259+
226260
if query and layer_name:
227261
if x_locator or query_ids:
228262
raise ValueError(
@@ -255,21 +289,30 @@ def __init__(
255289
seed=seed,
256290
return_sparse_X=return_sparse_X,
257291
use_eager_fetch=use_eager_fetch,
292+
shuffle_mode=shuffle_mode,
293+
device=device,
258294
)
259295

260296
def __attrs_post_init__(self) -> None:
261297
"""Validate configuration and initialize distributed state."""
262298
obs_column_names = self.obs_column_names
263299
if not obs_column_names:
264300
raise ValueError("Must specify at least one value in `obs_column_names`")
265-
266301
if self.shuffle:
267302
# Verify `io_batch_size` is a multiple of `shuffle_chunk_size`
268303
if self.io_batch_size % self.shuffle_chunk_size:
269304
raise ValueError(
270305
f"{self.io_batch_size=} is not a multiple of {self.shuffle_chunk_size=}"
271306
)
272307

308+
# Sanity Check for GPU Shuffle
309+
if self.shuffle and self.shuffle_mode != ShuffleMode.CPU:
310+
if self.device is None or getattr(self.device, "type", None) != "cuda":
311+
logger.warning(
312+
"GPU shuffle requested but `device` is not CUDA; defaulting to CPU within-IO shuffle."
313+
)
314+
object.__setattr__(self, "shuffle_mode", ShuffleMode.CPU)
315+
273316
if self.seed is None:
274317
object.__setattr__(
275318
self, "seed", np.random.default_rng().integers(0, 2**32 - 1)
@@ -333,7 +376,6 @@ def __iter__(self) -> Iterator[MiniBatch]:
333376
experimental
334377
"""
335378
self._multiproc_check()
336-
337379
worker_id, n_workers = get_worker_id_and_num()
338380
partition = Partition(
339381
rank=self.rank,
@@ -342,15 +384,25 @@ def __iter__(self) -> Iterator[MiniBatch]:
342384
n_workers=n_workers,
343385
)
344386
query_ids = self.query_ids.partitioned(partition)
345-
if self.shuffle:
346-
chunks = query_ids.shuffle_chunks(
387+
use_gpu_shuffle = False
388+
gpu_shuffle_mode = "none"
389+
if self.shuffle and getattr(self.device, "type", None) == "cuda":
390+
if self.shuffle_mode == ShuffleMode.GPU_IOBATCH:
391+
use_gpu_shuffle = True
392+
gpu_shuffle_mode = "iobatch"
393+
elif self.shuffle_mode == ShuffleMode.GPU_MINIBATCH:
394+
use_gpu_shuffle = True
395+
gpu_shuffle_mode = "minibatch"
396+
397+
if self.shuffle and self.shuffle_mode not in (ShuffleMode.GPU_MINIBATCH,):
398+
chunks = query_ids.shuffle_chunks( # provide shuffle chunk size of random chunks (upstream randomization)
347399
shuffle_chunk_size=self.shuffle_chunk_size,
348400
seed=self.seed,
349401
)
350402
else:
351-
# In no-shuffle mode, all the `obs_joinids` can be treated as one "shuffle chunk",
352-
# which IO-batches will stride over.
353-
chunks = [query_ids.obs_joinids]
403+
chunks = [
404+
query_ids.obs_joinids
405+
] # For no or just mini batch shuffling, provide sequential order of chunks
354406

355407
with self.x_locator.open() as (X, obs):
356408
io_batch_iter = IOBatchIterable(
@@ -361,7 +413,8 @@ def __iter__(self) -> Iterator[MiniBatch]:
361413
X=X,
362414
obs_column_names=self.obs_column_names,
363415
seed=self.seed,
364-
shuffle=self.shuffle,
416+
# disable internal shuffling if we are shuffling with GPU
417+
shuffle=(self.shuffle and not use_gpu_shuffle),
365418
use_eager_fetch=self.use_eager_fetch,
366419
)
367420

@@ -370,6 +423,12 @@ def __iter__(self) -> Iterator[MiniBatch]:
370423
batch_size=self.batch_size,
371424
use_eager_fetch=self.use_eager_fetch,
372425
return_sparse_X=self.return_sparse_X,
426+
# gpu shuffle params
427+
gpu_shuffle=use_gpu_shuffle,
428+
gpu_shuffle_mode=gpu_shuffle_mode, # "iobatch" | "minibatch"
429+
device=self.device,
430+
seed=self.seed,
431+
epoch=self.epoch,
373432
)
374433

375434
self.epoch += 1

tests/_utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,59 @@ def add_sparse_array(
187187
a.write(tensor)
188188

189189

190+
def flatten_joinids(batches: List[MiniBatch]) -> List[int]:
191+
return [int(i) for _, obs in batches for i in obs["soma_joinid"].tolist()]
192+
193+
194+
def minibatch_is_contiguous(ids: List[int]) -> bool:
195+
if len(ids) <= 1:
196+
return True
197+
ids_sorted = sorted(ids)
198+
return ids_sorted[-1] - ids_sorted[0] + 1 == len(ids_sorted)
199+
200+
201+
def assert_gpu_minibatch_no_upstream_mixing(batches: List[MiniBatch]) -> None:
202+
"""Each minibatch should be a contiguous slice; slices increase strictly.
203+
204+
Test for gpu_minibatch shuffling.
205+
"""
206+
prev_max = -1
207+
for _, obs in batches:
208+
ids = [int(i) for i in obs["soma_joinid"].tolist()]
209+
assert minibatch_is_contiguous(ids), f"Non-contiguous minibatch: {ids}"
210+
ids_sorted = sorted(ids)
211+
assert (
212+
ids_sorted[0] > prev_max
213+
), f"Detected upstream mixing: start={ids_sorted[0]} <= prev_max={prev_max}"
214+
prev_max = ids_sorted[-1]
215+
216+
217+
def assert_gpu_iobatch_invariants(
218+
batches: List[MiniBatch],
219+
batch_size: int,
220+
min_noncontig_ratio: float = 0.2,
221+
num_workers: int = 1,
222+
) -> None:
223+
"""Property checks for IO-batch GPU shuffle (not exact order)."""
224+
# Check for unecessary non-full batches
225+
sizes = [len(obs) for _, obs in batches]
226+
assert all(1 <= s <= batch_size for s in sizes), f"Invalid sizes: {sizes}"
227+
# If there are enough rows overall, expect at least one full minibatch
228+
if sum(sizes) >= batch_size:
229+
assert any(s == batch_size for s in sizes), "No full minibatches produced"
230+
231+
# measure dispersion b/w mini batches. Should not consistently fail.
232+
non_contig = 0
233+
for _, obs in batches:
234+
ids = [int(i) for i in obs["soma_joinid"].tolist()]
235+
if not minibatch_is_contiguous(ids):
236+
non_contig += 1
237+
if len(batches) >= 4: # avoid tiny outliers
238+
assert non_contig >= max(
239+
1, int(len(batches) * min_noncontig_ratio)
240+
), "Low dispersion in IO-batch GPU shuffle; check upstream shuffle chunk selection."
241+
242+
190243
@contextmanager
191244
def mock_dist_is_initialized():
192245
with patch("torch.distributed.is_initialized") as mock_dist_is_initialized:

0 commit comments

Comments
 (0)