55import argparse
66import cluster_tools
77import json
8+ import os
9+ import psutil
10+ from typing import List , Tuple , Union
811from glob import iglob
912from collections import namedtuple
1013from multiprocessing import cpu_count , Lock
1114import concurrent
1215from concurrent .futures import ProcessPoolExecutor
1316from os import path , getpid
14- from platform import python_version
1517from math import floor , ceil
16- from .mag import Mag
18+ from logging import getLogger
19+ import traceback
1720
18- from .knossos import KnossosDataset , CUBE_EDGE_LEN
21+ from .knossos import KnossosDataset
22+ from .mag import Mag
1923
2024WkwDatasetInfo = namedtuple (
2125 "WkwDatasetInfo" , ("dataset_path" , "layer_name" , "dtype" , "mag" )
2630
2731BLOCK_LEN = 32
2832
33+ logger = getLogger (__name__ )
34+
2935
3036def open_wkw (info , ** kwargs ):
3137 if hasattr (info , "dtype" ):
@@ -210,13 +216,16 @@ def wait_and_ensure_success(futures):
210216class BufferedSliceWriter (object ):
211217 def __init__ (
212218 self ,
213- dataset_path ,
214- layer_name ,
219+ dataset_path : str ,
220+ layer_name : str ,
215221 dtype ,
216- bounding_box ,
217- origin ,
218- buffer_size = 32 ,
219- mag = Mag (1 ),
222+ origin : Union [Tuple [int , int , int ], List [int ]],
223+ # buffer_size specifies, how many slices should be aggregated until they are flushed.
224+ buffer_size : int = 32 ,
225+ # file_len specifies, how many buckets written per dimension into a wkw cube. Using 32,
226+ # results in 1 GB/wkw file for 8-bit data
227+ file_len : int = 32 ,
228+ mag : Mag = Mag ("1" ),
220229 ):
221230
222231 self .dataset_path = dataset_path
@@ -225,9 +234,11 @@ def __init__(
225234
226235 layer_path = path .join (self .dataset_path , self .layer_name , mag .to_layer_name ())
227236
228- self .dataset = wkw .Dataset .open (layer_path , wkw .Header (dtype ))
237+ self .dtype = dtype
238+ self .dataset = wkw .Dataset .open (
239+ layer_path , wkw .Header (dtype , file_len = file_len )
240+ )
229241 self .origin = origin
230- self .bounding_box = bounding_box
231242
232243 self .buffer = []
233244 self .current_z = None
@@ -255,33 +266,66 @@ def _write_buffer(self):
255266 if len (self .buffer ) == 0 :
256267 return
257268
258- assert len (self .buffer ) <= self .buffer_size
269+ assert (
270+ len (self .buffer ) <= self .buffer_size
271+ ), "The WKW buffer is larger than the defined batch_size. The buffer should have been flushed earlier. This is probably a bug in the BufferedSliceWriter."
259272
260- logging .debug (
273+ uniq_dtypes = set (map (lambda _slice : _slice .dtype , self .buffer ))
274+ assert (
275+ len (uniq_dtypes ) == 1
276+ ), "The buffer of BufferedSliceWriter contains slices with differing dtype."
277+ assert uniq_dtypes .pop () == self .dtype , (
278+ "The buffer of BufferedSliceWriter contains slices with a dtype "
279+ "which differs from the dtype with which the BufferedSliceWriter was instantiated."
280+ )
281+
282+ logger .debug (
261283 "({}) Writing {} slices at position {}." .format (
262284 getpid (), len (self .buffer ), self .buffer_start_z
263285 )
264286 )
265-
266- origin_with_offset = self .origin .copy ()
267- origin_with_offset [2 ] = self .buffer_start_z
268- x_max = max (slice .shape [0 ] for slice in self .buffer )
269- y_max = max (slice .shape [1 ] for slice in self .buffer )
270- self .buffer = [
271- np .pad (
272- slice ,
273- mode = "constant" ,
274- pad_width = [(0 , x_max - slice .shape [0 ]), (0 , y_max - slice .shape [1 ])],
287+ log_memory_consumption ()
288+
289+ try :
290+ origin_with_offset = list (self .origin )
291+ origin_with_offset [2 ] = self .buffer_start_z
292+ x_max = max (slice .shape [0 ] for slice in self .buffer )
293+ y_max = max (slice .shape [1 ] for slice in self .buffer )
294+
295+ self .buffer = [
296+ np .pad (
297+ slice ,
298+ mode = "constant" ,
299+ pad_width = [
300+ (0 , x_max - slice .shape [0 ]),
301+ (0 , y_max - slice .shape [1 ]),
302+ ],
303+ )
304+ for slice in self .buffer
305+ ]
306+
307+ data = np .concatenate (
308+ [np .expand_dims (slice , 2 ) for slice in self .buffer ], axis = 2
275309 )
276- for slice in self .buffer
277- ]
278- data = np .concatenate (
279- [np .expand_dims (slice , 2 ) for slice in self .buffer ], axis = 2
280- )
281-
282- self .dataset .write (origin_with_offset , data )
310+ self .dataset .write (origin_with_offset , data )
311+
312+ except Exception as exc :
313+ logger .error (
314+ "({}) An exception occurred in BufferedSliceWriter._write_buffer with {} "
315+ "slices at position {}. Original error is:\n {}:{}\n \n Traceback:" .format (
316+ getpid (),
317+ len (self .buffer ),
318+ self .buffer_start_z ,
319+ type (exc ).__name__ ,
320+ exc ,
321+ )
322+ )
323+ traceback .print_tb (exc .__traceback__ )
324+ logger .error ("\n " )
283325
284- self .buffer = []
326+ raise exc
327+ finally :
328+ self .buffer = []
285329
286330 def close (self ):
287331
@@ -291,5 +335,19 @@ def close(self):
291335 def __enter__ (self ):
292336 return self
293337
294- def __exit__ (self , type , value , tb ):
338+ def __exit__ (self , _type , _value , _tb ):
295339 self .close ()
340+
341+
342+ def log_memory_consumption (additional_output = "" ):
343+ pid = os .getpid ()
344+ process = psutil .Process (pid )
345+ logging .info (
346+ "Currently consuming {:.2f} GB of memory ({:.2f} GB still available) "
347+ "in process {}. {}" .format (
348+ process .memory_info ().rss / 1024 ** 3 ,
349+ psutil .virtual_memory ().available / 1024 ** 3 ,
350+ pid ,
351+ additional_output ,
352+ )
353+ )
0 commit comments