Skip to content

Commit 2d1668f

Browse files
moved buffered_slice_writer (#135)
* moved buffered_slice_writer * reformatted files * fixed test * reformatted files
1 parent bc85e70 commit 2d1668f

File tree

3 files changed

+93
-36
lines changed

3 files changed

+93
-36
lines changed

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ wkw==0.0.8
77
requests
88
black
99
cluster_tools==1.41
10-
natsort
10+
natsort
11+
psutil

tests/test_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,7 @@ def test_buffered_slice_writer():
4242
mag = Mag(1)
4343
dataset_path = os.path.join(dataset_dir, layer_name, mag.to_layer_name())
4444

45-
with BufferedSliceWriter(
46-
dataset_dir, layer_name, dtype, bbox, origin, mag=mag
47-
) as writer:
45+
with BufferedSliceWriter(dataset_dir, layer_name, dtype, origin, mag=mag) as writer:
4846
for i in range(13):
4947
writer.write_slice(i, test_img)
5048
with wkw.Dataset.open(dataset_path, wkw.Header(dtype)) as data:

wkcuber/utils.py

Lines changed: 90 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,21 @@
55
import argparse
66
import cluster_tools
77
import json
8+
import os
9+
import psutil
10+
from typing import List, Tuple, Union
811
from glob import iglob
912
from collections import namedtuple
1013
from multiprocessing import cpu_count, Lock
1114
import concurrent
1215
from concurrent.futures import ProcessPoolExecutor
1316
from os import path, getpid
14-
from platform import python_version
1517
from math import floor, ceil
16-
from .mag import Mag
18+
from logging import getLogger
19+
import traceback
1720

18-
from .knossos import KnossosDataset, CUBE_EDGE_LEN
21+
from .knossos import KnossosDataset
22+
from .mag import Mag
1923

2024
WkwDatasetInfo = namedtuple(
2125
"WkwDatasetInfo", ("dataset_path", "layer_name", "dtype", "mag")
@@ -26,6 +30,8 @@
2630

2731
BLOCK_LEN = 32
2832

33+
logger = getLogger(__name__)
34+
2935

3036
def open_wkw(info, **kwargs):
3137
if hasattr(info, "dtype"):
@@ -210,13 +216,16 @@ def wait_and_ensure_success(futures):
210216
class BufferedSliceWriter(object):
211217
def __init__(
212218
self,
213-
dataset_path,
214-
layer_name,
219+
dataset_path: str,
220+
layer_name: str,
215221
dtype,
216-
bounding_box,
217-
origin,
218-
buffer_size=32,
219-
mag=Mag(1),
222+
origin: Union[Tuple[int, int, int], List[int]],
223+
# buffer_size specifies, how many slices should be aggregated until they are flushed.
224+
buffer_size: int = 32,
225+
# file_len specifies, how many buckets written per dimension into a wkw cube. Using 32,
226+
# results in 1 GB/wkw file for 8-bit data
227+
file_len: int = 32,
228+
mag: Mag = Mag("1"),
220229
):
221230

222231
self.dataset_path = dataset_path
@@ -225,9 +234,11 @@ def __init__(
225234

226235
layer_path = path.join(self.dataset_path, self.layer_name, mag.to_layer_name())
227236

228-
self.dataset = wkw.Dataset.open(layer_path, wkw.Header(dtype))
237+
self.dtype = dtype
238+
self.dataset = wkw.Dataset.open(
239+
layer_path, wkw.Header(dtype, file_len=file_len)
240+
)
229241
self.origin = origin
230-
self.bounding_box = bounding_box
231242

232243
self.buffer = []
233244
self.current_z = None
@@ -255,33 +266,66 @@ def _write_buffer(self):
255266
if len(self.buffer) == 0:
256267
return
257268

258-
assert len(self.buffer) <= self.buffer_size
269+
assert (
270+
len(self.buffer) <= self.buffer_size
271+
), "The WKW buffer is larger than the defined batch_size. The buffer should have been flushed earlier. This is probably a bug in the BufferedSliceWriter."
259272

260-
logging.debug(
273+
uniq_dtypes = set(map(lambda _slice: _slice.dtype, self.buffer))
274+
assert (
275+
len(uniq_dtypes) == 1
276+
), "The buffer of BufferedSliceWriter contains slices with differing dtype."
277+
assert uniq_dtypes.pop() == self.dtype, (
278+
"The buffer of BufferedSliceWriter contains slices with a dtype "
279+
"which differs from the dtype with which the BufferedSliceWriter was instantiated."
280+
)
281+
282+
logger.debug(
261283
"({}) Writing {} slices at position {}.".format(
262284
getpid(), len(self.buffer), self.buffer_start_z
263285
)
264286
)
265-
266-
origin_with_offset = self.origin.copy()
267-
origin_with_offset[2] = self.buffer_start_z
268-
x_max = max(slice.shape[0] for slice in self.buffer)
269-
y_max = max(slice.shape[1] for slice in self.buffer)
270-
self.buffer = [
271-
np.pad(
272-
slice,
273-
mode="constant",
274-
pad_width=[(0, x_max - slice.shape[0]), (0, y_max - slice.shape[1])],
287+
log_memory_consumption()
288+
289+
try:
290+
origin_with_offset = list(self.origin)
291+
origin_with_offset[2] = self.buffer_start_z
292+
x_max = max(slice.shape[0] for slice in self.buffer)
293+
y_max = max(slice.shape[1] for slice in self.buffer)
294+
295+
self.buffer = [
296+
np.pad(
297+
slice,
298+
mode="constant",
299+
pad_width=[
300+
(0, x_max - slice.shape[0]),
301+
(0, y_max - slice.shape[1]),
302+
],
303+
)
304+
for slice in self.buffer
305+
]
306+
307+
data = np.concatenate(
308+
[np.expand_dims(slice, 2) for slice in self.buffer], axis=2
275309
)
276-
for slice in self.buffer
277-
]
278-
data = np.concatenate(
279-
[np.expand_dims(slice, 2) for slice in self.buffer], axis=2
280-
)
281-
282-
self.dataset.write(origin_with_offset, data)
310+
self.dataset.write(origin_with_offset, data)
311+
312+
except Exception as exc:
313+
logger.error(
314+
"({}) An exception occurred in BufferedSliceWriter._write_buffer with {} "
315+
"slices at position {}. Original error is:\n{}:{}\n\nTraceback:".format(
316+
getpid(),
317+
len(self.buffer),
318+
self.buffer_start_z,
319+
type(exc).__name__,
320+
exc,
321+
)
322+
)
323+
traceback.print_tb(exc.__traceback__)
324+
logger.error("\n")
283325

284-
self.buffer = []
326+
raise exc
327+
finally:
328+
self.buffer = []
285329

286330
def close(self):
287331

@@ -291,5 +335,19 @@ def close(self):
291335
def __enter__(self):
292336
return self
293337

294-
def __exit__(self, type, value, tb):
338+
def __exit__(self, _type, _value, _tb):
295339
self.close()
340+
341+
342+
def log_memory_consumption(additional_output=""):
343+
pid = os.getpid()
344+
process = psutil.Process(pid)
345+
logging.info(
346+
"Currently consuming {:.2f} GB of memory ({:.2f} GB still available) "
347+
"in process {}. {}".format(
348+
process.memory_info().rss / 1024 ** 3,
349+
psutil.virtual_memory().available / 1024 ** 3,
350+
pid,
351+
additional_output,
352+
)
353+
)

0 commit comments

Comments
 (0)