Skip to content

Commit d9bf347

Browse files
Merge pull request #23 from jeromekelleher/more-refactoring2
More refactoring of core utilities
2 parents 9193834 + 434d60a commit d9bf347

File tree

4 files changed

+326
-175
lines changed

4 files changed

+326
-175
lines changed

bio2zarr/core.py

Lines changed: 121 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,39 @@
11
import dataclasses
22
import contextlib
33
import concurrent.futures as cf
4+
import multiprocessing
5+
import threading
46
import logging
7+
import functools
8+
import time
59

610
import zarr
711
import numpy as np
12+
import tqdm
813

914

1015
logger = logging.getLogger(__name__)
1116

1217

18+
class SynchronousExecutor(cf.Executor):
19+
def submit(self, fn, /, *args, **kwargs):
20+
future = cf.Future()
21+
future.set_result(fn(*args, **kwargs))
22+
return future
23+
24+
25+
def wait_on_futures(futures):
26+
for future in cf.as_completed(futures):
27+
exception = future.exception()
28+
if exception is not None:
29+
raise exception
30+
31+
32+
def cancel_futures(futures):
33+
for future in futures:
34+
future.cancel()
35+
36+
1337
@dataclasses.dataclass
1438
class BufferedArray:
1539
array: zarr.Array
@@ -32,6 +56,9 @@ def async_flush(self, executor, offset, buff_stop=None):
3256
return async_flush_array(executor, self.buff[:buff_stop], self.array, offset)
3357

3458

59+
# TODO: factor these functions into the BufferedArray class
60+
61+
3562
def sync_flush_array(np_buffer, zarr_array, offset):
3663
zarr_array[offset : offset + np_buffer.shape[0]] = np_buffer
3764

@@ -72,7 +99,9 @@ def flush_chunk(start, stop):
7299

73100

74101
class ThreadedZarrEncoder(contextlib.AbstractContextManager):
75-
def __init__(self, buffered_arrays, encoder_threads):
102+
# TODO (maybe) add option with encoder_threads=None to run synchronously for
103+
# debugging using a mock Executor
104+
def __init__(self, buffered_arrays, encoder_threads=1):
76105
self.buffered_arrays = buffered_arrays
77106
self.executor = cf.ThreadPoolExecutor(max_workers=encoder_threads)
78107
self.chunk_length = buffered_arrays[0].chunk_length
@@ -89,18 +118,10 @@ def next_buffer_row(self):
89118
self.next_row = 0
90119
return self.next_row
91120

92-
def wait_on_futures(self):
93-
for future in cf.as_completed(self.futures):
94-
exception = future.exception()
95-
if exception is not None:
96-
raise exception
97-
98121
def swap_buffers(self):
99-
self.wait_on_futures()
122+
wait_on_futures(self.futures)
100123
self.futures = []
101124
for ba in self.buffered_arrays:
102-
# TODO add debug log
103-
# print("Scheduling", ba.array, offset, buff_stop)
104125
self.futures.extend(
105126
ba.async_flush(self.executor, self.array_offset, self.next_row)
106127
)
@@ -111,10 +132,95 @@ def __exit__(self, exc_type, exc_val, exc_tb):
111132
# Normal exit condition
112133
self.next_row += 1
113134
self.swap_buffers()
114-
self.wait_on_futures()
115-
# TODO add arguments to wait and cancel_futures appropriate
116-
# for the an error condition occuring here. Generally need
117-
# to think about the error exit condition here (like running
118-
# out of disk space) to see what the right behaviour is.
135+
wait_on_futures(self.futures)
136+
else:
137+
cancel_futures(self.futures)
138+
self.executor.shutdown()
139+
return False
140+
141+
142+
@dataclasses.dataclass
143+
class ProgressConfig:
144+
total: int = 0
145+
units: str = ""
146+
title: str = ""
147+
show: bool = False
148+
poll_interval: float = 0.001
149+
150+
151+
_progress_counter = multiprocessing.Value("Q", 0)
152+
153+
154+
def update_progress(inc):
155+
with _progress_counter.get_lock():
156+
_progress_counter.value += inc
157+
158+
159+
def get_progress():
160+
with _progress_counter.get_lock():
161+
val = _progress_counter.value
162+
return val
163+
164+
165+
def set_progress(value):
166+
with _progress_counter.get_lock():
167+
_progress_counter.value = value
168+
169+
170+
def progress_thread_worker(config):
171+
pbar = tqdm.tqdm(
172+
total=config.total,
173+
desc=config.title,
174+
unit_scale=True,
175+
unit=config.units,
176+
smoothing=0.1,
177+
disable=not config.show,
178+
)
179+
180+
while (current := get_progress()) < config.total:
181+
inc = current - pbar.n
182+
pbar.update(inc)
183+
time.sleep(config.poll_interval)
184+
pbar.close()
185+
186+
187+
class ParallelWorkManager(contextlib.AbstractContextManager):
188+
def __init__(self, worker_processes=1, progress_config=None):
189+
if worker_processes <= 0:
190+
# NOTE: this is only for testing, not for production use!
191+
self.executor = SynchronousExecutor()
192+
else:
193+
self.executor = cf.ProcessPoolExecutor(
194+
max_workers=worker_processes,
195+
)
196+
set_progress(0)
197+
if progress_config is None:
198+
progress_config = ProgressConfig()
199+
self.bar_thread = threading.Thread(
200+
target=progress_thread_worker,
201+
args=(progress_config,),
202+
name="progress",
203+
daemon=True,
204+
)
205+
self.bar_thread.start()
206+
self.progress_config = progress_config
207+
self.futures = []
208+
209+
def submit(self, *args, **kwargs):
210+
self.futures.append(self.executor.submit(*args, **kwargs))
211+
212+
def results_as_completed(self):
213+
for future in cf.as_completed(self.futures):
214+
yield future.result()
215+
216+
def __exit__(self, exc_type, exc_val, exc_tb):
217+
if exc_type is None:
218+
wait_on_futures(self.futures)
219+
set_progress(self.progress_config.total)
220+
timeout = None
221+
else:
222+
cancel_futures(self.futures)
223+
timeout = 0
224+
self.bar_thread.join(timeout)
119225
self.executor.shutdown()
120226
return False

0 commit comments

Comments
 (0)