1
1
import dataclasses
2
2
import contextlib
3
3
import concurrent .futures as cf
4
+ import multiprocessing
5
+ import threading
4
6
import logging
7
+ import functools
8
+ import time
5
9
6
10
import zarr
7
11
import numpy as np
12
+ import tqdm
8
13
9
14
10
15
logger = logging .getLogger (__name__ )
11
16
12
17
18
+ class SynchronousExecutor (cf .Executor ):
19
+ def submit (self , fn , / , * args , ** kwargs ):
20
+ future = cf .Future ()
21
+ future .set_result (fn (* args , ** kwargs ))
22
+ return future
23
+
24
+
25
+ def wait_on_futures (futures ):
26
+ for future in cf .as_completed (futures ):
27
+ exception = future .exception ()
28
+ if exception is not None :
29
+ raise exception
30
+
31
+
32
+ def cancel_futures (futures ):
33
+ for future in futures :
34
+ future .cancel ()
35
+
36
+
13
37
@dataclasses .dataclass
14
38
class BufferedArray :
15
39
array : zarr .Array
@@ -32,6 +56,9 @@ def async_flush(self, executor, offset, buff_stop=None):
32
56
return async_flush_array (executor , self .buff [:buff_stop ], self .array , offset )
33
57
34
58
59
+ # TODO: factor these functions into the BufferedArray class
60
+
61
+
35
62
def sync_flush_array (np_buffer , zarr_array , offset ):
36
63
zarr_array [offset : offset + np_buffer .shape [0 ]] = np_buffer
37
64
@@ -72,7 +99,9 @@ def flush_chunk(start, stop):
72
99
73
100
74
101
class ThreadedZarrEncoder (contextlib .AbstractContextManager ):
75
- def __init__ (self , buffered_arrays , encoder_threads ):
102
+ # TODO (maybe) add option with encoder_threads=None to run synchronously for
103
+ # debugging using a mock Executor
104
+ def __init__ (self , buffered_arrays , encoder_threads = 1 ):
76
105
self .buffered_arrays = buffered_arrays
77
106
self .executor = cf .ThreadPoolExecutor (max_workers = encoder_threads )
78
107
self .chunk_length = buffered_arrays [0 ].chunk_length
@@ -89,18 +118,10 @@ def next_buffer_row(self):
89
118
self .next_row = 0
90
119
return self .next_row
91
120
92
- def wait_on_futures (self ):
93
- for future in cf .as_completed (self .futures ):
94
- exception = future .exception ()
95
- if exception is not None :
96
- raise exception
97
-
98
121
def swap_buffers (self ):
99
- self . wait_on_futures ()
122
+ wait_on_futures (self . futures )
100
123
self .futures = []
101
124
for ba in self .buffered_arrays :
102
- # TODO add debug log
103
- # print("Scheduling", ba.array, offset, buff_stop)
104
125
self .futures .extend (
105
126
ba .async_flush (self .executor , self .array_offset , self .next_row )
106
127
)
@@ -111,10 +132,95 @@ def __exit__(self, exc_type, exc_val, exc_tb):
111
132
# Normal exit condition
112
133
self .next_row += 1
113
134
self .swap_buffers ()
114
- self .wait_on_futures ()
115
- # TODO add arguments to wait and cancel_futures appropriate
116
- # for the an error condition occuring here. Generally need
117
- # to think about the error exit condition here (like running
118
- # out of disk space) to see what the right behaviour is.
135
+ wait_on_futures (self .futures )
136
+ else :
137
+ cancel_futures (self .futures )
138
+ self .executor .shutdown ()
139
+ return False
140
+
141
+
142
+ @dataclasses .dataclass
143
+ class ProgressConfig :
144
+ total : int = 0
145
+ units : str = ""
146
+ title : str = ""
147
+ show : bool = False
148
+ poll_interval : float = 0.001
149
+
150
+
151
+ _progress_counter = multiprocessing .Value ("Q" , 0 )
152
+
153
+
154
+ def update_progress (inc ):
155
+ with _progress_counter .get_lock ():
156
+ _progress_counter .value += inc
157
+
158
+
159
+ def get_progress ():
160
+ with _progress_counter .get_lock ():
161
+ val = _progress_counter .value
162
+ return val
163
+
164
+
165
+ def set_progress (value ):
166
+ with _progress_counter .get_lock ():
167
+ _progress_counter .value = value
168
+
169
+
170
+ def progress_thread_worker (config ):
171
+ pbar = tqdm .tqdm (
172
+ total = config .total ,
173
+ desc = config .title ,
174
+ unit_scale = True ,
175
+ unit = config .units ,
176
+ smoothing = 0.1 ,
177
+ disable = not config .show ,
178
+ )
179
+
180
+ while (current := get_progress ()) < config .total :
181
+ inc = current - pbar .n
182
+ pbar .update (inc )
183
+ time .sleep (config .poll_interval )
184
+ pbar .close ()
185
+
186
+
187
+ class ParallelWorkManager (contextlib .AbstractContextManager ):
188
+ def __init__ (self , worker_processes = 1 , progress_config = None ):
189
+ if worker_processes <= 0 :
190
+ # NOTE: this is only for testing, not for production use!
191
+ self .executor = SynchronousExecutor ()
192
+ else :
193
+ self .executor = cf .ProcessPoolExecutor (
194
+ max_workers = worker_processes ,
195
+ )
196
+ set_progress (0 )
197
+ if progress_config is None :
198
+ progress_config = ProgressConfig ()
199
+ self .bar_thread = threading .Thread (
200
+ target = progress_thread_worker ,
201
+ args = (progress_config ,),
202
+ name = "progress" ,
203
+ daemon = True ,
204
+ )
205
+ self .bar_thread .start ()
206
+ self .progress_config = progress_config
207
+ self .futures = []
208
+
209
+ def submit (self , * args , ** kwargs ):
210
+ self .futures .append (self .executor .submit (* args , ** kwargs ))
211
+
212
+ def results_as_completed (self ):
213
+ for future in cf .as_completed (self .futures ):
214
+ yield future .result ()
215
+
216
+ def __exit__ (self , exc_type , exc_val , exc_tb ):
217
+ if exc_type is None :
218
+ wait_on_futures (self .futures )
219
+ set_progress (self .progress_config .total )
220
+ timeout = None
221
+ else :
222
+ cancel_futures (self .futures )
223
+ timeout = 0
224
+ self .bar_thread .join (timeout )
119
225
self .executor .shutdown ()
120
226
return False
0 commit comments