@@ -55,27 +55,36 @@ def __iter__(self) -> Iterator[IOBatch]:
5555 """Emit |IOBatch|'s."""
5656 # Because obs/var IDs have been partitioned/split/shuffled upstream of this class, this RNG does not need to be
5757 # identical across sub-processes, but seeding is supported anyway, for testing/reproducibility.
58- shuffle_rng = np .random .default_rng (self .seed )
5958 X = self .X
6059 context = X .context
60+
61+ # only build rng if we shuffle
62+ shuffle_rng = np .random .default_rng (self .seed ) if self .shuffle else None
63+
6164 obs_column_names = (
6265 list (self .obs_column_names )
6366 if "soma_joinid" in self .obs_column_names
6467 else ["soma_joinid" , * self .obs_column_names ]
6568 )
6669 # NOTE: `.astype("int64")` works around the `np.int64` singleton failing reference-equality after cross-process
6770 # SerDes.
68- var_joinids = self .var_joinids .astype ("int64" )
71+ var_joinids = np .asarray (
72+ self .var_joinids , dtype = np .int64
73+ ) # as array only typecasts if needed
6974 var_indexer = IntIndexer (var_joinids , context = context )
7075
7176 for obs_coords in self .io_batch_ids :
7277 st_time = time .perf_counter ()
73- obs_shuffled_coords = (
78+
79+ if shuffle_rng is None :
80+ obs_order = np .fromiter (
81+ obs_coords , dtype = np .int64 , count = len (obs_coords )
82+ )
83+ else :
7484 np .array (obs_coords )
75- if not self .shuffle
76- else shuffle_rng .permuted (obs_coords )
77- )
78- obs_indexer = IntIndexer (obs_shuffled_coords , context = context )
85+ obs_order = shuffle_rng .permuted (obs_coords )
86+
87+ obs_indexer = IntIndexer (obs_order , context = context )
7988 logger .debug (
8089 f"Retrieving next SOMA IO batch of length { len (obs_coords )} ..."
8190 )
@@ -123,14 +132,14 @@ def make_io_buffer(
123132 .concat ()
124133 .to_pandas ()
125134 .set_index ("soma_joinid" )
126- .reindex (obs_shuffled_coords , copy = False )
135+ .reindex (obs_order , copy = False )
127136 .reset_index () # demote "soma_joinid" to a column
128137 [self .obs_column_names ]
129138 ) # fmt: on
130139
131140 X_io_batch = CSR_IO_Buffer .merge (tuple (_io_buf_iter ))
132141
133- del obs_indexer , obs_coords , obs_shuffled_coords , _io_buf_iter
142+ del obs_indexer , obs_coords , obs_order , _io_buf_iter
134143 gc .collect ()
135144
136145 tm = time .perf_counter () - st_time
0 commit comments