Skip to content

Commit 66dd9f5

Browse files
authored
Refactor WkwDatasetInfo to include all wkw-header-related properties (#144)
* extend the WkwDatasetInfo to include all important properties * convert dtype to a VALID_VOXEL_TYPE * use voxel_type of header * use parameter 'dtype' on lower-level methods as np.dtype consistently * reformating of code * fix test_element_class_convertion test * Info for KnossosDatasets do not contain a header * resolve circular dependency * black
1 parent e330475 commit 66dd9f5

File tree

10 files changed

+111
-71
lines changed

10 files changed

+111
-71
lines changed

tests/test_downsampling.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
WKW_CUBE_SIZE = 1024
1717
CUBE_EDGE_LEN = 256
1818

19-
source_info = WkwDatasetInfo("testdata/WT1_wkw", "color", "uint8", 1)
20-
target_info = WkwDatasetInfo("testoutput/WT1_wkw", "color", "uint8", 2)
19+
source_info = WkwDatasetInfo("testdata/WT1_wkw", "color", 1, wkw.Header(np.uint8))
20+
target_info = WkwDatasetInfo("testoutput/WT1_wkw", "color", 2, wkw.Header(np.uint8))
2121

2222

23-
def read_wkw(wkw_info, offset, size, **kwargs):
24-
with open_wkw(wkw_info, **kwargs) as wkw_dataset:
23+
def read_wkw(wkw_info, offset, size):
24+
with open_wkw(wkw_info) as wkw_dataset:
2525
return wkw_dataset.read(offset, size)
2626

2727

@@ -110,11 +110,10 @@ def downsample_test_helper(use_compress):
110110
block_type = (
111111
wkw.Header.BLOCK_TYPE_LZ4HC if use_compress else wkw.Header.BLOCK_TYPE_RAW
112112
)
113+
target_info.header.block_type = block_type
114+
113115
target_buffer = read_wkw(
114-
target_info,
115-
tuple(a * WKW_CUBE_SIZE for a in offset),
116-
(CUBE_EDGE_LEN,) * 3,
117-
block_type=block_type,
116+
target_info, tuple(a * WKW_CUBE_SIZE for a in offset), (CUBE_EDGE_LEN,) * 3
118117
)[0]
119118
assert np.any(target_buffer != 0)
120119

@@ -133,14 +132,6 @@ def test_compressed_downsample_cube_job():
133132

134133

135134
def test_downsample_multi_channel():
136-
source_info = WkwDatasetInfo("testoutput/multi-channel-test", "color", "uint8", 1)
137-
target_info = WkwDatasetInfo("testoutput/multi-channel-test", "color", "uint8", 2)
138-
try:
139-
shutil.rmtree(source_info.dataset_path)
140-
shutil.rmtree(target_info.dataset_path)
141-
except:
142-
pass
143-
144135
offset = (0, 0, 0)
145136
num_channels = 3
146137
size = (32, 32, 10)
@@ -149,9 +140,25 @@ def test_downsample_multi_channel():
149140
).astype("uint8")
150141
file_len = 32
151142

