Skip to content

Commit 4f70c96

Browse files
Merge pull request #62 from scalableminds/fix_downsampling
Fix downsampling for non uint8 data
2 parents f7374ef + 3932893 commit 4f70c96

File tree

4 files changed

+95
-26
lines changed

4 files changed

+95
-26
lines changed

tests/test_downsampling.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
)
99
import wkw
1010
from wkcuber.utils import WkwDatasetInfo, open_wkw
11+
from wkcuber.downsampling import _mode, non_linear_filter_3d
1112
import shutil
1213

1314
WKW_CUBE_SIZE = 1024
@@ -26,14 +27,50 @@ def test_downsample_cube():
2627
buffer = np.zeros((CUBE_EDGE_LEN,) * 3, dtype=np.uint8)
2728
buffer[:, :, :] = np.arange(0, CUBE_EDGE_LEN)
2829

29-
output = downsample_cube(buffer, (2, 2, 2), InterpolationModes.MEDIAN)
30+
output = downsample_cube(buffer, (2, 2, 2), InterpolationModes.MODE)
3031

3132
assert output.shape == (CUBE_EDGE_LEN // 2,) * 3
3233
assert buffer[0, 0, 0] == 0
3334
assert buffer[0, 0, 1] == 1
3435
assert np.all(output[:, :, :] == np.arange(0, CUBE_EDGE_LEN, 2))
3536

3637

38+
def test_downsample_mode():
39+
40+
a = np.array([[1, 3, 4, 2, 2, 7], [5, 2, 2, 1, 4, 1], [3, 3, 2, 2, 1, 1]])
41+
42+
result = _mode(a)
43+
expected_result = np.array([1, 3, 2, 2, 1, 1])
44+
45+
assert np.all(result == expected_result)
46+
47+
48+
def test_downsample_median():
49+
50+
a = np.array([[1, 3, 4, 2, 2, 7], [5, 2, 2, 1, 4, 1], [3, 3, 2, 2, 1, 1]])
51+
52+
result = np.median(a, axis=0)
53+
expected_result = np.array([3, 3, 2, 2, 2, 1])
54+
55+
assert np.all(result == expected_result)
56+
57+
58+
def test_non_linear_filter_reshape():
59+
a = np.array([[[1, 3], [1, 4]], [[4, 2], [3, 1]]], dtype=np.uint8)
60+
61+
a_filtered = non_linear_filter_3d(a, [2, 2, 2], _mode)
62+
assert a_filtered.dtype == np.uint8
63+
expected_result = [1]
64+
assert np.all(expected_result == a_filtered)
65+
66+
a = np.array([[[1, 3], [1, 4]], [[4, 3], [3, 1]]], np.uint32)
67+
68+
a_filtered = non_linear_filter_3d(a, [2, 2, 1], _mode)
69+
assert a_filtered.dtype == np.uint32
70+
expected_result = [1, 3]
71+
assert np.all(expected_result == a_filtered)
72+
73+
3774
def test_cube_addresses():
3875
addresses = cube_addresses(source_info)
3976
assert len(addresses) == 5 * 5 * 1

wkcuber/__main__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def create_parser():
9494
args.layer_name,
9595
Mag(1),
9696
Mag(args.max_mag),
97-
args.dtype,
9897
"default",
9998
DEFAULT_EDGE_LEN,
10099
args.jobs,

wkcuber/downsampling.py

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from argparse import ArgumentParser
66
from math import floor, log2
77
from os import path, listdir
8-
from scipy.stats import mode
98
from scipy.ndimage.interpolation import zoom
109
from itertools import product
1110
from functools import lru_cache
@@ -61,20 +60,13 @@ def create_parser():
6160
default="default",
6261
)
6362

64-
parser.add_argument(
65-
"--dtype",
66-
"-d",
67-
help="Target datatype (e.g. uint8, uint16, uint32)",
68-
default="uint8",
69-
)
70-
7163
parser.add_argument(
7264
"--from_mag",
7365
"--from",
7466
"-f",
7567
help="Resolution to base downsampling on",
76-
type=int,
77-
default=1,
68+
type=str,
69+
default="1",
7870
)
7971

