Skip to content

Commit bd53416

Browse files
Factor out encoder threads for PCVcf
1 parent 0260e41 commit bd53416

File tree

1 file changed

+27
-51
lines changed

1 file changed

+27
-51
lines changed

bio2zarr/vcf.py

Lines changed: 27 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import concurrent.futures as cf
21
import collections
32
import dataclasses
43
import functools
@@ -143,6 +142,22 @@ class VcfMetadata:
143142
partitions: list = None
144143
contig_lengths: list = None
145144

145+
@property
146+
def info_fields(self):
147+
fields = []
148+
for field in self.fields:
149+
if field.category == "INFO":
150+
fields.append(field)
151+
return fields
152+
153+
@property
154+
def format_fields(self):
155+
fields = []
156+
for field in self.fields:
157+
if field.category == "FORMAT":
158+
fields.append(field)
159+
return fields
160+
146161
@property
147162
def num_records(self):
148163
return sum(self.contig_record_counts.values())
@@ -629,35 +644,26 @@ def reset(self):
629644
self.chunk_index += 1
630645

631646

632-
class ThreadedColumnWriter(contextlib.AbstractContextManager):
647+
class ColumnWriter(contextlib.AbstractContextManager):
633648
def __init__(
634649
self,
635650
vcf_metadata,
636651
out_path,
637652
partition_index,
638653
*,
639-
encoder_threads=0,
640654
chunk_size=1,
641655
):
642-
self.encoder_threads = encoder_threads
643656
self.partition_index = partition_index
644657
# chunk_size is in megabytes
645658
self.max_buffered_bytes = chunk_size * 2**20
646659
assert self.max_buffered_bytes > 0
647660

648-
if encoder_threads <= 0:
649-
# NOTE: this is only for testing, not for production use!
650-
self.executor = core.SynchronousExecutor()
651-
else:
652-
self.executor = cf.ThreadPoolExecutor(max_workers=encoder_threads)
653-
654661
self.buffers = {}
655662
num_samples = len(vcf_metadata.samples)
656663
for vcf_field in vcf_metadata.fields:
657664
field = PickleChunkedVcfField(vcf_field, out_path)
658665
transformer = VcfValueTransformer.factory(vcf_field, num_samples)
659666
self.buffers[vcf_field.full_name] = FieldBuffer(field, transformer)
660-
self.futures = set()
661667

662668
@property
663669
def field_summaries(self):
@@ -676,37 +682,19 @@ def append(self, name, value):
676682
if buff.buffered_bytes >= self.max_buffered_bytes:
677683
self._flush_buffer(name, buff)
678684

679-
def _service_futures(self):
680-
max_waiting = 2 * self.encoder_threads
681-
while len(self.futures) > max_waiting:
682-
futures_done, _ = cf.wait(self.futures, return_when=cf.FIRST_COMPLETED)
683-
for future in futures_done:
684-
exception = future.exception()
685-
if exception is not None:
686-
raise exception
687-
self.futures.remove(future)
688-
689685
def _flush_buffer(self, name, buff):
690-
self._service_futures()
691686
logger.debug(f"Schedule write {name}:{self.partition_index}.{buff.chunk_index}")
692-
future = self.executor.submit(
693-
buff.field.write_chunk,
687+
buff.field.write_chunk(
694688
self.partition_index,
695689
buff.chunk_index,
696690
buff.buff,
697691
)
698-
self.futures.add(future)
699692
buff.reset()
700693

701694
def __exit__(self, exc_type, exc_val, exc_tb):
702695
if exc_type is None:
703-
# Normal exit condition
704696
for name, buff in self.buffers.items():
705697
self._flush_buffer(name, buff)
706-
core.wait_on_futures(self.futures)
707-
else:
708-
core.cancel_futures(self.futures)
709-
self.executor.shutdown()
710698
return False
711699

712700

@@ -812,41 +800,31 @@ def convert_partition(
812800
partition_index,
813801
out_path,
814802
*,
815-
encoder_threads=4,
816803
column_chunk_size=16,
817804
):
818805
partition = vcf_metadata.partitions[partition_index]
819806
logger.info(
820807
f"Start p{partition_index} {partition.vcf_path}__{partition.region}"
821808
)
822-
823-
info_fields = []
809+
info_fields = vcf_metadata.info_fields
824810
format_fields = []
825811
has_gt = False
826-
for field in vcf_metadata.fields:
827-
if field.category == "INFO":
828-
info_fields.append(field)
829-
elif field.category == "FORMAT":
830-
if field.name == "GT":
831-
has_gt = True
832-
else:
833-
format_fields.append(field)
834-
835-
# FIXME it looks like this is actually a bit pointless now that we
836-
# can split up into multiple regions within the VCF. It's simpler
837-
# and easier to explain and predict performance if we just do
838-
# everything syncronously. We can keep the same interface,
839-
# just remove the "Threaded" bit and simplify.
840-
with ThreadedColumnWriter(
812+
for field in vcf_metadata.format_fields:
813+
if field.name == "GT":
814+
has_gt = True
815+
else:
816+
format_fields.append(field)
817+
818+
with ColumnWriter(
841819
vcf_metadata,
842820
out_path,
843821
partition_index,
844-
encoder_threads=0,
845822
chunk_size=column_chunk_size,
846823
) as tcw:
847824
with vcf_utils.IndexedVcf(partition.vcf_path) as ivcf:
848825
num_records = 0
849826
for variant in ivcf.variants(partition.region):
827+
num_records += 1
850828
tcw.append("CHROM", variant.CHROM)
851829
tcw.append("POS", variant.POS)
852830
tcw.append("QUAL", variant.QUAL)
@@ -865,12 +843,10 @@ def convert_partition(
865843
except KeyError:
866844
pass
867845
tcw.append(field.full_name, val)
868-
869846
# Note: an issue with updating the progress per variant here like this
870847
# is that we get a significant pause at the end of the counter while
871848
# all the "small" fields get flushed. Possibly not much to be done about it.
872849
core.update_progress(1)
873-
num_records += 1
874850

875851
logger.info(
876852
f"Finish p{partition_index} {partition.vcf_path}__{partition.region}="

0 commit comments

Comments
 (0)