152-
with open_wkw(
153-
source_info, num_channels=num_channels, file_len=file_len
154-
) as wkw_dataset:
143+
source_info = WkwDatasetInfo(
144+
"testoutput/multi-channel-test",
145+
"color",
146+
1,
147+
wkw.Header(np.uint8, num_channels, file_len=file_len),
148+
)
149+
target_info = WkwDatasetInfo(
150+
"testoutput/multi-channel-test",
151+
"color",
152+
2,
153+
wkw.Header(np.uint8, file_len=file_len),
154+
)
155+
try:
156+
shutil.rmtree(source_info.dataset_path)
157+
shutil.rmtree(target_info.dataset_path)
158+
except:
159+
pass
160+
161+
with open_wkw(source_info) as wkw_dataset:
155162
print("writing source_data shape", source_data.shape)
156163
wkw_dataset.write(offset, source_data)
157164
assert np.any(source_data != 0)
@@ -180,7 +187,6 @@ def test_downsample_multi_channel():
180187
target_info,
181188
tuple(a * WKW_CUBE_SIZE for a in offset),
182189
list(map(lambda x: x // 2, size)),
183-
file_len=file_len,
184190
)
185191
assert np.any(target_buffer != 0)
186192

tests/test_metadata.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import os
3+
import wkw
34

45
from wkcuber.cubing import ensure_wkw
56
from wkcuber.utils import WkwDatasetInfo, open_wkw
@@ -16,9 +17,9 @@ def test_element_class_convertion():
1617
test_wkw_path = os.path.join("testoutput", "test_metadata")
1718
prediction_layer_name = "prediction"
1819
prediction_wkw_info = WkwDatasetInfo(
19-
test_wkw_path, prediction_layer_name, np.float32, 1
20+
test_wkw_path, prediction_layer_name, 1, wkw.Header(np.float32, num_channels=3)
2021
)
21-
ensure_wkw(prediction_wkw_info, num_channels=3)
22+
ensure_wkw(prediction_wkw_info)
2223

2324
write_custom_layer(test_wkw_path, "prediction", np.float32, num_channels=3)
2425
write_webknossos_metadata(
@@ -61,9 +62,11 @@ def write_custom_layer(target_path, layer_name, dtype, num_channels):
6162
.reshape((num_channels, 4, 4, 4))
6263
.astype(dtype)
6364
)
64-
prediction_wkw_info = WkwDatasetInfo(target_path, layer_name, dtype, 1)
65-
ensure_wkw(prediction_wkw_info, num_channels=num_channels)
66-
with open_wkw(prediction_wkw_info, num_channels=num_channels) as dataset:
65+
prediction_wkw_info = WkwDatasetInfo(
66+
target_path, layer_name, 1, wkw.Header(dtype, num_channels)
67+
)
68+
ensure_wkw(prediction_wkw_info)
69+
with open_wkw(prediction_wkw_info) as dataset:
6770
dataset.write(off=(0, 0, 0), data=data)
6871

6972

wkcuber/compress.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
setup_logging,
1919
)
2020
from .metadata import detect_resolutions
21+
from .metadata import convert_element_class_to_dtype
2122
from typing import List
2223

2324

@@ -80,10 +81,10 @@ def compress_mag(source_path, layer_name, target_path, mag: Mag, args=None):
8081
exit(1)
8182

8283
if args is not None and hasattr(args, "dtype"):
83-
dtype = args.dtype
84+
header = wkw.Header(convert_element_class_to_dtype(args.dtype))
8485
else:
85-
dtype = None
86-
source_wkw_info = WkwDatasetInfo(source_path, layer_name, dtype, mag)
86+
header = None
87+
source_wkw_info = WkwDatasetInfo(source_path, layer_name, mag, header)
8788
target_mag_path = path.join(target_path, layer_name, str(mag))
8889
logging.info("Compressing mag {0} in '{1}'".format(str(mag), target_mag_path))
8990

wkcuber/convert_knossos.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
setup_logging,
1919
)
2020
from .knossos import KnossosDataset, CUBE_EDGE_LEN
21+
from .metadata import convert_element_class_to_dtype
2122

2223

2324
def create_parser():
@@ -76,7 +77,9 @@ def convert_knossos(
7677
source_path, target_path, layer_name, dtype, mag=1, jobs=1, args=None
7778
):
7879
source_knossos_info = KnossosDatasetInfo(source_path, dtype)
79-
target_wkw_info = WkwDatasetInfo(target_path, layer_name, dtype, mag)
80+
target_wkw_info = WkwDatasetInfo(
81+
target_path, layer_name, mag, wkw.Header(convert_element_class_to_dtype(dtype))
82+
)
8083

8184
ensure_wkw(target_wkw_info)
8285

wkcuber/cubing.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import time
22
import logging
33
import numpy as np
4+
import wkw
45
from argparse import ArgumentParser
56
from os import path
67
from natsort import natsorted
@@ -22,6 +23,7 @@
2223
setup_logging,
2324
)
2425
from .image_readers import image_reader
26+
from .metadata import convert_element_class_to_dtype
2527

2628
BLOCK_LEN = 32
2729

@@ -121,15 +123,14 @@ def cubing_job(args):
121123
source_file_batches,
122124
batch_size,
123125
image_size,
124-
num_channels,
125126
pad,
126127
) = args
127128
if len(z_batches) == 0:
128129
return
129130

130131
downsampling_needed = target_mag != Mag(1)
131132

