Skip to content

Commit c496860

Browse files
Abstract parallel jobs with progress to core
1 parent ec22640 commit c496860

File tree

2 files changed

+56
-158
lines changed

2 files changed

+56
-158
lines changed

bio2zarr/core.py

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ def swap_buffers(self):
3535
def async_flush(self, executor, offset, buff_stop=None):
3636
return async_flush_array(executor, self.buff[:buff_stop], self.array, offset)
3737

38+
3839
# TODO: factor these functions into the BufferedArray class
3940

41+
4042
def sync_flush_array(np_buffer, zarr_array, offset):
4143
zarr_array[offset : offset + np_buffer.shape[0]] = np_buffer
4244

@@ -124,71 +126,59 @@ def __exit__(self, exc_type, exc_val, exc_tb):
124126
return False
125127

126128

129+
@dataclasses.dataclass
130+
class ProgressConfig:
131+
total: int
132+
units: str
133+
title: str
134+
127135

128-
progress_counter = multiprocessing.Value("Q", 0)
136+
_progress_counter = multiprocessing.Value("Q", 0)
129137

130-
import os
131138

132139
def update_progress(inc):
133-
print("update progress", os.getpid(), inc)
134-
with progress_counter.get_lock():
135-
progress_counter.value += 1
140+
with _progress_counter.get_lock():
141+
_progress_counter.value += inc
142+
136143

137144
def progress_thread_worker(config):
138145
pbar = tqdm.tqdm(
139-
total=config.total, desc=config.title, unit_scale=True, unit=config.units,
140-
smoothing=0.1
146+
total=config.total,
147+
desc=config.title,
148+
unit_scale=True,
149+
unit=config.units,
150+
smoothing=0.1,
141151
)
142152

143-
while (current := progress_counter.value) < config.total:
153+
while (current := _progress_counter.value) < config.total:
144154
inc = current - pbar.n
145155
pbar.update(inc)
146156
time.sleep(0.1)
147157
pbar.close()
148158

149159

150-
def init_workers(counter):
151-
global progress_counter
152-
progress_counter = counter
153-
154-
155-
@dataclasses.dataclass
156-
class ProgressConfig:
157-
total: int
158-
units: str
159-
title: str
160-
161-
162160
class ParallelWorkManager(contextlib.AbstractContextManager):
163-
164161
def __init__(self, worker_processes=1, progress_config=None):
165162
self.executor = cf.ProcessPoolExecutor(
166163
max_workers=worker_processes,
167-
initializer=init_workers,
168-
initargs=(progress_counter,),
169164
)
170165

