Skip to content

Commit 2d60468

Browse files
authored
Remove redundant memory alloc for Non Shuffled Scenario (#38)
* Remove redundant memory alloc * Cleanup redundant typecasting * Set default batch size
1 parent 8b60aae commit 2d60468

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

src/tiledbsoma_ml/_io_batch_iterable.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,27 +55,36 @@ def __iter__(self) -> Iterator[IOBatch]:
5555
"""Emit |IOBatch|'s."""
5656
# Because obs/var IDs have been partitioned/split/shuffled upstream of this class, this RNG does not need to be
5757
# identical across sub-processes, but seeding is supported anyway, for testing/reproducibility.
58-
shuffle_rng = np.random.default_rng(self.seed)
5958
X = self.X
6059
context = X.context
60+
61+
# only build rng if we shuffle
62+
shuffle_rng = np.random.default_rng(self.seed) if self.shuffle else None
63+
6164
obs_column_names = (
6265
list(self.obs_column_names)
6366
if "soma_joinid" in self.obs_column_names
6467
else ["soma_joinid", *self.obs_column_names]
6568
)
6669
# NOTE: `.astype("int64")` works around the `np.int64` singleton failing reference-equality after cross-process
6770
# SerDes.
68-
var_joinids = self.var_joinids.astype("int64")
71+
var_joinids = np.asarray(
72+
self.var_joinids, dtype=np.int64
73+
) # as array only typecasts if needed
6974
var_indexer = IntIndexer(var_joinids, context=context)
7075

7176
for obs_coords in self.io_batch_ids:
7277
st_time = time.perf_counter()
73-
obs_shuffled_coords = (
78+
79+
if shuffle_rng is None:
80+
obs_order = np.fromiter(
81+
obs_coords, dtype=np.int64, count=len(obs_coords)
82+
)
83+
else:
7484
np.array(obs_coords)
75-
if not self.shuffle
76-
else shuffle_rng.permuted(obs_coords)
77-
)
78-
obs_indexer = IntIndexer(obs_shuffled_coords, context=context)
85+
obs_order = shuffle_rng.permuted(obs_coords)
86+
87+
obs_indexer = IntIndexer(obs_order, context=context)
7988
logger.debug(
8089
f"Retrieving next SOMA IO batch of length {len(obs_coords)}..."
8190
)
@@ -123,14 +132,14 @@ def make_io_buffer(
123132
.concat()
124133
.to_pandas()
125134
.set_index("soma_joinid")
126-
.reindex(obs_shuffled_coords, copy=False)
135+
.reindex(obs_order, copy=False)
127136
.reset_index() # demote "soma_joinid" to a column
128137
[self.obs_column_names]
129138
) # fmt: on
130139

131140
X_io_batch = CSR_IO_Buffer.merge(tuple(_io_buf_iter))
132141

133-
del obs_indexer, obs_coords, obs_shuffled_coords, _io_buf_iter
142+
del obs_indexer, obs_coords, obs_order, _io_buf_iter
134143
gc.collect()
135144

136145
tm = time.perf_counter() - st_time

0 commit comments

Comments
 (0)