Skip to content

Commit bfbe3c7

Browse files
use fast mode implementation
1 parent f7374ef commit bfbe3c7

File tree

2 files changed

+51
-5
lines changed

2 files changed

+51
-5
lines changed

tests/test_downsampling.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
)
99
import wkw
1010
from wkcuber.utils import WkwDatasetInfo, open_wkw
11+
from wkcuber.downsampling import _mode
1112
import shutil
1213

1314
WKW_CUBE_SIZE = 1024
@@ -26,13 +27,24 @@ def test_downsample_cube():
2627
buffer = np.zeros((CUBE_EDGE_LEN,) * 3, dtype=np.uint8)
2728
buffer[:, :, :] = np.arange(0, CUBE_EDGE_LEN)
2829

29-
output = downsample_cube(buffer, (2, 2, 2), InterpolationModes.MEDIAN)
30+
output = downsample_cube(buffer, (2, 2, 2), InterpolationModes.MODE)
3031

3132
assert output.shape == (CUBE_EDGE_LEN // 2,) * 3
3233
assert buffer[0, 0, 0] == 0
3334
assert buffer[0, 0, 1] == 1
3435
assert np.all(output[:, :, :] == np.arange(0, CUBE_EDGE_LEN, 2))
3536

37+
def test_downsample_mode():
38+
39+
a = np.array([[1, 3, 4, 2, 2, 7],
40+
[5, 2, 2, 1, 4, 1],
41+
[3, 3, 2, 2, 1, 1]])
42+
43+
result = _mode(a)
44+
expected_result = np.array([1, 3, 2, 2, 1, 1])
45+
46+
assert np.all(result == expected_result)
47+
3648

3749
def test_cube_addresses():
3850
addresses = cube_addresses(source_info)

wkcuber/downsampling.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from argparse import ArgumentParser
66
from math import floor, log2
77
from os import path, listdir
8-
from scipy.stats import mode
98
from scipy.ndimage.interpolation import zoom
109
from itertools import product
1110
from functools import lru_cache
@@ -73,8 +72,8 @@ def create_parser():
7372
"--from",
7473
"-f",
7574
help="Resolution to base downsampling on",
76-
type=int,
77-
default=1,
75+
type=str,
76+
default='1',
7877
)
7978

8079
# Either provide the maximum resolution to be downsampled OR a specific, anisotropic magnification.
@@ -331,7 +330,42 @@ def _median(x):
331330

332331

333332
def _mode(x):
334-
return mode(x, axis=0, nan_policy="omit")[0][0]
333+
"""
334+
Fast mode implementation from: https://stackoverflow.com/a/35674754
335+
"""
336+
# Check inputs
337+
ndim = x.ndim
338+
axis = 0
339+
# Sort array
340+
sort = np.sort(x, axis=axis)
341+
# Create array to transpose along the axis and get padding shape
342+
transpose = np.roll(np.arange(ndim)[::-1], axis)
343+
shape = list(sort.shape)
344+
shape[axis] = 1
345+
# Create a boolean array along strides of unique values
346+
strides = np.concatenate([np.zeros(shape=shape, dtype='bool'),
347+
np.diff(sort, axis=axis) == 0,
348+
np.zeros(shape=shape, dtype='bool')],
349+
axis=axis).transpose(transpose).ravel()
350+
# Count the stride lengths
351+
counts = np.cumsum(strides)
352+
counts[~strides] = np.concatenate([[0], np.diff(counts[~strides])])
353+
counts[strides] = 0
354+
# Get shape of padded counts and slice to return to the original shape
355+
shape = np.array(sort.shape)
356+
shape[axis] += 1
357+
shape = shape[transpose]
358+
slices = [slice(None)] * ndim
359+
slices[axis] = slice(1, None)
360+
# Reshape and compute final counts
361+
counts = counts.reshape(shape).transpose(transpose)[slices] + 1
362+
363+
# Find maximum counts and return modals/counts
364+
slices = [slice(None, i) for i in sort.shape]
365+
del slices[axis]
366+
index = np.ogrid[slices]
367+
index.insert(axis, np.argmax(counts, axis=axis))
368+
return sort[index]
335369

336370

337371
def downsample_cube(cube_buffer, factors, interpolation_mode):

0 commit comments

Comments
 (0)