Skip to content

Commit 434d60a

Browse files
Abstract progress and multiprocessing
1 parent c496860 commit 434d60a

File tree

3 files changed

+161
-104
lines changed

3 files changed

+161
-104
lines changed

bio2zarr/core.py

Lines changed: 74 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import multiprocessing
55
import threading
66
import logging
7+
import functools
78
import time
89

910
import zarr
@@ -14,6 +15,25 @@
1415
logger = logging.getLogger(__name__)
1516

1617

18+
class SynchronousExecutor(cf.Executor):
19+
def submit(self, fn, /, *args, **kwargs):
20+
future = cf.Future()
21+
future.set_result(fn(*args, **kwargs))
22+
return future
23+
24+
25+
def wait_on_futures(futures):
26+
for future in cf.as_completed(futures):
27+
exception = future.exception()
28+
if exception is not None:
29+
raise exception
30+
31+
32+
def cancel_futures(futures):
33+
for future in futures:
34+
future.cancel()
35+
36+
1737
@dataclasses.dataclass
1838
class BufferedArray:
1939
array: zarr.Array
@@ -98,14 +118,8 @@ def next_buffer_row(self):
98118
self.next_row = 0
99119
return self.next_row
100120

101-
def wait_on_futures(self):
102-
for future in cf.as_completed(self.futures):
103-
exception = future.exception()
104-
if exception is not None:
105-
raise exception
106-
107121
def swap_buffers(self):
108-
self.wait_on_futures()
122+
wait_on_futures(self.futures)
109123
self.futures = []
110124
for ba in self.buffered_arrays:
111125
self.futures.extend(
@@ -118,19 +132,20 @@ def __exit__(self, exc_type, exc_val, exc_tb):
118132
# Normal exit condition
119133
self.next_row += 1
120134
self.swap_buffers()
121-
self.wait_on_futures()
135+
wait_on_futures(self.futures)
122136
else:
123-
for future in self.futures:
124-
future.cancel()
137+
cancel_futures(self.futures)
125138
self.executor.shutdown()
126139
return False
127140

128141

129142
@dataclasses.dataclass
130143
class ProgressConfig:
131-
total: int
132-
units: str
133-
title: str
144+
total: int = 0
145+
units: str = ""
146+
title: str = ""
147+
show: bool = False
148+
poll_interval: float = 0.001
134149

135150

136151
_progress_counter = multiprocessing.Value("Q", 0)
@@ -141,44 +156,71 @@ def update_progress(inc):
141156
_progress_counter.value += inc
142157

143158

159+
def get_progress():
160+
with _progress_counter.get_lock():
161+
val = _progress_counter.value
162+
return val
163+
164+
165+
def set_progress(value):
166+
with _progress_counter.get_lock():
167+
_progress_counter.value = value
168+
169+
144170
def progress_thread_worker(config):
145171
pbar = tqdm.tqdm(
146172
total=config.total,
147173
desc=config.title,
148174
unit_scale=True,
149175
unit=config.units,
150176
smoothing=0.1,
177+
disable=not config.show,
151178
)
152179

153-
while (current := _progress_counter.value) < config.total:
180+
while (current := get_progress()) < config.total:
154181
inc = current - pbar.n
155182
pbar.update(inc)
156-
time.sleep(0.1)
183+
time.sleep(config.poll_interval)
157184
pbar.close()
158185

159186

160187
class ParallelWorkManager(contextlib.AbstractContextManager):
161188
def __init__(self, worker_processes=1, progress_config=None):
162-
self.executor = cf.ProcessPoolExecutor(
163-
max_workers=worker_processes,
189+
if worker_processes <= 0:
190+
# NOTE: this is only for testing, not for production use!
191+
self.executor = SynchronousExecutor()
192+
else:
193+
self.executor = cf.ProcessPoolExecutor(
194+
max_workers=worker_processes,
195+
)
196+
set_progress(0)
197+
if progress_config is None:
198+
progress_config = ProgressConfig()
199+
self.bar_thread = threading.Thread(
200+
target=progress_thread_worker,
201+
args=(progress_config,),
202+
name="progress",
203+
daemon=True,
164204
)
205+
self.bar_thread.start()
206+
self.progress_config = progress_config
207+
self.futures = []
165208

166-
self.bar_thread = None
167-
if progress_config is not None:
168-
self.bar_thread = threading.Thread(
169-
target=progress_thread_worker,
170-
args=(progress_config,),
171-
name="progress",
172-
daemon=True,
173-
)
174-
self.bar_thread.start()
209+
def submit(self, *args, **kwargs):
210+
self.futures.append(self.executor.submit(*args, **kwargs))
211+
212+
def results_as_completed(self):
213+
for future in cf.as_completed(self.futures):
214+
yield future.result()
175215

176216
def __exit__(self, exc_type, exc_val, exc_tb):
177-
# if exc_type is None:
178-
# print("normal exit")
179-
# else:
180-
# print("Error occured")
181-
if self.bar_thread is not None:
182-
self.bar_thread.join(timeout=0)
217+
if exc_type is None:
218+
wait_on_futures(self.futures)
219+
set_progress(self.progress_config.total)
220+
timeout = None
221+
else:
222+
cancel_futures(self.futures)
223+
timeout = 0
224+
self.bar_thread.join(timeout)
183225
self.executor.shutdown()
184226
return False

bio2zarr/vcf.py

Lines changed: 38 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
import concurrent.futures as cf
22
import dataclasses
3-
import multiprocessing
43
import functools
54
import logging
65
import os
7-
import threading
86
import pathlib
9-
import time
107
import pickle
118
import sys
129
import shutil
@@ -111,15 +108,6 @@ def assert_prefix_float_equal_2d(vcf_val, zarr_val):
111108
# nt.assert_array_equal(v, z[:k])
112109

113110

114-
# TODO rename to wait_and_check_futures
115-
def flush_futures(futures):
116-
# Make sure previous futures have completed
117-
for future in cf.as_completed(futures):
118-
exception = future.exception()
119-
if exception is not None:
120-
raise exception
121-
122-
123111
@dataclasses.dataclass
124112
class VcfFieldSummary:
125113
num_chunks: int = 0
@@ -742,26 +730,19 @@ def convert(
742730
f"Exploding {pcvcf.num_columns} columns {total_variants} variants "
743731
f"{pcvcf.num_samples} samples"
744732
)
745-
progress_config = None
746-
if show_progress:
747-
progress_config = core.ProgressConfig(
748-
total=total_variants, units="vars", title="Explode"
749-
)
733+
progress_config = core.ProgressConfig(
734+
total=total_variants, units="vars", title="Explode", show=show_progress
735+
)
750736
with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
751-
futures = []
752737
for j, partition in enumerate(vcf_metadata.partitions):
753-
futures.append(
754-
pwm.executor.submit(
755-
PickleChunkedVcf.convert_partition,
756-
vcf_metadata,
757-
j,
758-
out_path,
759-
column_chunk_size=column_chunk_size,
760-
)
738+
pwm.submit(
739+
PickleChunkedVcf.convert_partition,
740+
vcf_metadata,
741+
j,
742+
out_path,
743+
column_chunk_size=column_chunk_size,
761744
)
762-
partition_summaries = [
763-
future.result() for future in cf.as_completed(futures)
764-
]
745+
partition_summaries = list(pwm.results_as_completed())
765746

766747
for field in vcf_metadata.fields:
767748
for summary in partition_summaries:
@@ -1258,31 +1239,28 @@ def convert(
12581239
for variable in conversion_spec.variables[:]:
12591240
sgvcf.create_array(variable)
12601241

1261-
progress_config = None
1262-
if show_progress:
1263-
progress_config = core.ProgressConfig(
1264-
total=pcvcf.total_uncompressed_bytes, title="Encode", units="b"
1265-
)
1242+
progress_config = core.ProgressConfig(
1243+
total=pcvcf.total_uncompressed_bytes,
1244+
title="Encode",
1245+
units="b",
1246+
show=show_progress,
1247+
)
12661248
with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
1267-
futures = [
1268-
pwm.executor.submit(
1269-
sgvcf.encode_samples,
1270-
pcvcf,
1271-
conversion_spec.sample_id,
1272-
conversion_spec.chunk_width,
1273-
),
1274-
pwm.executor.submit(sgvcf.encode_alleles, pcvcf),
1275-
pwm.executor.submit(sgvcf.encode_id, pcvcf),
1276-
pwm.executor.submit(
1277-
sgvcf.encode_contig,
1278-
pcvcf,
1279-
conversion_spec.contig_id,
1280-
conversion_spec.contig_length,
1281-
),
1282-
pwm.executor.submit(
1283-
sgvcf.encode_filters, pcvcf, conversion_spec.filter_id
1284-
),
1285-
]
1249+
pwm.submit(
1250+
sgvcf.encode_samples,
1251+
pcvcf,
1252+
conversion_spec.sample_id,
1253+
conversion_spec.chunk_width,
1254+
)
1255+
pwm.submit(sgvcf.encode_alleles, pcvcf)
1256+
pwm.submit(sgvcf.encode_id, pcvcf)
1257+
pwm.submit(
1258+
sgvcf.encode_contig,
1259+
pcvcf,
1260+
conversion_spec.contig_id,
1261+
conversion_spec.contig_length,
1262+
)
1263+
pwm.submit(sgvcf.encode_filters, pcvcf, conversion_spec.filter_id)
12861264
has_gt = False
12871265
for variable in conversion_spec.variables[:]:
12881266
if variable.vcf_field is not None:
@@ -1292,21 +1270,14 @@ def convert(
12921270
# long wait for the largest GT columns to finish.
12931271
# Straightforward to do because we can chunk-align the work
12941272
# packages.
1295-
future = pwm.executor.submit(sgvcf.encode_column, pcvcf, variable)
1296-
futures.append(future)
1273+
pwm.submit(sgvcf.encode_column, pcvcf, variable)
12971274
else:
12981275
if variable.name == "call_genotype":
12991276
has_gt = True
13001277
if has_gt:
13011278
# TODO add mixed ploidy
1302-
futures.append(pwm.executor.submit(sgvcf.encode_genotypes, pcvcf))
1303-
1304-
flush_futures(futures)
1279+
pwm.executor.submit(sgvcf.encode_genotypes, pcvcf)
13051280

1306-
# FIXME can't join the bar_thread because we never get to the correct
1307-
# number of bytes
1308-
# if bar_thread is not None:
1309-
# bar_thread.join()
13101281
zarr.consolidate_metadata(write_path)
13111282
# Atomic swap, now we've completely finished.
13121283
logger.info(f"Moving to final path {path}")
@@ -1617,14 +1588,9 @@ def convert_plink(
16171588
partitions.append((last_stop, m))
16181589
# print(partitions)
16191590

1620-
progress_config = None
1621-
if show_progress:
1622-
progress_config = core.ProgressConfig(total=m, title="Convert", units="vars")
1591+
progress_config = core.ProgressConfig(
1592+
total=m, title="Convert", units="vars", show=show_progress
1593+
)
16231594
with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
1624-
futures = [
1625-
pwm.executor.submit(
1626-
encode_bed_partition_genotypes, bed_path, zarr_path, start, end
1627-
)
1628-
for start, end in partitions
1629-
]
1630-
flush_futures(futures)
1595+
for start, end in partitions:
1596+
pwm.submit(encode_bed_partition_genotypes, bed_path, zarr_path, start, end)

tests/test_core.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,52 @@ def test_error_in_encode(self):
9393
# an error in the futures. In reality these will happen
9494
# when we run out of disk space, but this is hard to simulate
9595
ba.buff = np.array(["not an integer"])
96+
97+
98+
class TestParallelWorkManager:
99+
@pytest.mark.parametrize("total", [1, 10, 2**63])
100+
@pytest.mark.parametrize("workers", [0, 1])
101+
def test_one_future_progress(self, total, workers):
102+
progress_config = core.ProgressConfig(total=total)
103+
with core.ParallelWorkManager(workers, progress_config) as pwm:
104+
pwm.submit(core.update_progress, total)
105+
assert core.get_progress() == total
106+
107+
@pytest.mark.parametrize("total", [1, 10, 1000])
108+
@pytest.mark.parametrize("workers", [0, 1, 2, 3])
109+
def test_n_futures_progress(self, total, workers):
110+
progress_config = core.ProgressConfig(total=total)
111+
with core.ParallelWorkManager(workers, progress_config) as pwm:
112+
for _ in range(total):
113+
pwm.submit(core.update_progress, 1)
114+
assert core.get_progress() == total
115+
116+
@pytest.mark.parametrize("total", [1, 10, 20])
117+
@pytest.mark.parametrize("workers", [0, 1, 2, 3])
118+
def test_results_as_completed(self, total, workers):
119+
with core.ParallelWorkManager(workers) as pwm:
120+
for j in range(total):
121+
pwm.submit(frozenset, range(j))
122+
results = set(pwm.results_as_completed())
123+
assert results == set(frozenset(range(j)) for j in range(total))
124+
125+
@pytest.mark.parametrize("total", [1, 10, 20])
126+
@pytest.mark.parametrize("workers", [1, 2, 3])
127+
def test_error_in_workers_as_completed(self, total, workers):
128+
with pytest.raises(TypeError):
129+
with core.ParallelWorkManager(workers) as pwm:
130+
for j in range(total):
131+
pwm.submit(frozenset, range(j))
132+
# Raises a TypeError:
133+
pwm.submit(frozenset, j)
134+
set(pwm.results_as_completed())
135+
136+
@pytest.mark.parametrize("total", [1, 10, 20])
137+
@pytest.mark.parametrize("workers", [1, 2, 3])
138+
def test_error_in_workers_on_exit(self, total, workers):
139+
with pytest.raises(TypeError):
140+
with core.ParallelWorkManager(workers) as pwm:
141+
for j in range(total):
142+
pwm.submit(frozenset, range(j))
143+
# Raises a TypeError:
144+
pwm.submit(frozenset, j)

0 commit comments

Comments
 (0)