132-
with open_wkw(target_wkw_info, num_channels=num_channels) as target_wkw:
133+
with open_wkw(target_wkw_info) as target_wkw:
133134
# Iterate over batches of continuous z sections
134135
# The batches have a maximum size of `batch_size`
135136
# Batched iterations allows to utilize IO more efficiently
@@ -144,7 +145,9 @@ def cubing_job(args):
144145
# Iterate over each z section in the batch
145146
for z, file_name in zip(z_batch, source_file_batch):
146147
# Image shape will be (x, y, channel_count, z=1)
147-
image = read_image_file(file_name, target_wkw_info.dtype)
148+
image = read_image_file(
149+
file_name, target_wkw_info.header.voxel_type
150+
)
148151
if not pad:
149152
assert (
150153
image.shape[0:2] == image_size
@@ -171,7 +174,9 @@ def cubing_job(args):
171174
for _slice in slices
172175
]
173176

174-
buffer = prepare_slices_for_wkw(slices, num_channels)
177+
buffer = prepare_slices_for_wkw(
178+
slices, target_wkw_info.header.num_channels
179+
)
175180
if downsampling_needed:
176181
buffer = downsample_unpadded_data(
177182
buffer, target_mag, interpolation_mode
@@ -195,8 +200,20 @@ def cubing_job(args):
195200

196201
def cubing(source_path, target_path, layer_name, dtype, batch_size, args=None) -> dict:
197202

203+
source_files = find_source_filenames(source_path)
204+
205+
# All images are assumed to have equal dimensions
206+
num_x, num_y = image_reader.read_dimensions(source_files[0])
207+
num_channels = image_reader.read_channel_count(source_files[0])
208+
num_z = len(source_files)
209+
198210
target_mag = Mag(args.target_mag)
199-
target_wkw_info = WkwDatasetInfo(target_path, layer_name, dtype, target_mag)
211+
target_wkw_info = WkwDatasetInfo(
212+
target_path,
213+
layer_name,
214+
target_mag,
215+
wkw.Header(convert_element_class_to_dtype(dtype), num_channels),
216+
)
200217
interpolation_mode = parse_interpolation_mode(
201218
args.interpolation_mode, target_wkw_info.layer_name
202219
)
@@ -205,16 +222,9 @@ def cubing(source_path, target_path, layer_name, dtype, batch_size, args=None) -
205222
f"Downsampling the cubed image to {target_mag} in memory with interpolation mode {interpolation_mode}."
206223
)
207224

208-
source_files = find_source_filenames(source_path)
209-
210-
# All images are assumed to have equal dimensions
211-
num_x, num_y = image_reader.read_dimensions(source_files[0])
212-
num_channels = image_reader.read_channel_count(source_files[0])
213-
num_z = len(source_files)
214-
215225
logging.info("Found source files: count={} size={}x{}".format(num_z, num_x, num_y))
216226

217-
ensure_wkw(target_wkw_info, num_channels=num_channels)
227+
ensure_wkw(target_wkw_info)
218228

219229
with get_executor_for_args(args) as executor:
220230
job_args = []
@@ -233,7 +243,6 @@ def cubing(source_path, target_path, layer_name, dtype, batch_size, args=None) -
233243
source_files[z:max_z],
234244
batch_size,
235245
(num_x, num_y),
236-
num_channels,
237246
args.pad,
238247
)
239248
)

wkcuber/downsampling.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ def determine_buffer_edge_len(dataset):
3838
return min(DEFAULT_EDGE_LEN, dataset.header.file_len * dataset.header.block_len)
3939

4040