8072
# Either provide the maximum resolution to be downsampled OR a specific, anisotropic magnification.
@@ -193,6 +185,7 @@ def downsample_cube_job(
193185

194186
with open_wkw(source_wkw_info) as source_wkw:
195187
num_channels = source_wkw.header.num_channels
188+
source_dtype = source_wkw.header.voxel_type
196189
with open_wkw(
197190
target_wkw_info,
198191
pool_get_lock(),
@@ -203,7 +196,7 @@ def downsample_cube_job(
203196
source_wkw.header.file_len * source_wkw.header.block_len
204197
)
205198
shape = (num_channels,) + (wkw_cubelength,) * 3
206-
file_buffer = np.zeros(shape, target_wkw_info.dtype)
199+
file_buffer = np.zeros(shape, source_dtype)
207200
tile_length = cube_edge_len
208201
tile_count_per_dim = wkw_cubelength // tile_length
209202
assert (
@@ -331,7 +324,50 @@ def _median(x):
331324

332325

333326
def _mode(x):
334-
return mode(x, axis=0, nan_policy="omit")[0][0]
327+
"""
328+
Fast mode implementation from: https://stackoverflow.com/a/35674754
329+
"""
330+
# Check inputs
331+
ndim = x.ndim
332+
axis = 0
333+
# Sort array
334+
sort = np.sort(x, axis=axis)
335+
# Create array to transpose along the axis and get padding shape
336+
transpose = np.roll(np.arange(ndim)[::-1], axis)
337+
shape = list(sort.shape)
338+
shape[axis] = 1
339+
# Create a boolean array along strides of unique values
340+
strides = (
341+
np.concatenate(
342+
[
343+
np.zeros(shape=shape, dtype="bool"),
344+
np.diff(sort, axis=axis) == 0,
345+
np.zeros(shape=shape, dtype="bool"),
346+
],
347+
axis=axis,
348+
)
349+
.transpose(transpose)
350+
.ravel()
351+
)
352+
# Count the stride lengths
353+
counts = np.cumsum(strides)
354+
counts[~strides] = np.concatenate([[0], np.diff(counts[~strides])])
355+
counts[strides] = 0
356+
# Get shape of padded counts and slice to return to the original shape
357+
shape = np.array(sort.shape)
358+
shape[axis] += 1
359+
shape = shape[transpose]
360+
slices = [slice(None)] * ndim
361+
slices[axis] = slice(1, None)
362+
# Reshape and compute final counts
363+
counts = counts.reshape(shape).transpose(transpose)[tuple(slices)] + 1
364+
365+
# Find maximum counts and return modals/counts
366+
slices = [slice(None, i) for i in sort.shape]
367+
del slices[axis]
368+
index = np.ogrid[slices]
369+
index.insert(axis, np.argmax(counts, axis=axis))
370+
return sort[tuple(index)]
335371

336372

337373
def downsample_cube(cube_buffer, factors, interpolation_mode):
@@ -358,7 +394,6 @@ def downsample_mag(
358394
layer_name,
359395
source_mag: Mag,
360396
target_mag: Mag,
361-
dtype="uint8",
362397
interpolation_mode="default",
363398
cube_edge_len=DEFAULT_EDGE_LEN,
364399
jobs=1,
@@ -373,12 +408,11 @@ def downsample_mag(
373408
else:
374409
interpolation_mode = InterpolationModes[interpolation_mode.upper()]
375410

376-
source_wkw_info = WkwDatasetInfo(
377-
path, layer_name, dtype, source_mag.to_layer_name()
378-
)
379-
target_wkw_info = WkwDatasetInfo(
380-
path, layer_name, dtype, target_mag.to_layer_name()
381-
)
411+
source_wkw_info = WkwDatasetInfo(path, layer_name, None, source_mag.to_layer_name())
412+
with open_wkw(source_wkw_info) as source:
413+
target_wkw_info = WkwDatasetInfo(
414+
path, layer_name, source.header.voxel_type, target_mag.to_layer_name()
415+
)
382416
downsample(
383417
source_wkw_info,
384418
target_wkw_info,
@@ -396,7 +430,6 @@ def downsample_mags(
396430
layer_name,
397431
from_mag: Mag,
398432
max_mag: Mag,
399-
dtype,
400433
interpolation_mode,
401434
cube_edge_len,
402435
jobs,
@@ -410,7 +443,6 @@ def downsample_mags(
410443
layer_name,
411444
source_mag,
412445
target_mag,
413-
dtype,
414446
interpolation_mode,
415447
cube_edge_len,
416448
jobs,
@@ -435,7 +467,6 @@ def downsample_mags(
435467
args.layer_name,
436468
from_mag,
437469
anisotropic_target_mag,
438-
args.dtype,
439470
args.interpolation_mode,
440471
args.buffer_cube_size,
441472
args.jobs,
@@ -447,7 +478,6 @@ def downsample_mags(
447478
args.layer_name,
448479
from_mag,
449480
max_mag,
450-
args.dtype,
451481
args.interpolation_mode,
452482
args.buffer_cube_size,
453483
args.jobs,

wkcuber/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121

2222

2323
def _open_wkw(info, **kwargs):
24-
header = wkw.Header(np.dtype(info.dtype), **kwargs)
24+
if info.dtype is not None:
25+
header = wkw.Header(np.dtype(info.dtype), **kwargs)
26+
else:
27+
header = None
2528
ds = wkw.Dataset.open(
2629
path.join(info.dataset_path, info.layer_name, str(info.mag)), header
2730
)

0 commit comments

Comments
 (0)