@@ -138,11 +138,15 @@ class VcfMetadata:
138
138
format_version : str
139
139
samples : list
140
140
contig_names : list
141
+ contig_record_counts : dict
141
142
filters : list
142
143
fields : list
143
- contig_lengths : list = None
144
144
partitions : list = None
145
- num_records : int = 0
145
+ contig_lengths : list = None
146
+
147
+ @property
148
+ def num_records (self ):
149
+ return sum (self .contig_record_counts .values ())
146
150
147
151
@staticmethod
148
152
def fromdict (d ):
@@ -179,19 +183,10 @@ def make_field_def(name, vcf_type, vcf_number):
179
183
]
180
184
return fields
181
185
182
-
183
- # TODO refactor this to use the ProcessPoolExecutor, and the IndexedVCF class
184
- def scan_vcfs (paths , show_progress , target_num_partitions ):
185
- partitions = []
186
- vcf_metadata = None
187
- header = None
188
- logger .info (f"Scanning { len (paths )} VCFs" )
189
- total_records = 0
190
- for path in tqdm .tqdm (paths , desc = "Scan " , disable = not show_progress ):
191
- # TODO use contextlib.closing on this
192
- vcf = cyvcf2 .VCF (path )
193
- logger .debug (f"Scanning { path } " )
194
-
186
+ def scan_vcf (path , target_num_partitions ):
187
+ logger .debug (f"Scanning { path } " )
188
+ with vcf_utils .IndexedVcf (path ) as indexed_vcf :
189
+ vcf = indexed_vcf .vcf
195
190
filters = [
196
191
h ["ID" ]
197
192
for h in vcf .header_iter ()
@@ -214,43 +209,68 @@ def scan_vcfs(paths, show_progress, target_num_partitions):
214
209
metadata = VcfMetadata (
215
210
samples = vcf .samples ,
216
211
contig_names = vcf .seqnames ,
212
+ contig_record_counts = indexed_vcf .contig_record_counts (),
217
213
filters = filters ,
214
+ # TODO use the mapping dictionary
218
215
fields = fields ,
216
+ partitions = [],
219
217
# FIXME do something systematic with this
220
- format_version = "0.1"
218
+ format_version = "0.1" ,
221
219
)
222
220
try :
223
221
metadata .contig_lengths = vcf .seqlens
224
222
except AttributeError :
225
223
pass
226
224
227
- if vcf_metadata is None :
228
- vcf_metadata = metadata
229
- # We just take the first header, assuming the others
230
- # are compatible.
231
- header = vcf .raw_header
232
- else :
233
- if metadata != vcf_metadata :
234
- raise ValueError ("Incompatible VCF chunks" )
235
- vcf_metadata .num_records += vcf .num_records
236
-
237
- # TODO: Move all our usage of the VCF class behind the IndexedVCF
238
- # so that we open the VCF once, and we explicitly set the index.
239
- # Otherwise cyvcf2 will do things behind our backs.
240
- indexed_vcf = vcf_utils .IndexedVcf (path )
241
225
regions = indexed_vcf .partition_into_regions (num_parts = target_num_partitions )
242
226
for region in regions :
243
- partitions .append (
227
+ metadata . partitions .append (
244
228
VcfPartition (
245
229
vcf_path = str (path ),
246
230
region = region ,
247
231
)
248
232
)
233
+ core .update_progress (1 )
234
+ return metadata , vcf .raw_header
235
+
236
+
237
+ def scan_vcfs (paths , show_progress , target_num_partitions , worker_processes = 1 ):
238
+ logger .info (f"Scanning { len (paths )} VCFs" )
239
+ progress_config = core .ProgressConfig (
240
+ total = len (paths ),
241
+ units = "files" ,
242
+ title = "Scan" ,
243
+ show = show_progress ,
244
+ )
245
+ with core .ParallelWorkManager (worker_processes , progress_config ) as pwm :
246
+ for path in paths :
247
+ pwm .submit (scan_vcf , path , target_num_partitions )
248
+ results = list (pwm .results_as_completed ())
249
+
250
+ # Sort to make the ordering deterministic
251
+ results .sort (key = lambda t : t [0 ].partitions [0 ].vcf_path )
252
+ # We just take the first header, assuming the others
253
+ # are compatible.
254
+ all_partitions = []
255
+ contig_record_counts = collections .Counter ()
256
+ for metadata , _ in results :
257
+ all_partitions .extend (metadata .partitions )
258
+ metadata .partitions .clear ()
259
+ contig_record_counts += metadata .contig_record_counts
260
+ metadata .contig_record_counts .clear ()
261
+
262
+ vcf_metadata , header = results [0 ]
263
+ for metadata , _ in results [1 :]:
264
+ if metadata != vcf_metadata :
265
+ raise ValueError ("Incompatible VCF chunks" )
266
+
267
+ vcf_metadata .contig_record_counts = dict (contig_record_counts )
268
+
249
269
# Sort by contig (in the order they appear in the header) first,
250
270
# then by start coordinate
251
- contig_index_map = {contig : j for j , contig in enumerate (vcf . seqnames )}
252
- partitions .sort (key = lambda x : (contig_index_map [x .region .contig ], x .region .start ))
253
- vcf_metadata .partitions = partitions
271
+ contig_index_map = {contig : j for j , contig in enumerate (metadata . contig_names )}
272
+ all_partitions .sort (key = lambda x : (contig_index_map [x .region .contig ], x .region .start ))
273
+ vcf_metadata .partitions = all_partitions
254
274
return vcf_metadata , header
255
275
256
276
@@ -627,7 +647,7 @@ def __init__(
627
647
# NOTE: this is only for testing, not for production use!
628
648
self .executor = core .SynchronousExecutor ()
629
649
else :
630
- self .executor = cf .ProcessPoolExecutor (max_workers = encoder_threads )
650
+ self .executor = cf .ThreadPoolExecutor (max_workers = encoder_threads )
631
651
632
652
self .buffers = {}
633
653
num_samples = len (vcf_metadata .samples )
@@ -748,7 +768,7 @@ def total_uncompressed_bytes(self):
748
768
749
769
@functools .cached_property
750
770
def num_records (self ):
751
- return self .metadata .num_records
771
+ return sum ( self .metadata .contig_record_counts . values ())
752
772
753
773
@property
754
774
def num_partitions (self ):
@@ -883,6 +903,7 @@ def convert(
883
903
target_num_partitions = max (1 , worker_processes * 4 )
884
904
vcf_metadata , header = scan_vcfs (
885
905
vcfs ,
906
+ worker_processes = worker_processes ,
886
907
show_progress = show_progress ,
887
908
target_num_partitions = target_num_partitions ,
888
909
)
0 commit comments