1
- import concurrent .futures as cf
2
1
import collections
3
2
import dataclasses
4
3
import functools
@@ -143,6 +142,22 @@ class VcfMetadata:
143
142
partitions : list = None
144
143
contig_lengths : list = None
145
144
145
+ @property
146
+ def info_fields (self ):
147
+ fields = []
148
+ for field in self .fields :
149
+ if field .category == "INFO" :
150
+ fields .append (field )
151
+ return fields
152
+
153
+ @property
154
+ def format_fields (self ):
155
+ fields = []
156
+ for field in self .fields :
157
+ if field .category == "FORMAT" :
158
+ fields .append (field )
159
+ return fields
160
+
146
161
@property
147
162
def num_records (self ):
148
163
return sum (self .contig_record_counts .values ())
@@ -629,35 +644,26 @@ def reset(self):
629
644
self .chunk_index += 1
630
645
631
646
632
- class ThreadedColumnWriter (contextlib .AbstractContextManager ):
647
+ class ColumnWriter (contextlib .AbstractContextManager ):
633
648
def __init__ (
634
649
self ,
635
650
vcf_metadata ,
636
651
out_path ,
637
652
partition_index ,
638
653
* ,
639
- encoder_threads = 0 ,
640
654
chunk_size = 1 ,
641
655
):
642
- self .encoder_threads = encoder_threads
643
656
self .partition_index = partition_index
644
657
# chunk_size is in megabytes
645
658
self .max_buffered_bytes = chunk_size * 2 ** 20
646
659
assert self .max_buffered_bytes > 0
647
660
648
- if encoder_threads <= 0 :
649
- # NOTE: this is only for testing, not for production use!
650
- self .executor = core .SynchronousExecutor ()
651
- else :
652
- self .executor = cf .ThreadPoolExecutor (max_workers = encoder_threads )
653
-
654
661
self .buffers = {}
655
662
num_samples = len (vcf_metadata .samples )
656
663
for vcf_field in vcf_metadata .fields :
657
664
field = PickleChunkedVcfField (vcf_field , out_path )
658
665
transformer = VcfValueTransformer .factory (vcf_field , num_samples )
659
666
self .buffers [vcf_field .full_name ] = FieldBuffer (field , transformer )
660
- self .futures = set ()
661
667
662
668
@property
663
669
def field_summaries (self ):
@@ -676,37 +682,19 @@ def append(self, name, value):
676
682
if buff .buffered_bytes >= self .max_buffered_bytes :
677
683
self ._flush_buffer (name , buff )
678
684
679
- def _service_futures (self ):
680
- max_waiting = 2 * self .encoder_threads
681
- while len (self .futures ) > max_waiting :
682
- futures_done , _ = cf .wait (self .futures , return_when = cf .FIRST_COMPLETED )
683
- for future in futures_done :
684
- exception = future .exception ()
685
- if exception is not None :
686
- raise exception
687
- self .futures .remove (future )
688
-
689
685
def _flush_buffer (self , name , buff ):
690
- self ._service_futures ()
691
686
logger .debug (f"Schedule write { name } :{ self .partition_index } .{ buff .chunk_index } " )
692
- future = self .executor .submit (
693
- buff .field .write_chunk ,
687
+ buff .field .write_chunk (
694
688
self .partition_index ,
695
689
buff .chunk_index ,
696
690
buff .buff ,
697
691
)
698
- self .futures .add (future )
699
692
buff .reset ()
700
693
701
694
def __exit__ (self , exc_type , exc_val , exc_tb ):
702
695
if exc_type is None :
703
- # Normal exit condition
704
696
for name , buff in self .buffers .items ():
705
697
self ._flush_buffer (name , buff )
706
- core .wait_on_futures (self .futures )
707
- else :
708
- core .cancel_futures (self .futures )
709
- self .executor .shutdown ()
710
698
return False
711
699
712
700
@@ -812,41 +800,31 @@ def convert_partition(
812
800
partition_index ,
813
801
out_path ,
814
802
* ,
815
- encoder_threads = 4 ,
816
803
column_chunk_size = 16 ,
817
804
):
818
805
partition = vcf_metadata .partitions [partition_index ]
819
806
logger .info (
820
807
f"Start p{ partition_index } { partition .vcf_path } __{ partition .region } "
821
808
)
822
-
823
- info_fields = []
809
+ info_fields = vcf_metadata .info_fields
824
810
format_fields = []
825
811
has_gt = False
826
- for field in vcf_metadata .fields :
827
- if field .category == "INFO" :
828
- info_fields .append (field )
829
- elif field .category == "FORMAT" :
830
- if field .name == "GT" :
831
- has_gt = True
832
- else :
833
- format_fields .append (field )
834
-
835
- # FIXME it looks like this is actually a bit pointless now that we
836
- # can split up into multiple regions within the VCF. It's simpler
837
- # and easier to explain and predict performance if we just do
838
- # everything syncronously. We can keep the same interface,
839
- # just remove the "Threaded" bit and simplify.
840
- with ThreadedColumnWriter (
812
+ for field in vcf_metadata .format_fields :
813
+ if field .name == "GT" :
814
+ has_gt = True
815
+ else :
816
+ format_fields .append (field )
817
+
818
+ with ColumnWriter (
841
819
vcf_metadata ,
842
820
out_path ,
843
821
partition_index ,
844
- encoder_threads = 0 ,
845
822
chunk_size = column_chunk_size ,
846
823
) as tcw :
847
824
with vcf_utils .IndexedVcf (partition .vcf_path ) as ivcf :
848
825
num_records = 0
849
826
for variant in ivcf .variants (partition .region ):
827
+ num_records += 1
850
828
tcw .append ("CHROM" , variant .CHROM )
851
829
tcw .append ("POS" , variant .POS )
852
830
tcw .append ("QUAL" , variant .QUAL )
@@ -865,12 +843,10 @@ def convert_partition(
865
843
except KeyError :
866
844
pass
867
845
tcw .append (field .full_name , val )
868
-
869
846
# Note: an issue with updating the progress per variant here like this
870
847
# is that we get a significant pause at the end of the counter while
871
848
# all the "small" fields get flushed. Possibly not much to be done about it.
872
849
core .update_progress (1 )
873
- num_records += 1
874
850
875
851
logger .info (
876
852
f"Finish p{ partition_index } { partition .vcf_path } __{ partition .region } ="
0 commit comments