88import warnings
99from collections import namedtuple
1010from collections .abc import Sequence
11+ from concurrent .futures import ThreadPoolExecutor
1112from functools import partial , reduce
1213from itertools import product
1314from numbers import Integral
@@ -261,6 +262,7 @@ def make_bitmask(rows, cols):
261262 assert isinstance (labels , np .ndarray )
262263 shape = tuple (sum (c ) for c in chunks )
263264 nchunks = math .prod (len (c ) for c in chunks )
265+ approx_chunk_size = math .prod (c [0 ] for c in chunks )
264266
265267 # Shortcut for 1D with size-1 chunks
266268 if shape == (nchunks ,):
@@ -271,21 +273,55 @@ def make_bitmask(rows, cols):
271273
272274 labels = np .broadcast_to (labels , shape [- labels .ndim :])
273275 cols = []
274- # Add one to handle the -1 sentinel value
275- label_is_present = np .zeros ((nlabels + 1 ,), dtype = bool )
276276 ilabels = np .arange (nlabels )
277- for region in slices_from_chunks (chunks ):
277+
278+ def chunk_unique (labels , slicer , nlabels , label_is_present = None ):
279+ if label_is_present is None :
280+ label_is_present = np .empty ((nlabels + 1 ,), dtype = bool )
281+ label_is_present [:] = False
282+ subset = labels [slicer ]
278283 # This is a quite fast way to find unique integers, when we know how many there are
279284 # inspired by a similar idea in numpy_groupies for first, last
280285 # instead of explicitly finding uniques, repeatedly write True to the same location
281- subset = labels [region ]
282- # The reshape is not strictly necessary but is about 100ms faster on a test problem.
283286 label_is_present [subset .reshape (- 1 )] = True
284287 # skip the -1 sentinel by slicing
285288 # Faster than np.argwhere by a lot
286289 uniques = ilabels [label_is_present [:- 1 ]]
287- cols .append (uniques )
288- label_is_present [:] = False
290+ return uniques
291+
292+ # TODO: refine this heuristic.
293+ # The general idea is that with the threadpool, we repeatedly allocate memory
294+ # for `label_is_present`. We trade that off against the parallelism across number of chunks.
295+ # For large enough number of chunks (relative to number of labels), it makes sense to
296+ # suffer the extra allocation in exchange for parallelism.
297+ THRESHOLD = 2
298+ if nlabels < THRESHOLD * approx_chunk_size :
299+ logger .debug (
300+ "Using threadpool since num_labels %s < %d * chunksize %s" ,
301+ nlabels ,
302+ THRESHOLD ,
303+ approx_chunk_size ,
304+ )
305+ with ThreadPoolExecutor () as executor :
306+ futures = [
307+ executor .submit (chunk_unique , labels , slicer , nlabels )
308+ for slicer in slices_from_chunks (chunks )
309+ ]
310+ cols = tuple (f .result () for f in futures )
311+
312+ else :
313+ logger .debug (
314+ "Using serial loop since num_labels %s > %d * chunksize %s" ,
315+ nlabels ,
316+ THRESHOLD ,
317+ approx_chunk_size ,
318+ )
319+ cols = []
320+ # Add one to handle the -1 sentinel value
321+ label_is_present = np .empty ((nlabels + 1 ,), dtype = bool )
322+ for region in slices_from_chunks (chunks ):
323+ uniques = chunk_unique (labels , region , nlabels , label_is_present )
324+ cols .append (uniques )
289325 rows_array = np .repeat (np .arange (nchunks ), tuple (len (col ) for col in cols ))
290326 cols_array = np .concatenate (cols )
291327
0 commit comments