Skip to content

Commit df24a55

Browse files
Will-Tylerjeromekelleher
authored andcommitted
Parallelize chunk loading
1 parent 0e18357 commit df24a55

File tree

1 file changed

+126
-60
lines changed

1 file changed

+126
-60
lines changed

vcztools/vcf_writer.py

Lines changed: 126 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import concurrent.futures
12
import functools
23
import io
34
import re
@@ -292,74 +293,139 @@ def c_chunk_to_vcf(
292293
drop_genotypes,
293294
no_update,
294295
):
295-
chrom = contigs[get_vchunk_array(root.variant_contig, v_chunk, v_mask_chunk)]
296-
# TODO check we don't truncate silently by doing this
297-
pos = get_vchunk_array(root.variant_position, v_chunk, v_mask_chunk).astype(
298-
np.int32
299-
)
300-
id = get_vchunk_array(root.variant_id, v_chunk, v_mask_chunk).astype("S")
301-
alleles = get_vchunk_array(root.variant_allele, v_chunk, v_mask_chunk)
302-
ref = alleles[:, 0].astype("S")
303-
alt = alleles[:, 1:].astype("S")
304-
qual = get_vchunk_array(root.variant_quality, v_chunk, v_mask_chunk)
305-
filter_ = get_vchunk_array(root.variant_filter, v_chunk, v_mask_chunk)
306-
307-
num_variants = len(pos)
308-
if len(id.shape) == 1:
309-
id = id.reshape((num_variants, 1))
310-
311-
# TODO gathering fields and doing IO will be done separately later so that
312-
# we avoid retrieving stuff we don't need.
296+
chrom = None
297+
pos = None
298+
id = None
299+
alleles = None
300+
qual = None
301+
filter_ = None
313302
format_fields = {}
314303
info_fields = {}
315304
num_samples = len(samples_selection) if samples_selection is not None else None
316-
for name, array in root.items():
317-
if (
318-
name.startswith("call_")
319-
and not name.startswith("call_genotype")
320-
and num_samples != 0
321-
):
322-
vcf_name = name[len("call_") :]
323-
format_fields[vcf_name] = get_vchunk_array(
324-
array, v_chunk, v_mask_chunk, samples_selection
325-
)
326-
if num_samples is None:
327-
num_samples = array.shape[1]
328-
elif name.startswith("variant_") and name not in RESERVED_VARIABLE_NAMES:
329-
vcf_name = name[len("variant_") :]
330-
info_fields[vcf_name] = get_vchunk_array(array, v_chunk, v_mask_chunk)
331-
332305
gt = None
333306
gt_phased = None
334307

335-
if "call_genotype" in root and not drop_genotypes:
336-
array = root["call_genotype"]
308+
def load_chrom():
309+
nonlocal chrom
310+
chrom = contigs[get_vchunk_array(root.variant_contig, v_chunk, v_mask_chunk)]
311+
312+
def load_pos():
313+
nonlocal pos
314+
# TODO check we don't truncate silently by doing this
315+
pos = get_vchunk_array(root.variant_position, v_chunk, v_mask_chunk).astype(
316+
np.int32
317+
)
318+
319+
def load_id():
320+
nonlocal id
321+
id = get_vchunk_array(root.variant_id, v_chunk, v_mask_chunk).astype("S")
322+
323+
def load_alleles():
324+
nonlocal alleles
325+
alleles = get_vchunk_array(root.variant_allele, v_chunk, v_mask_chunk)
326+
327+
def load_qual():
328+
nonlocal qual
329+
qual = get_vchunk_array(root.variant_quality, v_chunk, v_mask_chunk)
330+
331+
def load_filter():
332+
nonlocal filter_
333+
filter_ = get_vchunk_array(root.variant_filter, v_chunk, v_mask_chunk)
334+
335+
def load_format_field(name, zarray):
336+
nonlocal format_fields, v_chunk, v_mask_chunk, samples_selection
337+
vcf_name = name[len("call_") :]
338+
format_fields[vcf_name] = get_vchunk_array(
339+
zarray, v_chunk, v_mask_chunk, samples_selection
340+
)
341+
342+
def load_info_field(name, zarray):
343+
nonlocal info_fields, v_chunk, v_mask_chunk
344+
vcf_name = name[len("variant_") :]
345+
info_fields[vcf_name] = get_vchunk_array(zarray, v_chunk, v_mask_chunk)
346+
347+
def load_gt():
348+
pass
337349

350+
def load_gt_phased():
351+
pass
352+
353+
if "call_genotype" in root and not drop_genotypes:
338354
if samples_selection is not None and num_samples != 0:
339-
gt = get_vchunk_array(array, v_chunk, v_mask_chunk, samples_selection)
355+
356+
def load_gt():
357+
nonlocal gt
358+
gt = get_vchunk_array(
359+
root["call_genotype"], v_chunk, v_mask_chunk, samples_selection
360+
)
340361
else:
341-
gt = get_vchunk_array(array, v_chunk, v_mask_chunk)
342362