171166
self.bar_thread = None
172167
if progress_config is not None:
173-
bar_thread = threading.Thread(
168+
self.bar_thread = threading.Thread(
174169
target=progress_thread_worker,
175170
args=(progress_config,),
176171
name="progress",
177172
daemon=True,
178173
)
179-
bar_thread.start()
174+
self.bar_thread.start()
180175

181176
def __exit__(self, exc_type, exc_val, exc_tb):
182177
# if exc_type is None:
183-
# # Normal exit condition
184-
# self.next_row += 1
185-
# self.swap_buffers()
186-
# self.wait_on_futures()
178+
# print("normal exit")
187179
# else:
188-
# for future in self.futures:
189-
# future.cancel()
190-
self.executor.shutdown()
180+
# print("Error occured")
191181
if self.bar_thread is not None:
192-
self.bar_thread.join()
182+
self.bar_thread.join(timeout=0)
183+
self.executor.shutdown()
193184
return False
194-

bio2zarr/vcf.py

Lines changed: 32 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -745,9 +745,9 @@ def convert(
745745
progress_config = None
746746
if show_progress:
747747
progress_config = core.ProgressConfig(
748-
total=total_variants, units="vars", title="Explode")
748+
total=total_variants, units="vars", title="Explode"
749+
)
749750
with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
750-
751751
futures = []
752752
for j, partition in enumerate(vcf_metadata.partitions):
753753
futures.append(
@@ -763,44 +763,6 @@ def convert(
763763
future.result() for future in cf.as_completed(futures)
764764
]
765765

766-
# global progress_counter
767-
# progress_counter = multiprocessing.Value("Q", 0)
768-
769-
# # start update progress bar process
770-
# bar_thread = None
771-
# if show_progress:
772-
# bar_thread = threading.Thread(
773-
# target=update_bar,
774-
# args=(progress_counter, total_variants, "Explode", "vars"),
775-
# name="progress",
776-
# daemon=True,
777-
# )
778-
# bar_thread.start()
779-
780-
# with cf.ProcessPoolExecutor(
781-
# max_workers=worker_processes,
782-
# initializer=init_workers,
783-
# initargs=(progress_counter,),
784-
# ) as executor:
785-
# futures = []
786-
# for j, partition in enumerate(vcf_metadata.partitions):
787-
# futures.append(
788-
# executor.submit(
789-
# PickleChunkedVcf.convert_partition,
790-
# vcf_metadata,
791-
# j,
792-
# out_path,
793-
# column_chunk_size=column_chunk_size,
794-
# )
795-
# )
796-
# partition_summaries = [
797-
# future.result() for future in cf.as_completed(futures)
798-
# ]
799-
800-
# assert progress_counter.value == total_variants
801-
# if bar_thread is not None:
802-
# bar_thread.join()
803-
804766
for field in vcf_metadata.fields:
805767
for summary in partition_summaries:
806768
field.summary.update(summary[field.full_name])
@@ -884,7 +846,6 @@ def service_futures(max_waiting=2 * flush_threads):
884846

885847
service_futures()
886848

887-
888849
# Note: an issue with updating the progress per variant here like this
889850
# is that we get a significant pause at the end of the counter while
890851
# all the "small" fields get flushed. Possibly not much to be done about it.
@@ -898,23 +859,6 @@ def service_futures(max_waiting=2 * flush_threads):
898859
return summaries
899860

900861

901-
# def update_bar(progress_counter, total, title, units):
902-
# pbar = tqdm.tqdm(
903-
# total=total, desc=title, unit_scale=True, unit=units, smoothing=0.1
904-
# )
905-
906-
# while (current := progress_counter.value) < total:
907-
# inc = current - pbar.n
908-
# pbar.update(inc)
909-
# time.sleep(0.1)
910-
# pbar.close()
911-
912-
913-
# def init_workers(counter):
914-
# global progress_counter
915-
# progress_counter = counter
916-
917-
918862
def explode(
919863
vcfs,
920864
out_path,
@@ -1160,9 +1104,9 @@ def encode_column(self, pcvcf, column, encoder_threads=4):
11601104
for value, bytes_read in source_col.iter_values_bytes():
11611105
j = te.next_buffer_row()
11621106
sanitiser(ba.buff, j, value)
1107+
# print(bytes_read, last_bytes_read, value)
11631108
if last_bytes_read != bytes_read:
1164-
with progress_counter.get_lock():
1165-
progress_counter.value += bytes_read - last_bytes_read
1109+
core.update_progress(bytes_read - last_bytes_read)
11661110
last_bytes_read = bytes_read
11671111

11681112
def encode_genotypes(self, pcvcf, encoder_threads=4):
@@ -1181,10 +1125,8 @@ def encode_genotypes(self, pcvcf, encoder_threads=4):
11811125
# TODO check is this the correct semantics when we are padding
11821126
# with mixed ploidies?
11831127
gt_mask.buff[j] = gt.buff[j] < 0
1184-
11851128
if last_bytes_read != bytes_read:
1186-
with progress_counter.get_lock():
1187-
progress_counter.value += bytes_read - last_bytes_read
1129+
core.update_progress(bytes_read - last_bytes_read)
11881130
last_bytes_read = bytes_read
11891131

11901132
def encode_alleles(self, pcvcf):
@@ -1200,10 +1142,10 @@ def encode_alleles(self, pcvcf):
12001142
alleles[j, 0] = ref
12011143
alleles[j, 1 : 1 + len(alt)] = alt
12021144
allele_array[:] = alleles
1203-
1204-
with progress_counter.get_lock():
1205-
for col in [ref_col, alt_col]:
1206-
progress_counter.value += col.vcf_field.summary.uncompressed_size
1145+
size = sum(
1146+
col.vcf_field.summary.uncompressed_size for col in [ref_col, alt_col]
1147+
)
1148+
core.update_progress(size)
12071149
logger.debug("alleles done")
12081150

12091151
def encode_samples(self, pcvcf, sample_id, chunk_width):
@@ -1249,8 +1191,7 @@ def encode_contig(self, pcvcf, contig_names, contig_lengths):
12491191

12501192
array[:] = buff
12511193

1252-
with progress_counter.get_lock():
1253-
progress_counter.value += col.vcf_field.summary.uncompressed_size
1194+
core.update_progress(col.vcf_field.summary.uncompressed_size)
12541195
logger.debug("Contig done")
12551196

12561197
def encode_filters(self, pcvcf, filter_names):
@@ -1277,8 +1218,7 @@ def encode_filters(self, pcvcf, filter_names):
12771218

12781219
array[:] = buff
12791220

1280-
with progress_counter.get_lock():
1281-
progress_counter.value += col.vcf_field.summary.uncompressed_size
1221+
core.update_progress(col.vcf_field.summary.uncompressed_size)
12821222
logger.debug("Filters done")
12831223

12841224
def encode_id(self, pcvcf):
@@ -1298,8 +1238,7 @@ def encode_id(self, pcvcf):
12981238
id_array[:] = id_buff
12991239
id_mask_array[:] = id_mask_buff
13001240

1301-
with progress_counter.get_lock():
1302-
progress_counter.value += col.vcf_field.summary.uncompressed_size
1241+
core.update_progress(col.vcf_field.summary.uncompressed_size)
13031242
logger.debug("ID done")
13041243

13051244
@staticmethod
@@ -1319,41 +1258,30 @@ def convert(
13191258
for variable in conversion_spec.variables[:]:
13201259
sgvcf.create_array(variable)
13211260

1322-
global progress_counter
1323-
progress_counter = multiprocessing.Value("Q", 0)
1324-
1325-
# start update progress bar process
1326-
bar_thread = None
1261+
progress_config = None
13271262
if show_progress:
1328-
bar_thread = threading.Thread(
1329-
target=update_bar,
1330-
args=(progress_counter, pcvcf.total_uncompressed_bytes, "Encode", "b"),
1331-
name="progress",
1332-
daemon=True,
1263+
progress_config = core.ProgressConfig(
1264+
total=pcvcf.total_uncompressed_bytes, title="Encode", units="b"
13331265
)
1334-
bar_thread.start()
1335-
1336-
with cf.ProcessPoolExecutor(
1337-
max_workers=worker_processes,
1338-
initializer=init_workers,
1339-
initargs=(progress_counter,),
1340-
) as executor:
1266+
with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
13411267
futures = [
1342-
executor.submit(
1268+
pwm.executor.submit(
13431269
sgvcf.encode_samples,
13441270
pcvcf,
13451271
conversion_spec.sample_id,
13461272
conversion_spec.chunk_width,
13471273
),
1348-
executor.submit(sgvcf.encode_alleles, pcvcf),
1349-
executor.submit(sgvcf.encode_id, pcvcf),
1350-
executor.submit(
1274+
pwm.executor.submit(sgvcf.encode_alleles, pcvcf),
1275+
pwm.executor.submit(sgvcf.encode_id, pcvcf),
1276+
pwm.executor.submit(
13511277
sgvcf.encode_contig,
13521278
pcvcf,
13531279
conversion_spec.contig_id,
13541280
conversion_spec.contig_length,
13551281
),
1356-
executor.submit(sgvcf.encode_filters, pcvcf, conversion_spec.filter_id),
1282+
pwm.executor.submit(
1283+
sgvcf.encode_filters, pcvcf, conversion_spec.filter_id
1284+
),
13571285
]
13581286
has_gt = False
13591287
for variable in conversion_spec.variables[:]:
@@ -1364,14 +1292,14 @@ def convert(
13641292
# long wait for the largest GT columns to finish.
13651293
# Straightforward to do because we can chunk-align the work
13661294
# packages.
1367-
future = executor.submit(sgvcf.encode_column, pcvcf, variable)
1295+
future = pwm.executor.submit(sgvcf.encode_column, pcvcf, variable)
13681296
futures.append(future)
13691297
else:
13701298
if variable.name == "call_genotype":
13711299
has_gt = True
13721300
if has_gt:
13731301
# TODO add mixed ploidy
1374-
futures.append(executor.submit(sgvcf.encode_genotypes, pcvcf))
1302+
futures.append(pwm.executor.submit(sgvcf.encode_genotypes, pcvcf))
13751303

13761304
flush_futures(futures)
13771305

@@ -1471,8 +1399,7 @@ def encode_bed_partition_genotypes(
14711399
dest[values == 1, 0] = 1
14721400
gt_phased.buff[j] = False
14731401
gt_mask.buff[j] = dest == -1
1474-
with progress_counter.get_lock():
1475-
progress_counter.value += 1
1402+
core.update_progress(1)
14761403
start = stop
14771404

14781405

@@ -1669,21 +1596,7 @@ def convert_plink(
16691596
)
16701597
a.attrs["_ARRAY_DIMENSIONS"] = list(dimensions)
16711598

1672-
global progress_counter
1673-
progress_counter = multiprocessing.Value("Q", 0)
1674-
1675-
# start update progress bar process
1676-
bar_thread = None
1677-
if show_progress:
1678-
bar_thread = threading.Thread(
1679-
target=update_bar,
1680-
args=(progress_counter, m, "Write", "vars"),
1681-
name="progress",
1682-
daemon=True,
1683-
)
1684-
bar_thread.start()
1685-
1686-
num_chunks = m // chunk_length
1599+
num_chunks = max(1, m // chunk_length)
16871600
worker_processes = min(worker_processes, num_chunks)
16881601
if num_chunks == 1 or worker_processes == 1:
16891602
partitions = [(0, m)]
@@ -1704,19 +1617,14 @@ def convert_plink(
17041617
partitions.append((last_stop, m))
17051618
# print(partitions)
17061619

1707-
with cf.ProcessPoolExecutor(
1708-
max_workers=worker_processes,
1709-
initializer=init_workers,
1710-
initargs=(progress_counter,),
1711-
) as executor:
1620+
progress_config = None
1621+
if show_progress:
1622+
progress_config = core.ProgressConfig(total=m, title="Convert", units="vars")
1623+
with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
17121624
futures = [
1713-
executor.submit(
1625+
pwm.executor.submit(
17141626
encode_bed_partition_genotypes, bed_path, zarr_path, start, end
17151627
)
17161628
for start, end in partitions
17171629
]
17181630
flush_futures(futures)
1719-
# print("progress counter = ", m, progress_counter.value)
1720-
assert progress_counter.value == m
1721-
1722-
# print(root["call_genotype"][:])

0 commit comments

Comments
 (0)