Skip to content

Commit cf1947f

Browse files
Improve local alleles implementation to reduce memory
1 parent 7b72ffd commit cf1947f

File tree

2 files changed

+100
-27
lines changed

2 files changed

+100
-27
lines changed

bio2zarr/vcf2zarr/vcz.py

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,24 @@ def compute_lpl_field(pl, la):
584584
return lpl
585585

586586

587+
@dataclasses.dataclass
588+
class LocalisableFieldDescriptor:
589+
array_name: str
590+
vcf_field: str
591+
sanitise: callable
592+
convert: callable
593+
594+
595+
localisable_fields = [
596+
LocalisableFieldDescriptor(
597+
"call_LAD", "FORMAT/AD", icf.sanitise_int_array, compute_lad_field
598+
),
599+
LocalisableFieldDescriptor(
600+
"call_LPL", "FORMAT/PL", icf.sanitise_int_array, compute_lpl_field
601+
),
602+
]
603+
604+
587605
@dataclasses.dataclass
588606
class VcfZarrWriteSummary(core.JsonDataclass):
589607
num_partitions: int
@@ -816,6 +834,7 @@ def encode_partition(self, partition_index):
816834
self.encode_genotypes_partition(partition_index)
817835
if self.has_local_alleles():
818836
self.encode_local_alleles_partition(partition_index)
837+
self.encode_local_allele_fields_partition(partition_index)
819838

820839
final_path = self.partition_path(partition_index)
821840
logger.info(f"Finalising {partition_index} at {final_path}")
@@ -892,17 +911,6 @@ def encode_local_alleles_partition(self, partition_index):
892911
call_LA_array = self.init_partition_array(partition_index, "call_LA")
893912
call_LA = core.BufferedArray(call_LA_array, partition.start)
894913

