99import numpy as np
1010import psutil
1111
12- from webknossos .geometry import Vec3Int , Vec3IntLike
12+ from webknossos .geometry import BoundingBox , Vec3Int , Vec3IntLike
1313
1414if TYPE_CHECKING :
1515 from webknossos .dataset import View
@@ -69,21 +69,43 @@ def __init__(
6969 )
7070 self .dimension = dimension
7171
72+ effective_offset = Vec3Int .full (0 )
73+ if self .relative_offset is not None :
74+ effective_offset = self .view .bounding_box .topleft + self .relative_offset
75+
76+ if self .absolute_offset is not None :
77+ effective_offset = self .absolute_offset
78+
79+ view_chunk_depth = self .view .info .chunk_shape [self .dimension ]
80+ if (
81+ effective_offset is not None
82+ and effective_offset [self .dimension ] % view_chunk_depth != 0
83+ ):
84+ warnings .warn (
85+ "[WARNING] Using an offset that doesn't align with the datataset's chunk size, "
86+ + "will slow down the buffered slice writer, because twice as many chunks will be written." ,
87+ )
88+ if buffer_size >= view_chunk_depth and buffer_size % view_chunk_depth > 0 :
89+ warnings .warn (
90+ "[WARNING] Using a buffer size that doesn't align with the datataset's chunk size, "
91+ + "will slow down the buffered slice writer." ,
92+ )
93+
7294 assert 0 <= dimension <= 2
7395
74- self .buffer : List [np .ndarray ] = []
96+ self .slices_to_write : List [np .ndarray ] = []
7597 self .current_slice : Optional [int ] = None
7698 self .buffer_start_slice : Optional [int ] = None
7799
78- def _write_buffer (self ) -> None :
79- if len (self .buffer ) == 0 :
100+ def _flush_buffer (self ) -> None :
101+ if len (self .slices_to_write ) == 0 :
80102 return
81103
82104 assert (
83- len (self .buffer ) <= self .buffer_size
105+ len (self .slices_to_write ) <= self .buffer_size
84106 ), "The WKW buffer is larger than the defined batch_size. The buffer should have been flushed earlier. This is probably a bug in the BufferedSliceWriter."
85107
86- uniq_dtypes = set (map (lambda _slice : _slice .dtype , self .buffer ))
108+ uniq_dtypes = set (map (lambda _slice : _slice .dtype , self .slices_to_write ))
87109 assert (
88110 len (uniq_dtypes ) == 1
89111 ), "The buffer of BufferedSliceWriter contains slices with differing dtype."
@@ -95,7 +117,7 @@ def _write_buffer(self) -> None:
95117 if self .use_logging :
96118 info (
97119 "({}) Writing {} slices at position {}." .format (
98- getpid (), len (self .buffer ), self .buffer_start_slice
120+ getpid (), len (self .slices_to_write ), self .buffer_start_slice
99121 )
100122 )
101123 log_memory_consumption ()
@@ -104,44 +126,65 @@ def _write_buffer(self) -> None:
104126 assert (
105127 self .buffer_start_slice is not None
106128 ), "Failed to write buffer: The buffer_start_slice is not set."
107- max_width = max (slice .shape [- 2 ] for slice in self .buffer )
108- max_height = max (slice .shape [- 1 ] for slice in self .buffer )
109-
110- self .buffer = [
111- np .pad (
112- slice ,
113- mode = "constant" ,
114- pad_width = [
115- (0 , 0 ),
116- (0 , max_width - slice .shape [- 2 ]),
117- (0 , max_height - slice .shape [- 1 ]),
118- ],
119- )
120- for slice in self .buffer
121- ]
129+ max_width = max (section .shape [- 2 ] for section in self .slices_to_write )
130+ max_height = max (section .shape [- 1 ] for section in self .slices_to_write )
131+ channel_count = self .slices_to_write [0 ].shape [0 ]
132+
133+ buffer_bbox = BoundingBox (
134+ (0 , 0 , 0 ), (max_width , max_height , self .buffer_size )
135+ )
122136
123- data = np .concatenate (
124- [np .expand_dims (slice , self .dimension + 1 ) for slice in self .buffer ],
125- axis = self .dimension + 1 ,
137+ shard_dimensions = self .view ._get_file_dimensions ().moveaxis (
138+ - 1 , self .dimension
126139 )
127- buffer_start_list = [0 , 0 , 0 ]
128- buffer_start_list [self .dimension ] = self .buffer_start_slice
129- buffer_start = Vec3Int (buffer_start_list )
130- buffer_start_mag1 = buffer_start * self .view .mag .to_vec3_int ()
131- self .view .write (
132- data ,
133- offset = buffer_start .add_or_none (self .offset ),
134- relative_offset = buffer_start_mag1 .add_or_none (self .relative_offset ),
135- absolute_offset = buffer_start_mag1 .add_or_none (self .absolute_offset ),
136- json_update_allowed = self .json_update_allowed ,
140+ chunk_size = Vec3Int (
141+ min (shard_dimensions [0 ], max_width ),
142+ min (shard_dimensions [1 ], max_height ),
143+ self .buffer_size ,
137144 )
145+ for chunk_bbox in buffer_bbox .chunk (chunk_size ):
146+ info (f"Writing chunk { chunk_bbox } " )
147+ width , height , _ = chunk_bbox .size
148+ data = np .zeros (
149+ (channel_count , width , height , self .buffer_size ),
150+ dtype = self .slices_to_write [0 ].dtype ,
151+ )
152+
153+ z = 0
154+ for section in self .slices_to_write :
155+ section_chunk = section [
156+ :,
157+ chunk_bbox .topleft .x : chunk_bbox .bottomright .x ,
158+ chunk_bbox .topleft .y : chunk_bbox .bottomright .y ,
159+ ]
160+ data [
161+ :, 0 : section_chunk .shape [- 2 ], 0 : section_chunk .shape [- 1 ], z
162+ ] = section_chunk
163+
164+ z += 1
165+
166+ buffer_start = Vec3Int (
167+ chunk_bbox .topleft .x , chunk_bbox .topleft .y , self .buffer_start_slice
168+ ).moveaxis (- 1 , self .dimension )
169+ buffer_start_mag1 = buffer_start * self .view .mag .to_vec3_int ()
170+
171+ data = np .moveaxis (data , - 1 , self .dimension + 1 )
172+
173+ self .view .write (
174+ data ,
175+ offset = buffer_start .add_or_none (self .offset ),
176+ relative_offset = buffer_start_mag1 .add_or_none (self .relative_offset ),
177+ absolute_offset = buffer_start_mag1 .add_or_none (self .absolute_offset ),
178+ json_update_allowed = self .json_update_allowed ,
179+ )
180+ del data
138181
139182 except Exception as exc :
140183 error (
141- "({}) An exception occurred in BufferedSliceWriter._write_buffer with {} "
184+ "({}) An exception occurred in BufferedSliceWriter._flush_buffer with {} "
142185 "slices at position {}. Original error is:\n {}:{}\n \n Traceback:" .format (
143186 getpid (),
144- len (self .buffer ),
187+ len (self .slices_to_write ),
145188 self .buffer_start_slice ,
146189 type (exc ).__name__ ,
147190 exc ,
@@ -152,23 +195,23 @@ def _write_buffer(self) -> None:
152195
153196 raise exc
154197 finally :
155- self .buffer = []
198+ self .slices_to_write = []
156199
157200 def _get_slice_generator (self ) -> Generator [None , np .ndarray , None ]:
158201 current_slice = 0
159202 while True :
160203 data = yield # Data gets send from the user
161- if len (self .buffer ) == 0 :
204+ if len (self .slices_to_write ) == 0 :
162205 self .buffer_start_slice = current_slice
163206 if len (data .shape ) == 2 :
164207 # The input data might contain channel data or not.
165208 # Bringing it into the same shape simplifies the code
166209 data = np .expand_dims (data , axis = 0 )
167- self .buffer .append (data )
210+ self .slices_to_write .append (data )
168211 current_slice += 1
169212
170213 if current_slice % self .buffer_size == 0 :
171- self ._write_buffer ()
214+ self ._flush_buffer ()
172215
173216 def __enter__ (self ) -> Generator [None , np .ndarray , None ]:
174217 gen = self ._get_slice_generator ()
@@ -182,4 +225,4 @@ def __exit__(
182225 _value : Optional [BaseException ],
183226 _tb : Optional [TracebackType ],
184227 ) -> None :
185- self ._write_buffer ()
228+ self ._flush_buffer ()
0 commit comments