Skip to content

Commit 51e4d30

Browse files
philippottonormanrz
authored andcommitted
Implement anisotropic downsampling (#50)
* implement anisotropic downsampling (fixes #49) * adapt tests to refactored downsampling * disable dm3 test due to data being unavailable * change format of anisotropic_target_mag arg
1 parent 54de6f9 commit 51e4d30

File tree

3 files changed

+151
-70
lines changed

3 files changed

+151
-70
lines changed

.circleci/config.yml

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -64,27 +64,27 @@ jobs:
6464
[ -d testoutput/tiff/color/1 ]
6565
[ $(find testoutput/tiff/color/1 -mindepth 3 -name "*.wkw" | wc -l) -eq 1 ]
6666
67-
- run:
68-
name: Test DM3 cubing
69-
command: |
70-
set -x
71-
mkdir -p testdata/dm3
72-
mkdir -p testoutput/dm3
73-
wget http://www.loci.wisc.edu/files/software/data/dnasample1.zip
74-
unzip -d testdata/dm3 dnasample1.zip dnasample1.dm3
75-
docker run \
76-
-v "${PWD}/testdata:/testdata" \
77-
-v "${PWD}/testoutput:/testoutput" \
78-
--rm \
79-
scalableminds/webknossos-cuber:${CIRCLE_BUILD_NUM} \
80-
wkcuber.cubing \
81-
--verbose \
82-
--jobs 1 \
83-
--layer_name color \
84-
/testdata/dm3 /testoutput/dm3
85-
[ -d testoutput/dm3/color ]
86-
[ -d testoutput/dm3/color/1 ]
87-
[ $(find testoutput/dm3/color/1 -mindepth 3 -name "*.wkw" | wc -l) -eq 16 ]
67+
# - run:
68+
# name: Test DM3 cubing
69+
# command: |
70+
# set -x
71+
# mkdir -p testdata/dm3
72+
# mkdir -p testoutput/dm3
73+
# wget http://www.loci.wisc.edu/files/software/data/dnasample1.zip
74+
# unzip -d testdata/dm3 dnasample1.zip dnasample1.dm3
75+
# docker run \
76+
# -v "${PWD}/testdata:/testdata" \
77+
# -v "${PWD}/testoutput:/testoutput" \
78+
# --rm \
79+
# scalableminds/webknossos-cuber:${CIRCLE_BUILD_NUM} \
80+
# wkcuber.cubing \
81+
# --verbose \
82+
# --jobs 1 \
83+
# --layer_name color \
84+
# /testdata/dm3 /testoutput/dm3
85+
# [ -d testoutput/dm3/color ]
86+
# [ -d testoutput/dm3/color/1 ]
87+
# [ $(find testoutput/dm3/color/1 -mindepth 3 -name "*.wkw" | wc -l) -eq 16 ]
8888

8989
- run:
9090
name: Test tile cubing

tests/test_downsampling.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_downsample_cube():
2626
buffer = np.zeros((CUBE_EDGE_LEN,) * 3, dtype=np.uint8)
2727
buffer[:, :, :] = np.arange(0, CUBE_EDGE_LEN)
2828

29-
output = downsample_cube(buffer, 2, InterpolationModes.MEDIAN)
29+
output = downsample_cube(buffer, (2, 2, 2), InterpolationModes.MEDIAN)
3030

3131
assert output.shape == (CUBE_EDGE_LEN // 2,) * 3
3232
assert buffer[0, 0, 0] == 0
@@ -59,7 +59,7 @@ def downsample_test_helper(use_compress):
5959
downsample_cube_job(
6060
source_info,
6161
target_info,
62-
2,
62+
(2, 2, 2),
6363
InterpolationModes.MAX,
6464
CUBE_EDGE_LEN,
6565
offset,
@@ -79,7 +79,7 @@ def downsample_test_helper(use_compress):
7979
assert np.any(target_buffer != 0)
8080

8181
assert np.all(
82-
target_buffer == downsample_cube(source_buffer, 2, InterpolationModes.MAX)
82+
target_buffer == downsample_cube(source_buffer, (2, 2, 2), InterpolationModes.MAX)
8383
)
8484

8585

@@ -114,7 +114,7 @@ def test_downsample_multi_channel():
114114
downsample_cube_job(
115115
source_info,
116116
target_info,
117-
2,
117+
(2, 2, 2),
118118
InterpolationModes.MAX,
119119
CUBE_EDGE_LEN,
120120
tuple(a * WKW_CUBE_SIZE for a in offset),
@@ -123,7 +123,7 @@ def test_downsample_multi_channel():
123123

124124
channels = []
125125
for channel_index in range(num_channels):
126-
channels.append(downsample_cube(source_data[channel_index], 2, InterpolationModes.MAX))
126+
channels.append(downsample_cube(source_data[channel_index], (2, 2, 2), InterpolationModes.MAX))
127127
joined_buffer = np.stack(channels)
128128

129129
target_buffer = read_wkw(

wkcuber/downsampling.py

Lines changed: 125 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,6 @@ def create_parser():
6767
default="uint8",
6868
)
6969

70-
parser.add_argument(
71-
"--max", "-m", help="Max resolution to be downsampled", type=int, default=512
72-
)
73-
7470
parser.add_argument(
7571
"--from_mag",
7672
"--from",
@@ -80,6 +76,17 @@ def create_parser():
8076
default=1,
8177
)
8278

79+
# Either provide the maximum resolution to be downsampled OR a specific, anisotropic magnification.
80+
group = parser.add_mutually_exclusive_group()
81+
group.add_argument(
82+
"--max", "-m", help="Max resolution to be downsampled", type=int, default=512
83+
)
84+
group.add_argument(
85+
"--anisotropic_target_mag",
86+
help="Specify an anisotropic target magnification which should be created (e.g., --anisotropic_target_mag 2-2-1)",
87+
type=str,
88+
)
89+
8390
parser.add_argument(
8491
"--buffer_cube_size",
8592
"-b",
@@ -106,25 +113,61 @@ def cube_addresses(source_wkw_info):
106113
return wkw_addresses
107114

108115

116+
def mag_as_vector(mag):
117+
if isinstance(mag, int):
118+
return (mag, mag, mag)
119+
else:
120+
return mag
121+
122+
123+
def is_mag_greater_than(mag1, mag2):
124+
for (m1, m2) in zip(mag1, mag2):
125+
if m1 > m2:
126+
return True
127+
return False
128+
129+
130+
def assert_valid_mag(mag):
131+
if isinstance(mag, int):
132+
assert log2(mag) % 1 == 0, "magnification needs to be power of 2."
133+
else:
134+
assert len(mag) == 3, "magnification must be int or a vector3 of ints"
135+
for mag_dim in mag:
136+
assert log2(mag_dim) % 1 == 0, "magnification needs to be power of 2."
137+
138+
139+
def mag_as_layer_name(mag):
140+
if isinstance(mag, int):
141+
return str(mag)
142+
else:
143+
x, y, z = mag
144+
return "{}-{}-{}".format(x, y, z)
145+
146+
109147
def downsample(
110148
source_wkw_info,
111149
target_wkw_info,
112-
source_mag,
113-
target_mag,
150+
_source_mag,
151+
_target_mag,
114152
interpolation_mode,
115153
cube_edge_len,
116154
jobs,
117155
compress,
118156
):
119-
assert source_mag < target_mag
120-
logging.info("Downsampling mag {} from mag {}".format(target_mag, source_mag))
157+
source_mag = mag_as_vector(_source_mag)
158+
target_mag = mag_as_vector(_target_mag)
159+
assert not is_mag_greater_than(source_mag, target_mag)
160+
logging.info("Downsampling mag {} from mag {}".format(_target_mag, _source_mag))
121161

122-
mag_factor = int(target_mag / source_mag)
162+
mag_factors = [int(t / s) for (t, s) in zip(target_mag, source_mag)]
123163
# Detect the cubes that we want to downsample
124164
source_cube_addresses = cube_addresses(source_wkw_info)
125165

126166
target_cube_addresses = list(
127-
set(tuple(x // mag_factor for x in xyz) for xyz in source_cube_addresses)
167+
set(
168+
tuple(dim // mag_factor for (dim, mag_factor) in zip(xyz, mag_factors))
169+
for xyz in source_cube_addresses
170+
)
128171
)
129172
target_cube_addresses.sort()
130173
logging.debug(
@@ -150,7 +193,7 @@ def downsample(
150193
downsample_cube_job,
151194
source_wkw_info,
152195
target_wkw_info,
153-
mag_factor,
196+
mag_factors,
154197
interpolation_mode,
155198
cube_edge_len,
156199
target_cube_xyz,
@@ -163,7 +206,7 @@ def downsample(
163206
def downsample_cube_job(
164207
source_wkw_info,
165208
target_wkw_info,
166-
mag_factor,
209+
mag_factors,
167210
interpolation_mode,
168211
cube_edge_len,
169212
target_cube_xyz,
@@ -204,12 +247,12 @@ def downsample_cube_job(
204247
target_offset = np.array(
205248
tile
206249
) * tile_length + wkw_cubelength * np.array(target_cube_xyz)
207-
source_offset = mag_factor * target_offset
250+
source_offset = mag_factors * target_offset
208251

209252
# Read source buffer
210253
cube_buffer_channels = source_wkw.read(
211254
source_offset,
212-
(wkw_cubelength * mag_factor // tile_count_per_dim,) * 3,
255+
(wkw_cubelength * np.array(mag_factors) // tile_count_per_dim),
213256
)
214257

215258
for channel_index in range(num_channels):
@@ -225,7 +268,7 @@ def downsample_cube_job(
225268
# Downsample the buffer
226269

227270
data_cube = downsample_cube(
228-
cube_buffer, mag_factor, interpolation_mode
271+
cube_buffer, mag_factors, interpolation_mode
229272
)
230273

231274
buffer_offset = target_offset - file_offset
@@ -247,29 +290,46 @@ def downsample_cube_job(
247290
raise exc
248291

249292

250-
def non_linear_filter_3d(data, factor, func):
293+
def non_linear_filter_3d(data, factors, func):
251294
ds = data.shape
252-
assert not any((d % factor > 0 for d in ds))
253-
data = data.reshape((ds[0], factor, ds[1] // factor, ds[2]), order="F")
295+
assert not any((d % factor > 0 for (d, factor) in zip(ds, factors)))
296+
data = data.reshape((ds[0], factors[1], ds[1] // factors[1], ds[2]), order="F")
254297
data = data.swapaxes(0, 1)
255298
data = data.reshape(
256-
(factor * factor, ds[0] * ds[1] // (factor * factor), factor, ds[2] // factor),
299+
(
300+
factors[0] * factors[1],
301+
ds[0] * ds[1] // (factors[0] * factors[1]),
302+
factors[2],
303+
ds[2] // factors[2],
304+
),
257305
order="F",
258306
)
259307
data = data.swapaxes(2, 1)
260308
data = data.reshape(
261309
(
262-
factor * factor * factor,
263-
(ds[0] * ds[1] * ds[2]) // (factor * factor * factor),
310+
factors[0] * factors[1] * factors[2],
311+
(ds[0] * ds[1] * ds[2]) // (factors[0] * factors[1] * factors[2]),
264312
),
265313
order="F",
266314
)
267315
data = func(data)
268-
data = data.reshape((ds[0] // factor, ds[1] // factor, ds[2] // factor), order="F")
316+
data = data.reshape(
317+
(ds[0] // factors[0], ds[1] // factors[1], ds[2] // factors[2]), order="F"
318+
)
269319
return data
270320

271321

272-
def linear_filter_3d(data, factor, order):
322+
def linear_filter_3d(data, factors, order):
323+
factors = np.array(factors)
324+
325+
if not np.all(factors == factors[0]):
326+
logging.debug(
327+
"the selected filtering strategy does not support anisotropic downsampling. Selecting {} as uniform downsampling factor".format(
328+
factors[0]
329+
)
330+
)
331+
factor = factors[0]
332+
273333
ds = data.shape
274334
assert not any((d % factor > 0 for d in ds))
275335
return zoom(
@@ -303,21 +363,21 @@ def _mode(x):
303363
return mode(x, axis=0, nan_policy="omit")[0][0]
304364

305365

306-
def downsample_cube(cube_buffer, factor, interpolation_mode):
366+
def downsample_cube(cube_buffer, factors, interpolation_mode):
307367
if interpolation_mode == InterpolationModes.MODE:
308-
return non_linear_filter_3d(cube_buffer, factor, _mode)
368+
return non_linear_filter_3d(cube_buffer, factors, _mode)
309369
elif interpolation_mode == InterpolationModes.MEDIAN:
310-
return non_linear_filter_3d(cube_buffer, factor, _median)
370+
return non_linear_filter_3d(cube_buffer, factors, _median)
311371
elif interpolation_mode == InterpolationModes.NEAREST:
312-
return linear_filter_3d(cube_buffer, factor, 0)
372+
return linear_filter_3d(cube_buffer, factors, 0)
313373
elif interpolation_mode == InterpolationModes.BILINEAR:
314-
return linear_filter_3d(cube_buffer, factor, 1)
374+
return linear_filter_3d(cube_buffer, factors, 1)
315375
elif interpolation_mode == InterpolationModes.BICUBIC:
316-
return linear_filter_3d(cube_buffer, factor, 2)
376+
return linear_filter_3d(cube_buffer, factors, 2)
317377
elif interpolation_mode == InterpolationModes.MAX:
318-
return non_linear_filter_3d(cube_buffer, factor, _max)
378+
return non_linear_filter_3d(cube_buffer, factors, _max)
319379
elif interpolation_mode == InterpolationModes.MIN:
320-
return non_linear_filter_3d(cube_buffer, factor, _min)
380+
return non_linear_filter_3d(cube_buffer, factors, _min)
321381
else:
322382
raise Exception("Invalid interpolation mode: {}".format(interpolation_mode))
323383

@@ -342,8 +402,12 @@ def downsample_mag(
342402
else:
343403
interpolation_mode = InterpolationModes[interpolation_mode.upper()]
344404

345-
source_wkw_info = WkwDatasetInfo(path, layer_name, dtype, source_mag)
346-
target_wkw_info = WkwDatasetInfo(path, layer_name, dtype, target_mag)
405+
source_wkw_info = WkwDatasetInfo(
406+
path, layer_name, dtype, mag_as_layer_name(source_mag)
407+
)
408+
target_wkw_info = WkwDatasetInfo(
409+
path, layer_name, dtype, mag_as_layer_name(target_mag)
410+
)
347411
downsample(
348412
source_wkw_info,
349413
target_wkw_info,
@@ -391,14 +455,31 @@ def downsample_mags(
391455
if args.verbose:
392456
logging.basicConfig(level=logging.DEBUG)
393457

394-
downsample_mags(
395-
args.path,
396-
args.layer_name,
397-
args.from_mag,
398-
args.max,
399-
args.dtype,
400-
args.interpolation_mode,
401-
args.buffer_cube_size,
402-
args.jobs,
403-
args.compress,
404-
)
458+
if args.anisotropic_target_mag:
459+
anisotropic_target_mag = tuple(map(int, args.anisotropic_target_mag.split("-")))
460+
assert_valid_mag(args.from_mag)
461+
assert_valid_mag(anisotropic_target_mag)
462+
463+
downsample_mag(
464+
args.path,
465+
args.layer_name,
466+
args.from_mag,
467+
anisotropic_target_mag,
468+
args.dtype,
469+
args.interpolation_mode,
470+
args.buffer_cube_size,
471+
args.jobs,
472+
args.compress,
473+
)
474+
else:
475+
downsample_mags(
476+
args.path,
477+
args.layer_name,
478+
args.from_mag,
479+
args.max,
480+
args.dtype,
481+
args.interpolation_mode,
482+
args.buffer_cube_size,
483+
args.jobs,
484+
args.compress,
485+
)

0 commit comments

Comments
 (0)