Skip to content

Commit 324987b

Browse files
committed
Fast Non Shuffled Path
1 parent b56749d commit 324987b

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

src/tiledbsoma_ml/_io_batch_iterable.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,12 @@ def __iter__(self) -> Iterator[IOBatch]:
5555
"""Emit |IOBatch|'s."""
5656
# Because obs/var IDs have been partitioned/split/shuffled upstream of this class, this RNG does not need to be
5757
# identical across sub-processes, but seeding is supported anyway, for testing/reproducibility.
58-
shuffle_rng = np.random.default_rng(self.seed)
5958
X = self.X
6059
context = X.context
60+
61+
# only build rng if we shuffle
62+
shuffle_rng = np.random.default_rng(self.seed) if self.shuffle else None
63+
6164
obs_column_names = (
6265
list(self.obs_column_names)
6366
if "soma_joinid" in self.obs_column_names
@@ -70,12 +73,14 @@ def __iter__(self) -> Iterator[IOBatch]:
7073

7174
for obs_coords in self.io_batch_ids:
7275
st_time = time.perf_counter()
73-
obs_shuffled_coords = (
76+
77+
if shuffle_rng is None:
78+
obs_order = obs_shuffled_coords
79+
else:
7480
np.array(obs_coords)
75-
if not self.shuffle
76-
else shuffle_rng.permuted(obs_coords)
77-
)
78-
obs_indexer = IntIndexer(obs_shuffled_coords, context=context)
81+
obs_order = rng.permuted(obs_coords)
82+
83+
obs_indexer = IntIndexer(obs_order, context=context)
7984
logger.debug(
8085
f"Retrieving next SOMA IO batch of length {len(obs_coords)}..."
8186
)

0 commit comments

Comments
 (0)