@@ -51,50 +51,31 @@ def __init__(
5151 ) -> None :
5252 """see `View.get_buffered_slice_writer()`"""
5353
54- self .view = view
55- self .buffer_size = buffer_size
56- self .dtype = self .view .get_dtype ()
57- self .use_logging = use_logging
58- self .json_update_allowed = json_update_allowed
59- self .bbox : NDBoundingBox
60-
61- if (
62- offset is None
63- and relative_offset is None
64- and absolute_offset is None
65- and relative_bounding_box is None
66- and absolute_bounding_box is None
67- ):
68- relative_offset = Vec3Int .zeros ()
69- if offset is not None :
70- warnings .warn (
71- "[DEPRECATION] Using offset for a buffered slice writer is deprecated. "
72- + "Please use the parameter relative_offset or absolute_offset in Mag(1) instead." ,
73- DeprecationWarning ,
74- )
75- self .offset = None if offset is None else Vec3Int (offset )
76-
77- if relative_offset is not None :
78- self .bbox = BoundingBox (
79- self .view .bounding_box .topleft + relative_offset , Vec3Int .zeros ()
80- )
81-
82- if absolute_offset is not None :
83- self .bbox = BoundingBox (absolute_offset , Vec3Int .zeros ())
84-
85- if relative_bounding_box is not None :
86- self .bbox = relative_bounding_box .offset (self .view .bounding_box .topleft )
87-
88- if absolute_bounding_box is not None :
89- self .bbox = absolute_bounding_box
54+ self ._view = view
55+ self ._buffer_size = buffer_size
56+ self ._dtype = self ._view .get_dtype ()
57+ self ._use_logging = use_logging
58+ self ._json_update_allowed = json_update_allowed
59+ self ._bbox : NDBoundingBox
60+ self ._slices_to_write : List [np .ndarray ] = []
61+ self ._current_slice : Optional [int ] = None
62+ self ._buffer_start_slice : Optional [int ] = None
63+
64+ self .reset_offset (
65+ offset ,
66+ relative_offset ,
67+ absolute_offset ,
68+ relative_bounding_box ,
69+ absolute_bounding_box ,
70+ )
9071
9172 assert 0 <= dimension <= 2 # either x (0), y (1) or z (2)
9273 self .dimension = dimension
9374
94- view_chunk_depth = self .view .info .chunk_shape [dimension ]
75+ view_chunk_depth = self ._view .info .chunk_shape [dimension ]
9576 if (
96- self .bbox is not None
97- and self .bbox .topleft_xyz [self .dimension ] % view_chunk_depth != 0
77+ self ._bbox is not None
78+ and self ._bbox .topleft_xyz [self .dimension ] % view_chunk_depth != 0
9879 ):
9980 warnings .warn (
10081 "[WARNING] Using an offset that doesn't align with the datataset's chunk size, "
@@ -106,54 +87,50 @@ def __init__(
10687 + "will slow down the buffered slice writer." ,
10788 )
10889
109- self .slices_to_write : List [np .ndarray ] = []
110- self .current_slice : Optional [int ] = None
111- self .buffer_start_slice : Optional [int ] = None
112-
11390 def _flush_buffer (self ) -> None :
114- if len (self .slices_to_write ) == 0 :
91+ if len (self ._slices_to_write ) == 0 :
11592 return
11693
11794 assert (
118- len (self .slices_to_write ) <= self .buffer_size
95+ len (self ._slices_to_write ) <= self ._buffer_size
11996 ), "The WKW buffer is larger than the defined batch_size. The buffer should have been flushed earlier. This is probably a bug in the BufferedSliceWriter."
12097
121- uniq_dtypes = set (map (lambda _slice : _slice .dtype , self .slices_to_write ))
98+ uniq_dtypes = set (map (lambda _slice : _slice .dtype , self ._slices_to_write ))
12299 assert (
123100 len (uniq_dtypes ) == 1
124101 ), "The buffer of BufferedSliceWriter contains slices with differing dtype."
125- assert uniq_dtypes .pop () == self .dtype , (
102+ assert uniq_dtypes .pop () == self ._dtype , (
126103 "The buffer of BufferedSliceWriter contains slices with a dtype "
127104 "which differs from the dtype with which the BufferedSliceWriter was instantiated."
128105 )
129106
130- if self .use_logging :
107+ if self ._use_logging :
131108 info (
132109 "({}) Writing {} slices at position {}." .format (
133- getpid (), len (self .slices_to_write ), self .buffer_start_slice
110+ getpid (), len (self ._slices_to_write ), self ._buffer_start_slice
134111 )
135112 )
136113 log_memory_consumption ()
137114
138115 try :
139116 assert (
140- self .buffer_start_slice is not None
117+ self ._buffer_start_slice is not None
141118 ), "Failed to write buffer: The buffer_start_slice is not set."
142- max_width = max (section .shape [- 2 ] for section in self .slices_to_write )
143- max_height = max (section .shape [- 1 ] for section in self .slices_to_write )
144- channel_count = self .slices_to_write [0 ].shape [0 ]
145- buffer_depth = min (self .buffer_size , len (self .slices_to_write ))
119+ max_width = max (section .shape [- 2 ] for section in self ._slices_to_write )
120+ max_height = max (section .shape [- 1 ] for section in self ._slices_to_write )
121+ channel_count = self ._slices_to_write [0 ].shape [0 ]
122+ buffer_depth = min (self ._buffer_size , len (self ._slices_to_write ))
146123 buffer_start = Vec3Int .zeros ().with_replaced (
147- self .dimension , self .buffer_start_slice
124+ self .dimension , self ._buffer_start_slice
148125 )
149126
150- bbox = self .bbox .with_size_xyz (
127+ bbox = self ._bbox .with_size_xyz (
151128 Vec3Int (max_width , max_height , buffer_depth ).moveaxis (
152129 - 1 , self .dimension
153130 )
154131 ).offset (buffer_start )
155132
156- shard_dimensions = self .view ._get_file_dimensions ()
133+ shard_dimensions = self ._view ._get_file_dimensions ()
157134 chunk_size = Vec3Int (
158135 min (shard_dimensions [0 ], max_width ),
159136 min (shard_dimensions [1 ], max_height ),
@@ -164,7 +141,7 @@ def _flush_buffer(self) -> None:
164141
165142 data = np .zeros (
166143 (channel_count , * chunk_bbox .size ),
167- dtype = self .slices_to_write [0 ].dtype ,
144+ dtype = self ._slices_to_write [0 ].dtype ,
168145 )
169146 section_topleft = Vec3Int (
170147 (chunk_bbox .topleft_xyz - bbox .topleft_xyz ).moveaxis (
@@ -180,7 +157,7 @@ def _flush_buffer(self) -> None:
180157 z_index = chunk_bbox .index_xyz [self .dimension ]
181158
182159 z = 0
183- for section in self .slices_to_write :
160+ for section in self ._slices_to_write :
184161 section_chunk = section [
185162 :,
186163 section_topleft .x : section_bottomright .x ,
@@ -215,10 +192,10 @@ def _flush_buffer(self) -> None:
215192
216193 z += 1
217194
218- self .view .write (
195+ self ._view .write (
219196 data ,
220- json_update_allowed = self .json_update_allowed ,
221- absolute_bounding_box = chunk_bbox .from_mag_to_mag1 (self .view ._mag ),
197+ json_update_allowed = self ._json_update_allowed ,
198+ absolute_bounding_box = chunk_bbox .from_mag_to_mag1 (self ._view ._mag ),
222199 )
223200 del data
224201
@@ -227,8 +204,8 @@ def _flush_buffer(self) -> None:
227204 "({}) An exception occurred in BufferedSliceWriter._flush_buffer with {} "
228205 "slices at position {}. Original error is:\n {}:{}\n \n Traceback:" .format (
229206 getpid (),
230- len (self .slices_to_write ),
231- self .buffer_start_slice ,
207+ len (self ._slices_to_write ),
208+ self ._buffer_start_slice ,
232209 type (exc ).__name__ ,
233210 exc ,
234211 )
@@ -238,29 +215,81 @@ def _flush_buffer(self) -> None:
238215
239216 raise exc
240217 finally :
241- self .slices_to_write = []
218+ self ._slices_to_write = []
242219
243220 def _get_slice_generator (self ) -> Generator [None , np .ndarray , None ]:
244221 current_slice = 0
245222 while True :
246223 data = yield # Data gets send from the user
247- if len (self .slices_to_write ) == 0 :
248- self .buffer_start_slice = current_slice
224+ if len (self ._slices_to_write ) == 0 :
225+ self ._buffer_start_slice = current_slice
249226 if len (data .shape ) == 2 :
250227 # The input data might contain channel data or not.
251228 # Bringing it into the same shape simplifies the code
252229 data = np .expand_dims (data , axis = 0 )
253- self .slices_to_write .append (data )
230+ self ._slices_to_write .append (data )
254231 current_slice += 1
255232
256- if current_slice % self .buffer_size == 0 :
233+ if current_slice % self ._buffer_size == 0 :
257234 self ._flush_buffer ()
258235
259- def __enter__ (self ) -> Generator [None , np .ndarray , None ]:
260- gen = self ._get_slice_generator ()
261- # It is necessary to start the generator by sending "None"
262- gen .send (None ) # type: ignore
263- return gen
236+ def send (self , value : np .ndarray ) -> None :
237+ self ._generator .send (value )
238+
239+ def reset_offset (
240+ self ,
241+ offset : Optional [Vec3IntLike ] = None , # deprecated, relative in current mag
242+ relative_offset : Optional [Vec3IntLike ] = None , # in mag1
243+ absolute_offset : Optional [Vec3IntLike ] = None , # in mag1
244+ relative_bounding_box : Optional [NDBoundingBox ] = None , # in mag1
245+ absolute_bounding_box : Optional [NDBoundingBox ] = None , # in mag1
246+ ) -> None :
247+ if self ._slices_to_write :
248+ self ._flush_buffer ()
249+
250+ # Reset the generator
251+ self ._generator = self ._get_slice_generator ()
252+ next (self ._generator )
253+
254+ if (
255+ offset is None
256+ and relative_offset is None
257+ and absolute_offset is None
258+ and relative_bounding_box is None
259+ and absolute_bounding_box is None
260+ ):
261+ relative_offset = Vec3Int .zeros ()
262+ if offset is not None :
263+ warnings .warn (
264+ "[DEPRECATION] Using offset for a buffered slice writer is deprecated. "
265+ + "Please use the parameter relative_offset or absolute_offset in Mag(1) instead." ,
266+ DeprecationWarning ,
267+ )
268+
269+ if offset is not None :
270+ self ._bbox = BoundingBox (
271+ self ._view .bounding_box .topleft_xyz + Vec3Int (offset ) * self ._view .mag ,
272+ Vec3Int .zeros (),
273+ )
274+
275+ if relative_offset is not None :
276+ self ._bbox = BoundingBox (
277+ self ._view .bounding_box .topleft + relative_offset , Vec3Int .zeros ()
278+ )
279+
280+ if absolute_offset is not None :
281+ self ._bbox = BoundingBox (absolute_offset , Vec3Int .zeros ())
282+
283+ if relative_bounding_box is not None :
284+ self ._bbox = relative_bounding_box .offset (self ._view .bounding_box .topleft )
285+
286+ if absolute_bounding_box is not None :
287+ self ._bbox = absolute_bounding_box
288+
289+ def __enter__ (self ) -> "BufferedSliceWriter" :
290+ self ._generator = self ._get_slice_generator ()
291+ next (self ._generator )
292+ return self
264293
265294 def __exit__ (
266295 self ,
0 commit comments