55from argparse import ArgumentParser
66from math import floor , log2
77from os import path , listdir
8- from scipy .stats import mode
98from scipy .ndimage .interpolation import zoom
109from itertools import product
1110from functools import lru_cache
@@ -61,20 +60,13 @@ def create_parser():
6160 default = "default" ,
6261 )
6362
64- parser .add_argument (
65- "--dtype" ,
66- "-d" ,
67- help = "Target datatype (e.g. uint8, uint16, uint32)" ,
68- default = "uint8" ,
69- )
70-
7163 parser .add_argument (
7264 "--from_mag" ,
7365 "--from" ,
7466 "-f" ,
7567 help = "Resolution to base downsampling on" ,
76- type = int ,
77- default = 1 ,
68+ type = str ,
69+ default = "1" ,
7870 )
7971
8072 # Either provide the maximum resolution to be downsampled OR a specific, anisotropic magnification.
@@ -193,6 +185,7 @@ def downsample_cube_job(
193185
194186 with open_wkw (source_wkw_info ) as source_wkw :
195187 num_channels = source_wkw .header .num_channels
188+ source_dtype = source_wkw .header .voxel_type
196189 with open_wkw (
197190 target_wkw_info ,
198191 pool_get_lock (),
@@ -203,7 +196,7 @@ def downsample_cube_job(
203196 source_wkw .header .file_len * source_wkw .header .block_len
204197 )
205198 shape = (num_channels ,) + (wkw_cubelength ,) * 3
206- file_buffer = np .zeros (shape , target_wkw_info . dtype )
199+ file_buffer = np .zeros (shape , source_dtype )
207200 tile_length = cube_edge_len
208201 tile_count_per_dim = wkw_cubelength // tile_length
209202 assert (
@@ -331,7 +324,50 @@ def _median(x):
331324
332325
333326def _mode (x ):
334- return mode (x , axis = 0 , nan_policy = "omit" )[0 ][0 ]
327+ """
328+ Fast mode implementation from: https://stackoverflow.com/a/35674754
329+ """
330+ # Check inputs
331+ ndim = x .ndim
332+ axis = 0
333+ # Sort array
334+ sort = np .sort (x , axis = axis )
335+ # Create array to transpose along the axis and get padding shape
336+ transpose = np .roll (np .arange (ndim )[::- 1 ], axis )
337+ shape = list (sort .shape )
338+ shape [axis ] = 1
339+ # Create a boolean array along strides of unique values
340+ strides = (
341+ np .concatenate (
342+ [
343+ np .zeros (shape = shape , dtype = "bool" ),
344+ np .diff (sort , axis = axis ) == 0 ,
345+ np .zeros (shape = shape , dtype = "bool" ),
346+ ],
347+ axis = axis ,
348+ )
349+ .transpose (transpose )
350+ .ravel ()
351+ )
352+ # Count the stride lengths
353+ counts = np .cumsum (strides )
354+ counts [~ strides ] = np .concatenate ([[0 ], np .diff (counts [~ strides ])])
355+ counts [strides ] = 0
356+ # Get shape of padded counts and slice to return to the original shape
357+ shape = np .array (sort .shape )
358+ shape [axis ] += 1
359+ shape = shape [transpose ]
360+ slices = [slice (None )] * ndim
361+ slices [axis ] = slice (1 , None )
362+ # Reshape and compute final counts
363+ counts = counts .reshape (shape ).transpose (transpose )[tuple (slices )] + 1
364+
365+ # Find maximum counts and return modals/counts
366+ slices = [slice (None , i ) for i in sort .shape ]
367+ del slices [axis ]
368+ index = np .ogrid [slices ]
369+ index .insert (axis , np .argmax (counts , axis = axis ))
370+ return sort [tuple (index )]
335371
336372
337373def downsample_cube (cube_buffer , factors , interpolation_mode ):
@@ -358,7 +394,6 @@ def downsample_mag(
358394 layer_name ,
359395 source_mag : Mag ,
360396 target_mag : Mag ,
361- dtype = "uint8" ,
362397 interpolation_mode = "default" ,
363398 cube_edge_len = DEFAULT_EDGE_LEN ,
364399 jobs = 1 ,
@@ -373,12 +408,11 @@ def downsample_mag(
373408 else :
374409 interpolation_mode = InterpolationModes [interpolation_mode .upper ()]
375410
376- source_wkw_info = WkwDatasetInfo (
377- path , layer_name , dtype , source_mag .to_layer_name ()
378- )
379- target_wkw_info = WkwDatasetInfo (
380- path , layer_name , dtype , target_mag .to_layer_name ()
381- )
411+ source_wkw_info = WkwDatasetInfo (path , layer_name , None , source_mag .to_layer_name ())
412+ with open_wkw (source_wkw_info ) as source :
413+ target_wkw_info = WkwDatasetInfo (
414+ path , layer_name , source .header .voxel_type , target_mag .to_layer_name ()
415+ )
382416 downsample (
383417 source_wkw_info ,
384418 target_wkw_info ,
@@ -396,7 +430,6 @@ def downsample_mags(
396430 layer_name ,
397431 from_mag : Mag ,
398432 max_mag : Mag ,
399- dtype ,
400433 interpolation_mode ,
401434 cube_edge_len ,
402435 jobs ,
@@ -410,7 +443,6 @@ def downsample_mags(
410443 layer_name ,
411444 source_mag ,
412445 target_mag ,
413- dtype ,
414446 interpolation_mode ,
415447 cube_edge_len ,
416448 jobs ,
@@ -435,7 +467,6 @@ def downsample_mags(
435467 args .layer_name ,
436468 from_mag ,
437469 anisotropic_target_mag ,
438- args .dtype ,
439470 args .interpolation_mode ,
440471 args .buffer_cube_size ,
441472 args .jobs ,
@@ -447,7 +478,6 @@ def downsample_mags(
447478 args .layer_name ,
448479 from_mag ,
449480 max_mag ,
450- args .dtype ,
451481 args .interpolation_mode ,
452482 args .buffer_cube_size ,
453483 args .jobs ,
0 commit comments