343-
if not no_update and samples_selection is not None:
344-
# Recompute INFO/AC and INFO/AN
345-
info_fields |= _compute_info_fields(gt, alt)
346-
if num_samples == 0:
347-
gt = None
363+
def load_gt():
364+
nonlocal gt
365+
gt = get_vchunk_array(root["call_genotype"], v_chunk, v_mask_chunk)
366+
348367
if (
349368
"call_genotype_phased" in root
350369
and not drop_genotypes
351370
and (samples_selection is None or num_samples > 0)
352371
):
353-
array = root["call_genotype_phased"]
354-
gt_phased = get_vchunk_array(
355-
array, v_chunk, v_mask_chunk, samples_selection
356-
)
372+
373+
def load_gt_phased():
374+
nonlocal gt_phased
375+
gt_phased = get_vchunk_array(
376+
root["call_genotype_phased"],
377+
v_chunk,
378+
v_mask_chunk,
379+
samples_selection,
380+
)
357381
else:
358-
gt_phased = np.zeros_like(gt, dtype=bool)
359382

383+
def load_gt_phased():
384+
nonlocal gt_phased
385+
gt_phased = np.zeros_like(gt, dtype=bool)
386+
387+
with concurrent.futures.ThreadPoolExecutor() as executor:
388+
executor.submit(load_chrom)
389+
executor.submit(load_pos)
390+
executor.submit(load_id)
391+
executor.submit(load_alleles)
392+
executor.submit(load_qual)
393+
executor.submit(load_filter)
394+
395+
for name, zarray in root.items():
396+
if (
397+
name.startswith("call_")
398+
and not name.startswith("call_genotype")
399+
and num_samples != 0
400+
):
401+
executor.submit(load_format_field, name, zarray)
402+
if num_samples is None:
403+
num_samples = zarray.shape[1]
404+
elif name.startswith("variant_") and name not in RESERVED_VARIABLE_NAMES:
405+
executor.submit(load_info_field, name, zarray)
406+
407+
executor.submit(load_gt)
408+
executor.submit(load_gt_phased)
409+
410+
ref = alleles[:, 0].astype("S")
411+
alt = alleles[:, 1:].astype("S")
412+
413+
if len(id.shape) == 1:
414+
id = id.reshape((-1, 1))
415+
if (
416+
not no_update
417+
and samples_selection is not None
418+
and "call_genotype" in root
419+
and not drop_genotypes
420+
):
421+
# Recompute INFO/AC and INFO/AN
422+
info_fields |= _compute_info_fields(gt, alt)
423+
if num_samples == 0:
424+
gt = None
360425
if gt is not None and num_samples is None:
361426
num_samples = gt.shape[1]
362427

428+
num_variants = len(pos)
363429
encoder = _vcztools.VcfEncoder(
364430
num_variants,
365431
num_samples if num_samples is not None else 0,
@@ -375,21 +441,21 @@ def c_chunk_to_vcf(
375441
# print(encoder.arrays)
376442
if gt is not None:
377443
encoder.add_gt_field(gt, gt_phased)
378-
for name, array in info_fields.items():
444+
for name, zarray in info_fields.items():
379445
# print(array.dtype.kind)
380-
if array.dtype.kind in ("O", "U"):
381-
array = array.astype("S")
382-
if len(array.shape) == 1:
383-
array = array.reshape((num_variants, 1))
384-
encoder.add_info_field(name, array)
446+
if zarray.dtype.kind in ("O", "U"):
447+
zarray = zarray.astype("S")
448+
if len(zarray.shape) == 1:
449+
zarray = zarray.reshape((num_variants, 1))
450+
encoder.add_info_field(name, zarray)
385451

386452
if num_samples != 0:
387-
for name, array in format_fields.items():
388-
if array.dtype.kind in ("O", "U"):
389-
array = array.astype("S")
390-
if len(array.shape) == 2:
391-
array = array.reshape((num_variants, num_samples, 1))
392-
encoder.add_format_field(name, array)
453+
for name, zarray in format_fields.items():
454+
if zarray.dtype.kind in ("O", "U"):
455+
zarray = zarray.astype("S")
456+
if len(zarray.shape) == 2:
457+
zarray = zarray.reshape((num_variants, num_samples, 1))
458+
encoder.add_format_field(name, zarray)
393459
# TODO: (1) make a guess at this based on number of fields and samples,
394460
# and (2) log a DEBUG message when we have to double.
395461
buflen = 1024

0 commit comments

Comments
 (0)