Skip to content

Commit 4be5dba

Browse files
Improve speed and RAM consumption of buffered slice writer (#937)
* add tests for buffered slice writer * reduce data size for test * improve performance of buffered slice writer * format * format and lint * re-add support for dimension parameter in buffered slice writer; clean up tests * implement warnings for unaligned writes in buffered slice writer * fix type * linting * format * try to fix flaky test --------- Co-authored-by: Norman Rzepka <[email protected]>
1 parent 4b5ad93 commit 4be5dba

File tree

3 files changed

+223
-44
lines changed

3 files changed

+223
-44
lines changed

webknossos/tests/dataset/test_buffered_slice_utils.py

Lines changed: 123 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from pathlib import Path
23

34
import numpy as np
@@ -8,6 +9,10 @@
89
from webknossos.geometry import BoundingBox, Mag, Vec3Int
910
from webknossos.utils import rmtree
1011

12+
# This module effectively tests BufferedSliceWriter and
13+
# BufferedSliceReader (by calling get_buffered_slice_writer
14+
# and get_buffered_slice_reader).
15+
1116

1217
def test_buffered_slice_writer() -> None:
1318
test_img = np.arange(24 * 24).reshape(24, 24).astype(np.uint16) + 1
@@ -77,11 +82,13 @@ def test_buffered_slice_writer() -> None:
7782
def test_buffered_slice_writer_along_different_axis(tmp_path: Path) -> None:
7883
test_cube = (np.random.random((3, 13, 13, 13)) * 100).astype(np.uint8)
7984
cube_size_without_channel = test_cube.shape[1:]
80-
offset = Vec3Int(5, 10, 20)
85+
offset = Vec3Int(64, 96, 32)
8186

8287
for dim in [0, 1, 2]:
8388
ds = Dataset(tmp_path / f"buffered_slice_writer_{dim}", voxel_size=(1, 1, 1))
84-
mag_view = ds.add_layer("color", COLOR_CATEGORY, num_channels=3).add_mag(1)
89+
mag_view = ds.add_layer(
90+
"color", COLOR_CATEGORY, num_channels=test_cube.shape[0]
91+
).add_mag(1)
8592

8693
with mag_view.get_buffered_slice_writer(
8794
absolute_offset=offset, buffer_size=5, dimension=dim
@@ -129,3 +136,117 @@ def test_buffered_slice_reader_along_different_axis(tmp_path: Path) -> None:
129136

130137
assert np.array_equal(slice_data_a, original_slice)
131138
assert np.array_equal(slice_data_b, original_slice)
139+
140+
141+
def test_basic_buffered_slice_writer(tmp_path: Path) -> None:
142+
# Create DS
143+
dataset = Dataset(tmp_path, voxel_size=(1, 1, 1))
144+
layer = dataset.add_layer(
145+
layer_name="color", category="color", dtype_per_channel="uint8", num_channels=1
146+
)
147+
mag1 = layer.add_mag("1", chunk_shape=(32, 32, 32), chunks_per_shard=(8, 8, 8))
148+
149+
# Allocate some data (~ 8 MB)
150+
shape = (512, 512, 32)
151+
data = np.random.randint(0, 255, shape, dtype=np.uint8)
152+
153+
with warnings.catch_warnings():
154+
warnings.filterwarnings("error") # This escalates the warning to an error
155+
156+
# Write some slices
157+
with mag1.get_buffered_slice_writer() as writer:
158+
for z in range(0, shape[2]):
159+
section = data[:, :, z]
160+
writer.send(section)
161+
162+
written_data = mag1.read(absolute_offset=(0, 0, 0), size=shape)
163+
164+
assert np.all(data == written_data)
165+
166+
167+
def test_buffered_slice_writer_should_warn_about_unaligned_usage(
168+
tmp_path: Path,
169+
) -> None:
170+
# Create DS
171+
dataset = Dataset(tmp_path, voxel_size=(1, 1, 1))
172+
layer = dataset.add_layer(
173+
layer_name="color", category="color", dtype_per_channel="uint8", num_channels=1
174+
)
175+
mag1 = layer.add_mag("1", chunk_shape=(32, 32, 32), chunks_per_shard=(8, 8, 8))
176+
177+
offset = (1, 1, 1)
178+
179+
# Allocate some data (~ 8 MB)
180+
shape = (512, 512, 32)
181+
data = np.random.randint(0, 255, shape, dtype=np.uint8)
182+
183+
with warnings.catch_warnings(record=True) as recorded_warnings:
184+
warnings.filterwarnings("default", module="webknossos", message=r"\[WARNING\]")
185+
# Write some slices
186+
with mag1.get_buffered_slice_writer(
187+
absolute_offset=offset, buffer_size=35
188+
) as writer:
189+
for z in range(0, shape[2]):
190+
section = data[:, :, z]
191+
writer.send(section)
192+
193+
warning1, warning2 = recorded_warnings
194+
assert issubclass(warning1.category, UserWarning) and "Using an offset" in str(
195+
warning1.message
196+
)
197+
assert issubclass(
198+
warning2.category, UserWarning
199+
) and "Using a buffer size" in str(warning2.message)
200+
201+
written_data = mag1.read(absolute_offset=offset, size=shape)
202+
203+
assert np.all(data == written_data)
204+
205+
206+
def test_basic_buffered_slice_writer_multi_shard(tmp_path: Path) -> None:
207+
# Create DS
208+
dataset = Dataset(tmp_path, voxel_size=(1, 1, 1))
209+
layer = dataset.add_layer(
210+
layer_name="color", category="color", dtype_per_channel="uint8", num_channels=1
211+
)
212+
mag1 = layer.add_mag("1", chunk_shape=(32, 32, 32), chunks_per_shard=(4, 4, 4))
213+
214+
# Allocate some data (~ 3 MB) that covers multiple shards (also in z)
215+
shape = (160, 150, 140)
216+
data = np.random.randint(0, 255, shape, dtype=np.uint8)
217+
218+
with warnings.catch_warnings():
219+
warnings.filterwarnings("error") # This escalates the warning to an error
220+
221+
# Write some slices
222+
with mag1.get_buffered_slice_writer() as writer:
223+
for z in range(0, shape[2]):
224+
section = data[:, :, z]
225+
writer.send(section)
226+
227+
written_data = mag1.read(absolute_offset=(0, 0, 0), size=shape)
228+
229+
assert np.all(data == written_data)
230+
231+
232+
def test_basic_buffered_slice_writer_multi_shard_multi_channel(tmp_path: Path) -> None:
233+
# Create DS
234+
dataset = Dataset(tmp_path, voxel_size=(1, 1, 1))
235+
layer = dataset.add_layer(
236+
layer_name="color", category="color", dtype_per_channel="uint8", num_channels=3
237+
)
238+
mag1 = layer.add_mag("1", chunk_shape=(32, 32, 32), chunks_per_shard=(4, 4, 4))
239+
240+
# Allocate some data (~ 3 MB) that covers multiple shards (also in z)
241+
shape = (3, 160, 150, 140)
242+
data = np.random.randint(0, 255, shape, dtype=np.uint8)
243+
244+
# Write some slices
245+
with mag1.get_buffered_slice_writer() as writer:
246+
for z in range(0, shape[-1]):
247+
section = data[:, :, :, z]
248+
writer.send(section)
249+
250+
written_data = mag1.read(absolute_offset=(0, 0, 0), size=shape[1:])
251+
252+
assert np.all(data == written_data)

webknossos/webknossos/dataset/_utils/buffered_slice_writer.py

Lines changed: 85 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010
import psutil
1111

12-
from webknossos.geometry import Vec3Int, Vec3IntLike
12+
from webknossos.geometry import BoundingBox, Vec3Int, Vec3IntLike
1313

1414
if TYPE_CHECKING:
1515
from webknossos.dataset import View
@@ -69,21 +69,43 @@ def __init__(
6969
)
7070
self.dimension = dimension
7171

72+
effective_offset = Vec3Int.full(0)
73+
if self.relative_offset is not None:
74+
effective_offset = self.view.bounding_box.topleft + self.relative_offset
75+
76+
if self.absolute_offset is not None:
77+
effective_offset = self.absolute_offset
78+
79+
view_chunk_depth = self.view.info.chunk_shape[self.dimension]
80+
if (
81+
effective_offset is not None
82+
and effective_offset[self.dimension] % view_chunk_depth != 0
83+
):
84+
warnings.warn(
85+
"[WARNING] Using an offset that doesn't align with the datataset's chunk size, "
86+
+ "will slow down the buffered slice writer, because twice as many chunks will be written.",
87+
)
88+
if buffer_size >= view_chunk_depth and buffer_size % view_chunk_depth > 0:
89+
warnings.warn(
90+
"[WARNING] Using a buffer size that doesn't align with the datataset's chunk size, "
91+
+ "will slow down the buffered slice writer.",
92+
)
93+
7294
assert 0 <= dimension <= 2
7395

74-
self.buffer: List[np.ndarray] = []
96+
self.slices_to_write: List[np.ndarray] = []
7597
self.current_slice: Optional[int] = None
7698
self.buffer_start_slice: Optional[int] = None
7799

78-
def _write_buffer(self) -> None:
79-
if len(self.buffer) == 0:
100+
def _flush_buffer(self) -> None:
101+
if len(self.slices_to_write) == 0:
80102
return
81103

82104
assert (
83-
len(self.buffer) <= self.buffer_size
105+
len(self.slices_to_write) <= self.buffer_size
84106
), "The WKW buffer is larger than the defined batch_size. The buffer should have been flushed earlier. This is probably a bug in the BufferedSliceWriter."
85107

86-
uniq_dtypes = set(map(lambda _slice: _slice.dtype, self.buffer))
108+
uniq_dtypes = set(map(lambda _slice: _slice.dtype, self.slices_to_write))
87109
assert (
88110
len(uniq_dtypes) == 1
89111
), "The buffer of BufferedSliceWriter contains slices with differing dtype."
@@ -95,7 +117,7 @@ def _write_buffer(self) -> None:
95117
if self.use_logging:
96118
info(
97119
"({}) Writing {} slices at position {}.".format(
98-
getpid(), len(self.buffer), self.buffer_start_slice
120+
getpid(), len(self.slices_to_write), self.buffer_start_slice
99121
)
100122
)
101123
log_memory_consumption()
@@ -104,44 +126,65 @@ def _write_buffer(self) -> None:
104126
assert (
105127
self.buffer_start_slice is not None
106128
), "Failed to write buffer: The buffer_start_slice is not set."
107-
max_width = max(slice.shape[-2] for slice in self.buffer)
108-
max_height = max(slice.shape[-1] for slice in self.buffer)
109-
110-
self.buffer = [
111-
np.pad(
112-
slice,
113-
mode="constant",
114-
pad_width=[
115-
(0, 0),
116-
(0, max_width - slice.shape[-2]),
117-
(0, max_height - slice.shape[-1]),
118-
],
119-
)
120-
for slice in self.buffer
121-
]
129+
max_width = max(section.shape[-2] for section in self.slices_to_write)
130+
max_height = max(section.shape[-1] for section in self.slices_to_write)
131+
channel_count = self.slices_to_write[0].shape[0]
132+
133+
buffer_bbox = BoundingBox(
134+
(0, 0, 0), (max_width, max_height, self.buffer_size)
135+
)
122136

123-
data = np.concatenate(
124-
[np.expand_dims(slice, self.dimension + 1) for slice in self.buffer],
125-
axis=self.dimension + 1,
137+
shard_dimensions = self.view._get_file_dimensions().moveaxis(
138+
-1, self.dimension
126139
)
127-
buffer_start_list = [0, 0, 0]
128-
buffer_start_list[self.dimension] = self.buffer_start_slice
129-
buffer_start = Vec3Int(buffer_start_list)
130-
buffer_start_mag1 = buffer_start * self.view.mag.to_vec3_int()
131-
self.view.write(
132-
data,
133-
offset=buffer_start.add_or_none(self.offset),
134-
relative_offset=buffer_start_mag1.add_or_none(self.relative_offset),
135-
absolute_offset=buffer_start_mag1.add_or_none(self.absolute_offset),
136-
json_update_allowed=self.json_update_allowed,
140+
chunk_size = Vec3Int(
141+
min(shard_dimensions[0], max_width),
142+
min(shard_dimensions[1], max_height),
143+
self.buffer_size,
137144
)
145+
for chunk_bbox in buffer_bbox.chunk(chunk_size):
146+
info(f"Writing chunk {chunk_bbox}")
147+
width, height, _ = chunk_bbox.size
148+
data = np.zeros(
149+
(channel_count, width, height, self.buffer_size),
150+
dtype=self.slices_to_write[0].dtype,
151+
)
152+
153+
z = 0
154+
for section in self.slices_to_write:
155+
section_chunk = section[
156+
:,
157+
chunk_bbox.topleft.x : chunk_bbox.bottomright.x,
158+
chunk_bbox.topleft.y : chunk_bbox.bottomright.y,
159+
]
160+
data[
161+
:, 0 : section_chunk.shape[-2], 0 : section_chunk.shape[-1], z
162+
] = section_chunk
163+
164+
z += 1
165+
166+
buffer_start = Vec3Int(
167+
chunk_bbox.topleft.x, chunk_bbox.topleft.y, self.buffer_start_slice
168+
).moveaxis(-1, self.dimension)
169+
buffer_start_mag1 = buffer_start * self.view.mag.to_vec3_int()
170+
171+
data = np.moveaxis(data, -1, self.dimension + 1)
172+
173+
self.view.write(
174+
data,
175+
offset=buffer_start.add_or_none(self.offset),
176+
relative_offset=buffer_start_mag1.add_or_none(self.relative_offset),
177+
absolute_offset=buffer_start_mag1.add_or_none(self.absolute_offset),
178+
json_update_allowed=self.json_update_allowed,
179+
)
180+
del data
138181

139182
except Exception as exc:
140183
error(
141-
"({}) An exception occurred in BufferedSliceWriter._write_buffer with {} "
184+
"({}) An exception occurred in BufferedSliceWriter._flush_buffer with {} "
142185
"slices at position {}. Original error is:\n{}:{}\n\nTraceback:".format(
143186
getpid(),
144-
len(self.buffer),
187+
len(self.slices_to_write),
145188
self.buffer_start_slice,
146189
type(exc).__name__,
147190
exc,
@@ -152,23 +195,23 @@ def _write_buffer(self) -> None:
152195

153196
raise exc
154197
finally:
155-
self.buffer = []
198+
self.slices_to_write = []
156199

157200
def _get_slice_generator(self) -> Generator[None, np.ndarray, None]:
158201
current_slice = 0
159202
while True:
160203
data = yield # Data gets send from the user
161-
if len(self.buffer) == 0:
204+
if len(self.slices_to_write) == 0:
162205
self.buffer_start_slice = current_slice
163206
if len(data.shape) == 2:
164207
# The input data might contain channel data or not.
165208
# Bringing it into the same shape simplifies the code
166209
data = np.expand_dims(data, axis=0)
167-
self.buffer.append(data)
210+
self.slices_to_write.append(data)
168211
current_slice += 1
169212

170213
if current_slice % self.buffer_size == 0:
171-
self._write_buffer()
214+
self._flush_buffer()
172215

173216
def __enter__(self) -> Generator[None, np.ndarray, None]:
174217
gen = self._get_slice_generator()
@@ -182,4 +225,4 @@ def __exit__(
182225
_value: Optional[BaseException],
183226
_tb: Optional[TracebackType],
184227
) -> None:
185-
self._write_buffer()
228+
self._flush_buffer()

webknossos/webknossos/geometry/vec3_int.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,21 @@ def __repr__(self) -> str:
166166
def add_or_none(self, other: Optional["Vec3Int"]) -> Optional["Vec3Int"]:
167167
return None if other is None else self + other
168168

169+
def moveaxis(
170+
self, source: Union[int, List[int]], target: Union[int, List[int]]
171+
) -> "Vec3Int":
172+
"""
173+
Allows to move one element at index `source` to another index `target`. Similar to
174+
np.moveaxis, this is *not* a swap operation but instead it moves the specified
175+
source so that the other elements move when necessary.
176+
"""
177+
178+
# Piggy-back on np.moveaxis by creating an auxiliary array where the indices 0, 1 and
179+
# 2 appear in the shape.
180+
indices = np.moveaxis(np.zeros((0, 1, 2)), source, target).shape
181+
arr = self.to_np()[np.array(indices)]
182+
return Vec3Int(arr)
183+
169184
@classmethod
170185
def zeros(cls) -> "Vec3Int":
171186
return cls(0, 0, 0)

0 commit comments

Comments
 (0)