4
4
import multiprocessing
5
5
import threading
6
6
import logging
7
- import functools
8
7
import time
9
8
10
9
import zarr
@@ -159,32 +158,6 @@ def set_progress(value):
159
158
_progress_counter .value = value
160
159
161
160
162
- def progress_thread_worker (config ):
163
- pbar = tqdm .tqdm (
164
- total = config .total ,
165
- desc = f"{ config .title :>7} " ,
166
- unit_scale = True ,
167
- unit = config .units ,
168
- smoothing = 0.1 ,
169
- disable = not config .show ,
170
- )
171
-
172
- while (current := get_progress ()) < config .total :
173
- inc = current - pbar .n
174
- pbar .update (inc )
175
- time .sleep (config .poll_interval )
176
- # TODO figure out why we're sometimes going over total
177
- # if get_progress() != config.total:
178
- # print("HOW DID THIS HAPPEN!!")
179
- # print(get_progress())
180
- # print(config)
181
- # assert get_progress() == config.total
182
- inc = config .total - pbar .n
183
- pbar .update (inc )
184
- pbar .close ()
185
- # print("EXITING PROGRESS THREAD")
186
-
187
-
188
161
class ParallelWorkManager (contextlib .AbstractContextManager ):
189
162
def __init__ (self , worker_processes = 1 , progress_config = None ):
190
163
if worker_processes <= 0 :
@@ -194,18 +167,42 @@ def __init__(self, worker_processes=1, progress_config=None):
194
167
self .executor = cf .ProcessPoolExecutor (
195
168
max_workers = worker_processes ,
196
169
)
170
+ self .futures = []
171
+
197
172
set_progress (0 )
198
173
if progress_config is None :
199
174
progress_config = ProgressConfig ()
200
- self .bar_thread = threading .Thread (
201
- target = progress_thread_worker ,
202
- args = (progress_config ,),
203
- name = "progress" ,
204
- daemon = True ,
205
- )
206
- self .bar_thread .start ()
207
175
self .progress_config = progress_config
208
- self .futures = []
176
+ self .progress_bar = tqdm .tqdm (
177
+ total = progress_config .total ,
178
+ desc = f"{ progress_config .title :>7} " ,
179
+ unit_scale = True ,
180
+ unit = progress_config .units ,
181
+ smoothing = 0.1 ,
182
+ disable = not progress_config .show ,
183
+ )
184
+ self .completed = False
185
+ self .completed_lock = threading .Lock ()
186
+ self .progress_thread = threading .Thread (
187
+ target = self ._update_progress_worker ,
188
+ name = "progress-update" ,
189
+ )
190
+ self .progress_thread .start ()
191
+
192
+ def _update_progress (self ):
193
+ current = get_progress ()
194
+ inc = current - self .progress_bar .n
195
+ # print("UPDATE PROGRESS: current = ", current, self.progress_config.total, inc)
196
+ self .progress_bar .update (inc )
197
+
198
+ def _update_progress_worker (self ):
199
+ completed = False
200
+ while not completed :
201
+ self ._update_progress ()
202
+ time .sleep (self .progress_config .poll_interval )
203
+ with self .completed_lock :
204
+ completed = self .completed
205
+ logger .debug ("Exit progress thread" )
209
206
210
207
def submit (self , * args , ** kwargs ):
211
208
self .futures .append (self .executor .submit (* args , ** kwargs ))
@@ -217,13 +214,19 @@ def results_as_completed(self):
217
214
def __exit__ (self , exc_type , exc_val , exc_tb ):
218
215
if exc_type is None :
219
216
wait_on_futures (self .futures )
220
- # Note: this doesn't seem to be working correctly. If
221
- # we set a timeout of None we get deadlocks
222
- set_progress (self .progress_config .total )
223
- timeout = 0.1
224
217
else :
225
218
cancel_futures (self .futures )
226
- timeout = 0
227
- self .bar_thread .join (timeout )
228
- self .executor .shutdown ()
219
+ # There's probably a much cleaner way of doing this with a Condition
220
+ # or something, but this seems to work OK for now. This setup might
221
+ # make small conversions a bit laggy as we wait on the sleep interval
222
+ # though.
223
+ with self .completed_lock :
224
+ self .completed = True
225
+ self .executor .shutdown (wait = False )
226
+ # FIXME there's currently some thing weird happening at the end of
227
+ # Encode 1D for 1kg-p3. The progress bar disappears, like we're
228
+ # setting a total of zero or something.
229
+ self .progress_thread .join ()
230
+ self ._update_progress ()
231
+ self .progress_bar .close ()
229
232
return False
0 commit comments