Skip to content

Commit ec22640

Browse files
Work on refactor progress code
1 parent 1629def commit ec22640

File tree

2 files changed

+134
-37
lines changed

2 files changed

+134
-37
lines changed

bio2zarr/core.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import dataclasses
22
import contextlib
33
import concurrent.futures as cf
4+
import multiprocessing
5+
import threading
46
import logging
7+
import time
58

69
import zarr
710
import numpy as np
11+
import tqdm
812

913

1014
logger = logging.getLogger(__name__)
@@ -118,3 +122,73 @@ def __exit__(self, exc_type, exc_val, exc_tb):
118122
future.cancel()
119123
self.executor.shutdown()
120124
return False
125+
126+
127+
128+
progress_counter = multiprocessing.Value("Q", 0)
129+
130+
import os
131+
132+
def update_progress(inc):
133+
print("update progress", os.getpid(), inc)
134+
with progress_counter.get_lock():
135+
progress_counter.value += 1
136+
137+
def progress_thread_worker(config):
138+
pbar = tqdm.tqdm(
139+
total=config.total, desc=config.title, unit_scale=True, unit=config.units,
140+
smoothing=0.1
141+
)
142+
143+
while (current := progress_counter.value) < config.total:
144+
inc = current - pbar.n
145+
pbar.update(inc)
146+
time.sleep(0.1)
147+
pbar.close()
148+
149+
150+
def init_workers(counter):
151+
global progress_counter
152+
progress_counter = counter
153+
154+
155+
@dataclasses.dataclass
156+
class ProgressConfig:
157+
total: int
158+
units: str
159+
title: str
160+
161+
162+
class ParallelWorkManager(contextlib.AbstractContextManager):
163+
164+
def __init__(self, worker_processes=1, progress_config=None):
165+
self.executor = cf.ProcessPoolExecutor(
166+
max_workers=worker_processes,
167+
initializer=init_workers,
168+
initargs=(progress_counter,),
169+
)
170+
171+
self.bar_thread = None
172+
if progress_config is not None:
173+
bar_thread = threading.Thread(
174+
target=progress_thread_worker,
175+
args=(progress_config,),
176+
name="progress",
177+
daemon=True,
178+
)
179+
bar_thread.start()
180+
181+
def __exit__(self, exc_type, exc_val, exc_tb):
182+
# if exc_type is None:
183+
# # Normal exit condition
184+
# self.next_row += 1
185+
# self.swap_buffers()
186+
# self.wait_on_futures()
187+
# else:
188+
# for future in self.futures:
189+
# future.cancel()
190+
self.executor.shutdown()
191+
if self.bar_thread is not None:
192+
self.bar_thread.join()
193+
return False
194+

bio2zarr/vcf.py

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,7 @@ def convert(
729729
vcfs, out_path, *, column_chunk_size=16, worker_processes=1, show_progress=False
730730
):
731731
out_path = pathlib.Path(out_path)
732+
# TODO make scan work in parallel using general progress code too
732733
vcf_metadata = scan_vcfs(vcfs, show_progress=show_progress)
733734
pcvcf = PickleChunkedVcf(out_path, vcf_metadata)
734735
pcvcf.mkdirs()
@@ -741,29 +742,16 @@ def convert(
741742
f"Exploding {pcvcf.num_columns} columns {total_variants} variants "
742743
f"{pcvcf.num_samples} samples"
743744
)
744-
global progress_counter
745-
progress_counter = multiprocessing.Value("Q", 0)
746-
747-
# start update progress bar process
748-
bar_thread = None
745+
progress_config = None
749746
if show_progress:
750-
bar_thread = threading.Thread(
751-
target=update_bar,
752-
args=(progress_counter, total_variants, "Explode", "vars"),
753-
name="progress",
754-
daemon=True,
755-
)
756-
bar_thread.start()
747+
progress_config = core.ProgressConfig(
748+
total=total_variants, units="vars", title="Explode")
749+
with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
757750

