Skip to content

Commit 4f037db

Browse files
remove dtype argument and infer dtype from input data
1 parent bfbe3c7 commit 4f037db

File tree

3 files changed

+49
-29
lines changed

3 files changed

+49
-29
lines changed

tests/test_downsampling.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
)
99
import wkw
1010
from wkcuber.utils import WkwDatasetInfo, open_wkw
11-
from wkcuber.downsampling import _mode
11+
from wkcuber.downsampling import _mode, non_linear_filter_3d
1212
import shutil
1313

1414
WKW_CUBE_SIZE = 1024
@@ -34,18 +34,43 @@ def test_downsample_cube():
3434
assert buffer[0, 0, 1] == 1
3535
assert np.all(output[:, :, :] == np.arange(0, CUBE_EDGE_LEN, 2))
3636

37+
3738
def test_downsample_mode():
3839

39-
a = np.array([[1, 3, 4, 2, 2, 7],
40-
[5, 2, 2, 1, 4, 1],
41-
[3, 3, 2, 2, 1, 1]])
40+
a = np.array([[1, 3, 4, 2, 2, 7], [5, 2, 2, 1, 4, 1], [3, 3, 2, 2, 1, 1]])
4241

4342
result = _mode(a)
4443
expected_result = np.array([1, 3, 2, 2, 1, 1])
4544

4645
assert np.all(result == expected_result)
4746

4847

48+
def test_downsample_median():
49+
50+
a = np.array([[1, 3, 4, 2, 2, 7], [5, 2, 2, 1, 4, 1], [3, 3, 2, 2, 1, 1]])
51+
52+
result = np.median(a, axis=0)
53+
expected_result = np.array([3, 3, 2, 2, 2, 1])
54+
55+
assert np.all(result == expected_result)
56+
57+
58+
def test_non_linear_filter_reshape():
59+
a = np.array([[[1, 3], [1, 4]], [[4, 2], [3, 1]]], dtype=np.uint8)
60+
61+
a_filtered = non_linear_filter_3d(a, [2, 2, 2], _mode)
62+
assert a_filtered.dtype == np.uint8
63+
expected_result = [1]
64+
assert np.all(expected_result == a_filtered)
65+
66+
a = np.array([[[1, 3], [1, 4]], [[4, 3], [3, 1]]], np.uint32)
67+
68+
a_filtered = non_linear_filter_3d(a, [2, 2, 1], _mode)
69+
assert a_filtered.dtype == np.uint32
70+
expected_result = [1, 3]
71+
assert np.all(expected_result == a_filtered)
72+
73+
4974
def test_cube_addresses():
5075
addresses = cube_addresses(source_info)
5176
assert len(addresses) == 5 * 5 * 1

wkcuber/downsampling.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,13 @@ def create_parser():
6060
default="default",
6161
)
6262

63-
parser.add_argument(
64-
"--dtype",
65-
"-d",
66-
help="Target datatype (e.g. uint8, uint16, uint32)",
67-
default="uint8",
68-
)
69-
7063
parser.add_argument(
7164
"--from_mag",
7265
"--from",
7366
"-f",
7467
help="Resolution to base downsampling on",
7568
type=str,
76-
default='1',
69+
default="1",
7770
)
7871

7972
# Either provide the maximum resolution to be downsampled OR a specific, anisotropic magnification.
@@ -202,7 +195,7 @@ def downsample_cube_job(
202195
source_wkw.header.file_len * source_wkw.header.block_len
203196
)
204197
shape = (num_channels,) + (wkw_cubelength,) * 3
205-
file_buffer = np.zeros(shape, target_wkw_info.dtype)
198+
file_buffer = np.zeros(shape, target_wkw.header.voxel_type)
206199
tile_length = cube_edge_len
207200
tile_count_per_dim = wkw_cubelength // tile_length
208201
assert (
@@ -343,10 +336,18 @@ def _mode(x):
343336
shape = list(sort.shape)
344337
shape[axis] = 1
345338
# Create a boolean array along strides of unique values
346-
strides = np.concatenate([np.zeros(shape=shape, dtype='bool'),
347-
np.diff(sort, axis=axis) == 0,
348-
np.zeros(shape=shape, dtype='bool')],
349-
axis=axis).transpose(transpose).ravel()
339+
strides = (
340+
np.concatenate(
341+
[
342+
np.zeros(shape=shape, dtype="bool"),
343+
np.diff(sort, axis=axis) == 0,
344+
np.zeros(shape=shape, dtype="bool"),
345+
],
346+
axis=axis,
347+
)
348+
.transpose(transpose)
349+
.ravel()
350+
)
350351
# Count the stride lengths
351352
counts = np.cumsum(strides)
352353
counts[~strides] = np.concatenate([[0], np.diff(counts[~strides])])
@@ -392,7 +393,6 @@ def downsample_mag(
392393
layer_name,
393394
source_mag: Mag,
394395
target_mag: Mag,
395-
dtype="uint8",
396396
interpolation_mode="default",
397397
cube_edge_len=DEFAULT_EDGE_LEN,
398398
jobs=1,
@@ -407,12 +407,8 @@ def downsample_mag(
407407
else:
408408
interpolation_mode = InterpolationModes[interpolation_mode.upper()]
409409

410-
source_wkw_info = WkwDatasetInfo(
411-
path, layer_name, dtype, source_mag.to_layer_name()
412-
)
413-
target_wkw_info = WkwDatasetInfo(
414-
path, layer_name, dtype, target_mag.to_layer_name()
415-
)
410+
source_wkw_info = WkwDatasetInfo(path, layer_name, None, source_mag.to_layer_name())
411+
target_wkw_info = WkwDatasetInfo(path, layer_name, None, target_mag.to_layer_name())
416412
downsample(
417413
source_wkw_info,
418414
target_wkw_info,
@@ -430,7 +426,6 @@ def downsample_mags(
430426
layer_name,
431427
from_mag: Mag,
432428
max_mag: Mag,
433-
dtype,
434429
interpolation_mode,
435430
cube_edge_len,
436431
jobs,
@@ -444,7 +439,6 @@ def downsample_mags(
444439
layer_name,
445440
source_mag,
446441
target_mag,
447-
dtype,
448442
interpolation_mode,
449443
cube_edge_len,
450444
jobs,
@@ -469,7 +463,6 @@ def downsample_mags(
469463
args.layer_name,
470464
from_mag,
471465
anisotropic_target_mag,
472-
args.dtype,
473466
args.interpolation_mode,
474467
args.buffer_cube_size,
475468
args.jobs,
@@ -481,7 +474,6 @@ def downsample_mags(
481474
args.layer_name,
482475
from_mag,
483476
max_mag,
484-
args.dtype,
485477
args.interpolation_mode,
486478
args.buffer_cube_size,
487479
args.jobs,

wkcuber/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121

2222

2323
def _open_wkw(info, **kwargs):
24-
header = wkw.Header(np.dtype(info.dtype), **kwargs)
24+
if info.dtype is not None:
25+
header = wkw.Header(np.dtype(info.dtype), **kwargs)
26+
else:
27+
header = None
2528
ds = wkw.Dataset.open(
2629
path.join(info.dataset_path, info.layer_name, str(info.mag)), header
2730
)

0 commit comments

Comments
 (0)