Skip to content

Commit 46033a8

Browse files
Merge pull request #55 from scalableminds/mag
Mag
2 parents 17e2bfa + ae03e88 commit 46033a8

File tree

9 files changed

+173
-106
lines changed

9 files changed

+173
-106
lines changed

setup.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
11
from setuptools import setup, find_packages
22

33
setup(
4-
name='wkcuber',
4+
name="wkcuber",
55
packages=find_packages(exclude=("tests",)),
6-
version='0.2.0',
7-
install_requires=[
8-
'scipy',
9-
'numpy',
10-
'pillow',
11-
'pyyaml',
12-
'wkw'
13-
],
14-
description='A cubing tool for webKnossos',
15-
author='Norman Rzepka',
16-
author_email='[email protected]',
17-
url='https://scalableminds.com'
6+
version="0.2.0",
7+
install_requires=["scipy", "numpy", "pillow", "pyyaml", "wkw"],
8+
description="A cubing tool for webKnossos",
9+
author="Norman Rzepka",
10+
author_email="[email protected]",
11+
url="https://scalableminds.com",
1812
)

testdata/WT1_wkw.tar.gz

14.5 MB
Binary file not shown.

tests/test_downsampling.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def downsample_test_helper(use_compress):
7979
assert np.any(target_buffer != 0)
8080

8181
assert np.all(
82-
target_buffer == downsample_cube(source_buffer, (2, 2, 2), InterpolationModes.MAX)
82+
target_buffer
83+
== downsample_cube(source_buffer, (2, 2, 2), InterpolationModes.MAX)
8384
)
8485

8586

@@ -103,10 +104,14 @@ def test_downsample_multi_channel():
103104
offset = (0, 0, 0)
104105
num_channels = 3
105106
size = (32, 32, 10)
106-
source_data = (128 * np.random.randn(num_channels, size[0], size[1], size[2])).astype('uint8')
107+
source_data = (
108+
128 * np.random.randn(num_channels, size[0], size[1], size[2])
109+
).astype("uint8")
107110
file_len = 32
108111

109-
with open_wkw(source_info, num_channels=num_channels, file_len=file_len) as wkw_dataset:
112+
with open_wkw(
113+
source_info, num_channels=num_channels, file_len=file_len
114+
) as wkw_dataset:
110115
print("writing source_data shape", source_data.shape)
111116
wkw_dataset.write(offset, source_data)
112117
assert np.any(source_data != 0)
@@ -123,17 +128,19 @@ def test_downsample_multi_channel():
123128

124129
channels = []
125130
for channel_index in range(num_channels):
126-
channels.append(downsample_cube(source_data[channel_index], (2, 2, 2), InterpolationModes.MAX))
131+
channels.append(
132+
downsample_cube(
133+
source_data[channel_index], (2, 2, 2), InterpolationModes.MAX
134+
)
135+
)
127136
joined_buffer = np.stack(channels)
128137