895-
call_LAD_array = self.init_partition_array(partition_index, "call_LAD")
896-
call_LAD = core.BufferedArray(call_LAD_array, partition.start)
897-
call_AD_source = self.icf.fields["FORMAT/AD"].iter_values(
898-
partition.start, partition.stop
899-
)
900-
call_LPL_array = self.init_partition_array(partition_index, "call_LPL")
901-
call_LPL = core.BufferedArray(call_LPL_array, partition.start)
902-
call_PL_source = self.icf.fields["FORMAT/PL"].iter_values(
903-
partition.start, partition.stop
904-
)
905-
906914
gt_array = zarr.open_array(
907915
store=self.wip_partition_array_path(partition_index, "call_genotype"),
908916
mode="r",
@@ -914,24 +922,37 @@ def encode_local_alleles_partition(self, partition_index):
914922
j = call_LA.next_buffer_row()
915923
call_LA.buff[j] = la
916924

917-
ad = next(call_AD_source)
918-
ad = icf.sanitise_int_array(ad, 2, ad.dtype)
919-
k = call_LAD.next_buffer_row()
920-
assert j == k
921-
call_LAD.buff[j] = compute_lad_field(ad, la)
922-
923-
pl = next(call_PL_source)
924-
pl = icf.sanitise_int_array(pl, 2, pl.dtype)
925-
k = call_LPL.next_buffer_row()
926-
assert j == k
927-
call_LPL.buff[j] = compute_lpl_field(pl, la)
928-
929925
call_LA.flush()
930926
self.finalise_partition_array(partition_index, "call_LA")
931-
call_LAD.flush()
932-
self.finalise_partition_array(partition_index, "call_LAD")
933-
call_LPL.flush()
934-
self.finalise_partition_array(partition_index, "call_LPL")
927+
928+
def encode_local_allele_fields_partition(self, partition_index):
929+
partition = self.metadata.partitions[partition_index]
930+
la_array = zarr.open_array(
931+
store=self.wip_partition_array_path(partition_index, "call_LA"),
932+
mode="r",
933+
)
934+
field_map = self.schema.field_map()
935+
# We got through the localisable fields one-by-one so that we don't need to
936+
# keep several large arrays in memory at once for each partition.
937+
for descriptor in localisable_fields:
938+
if descriptor.array_name not in field_map:
939+
continue
940+
assert field_map[descriptor.array_name].vcf_field is None
941+
942+
array = self.init_partition_array(partition_index, descriptor.array_name)
943+
buff = core.BufferedArray(array, partition.start)
944+
source = self.icf.fields[descriptor.vcf_field].iter_values(
945+
partition.start, partition.stop
946+
)
947+
for la in core.first_dim_slice_iter(
948+
la_array, partition.start, partition.stop
949+
):
950+
raw_value = next(source)
951+
value = descriptor.sanitise(raw_value, 2, raw_value.dtype)
952+
j = buff.next_buffer_row()
953+
buff.buff[j] = descriptor.convert(value, la)
954+
buff.flush()
955+
self.finalise_partition_array(partition_index, "array_name")
935956

936957
def encode_alleles_partition(self, partition_index):
937958
array_name = "variant_allele"

tests/test_vcf_examples.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,59 @@ def test_vcf_field_description(self, ds, field, description):
444444
assert ds[field].attrs["description"] == description
445445

446446

447+
class TestSmallExampleLocalAlleles:
448+
data_path = "tests/data/vcf/sample.vcf.gz"
449+
450+
@pytest.fixture(scope="class")
451+
def ds(self, tmp_path_factory):
452+
out = tmp_path_factory.mktemp("data") / "example.vcf.zarr"
453+
vcf2zarr.convert([self.data_path], out, local_alleles=True)
454+
return sg.load_dataset(out)
455+
456+
def test_call_LA(self, ds):
457+
call_genotype = np.array(
458+
[
459+
[[0, 0], [0, 0], [0, 1]],
460+
[[0, 0], [0, 0], [0, 1]],
461+
[[0, 0], [1, 0], [1, 1]],
462+
[[0, 0], [0, 1], [0, 0]],
463+
[[1, 2], [2, 1], [2, 2]],
464+
[[0, 0], [0, 0], [0, 0]],
465+
[[0, 1], [0, 2], [-1, -1]],
466+
[[0, 0], [0, 0], [-1, -1]],
467+
# FIXME this depends on "mixed ploidy" interpretation.
468+
[[0, -2], [0, 1], [0, 2]],
469+
],
470+
dtype="i1",
471+
)
472+
nt.assert_array_equal(ds["call_genotype"], call_genotype)
473+
nt.assert_array_equal(ds["call_genotype_mask"], call_genotype < 0)
474+
475+
call_LA = np.array(
476+
[
477+
[[0, -2], [0, -2], [0, 1]],
478+
[[0, -2], [0, -2], [0, 1]],
479+
[[0, -2], [0, 1], [1, -2]],
480+
[[0, -2], [0, 1], [0, -2]],
481+
[[1, 2], [1, 2], [2, -2]],
482+
[[0, -2], [0, -2], [0, -2]],
483+
[[0, 1], [0, 2], [-2, -2]],
484+
[[0, -2], [0, -2], [-2, -2]],
485+
[[0, -2], [0, 1], [0, 2]],
486+
],
487+
)
488+
nt.assert_array_equal(ds.call_LA.values, call_LA)
489+
490+
@pytest.mark.parametrize("field", ["call_LPL", "call_LAD"])
491+
def test_no_localised_fields(self, ds, field):
492+
assert field not in ds
493+
494+
447495
class TestLocalAllelesExample:
496+
# Note this example has a mixture of local and non local fields, and uses
497+
# the VCF standard LA field. It's not actually testing our implementation
498+
# of localisation.
499+
448500
data_path = "tests/data/vcf/local_alleles.vcf.gz"
449501

450502
@pytest.fixture(scope="class")

0 commit comments

Comments
 (0)