758-
with cf.ProcessPoolExecutor(
759-
max_workers=worker_processes,
760-
initializer=init_workers,
761-
initargs=(progress_counter,),
762-
) as executor:
763751
futures = []
764752
for j, partition in enumerate(vcf_metadata.partitions):
765753
futures.append(
766-
executor.submit(
754+
pwm.executor.submit(
767755
PickleChunkedVcf.convert_partition,
768756
vcf_metadata,
769757
j,
@@ -775,9 +763,43 @@ def convert(
775763
future.result() for future in cf.as_completed(futures)
776764
]
777765

778-
assert progress_counter.value == total_variants
779-
if bar_thread is not None:
780-
bar_thread.join()
766+
# global progress_counter
767+
# progress_counter = multiprocessing.Value("Q", 0)
768+
769+
# # start update progress bar process
770+
# bar_thread = None
771+
# if show_progress:
772+
# bar_thread = threading.Thread(
773+
# target=update_bar,
774+
# args=(progress_counter, total_variants, "Explode", "vars"),
775+
# name="progress",
776+
# daemon=True,
777+
# )
778+
# bar_thread.start()
779+
780+
# with cf.ProcessPoolExecutor(
781+
# max_workers=worker_processes,
782+
# initializer=init_workers,
783+
# initargs=(progress_counter,),
784+
# ) as executor:
785+
# futures = []
786+
# for j, partition in enumerate(vcf_metadata.partitions):
787+
# futures.append(
788+
# executor.submit(
789+
# PickleChunkedVcf.convert_partition,
790+
# vcf_metadata,
791+
# j,
792+
# out_path,
793+
# column_chunk_size=column_chunk_size,
794+
# )
795+
# )
796+
# partition_summaries = [
797+
# future.result() for future in cf.as_completed(futures)
798+
# ]
799+
800+
# assert progress_counter.value == total_variants
801+
# if bar_thread is not None:
802+
# bar_thread.join()
781803

782804
for field in vcf_metadata.fields:
783805
for summary in partition_summaries:
@@ -862,11 +884,11 @@ def service_futures(max_waiting=2 * flush_threads):
862884

863885
service_futures()
864886

887+
865888
# Note: an issue with updating the progress per variant here like this
866889
# is that we get a significant pause at the end of the counter while
867890
# all the "small" fields get flushed. Possibly not much to be done about it.
868-
with progress_counter.get_lock():
869-
progress_counter.value += 1
891+
core.update_progress(1)
870892

871893
for col in columns.values():
872894
col.flush()
@@ -876,21 +898,21 @@ def service_futures(max_waiting=2 * flush_threads):
876898
return summaries
877899

878900

879-
def update_bar(progress_counter, total, title, units):
880-
pbar = tqdm.tqdm(
881-
total=total, desc=title, unit_scale=True, unit=units, smoothing=0.1
882-
)
901+
# def update_bar(progress_counter, total, title, units):
902+
# pbar = tqdm.tqdm(
903+
# total=total, desc=title, unit_scale=True, unit=units, smoothing=0.1
904+
# )
883905

884-
while (current := progress_counter.value) < total:
885-
inc = current - pbar.n
886-
pbar.update(inc)
887-
time.sleep(0.1)
888-
pbar.close()
906+
# while (current := progress_counter.value) < total:
907+
# inc = current - pbar.n
908+
# pbar.update(inc)
909+
# time.sleep(0.1)
910+
# pbar.close()
889911

890912

891-
def init_workers(counter):
892-
global progress_counter
893-
progress_counter = counter
913+
# def init_workers(counter):
914+
# global progress_counter
915+
# progress_counter = counter
894916

895917

896918
def explode(
@@ -1418,7 +1440,9 @@ def convert_vcf(
14181440
)
14191441

14201442

1421-
def encode_bed_partition_genotypes(bed_path, zarr_path, start_variant, end_variant, encoder_threads=8):
1443+
def encode_bed_partition_genotypes(
1444+
bed_path, zarr_path, start_variant, end_variant, encoder_threads=8
1445+
):
14221446
bed = bed_reader.open_bed(bed_path, num_threads=1)
14231447

14241448
store = zarr.DirectoryStore(zarr_path)
@@ -1432,7 +1456,6 @@ def encode_bed_partition_genotypes(bed_path, zarr_path, start_variant, end_varia
14321456
buffered_arrays = [gt, gt_phased, gt_mask]
14331457

14341458
with core.ThreadedZarrEncoder(buffered_arrays, encoder_threads) as te:
1435-
14361459
start = start_variant
14371460
while start < end_variant:
14381461
stop = min(start + chunk_length, end_variant)

0 commit comments

Comments
 (0)