129138
target_buffer = read_wkw(
130139
target_info,
131140
tuple(a * WKW_CUBE_SIZE for a in offset),
132141
list(map(lambda x: x // 2, size)),
133-
file_len=file_len
142+
file_len=file_len,
134143
)
135144
assert np.any(target_buffer != 0)
136145

137-
assert np.all(
138-
target_buffer == joined_buffer
139-
)
146+
assert np.all(target_buffer == joined_buffer)

tests/test_mag.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from wkcuber.mag import Mag
2+
from wkcuber.metadata import detect_resolutions
3+
4+
5+
def test_detect_resolutions():
6+
resolutions = sorted(list(detect_resolutions("/testdata/WT1_wkw", "color")))
7+
assert [mag.to_layer_name() for mag in resolutions] == ["1", "2-2-1"]
8+
9+
10+
def test_mag_constructor():
11+
mag = Mag(16)
12+
assert mag.to_array() == [16, 16, 16]
13+
14+
mag = Mag("256")
15+
assert mag.to_array() == [256, 256, 256]
16+
17+
mag = Mag("16-2-4")
18+
19+
assert mag.to_array() == [16, 2, 4]
20+
21+
mag1 = Mag("16-2-4")
22+
mag2 = Mag("8-2-4")
23+
24+
assert mag1 > mag2
25+
assert mag1.to_layer_name() == "16-2-4"

wkcuber/__main__.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from argparse import ArgumentParser
2-
from os import path
3-
from uuid import uuid4
2+
43
import logging
5-
import shutil
64

75
from .cubing import cubing, BLOCK_LEN
86
from .downsampling import downsample_mags, DEFAULT_EDGE_LEN
9-
from .compress import compress_mag
7+
from .compress import compress_mag_inplace
108
from .metadata import write_webknossos_metadata
119
from .utils import add_verbose_flag, add_jobs_flag
10+
from .mag import Mag
1211

1312

1413
def create_parser():
@@ -72,18 +71,6 @@ def create_parser():
7271
return parser
7372

7473

75-
def compress_mag_inplace(target_path, layer_name, mag, jobs):
76-
compress_target_path = "{}.compress-{}".format(target_path, uuid4())
77-
compress_mag(target_path, layer_name, compress_target_path, mag, jobs)
78-
79-
shutil.rmtree(path.join(args.target_path, args.layer_name, str(mag)))
80-
shutil.move(
81-
path.join(compress_target_path, layer_name, str(mag)),
82-
path.join(target_path, layer_name, str(mag)),
83-
)
84-
shutil.rmtree(compress_target_path)
85-
86-
8774
if __name__ == "__main__":
8875
args = create_parser().parse_args()
8976

@@ -100,13 +87,13 @@ def compress_mag_inplace(target_path, layer_name, mag, jobs):
10087
)
10188

10289
if not args.no_compress:
103-
compress_mag_inplace(args.target_path, args.layer_name, 1, args.jobs)
90+
compress_mag_inplace(args.target_path, args.layer_name, Mag(1), args.jobs)
10491

10592
downsample_mags(
10693
args.target_path,
10794
args.layer_name,
108-
1,
109-
args.max_mag,
95+
Mag(1),
96+
Mag(args.max_mag),
11097
args.dtype,
11198
"default",
11299
DEFAULT_EDGE_LEN,

wkcuber/compress.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import numpy as np
66
from argparse import ArgumentParser
77
from os import path, makedirs
8+
from uuid import uuid4
9+
from .mag import Mag
810

911
from .utils import (
1012
add_verbose_flag,
@@ -14,6 +16,7 @@
1416
ParallelExecutor,
1517
)
1618
from .metadata import detect_resolutions
19+
from typing import List
1720

1821

1922
def create_parser():
@@ -68,25 +71,39 @@ def compress_file_job(source_path, target_path):
6871
raise exc
6972

7073

71-
def compress_mag(source_path, layer_name, target_path, mag, jobs):
74+
def compress_mag(source_path, layer_name, target_path, mag: Mag, jobs):
7275
if path.exists(path.join(target_path, layer_name, str(mag))):
7376
logging.error("Target path '{}' already exists".format(target_path))
7477
exit(1)
7578

7679
source_wkw_info = WkwDatasetInfo(source_path, layer_name, None, mag)
7780
target_mag_path = path.join(target_path, layer_name, str(mag))
78-
logging.info("Compressing mag {0} in '{1}'".format(mag, target_mag_path))
81+
logging.info("Compressing mag {0} in '{1}'".format(str(mag), target_mag_path))
7982

8083
with open_wkw(source_wkw_info) as source_wkw, ParallelExecutor(jobs) as pool:
8184
source_wkw.compress(target_mag_path)
8285
for file in source_wkw.list_files():
8386
rel_file = path.relpath(file, source_wkw.root)
8487
pool.submit(compress_file_job, file, path.join(target_mag_path, rel_file))
8588

86-
logging.info("Mag {0} succesfully compressed".format(mag))
89+
logging.info("Mag {0} succesfully compressed".format(str(mag)))
8790

8891

89-
def compress_mags(source_path, layer_name, target_path=None, mags=None, jobs=1):
92+
def compress_mag_inplace(target_path, layer_name, mag: Mag, jobs):
93+
compress_target_path = "{}.compress-{}".format(target_path, uuid4())
94+
compress_mag(target_path, layer_name, compress_target_path, mag, jobs)
95+
96+
shutil.rmtree(path.join(target_path, layer_name, str(mag)))
97+
shutil.move(
98+
path.join(compress_target_path, layer_name, str(mag)),
99+
path.join(target_path, layer_name, str(mag)),
100+
)
101+
shutil.rmtree(compress_target_path)
102+
103+
104+
def compress_mags(
105+
source_path, layer_name, target_path=None, mags: List[Mag] = None, jobs=1
106+
):
90107
with_tmp_dir = target_path is None
91108
target_path = source_path + ".tmp" if with_tmp_dir else target_path
92109

wkcuber/downsampling.py

Lines changed: 24 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from itertools import product
1111
from functools import lru_cache
1212
from enum import Enum
13+
from .mag import Mag
1314

1415
from .utils import (
1516
add_jobs_flag,
@@ -113,53 +114,23 @@ def cube_addresses(source_wkw_info):
113114
return wkw_addresses
114115

115116

116-
def mag_as_vector(mag):
117-
if isinstance(mag, int):
118-
return (mag, mag, mag)
119-
else:
120-
return mag
121-
122-
123-
def is_mag_greater_than(mag1, mag2):
124-
for (m1, m2) in zip(mag1, mag2):
125-
if m1 > m2:
126-
return True
127-
return False
128-
129-
130-
def assert_valid_mag(mag):
131-
if isinstance(mag, int):
132-
assert log2(mag) % 1 == 0, "magnification needs to be power of 2."
133-
else:
134-
assert len(mag) == 3, "magnification must be int or a vector3 of ints"
135-
for mag_dim in mag:
136-
assert log2(mag_dim) % 1 == 0, "magnification needs to be power of 2."
137-
138-
139-
def mag_as_layer_name(mag):
140-
if isinstance(mag, int):
141-
return str(mag)
142-
else:
143-
x, y, z = mag
144-
return "{}-{}-{}".format(x, y, z)
145-
146-
147117
def downsample(
148118
source_wkw_info,
149119
target_wkw_info,
150-
_source_mag,
151-
_target_mag,
120+
source_mag: Mag,
121+
target_mag: Mag,
152122
interpolation_mode,
153123
cube_edge_len,
154124
jobs,
155125
compress,
156126
):
157-
source_mag = mag_as_vector(_source_mag)
158-
target_mag = mag_as_vector(_target_mag)
159-
assert not is_mag_greater_than(source_mag, target_mag)
160-
logging.info("Downsampling mag {} from mag {}".format(_target_mag, _source_mag))
161127

162-
mag_factors = [int(t / s) for (t, s) in zip(target_mag, source_mag)]
128+
assert source_mag < target_mag
129+
logging.info("Downsampling mag {} from mag {}".format(target_mag, source_mag))
130+
131+
mag_factors = [
132+
t // s for (t, s) in zip(target_mag.to_array(), source_mag.to_array())
133+
]
163134
# Detect the cubes that we want to downsample
164135
source_cube_addresses = cube_addresses(source_wkw_info)
165136

@@ -385,8 +356,8 @@ def downsample_cube(cube_buffer, factors, interpolation_mode):
385356
def downsample_mag(
386357
path,
387358
layer_name,
388-
source_mag,
389-
target_mag,
359+
source_mag: Mag,
360+
target_mag: Mag,
390361
dtype="uint8",
391362
interpolation_mode="default",
392363
cube_edge_len=DEFAULT_EDGE_LEN,
@@ -403,10 +374,10 @@ def downsample_mag(
403374
interpolation_mode = InterpolationModes[interpolation_mode.upper()]
404375

405376
source_wkw_info = WkwDatasetInfo(
406-
path, layer_name, dtype, mag_as_layer_name(source_mag)
377+
path, layer_name, dtype, source_mag.to_layer_name()
407378
)
408379
target_wkw_info = WkwDatasetInfo(
409-
path, layer_name, dtype, mag_as_layer_name(target_mag)
380+
path, layer_name, dtype, target_mag.to_layer_name()
410381
)
411382
downsample(
412383
source_wkw_info,
@@ -423,18 +394,17 @@ def downsample_mag(
423394
def downsample_mags(
424395
path,
425396
layer_name,
426-
from_mag,
427-
max_mag,
397+
from_mag: Mag,
398+
max_mag: Mag,
428399
dtype,
429400
interpolation_mode,
430401
cube_edge_len,
431402
jobs,
432403
compress,
433404
):
434-
assert log2(from_mag) % 1 == 0, "'from_mag' needs to be power of 2."
435-
target_mag = from_mag * 2
405+
target_mag = from_mag.scaled_by(2)
436406
while target_mag <= max_mag:
437-
source_mag = target_mag // 2
407+
source_mag = target_mag.divided_by(2)
438408
downsample_mag(
439409
path,
440410
layer_name,
@@ -446,7 +416,7 @@ def downsample_mags(
446416
jobs,
447417
compress,
448418
)
449-
target_mag = target_mag * 2
419+
target_mag.scale_by(2)
450420

451421

452422
if __name__ == "__main__":
@@ -455,15 +425,15 @@ def downsample_mags(
455425
if args.verbose:
456426
logging.basicConfig(level=logging.DEBUG)
457427

428+
from_mag = Mag(args.from_mag)
429+
max_mag = Mag(args.max)
458430
if args.anisotropic_target_mag:
459-
anisotropic_target_mag = tuple(map(int, args.anisotropic_target_mag.split("-")))
460-
assert_valid_mag(args.from_mag)
461-
assert_valid_mag(anisotropic_target_mag)
431+
anisotropic_target_mag = Mag(args.anisotropic_target_mag)
462432

463433
downsample_mag(
464434
args.path,
465435
args.layer_name,
466-
args.from_mag,
436+
from_mag,
467437
anisotropic_target_mag,
468438
args.dtype,
469439
args.interpolation_mode,
@@ -475,8 +445,8 @@ def downsample_mags(
475445
downsample_mags(
476446
args.path,
477447
args.layer_name,
478-
args.from_mag,
479-
args.max,
448+
from_mag,
449+
max_mag,
480450
args.dtype,
481451
args.interpolation_mode,
482452
args.buffer_cube_size,

0 commit comments

Comments
 (0)