Skip to content

Commit 130eac8

Browse files
Refactor scan process, small changes to vcf metadata
1 parent 3e617bb commit 130eac8

File tree

1 file changed

+57
-36
lines changed

1 file changed

+57
-36
lines changed

bio2zarr/vcf.py

Lines changed: 57 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,15 @@ class VcfMetadata:
138138
format_version: str
139139
samples: list
140140
contig_names: list
141+
contig_record_counts: dict
141142
filters: list
142143
fields: list
143-
contig_lengths: list = None
144144
partitions: list = None
145-
num_records: int = 0
145+
contig_lengths: list = None
146+
147+
@property
148+
def num_records(self):
149+
return sum(self.contig_record_counts.values())
146150

147151
@staticmethod
148152
def fromdict(d):
@@ -179,19 +183,10 @@ def make_field_def(name, vcf_type, vcf_number):
179183
]
180184
return fields
181185

182-
183-
# TODO refactor this to use the ProcessPoolExecutor, and the IndexedVCF class
184-
def scan_vcfs(paths, show_progress, target_num_partitions):
185-
partitions = []
186-
vcf_metadata = None
187-
header = None
188-
logger.info(f"Scanning {len(paths)} VCFs")
189-
total_records = 0
190-
for path in tqdm.tqdm(paths, desc="Scan ", disable=not show_progress):
191-
# TODO use contextlib.closing on this
192-
vcf = cyvcf2.VCF(path)
193-
logger.debug(f"Scanning {path}")
194-
186+
def scan_vcf(path, target_num_partitions):
187+
logger.debug(f"Scanning {path}")
188+
with vcf_utils.IndexedVcf(path) as indexed_vcf:
189+
vcf = indexed_vcf.vcf
195190
filters = [
196191
h["ID"]
197192
for h in vcf.header_iter()
@@ -214,43 +209,68 @@ def scan_vcfs(paths, show_progress, target_num_partitions):
214209
metadata = VcfMetadata(
215210
samples=vcf.samples,
216211
contig_names=vcf.seqnames,
212+
contig_record_counts=indexed_vcf.contig_record_counts(),
217213
filters=filters,
214+
# TODO use the mapping dictionary
218215
fields=fields,
216+
partitions=[],
219217
# FIXME do something systematic with this
220-
format_version="0.1"
218+
format_version="0.1",
221219
)
222220
try:
223221
metadata.contig_lengths = vcf.seqlens
224222
except AttributeError:
225223
pass
226224

227-
if vcf_metadata is None:
228-
vcf_metadata = metadata
229-
# We just take the first header, assuming the others
230-
# are compatible.
231-
header = vcf.raw_header
232-
else:
233-
if metadata != vcf_metadata:
234-
raise ValueError("Incompatible VCF chunks")
235-
vcf_metadata.num_records += vcf.num_records
236-
237-
# TODO: Move all our usage of the VCF class behind the IndexedVCF
238-
# so that we open the VCF once, and we explicitly set the index.
239-
# Otherwise cyvcf2 will do things behind our backs.
240-
indexed_vcf = vcf_utils.IndexedVcf(path)
241225
regions = indexed_vcf.partition_into_regions(num_parts=target_num_partitions)
242226
for region in regions:
243-
partitions.append(
227+
metadata.partitions.append(
244228
VcfPartition(
245229
vcf_path=str(path),
246230
region=region,
247231
)
248232
)
233+
core.update_progress(1)
234+
return metadata, vcf.raw_header
235+
236+
237+
def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
238+
logger.info(f"Scanning {len(paths)} VCFs")
239+
progress_config = core.ProgressConfig(
240+
total=len(paths),
241+
units="files",
242+
title="Scan",
243+
show=show_progress,
244+
)
245+
with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
246+
for path in paths:
247+
pwm.submit(scan_vcf, path, target_num_partitions)
248+
results = list(pwm.results_as_completed())
249+
250+
# Sort to make the ordering deterministic
251+
results.sort(key=lambda t: t[0].partitions[0].vcf_path)
252+
# We just take the first header, assuming the others
253+
# are compatible.
254+
all_partitions = []
255+
contig_record_counts = collections.Counter()
256+
for metadata, _ in results:
257+
all_partitions.extend(metadata.partitions)
258+
metadata.partitions.clear()
259+
contig_record_counts += metadata.contig_record_counts
260+
metadata.contig_record_counts.clear()
261+
262+
vcf_metadata, header = results[0]
263+
for metadata, _ in results[1:]:
264+
if metadata != vcf_metadata:
265+
raise ValueError("Incompatible VCF chunks")
266+
267+
vcf_metadata.contig_record_counts = dict(contig_record_counts)
268+
249269
# Sort by contig (in the order they appear in the header) first,
250270
# then by start coordinate
251-
contig_index_map = {contig: j for j, contig in enumerate(vcf.seqnames)}
252-
partitions.sort(key=lambda x: (contig_index_map[x.region.contig], x.region.start))
253-
vcf_metadata.partitions = partitions
271+
contig_index_map = {contig: j for j, contig in enumerate(metadata.contig_names)}
272+
all_partitions.sort(key=lambda x: (contig_index_map[x.region.contig], x.region.start))
273+
vcf_metadata.partitions = all_partitions
254274
return vcf_metadata, header
255275

256276

@@ -627,7 +647,7 @@ def __init__(
627647
# NOTE: this is only for testing, not for production use!
628648
self.executor = core.SynchronousExecutor()
629649
else:
630-
self.executor = cf.ProcessPoolExecutor(max_workers=encoder_threads)
650+
self.executor = cf.ThreadPoolExecutor(max_workers=encoder_threads)
631651

632652
self.buffers = {}
633653
num_samples = len(vcf_metadata.samples)
@@ -748,7 +768,7 @@ def total_uncompressed_bytes(self):
748768

749769
@functools.cached_property
750770
def num_records(self):
751-
return self.metadata.num_records
771+
return sum(self.metadata.contig_record_counts.values())
752772

753773
@property
754774
def num_partitions(self):
@@ -883,6 +903,7 @@ def convert(
883903
target_num_partitions = max(1, worker_processes * 4)
884904
vcf_metadata, header = scan_vcfs(
885905
vcfs,
906+
worker_processes=worker_processes,
886907
show_progress=show_progress,
887908
target_num_partitions=target_num_partitions,
888909
)

0 commit comments

Comments
 (0)