Skip to content

Commit 11f9931

Browse files
authored
Adapt BufferedSliceWriter, to enable configuration after initialization. (#1052)
* Adapt BufferedSliceWriter, to enable configuration after initialization. * Update Changelog.md * Adapt test and fix bug with previously persisted current_slice * Adapt test to write at different positions.
1 parent d3fd9ec commit 11f9931

File tree

3 files changed

+143
-74
lines changed

3 files changed

+143
-74
lines changed

webknossos/Changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ For upgrade instructions, please check the respective _Breaking Changes_ section
1717
### Added
1818

1919
### Changed
20+
- The context variable of View.get_buffered_slice_writer() is a BufferedSliceWriter now instead of a Generator. Interaction with the SliceWriter does not change, but updating the offset after first initialization is possible now. [1052](https://github.com/scalableminds/webknossos-libs/pull/1052)
2021

2122
### Fixed
2223

webknossos/tests/dataset/test_buffered_slice_utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,42 @@ def test_basic_buffered_slice_writer_multi_shard_multi_channel(tmp_path: Path) -
293293
written_data = mag1.read(absolute_offset=(0, 0, 0), size=shape[1:])
294294

295295
assert np.all(data == written_data)
296+
297+
298+
def test_buffered_slice_writer_reset_offset(tmp_path: Path) -> None:
299+
# Create DS
300+
dataset = Dataset(tmp_path, voxel_size=(1, 1, 1))
301+
layer = dataset.add_layer(
302+
layer_name="color", category="color", dtype_per_channel="uint8", num_channels=1
303+
)
304+
mag1 = layer.add_mag("1", chunk_shape=(32, 32, 32), chunks_per_shard=(8, 8, 8))
305+
306+
# Allocate some data (~ 8 MB)
307+
shape = (512, 512, 32)
308+
data = np.random.randint(0, 255, shape, dtype=np.uint8)
309+
310+
with warnings.catch_warnings():
311+
warnings.filterwarnings("error") # This escalates the warning to an error
312+
313+
# Write some slices
314+
with mag1.get_buffered_slice_writer() as writer:
315+
for z in range(0, shape[2] - 8):
316+
section = data[:, :, z]
317+
writer.send(section)
318+
writer.reset_offset(absolute_offset=(0, 0, shape[2]))
319+
for z in range(shape[2] - 8, shape[2]):
320+
section = data[:, :, z]
321+
writer.send(section)
322+
323+
written_data_before_reset = mag1.read(
324+
absolute_offset=(0, 0, 0), size=(shape[0], shape[1], shape[2] - 8)
325+
)
326+
written_data_after_reset = mag1.read(
327+
absolute_offset=(0, 0, shape[2]), size=(shape[0], shape[1], 8)
328+
)
329+
330+
written_data = np.concatenate(
331+
(written_data_before_reset, written_data_after_reset), axis=3
332+
)
333+
334+
assert np.all(data == written_data)

webknossos/webknossos/dataset/_utils/buffered_slice_writer.py

Lines changed: 103 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -51,50 +51,31 @@ def __init__(
5151
) -> None:
5252
"""see `View.get_buffered_slice_writer()`"""
5353

54-
self.view = view
55-
self.buffer_size = buffer_size
56-
self.dtype = self.view.get_dtype()
57-
self.use_logging = use_logging
58-
self.json_update_allowed = json_update_allowed
59-
self.bbox: NDBoundingBox
60-
61-
if (
62-
offset is None
63-
and relative_offset is None
64-
and absolute_offset is None
65-
and relative_bounding_box is None
66-
and absolute_bounding_box is None
67-
):
68-
relative_offset = Vec3Int.zeros()
69-
if offset is not None:
70-
warnings.warn(
71-
"[DEPRECATION] Using offset for a buffered slice writer is deprecated. "
72-
+ "Please use the parameter relative_offset or absolute_offset in Mag(1) instead.",
73-
DeprecationWarning,
74-
)
75-
self.offset = None if offset is None else Vec3Int(offset)
76-
77-
if relative_offset is not None:
78-
self.bbox = BoundingBox(
79-
self.view.bounding_box.topleft + relative_offset, Vec3Int.zeros()
80-
)
81-
82-
if absolute_offset is not None:
83-
self.bbox = BoundingBox(absolute_offset, Vec3Int.zeros())
84-
85-
if relative_bounding_box is not None:
86-
self.bbox = relative_bounding_box.offset(self.view.bounding_box.topleft)
87-
88-
if absolute_bounding_box is not None:
89-
self.bbox = absolute_bounding_box
54+
self._view = view
55+
self._buffer_size = buffer_size
56+
self._dtype = self._view.get_dtype()
57+
self._use_logging = use_logging
58+
self._json_update_allowed = json_update_allowed
59+
self._bbox: NDBoundingBox
60+
self._slices_to_write: List[np.ndarray] = []
61+
self._current_slice: Optional[int] = None
62+
self._buffer_start_slice: Optional[int] = None
63+
64+
self.reset_offset(
65+
offset,
66+
relative_offset,
67+
absolute_offset,
68+
relative_bounding_box,
69+
absolute_bounding_box,
70+
)
9071

9172
assert 0 <= dimension <= 2 # either x (0), y (1) or z (2)
9273
self.dimension = dimension
9374

94-
view_chunk_depth = self.view.info.chunk_shape[dimension]
75+
view_chunk_depth = self._view.info.chunk_shape[dimension]
9576
if (
96-
self.bbox is not None
97-
and self.bbox.topleft_xyz[self.dimension] % view_chunk_depth != 0
77+
self._bbox is not None
78+
and self._bbox.topleft_xyz[self.dimension] % view_chunk_depth != 0
9879
):
9980
warnings.warn(
10081
"[WARNING] Using an offset that doesn't align with the datataset's chunk size, "
@@ -106,54 +87,50 @@ def __init__(
10687
+ "will slow down the buffered slice writer.",
10788
)
10889

109-
self.slices_to_write: List[np.ndarray] = []
110-
self.current_slice: Optional[int] = None
111-
self.buffer_start_slice: Optional[int] = None
112-
11390
def _flush_buffer(self) -> None:
114-
if len(self.slices_to_write) == 0:
91+
if len(self._slices_to_write) == 0:
11592
return
11693

11794
assert (
118-
len(self.slices_to_write) <= self.buffer_size
95+
len(self._slices_to_write) <= self._buffer_size
11996
), "The WKW buffer is larger than the defined batch_size. The buffer should have been flushed earlier. This is probably a bug in the BufferedSliceWriter."
12097

121-
uniq_dtypes = set(map(lambda _slice: _slice.dtype, self.slices_to_write))
98+
uniq_dtypes = set(map(lambda _slice: _slice.dtype, self._slices_to_write))
12299
assert (
123100
len(uniq_dtypes) == 1
124101
), "The buffer of BufferedSliceWriter contains slices with differing dtype."
125-
assert uniq_dtypes.pop() == self.dtype, (
102+
assert uniq_dtypes.pop() == self._dtype, (
126103
"The buffer of BufferedSliceWriter contains slices with a dtype "
127104
"which differs from the dtype with which the BufferedSliceWriter was instantiated."
128105
)
129106

130-
if self.use_logging:
107+
if self._use_logging:
131108
info(
132109
"({}) Writing {} slices at position {}.".format(
133-
getpid(), len(self.slices_to_write), self.buffer_start_slice
110+
getpid(), len(self._slices_to_write), self._buffer_start_slice
134111
)
135112
)
136113
log_memory_consumption()
137114

138115
try:
139116
assert (
140-
self.buffer_start_slice is not None
117+
self._buffer_start_slice is not None
141118
), "Failed to write buffer: The buffer_start_slice is not set."
142-
max_width = max(section.shape[-2] for section in self.slices_to_write)
143-
max_height = max(section.shape[-1] for section in self.slices_to_write)
144-
channel_count = self.slices_to_write[0].shape[0]
145-
buffer_depth = min(self.buffer_size, len(self.slices_to_write))
119+
max_width = max(section.shape[-2] for section in self._slices_to_write)
120+
max_height = max(section.shape[-1] for section in self._slices_to_write)
121+
channel_count = self._slices_to_write[0].shape[0]
122+
buffer_depth = min(self._buffer_size, len(self._slices_to_write))
146123
buffer_start = Vec3Int.zeros().with_replaced(
147-
self.dimension, self.buffer_start_slice
124+
self.dimension, self._buffer_start_slice
148125
)
149126

150-
bbox = self.bbox.with_size_xyz(
127+
bbox = self._bbox.with_size_xyz(
151128
Vec3Int(max_width, max_height, buffer_depth).moveaxis(
152129
-1, self.dimension
153130
)
154131
).offset(buffer_start)
155132

156-
shard_dimensions = self.view._get_file_dimensions()
133+
shard_dimensions = self._view._get_file_dimensions()
157134
chunk_size = Vec3Int(
158135
min(shard_dimensions[0], max_width),
159136
min(shard_dimensions[1], max_height),
@@ -164,7 +141,7 @@ def _flush_buffer(self) -> None:
164141

165142
data = np.zeros(
166143
(channel_count, *chunk_bbox.size),
167-
dtype=self.slices_to_write[0].dtype,
144+
dtype=self._slices_to_write[0].dtype,
168145
)
169146
section_topleft = Vec3Int(
170147
(chunk_bbox.topleft_xyz - bbox.topleft_xyz).moveaxis(
@@ -180,7 +157,7 @@ def _flush_buffer(self) -> None:
180157
z_index = chunk_bbox.index_xyz[self.dimension]
181158

182159
z = 0
183-
for section in self.slices_to_write:
160+
for section in self._slices_to_write:
184161
section_chunk = section[
185162
:,
186163
section_topleft.x : section_bottomright.x,
@@ -215,10 +192,10 @@ def _flush_buffer(self) -> None:
215192

216193
z += 1
217194

218-
self.view.write(
195+
self._view.write(
219196
data,
220-
json_update_allowed=self.json_update_allowed,
221-
absolute_bounding_box=chunk_bbox.from_mag_to_mag1(self.view._mag),
197+
json_update_allowed=self._json_update_allowed,
198+
absolute_bounding_box=chunk_bbox.from_mag_to_mag1(self._view._mag),
222199
)
223200
del data
224201

@@ -227,8 +204,8 @@ def _flush_buffer(self) -> None:
227204
"({}) An exception occurred in BufferedSliceWriter._flush_buffer with {} "
228205
"slices at position {}. Original error is:\n{}:{}\n\nTraceback:".format(
229206
getpid(),
230-
len(self.slices_to_write),
231-
self.buffer_start_slice,
207+
len(self._slices_to_write),
208+
self._buffer_start_slice,
232209
type(exc).__name__,
233210
exc,
234211
)
@@ -238,29 +215,81 @@ def _flush_buffer(self) -> None:
238215

239216
raise exc
240217
finally:
241-
self.slices_to_write = []
218+
self._slices_to_write = []
242219

243220
def _get_slice_generator(self) -> Generator[None, np.ndarray, None]:
244221
current_slice = 0
245222
while True:
246223
data = yield # Data gets send from the user
247-
if len(self.slices_to_write) == 0:
248-
self.buffer_start_slice = current_slice
224+
if len(self._slices_to_write) == 0:
225+
self._buffer_start_slice = current_slice
249226
if len(data.shape) == 2:
250227
# The input data might contain channel data or not.
251228
# Bringing it into the same shape simplifies the code
252229
data = np.expand_dims(data, axis=0)
253-
self.slices_to_write.append(data)
230+
self._slices_to_write.append(data)
254231
current_slice += 1
255232

256-
if current_slice % self.buffer_size == 0:
233+
if current_slice % self._buffer_size == 0:
257234
self._flush_buffer()
258235

259-
def __enter__(self) -> Generator[None, np.ndarray, None]:
260-
gen = self._get_slice_generator()
261-
# It is necessary to start the generator by sending "None"
262-
gen.send(None) # type: ignore
263-
return gen
236+
def send(self, value: np.ndarray) -> None:
237+
self._generator.send(value)
238+
239+
def reset_offset(
240+
self,
241+
offset: Optional[Vec3IntLike] = None, # deprecated, relative in current mag
242+
relative_offset: Optional[Vec3IntLike] = None, # in mag1
243+
absolute_offset: Optional[Vec3IntLike] = None, # in mag1
244+
relative_bounding_box: Optional[NDBoundingBox] = None, # in mag1
245+
absolute_bounding_box: Optional[NDBoundingBox] = None, # in mag1
246+
) -> None:
247+
if self._slices_to_write:
248+
self._flush_buffer()
249+
250+
# Reset the generator
251+
self._generator = self._get_slice_generator()
252+
next(self._generator)
253+
254+
if (
255+
offset is None
256+
and relative_offset is None
257+
and absolute_offset is None
258+
and relative_bounding_box is None
259+
and absolute_bounding_box is None
260+
):
261+
relative_offset = Vec3Int.zeros()
262+
if offset is not None:
263+
warnings.warn(
264+
"[DEPRECATION] Using offset for a buffered slice writer is deprecated. "
265+
+ "Please use the parameter relative_offset or absolute_offset in Mag(1) instead.",
266+
DeprecationWarning,
267+
)
268+
269+
if offset is not None:
270+
self._bbox = BoundingBox(
271+
self._view.bounding_box.topleft_xyz + Vec3Int(offset) * self._view.mag,
272+
Vec3Int.zeros(),
273+
)
274+
275+
if relative_offset is not None:
276+
self._bbox = BoundingBox(
277+
self._view.bounding_box.topleft + relative_offset, Vec3Int.zeros()
278+
)
279+
280+
if absolute_offset is not None:
281+
self._bbox = BoundingBox(absolute_offset, Vec3Int.zeros())
282+
283+
if relative_bounding_box is not None:
284+
self._bbox = relative_bounding_box.offset(self._view.bounding_box.topleft)
285+
286+
if absolute_bounding_box is not None:
287+
self._bbox = absolute_bounding_box
288+
289+
def __enter__(self) -> "BufferedSliceWriter":
290+
self._generator = self._get_slice_generator()
291+
next(self._generator)
292+
return self
264293

265294
def __exit__(
266295
self,

0 commit comments

Comments
 (0)