41+
def extend_wkw_dataset_info_header(wkw_info, **kwargs):
42+
for key, value in kwargs.items():
43+
setattr(wkw_info.header, key, value)
44+
45+
4146
def calculate_virtual_scale_for_target_mag(target_mag):
4247
"This scale is not the actual scale of the dataset"
4348
"The virtual scale is used for downsample_mags_anisotropic."
@@ -182,13 +187,16 @@ def downsample(
182187
header_block_type = (
183188
wkw.Header.BLOCK_TYPE_LZ4HC if compress else wkw.Header.BLOCK_TYPE_RAW
184189
)
185-
ensure_wkw(
190+
191+
extend_wkw_dataset_info_header(
186192
target_wkw_info,
187-
block_type=header_block_type,
188193
num_channels=num_channels,
189194
file_len=source_wkw.header.file_len,
195+
block_type=header_block_type,
190196
)
191197

198+
ensure_wkw(target_wkw_info)
199+
192200
with get_executor_for_args(args) as executor:
193201
job_args = []
194202
for target_cube_xyz in target_cube_addresses:
@@ -230,12 +238,16 @@ def downsample_cube_job(args):
230238
with open_wkw(source_wkw_info) as source_wkw:
231239
num_channels = source_wkw.header.num_channels
232240
source_dtype = source_wkw.header.voxel_type
233-
with open_wkw(
241+
242+
extend_wkw_dataset_info_header(
234243
target_wkw_info,
235-
block_type=header_block_type,
244+
voxel_type=source_dtype,
236245
num_channels=num_channels,
237246
file_len=source_wkw.header.file_len,
238-
) as target_wkw:
247+
block_type=header_block_type,
248+
)
249+
250+
with open_wkw(target_wkw_info) as target_wkw:
239251
wkw_cubelength = (
240252
source_wkw.header.file_len * source_wkw.header.block_len
241253
)
@@ -470,10 +482,13 @@ def downsample_mag(
470482
):
471483
interpolation_mode = parse_interpolation_mode(interpolation_mode, layer_name)
472484

473-
source_wkw_info = WkwDatasetInfo(path, layer_name, None, source_mag.to_layer_name())
485+
source_wkw_info = WkwDatasetInfo(path, layer_name, source_mag.to_layer_name(), None)
474486
with open_wkw(source_wkw_info) as source:
475487
target_wkw_info = WkwDatasetInfo(
476-
path, layer_name, source.header.voxel_type, target_mag.to_layer_name()
488+
path,
489+
layer_name,
490+
target_mag.to_layer_name(),
491+
wkw.Header(source.header.voxel_type),
477492
)
478493

479494
downsample(

wkcuber/image_readers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
class PillowImageReader:
1414
def read_array(self, file_name, dtype):
15-
this_layer = np.array(Image.open(file_name), np.dtype(dtype))
15+
this_layer = np.array(Image.open(file_name), dtype)
1616
this_layer = this_layer.swapaxes(0, 1)
1717
this_layer = this_layer.reshape(this_layer.shape + (1,))
1818
return this_layer
@@ -34,7 +34,7 @@ def read_channel_count(self, file_name):
3434
def to_target_datatype(data: np.ndarray, target_dtype) -> np.ndarray:
3535

3636
factor = (1 + np.iinfo(data.dtype).max) / (1 + np.iinfo(target_dtype).max)
37-
return (data / factor).astype(np.dtype(target_dtype))
37+
return (data / factor).astype(target_dtype)
3838

3939

4040
class Dm3ImageReader:

wkcuber/metadata.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def read_metadata_for_layer(wkw_path, layer_name):
163163
return layer_info, dtype, bounding_box, origin
164164

165165

166-
def convert_dype_to_element_class(dtype):
166+
def convert_dtype_to_element_class(dtype):
167167
element_class_to_dtype_map = {
168168
"float": np.float32,
169169
"double": np.float64,
@@ -173,20 +173,19 @@ def convert_dype_to_element_class(dtype):
173173
"uint64": np.uint64,
174174
}
175175
conversion_map = {v: k for k, v in element_class_to_dtype_map.items()}
176-
return conversion_map.get(dtype.type, str(dtype))
176+
return conversion_map.get(dtype, str(dtype))
177177

178178

179179
def detect_dtype(dataset_path, layer, mag: Mag = Mag(1)):
180180
layer_path = path.join(dataset_path, layer, str(mag))
181181
if path.exists(layer_path):
182182
with wkw.Dataset.open(layer_path) as dataset:
183-
voxel_type = dataset.header.voxel_type
183+
voxel_size = dataset.header.voxel_type
184184
num_channels = dataset.header.num_channels
185-
voxel_size = np.dtype(voxel_type)
186185
if voxel_size == np.uint8 and num_channels > 1:
187186
return "uint" + str(8 * num_channels)
188187
else:
189-
return convert_dype_to_element_class(voxel_size)
188+
return convert_dtype_to_element_class(voxel_size)
190189

191190

192191
def detect_cubeLength(dataset_path, layer, mag: Mag = Mag(1)):

0 commit comments

Comments
 (0)