Skip to content

Commit 92d19c8

Browse files
First pass at LAD
1 parent 50a87b9 commit 92d19c8

File tree

3 files changed

+130
-32
lines changed

3 files changed

+130
-32
lines changed

bio2zarr/vcf2zarr/vcz.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,8 @@ def convert_local_allele_field_types(fields):
199199
raise ValueError("Local alleles only supported on diploid data")
200200
# TODO check if LAA is already in here
201201

202-
# shape = gt.shape[:-1]
203-
# chunks = gt.chunks[:-1]
202+
shape = gt.shape[:-1]
203+
chunks = gt.chunks[:-1]
204204

205205
laa = ZarrArraySpec.new(
206206
vcf_field=None,
@@ -214,6 +214,16 @@ def convert_local_allele_field_types(fields):
214214
" are relevant (local) for the current sample"
215215
),
216216
)
217+
ad = fields_by_name.get("call_AD", None)
218+
if ad is not None:
219+
# TODO check if call_LAD is in the list already
220+
ad.name = "call_LAD"
221+
ad.vcf_field = None
222+
ad.shape = (*shape, 2)
223+
ad.chunks = (*chunks, 2)
224+
ad.description += " (local-alleles)"
225+
# TODO fix dimensions
226+
217227
# pl = fields_by_name.get("call_PL", None)
218228
# if pl is not None:
219229
# # TODO check if call_LPL is in the list already
@@ -513,7 +523,7 @@ def fromdict(d):
513523
return ret
514524

515525

516-
def compute_laa_field(genotypes) -> np.ndarray:
526+
def compute_laa_field(genotypes):
517527
"""
518528
Computes the value of the LAA field for each sample given the genotypes
519529
for a variant.
@@ -539,6 +549,19 @@ def compute_laa_field(genotypes) -> np.ndarray:
539549
return G.astype(genotypes.dtype)
540550

541551

552+
def compute_lad_field(ad, laa):
553+
lad = np.full((ad.shape[0], 2), -2, dtype=ad.dtype)
554+
ref_ref = np.where((laa[:, 0] == -2) & (laa[:, 1] == -2))[0]
555+
lad[ref_ref, 0] = ad[ref_ref, 0]
556+
ref_alt = np.where((laa[:, 0] != -2) & (laa[:, 1] == -2))[0]
557+
lad[ref_alt, 0] = ad[ref_alt, 0]
558+
lad[ref_alt, 1] = ad[ref_alt, laa[ref_alt, 0]]
559+
alt_alt = np.where((laa[:, 0] != -2) & (laa[:, 1] != -2))[0]
560+
lad[alt_alt, 0] = ad[alt_alt, laa[alt_alt, 0]]
561+
lad[alt_alt, 1] = ad[alt_alt, laa[alt_alt, 1]]
562+
return lad
563+
564+
542565
@dataclasses.dataclass
543566
class VcfZarrWriteSummary(core.JsonDataclass):
544567
num_partitions: int
@@ -843,21 +866,35 @@ def encode_genotypes_partition(self, partition_index):
843866
self.finalise_partition_array(partition_index, "call_genotype_phased")
844867

845868
def encode_local_alleles_partition(self, partition_index):
846-
call_LAA_array = self.init_partition_array(partition_index, "call_LAA")
847869
partition = self.metadata.partitions[partition_index]
870+
call_LAA_array = self.init_partition_array(partition_index, "call_LAA")
848871
call_LAA = core.BufferedArray(call_LAA_array, partition.start)
849872

873+
call_LAD_array = self.init_partition_array(partition_index, "call_LAD")
874+
call_LAD = core.BufferedArray(call_LAD_array, partition.start)
875+
call_AD_source = self.icf.fields["FORMAT/AD"].iter_values(
876+
partition.start, partition.stop
877+
)
878+
850879
gt_array = zarr.open_array(
851880
store=self.wip_partition_array_path(partition_index, "call_genotype"),
852881
mode="r",
853882
)
854883
for chunk_index in range(gt_array.cdata_shape[0]):
855884
for genotypes in gt_array.blocks[chunk_index]:
885+
laa = compute_laa_field(genotypes)
856886
j = call_LAA.next_buffer_row()
857-
call_LAA.buff[j] = compute_laa_field(genotypes)
887+
call_LAA.buff[j] = laa
888+
889+
ad = next(call_AD_source)
890+
j = call_LAD.next_buffer_row()
891+
lad = compute_lad_field(ad, laa)
892+
call_LAD.buff[j] = lad
858893

859894
call_LAA.flush()
860895
self.finalise_partition_array(partition_index, "call_LAA")
896+
call_LAD.flush()
897+
self.finalise_partition_array(partition_index, "call_LAD")
861898

862899
def encode_alleles_partition(self, partition_index):
863900
array_name = "variant_allele"

tests/test_local_alleles.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy.testing as nt
33
import pytest
44

5-
from bio2zarr.vcf2zarr.vcz import compute_laa_field
5+
from bio2zarr.vcf2zarr.vcz import compute_laa_field, compute_lad_field
66

77

88
class TestComputeLAA:
@@ -29,3 +29,36 @@ def test_extreme_value(self):
2929
G = np.array([[0, 2**32 - 1]])
3030
with pytest.raises(ValueError, match="Extreme"):
3131
compute_laa_field(G)
32+
33+
34+
class TestComputeLAD:
35+
@pytest.mark.parametrize(
36+
("ad", "laa", "expected"),
37+
[
38+
# 0/0 calls
39+
([[10, 0]], [[-2, -2]], [[10, -2]]),
40+
([[10, 0, 0], [11, 0, 0]], [[-2, -2], [-2, -2]], [[10, -2], [11, -2]]),
41+
# 0/1 calls
42+
([[10, 11]], [[1, -2]], [[10, 11]]),
43+
([[10, 11], [12, 0]], [[1, -2], [-2, -2]], [[10, 11], [12, -2]]),
44+
# 0/2 calls
45+
([[10, 0, 11]], [[2, -2]], [[10, 11]]),
46+
([[10, 0, 11], [10, 11, 0]], [[2, -2], [1, -2]], [[10, 11], [10, 11]]),
47+
(
48+
[[10, 0, 11], [10, 11, 0], [12, 0, 0]],
49+
[[2, -2], [1, -2], [-2, -2]],
50+
[[10, 11], [10, 11], [12, -2]],
51+
),
52+
# 1/2 calls
53+
([[0, 10, 11]], [[1, 2]], [[10, 11]]),
54+
([[0, 10, 11], [12, 0, 13]], [[1, 2], [2, -2]], [[10, 11], [12, 13]]),
55+
(
56+
[[0, 10, 11], [12, 0, 13], [14, 0, 0]],
57+
[[1, 2], [2, -2], [-2, -2]],
58+
[[10, 11], [12, 13], [14, -2]],
59+
),
60+
],
61+
)
62+
def test_simple_examples(self, ad, laa, expected):
63+
result = compute_lad_field(np.array(ad), np.array(laa))
64+
nt.assert_array_equal(result, expected)

tests/test_vcf_examples.py

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -617,33 +617,61 @@ def test_variant_MLEAC(self, ds):
617617
)
618618
nt.assert_array_equal(ds.variant_MLEAC.values, MLEAC)
619619

620-
def test_call_AD(self, ds):
621-
call_AD = [
622-
[[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
623-
[[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
624-
[[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
625-
[[446, 0, 0, -2, -2], [393, 0, 0, -2, -2], [486, 0, 0, -2, -2]],
626-
[[446, 0, 0, 0, -2], [393, 0, 0, 0, -2], [486, 0, 0, 0, -2]],
627-
[[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
628-
[[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
629-
[[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
630-
[[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
631-
[[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
632-
[[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
633-
[[446, 0, 0, 0, 0], [393, 0, 0, 0, 0], [486, 0, 0, 0, 0]],
634-
[[446, 0, 0, 0, 0], [393, 0, 0, 0, 0], [486, 0, 0, 0, 0]],
635-
[[446, 0, 0, 0, 0], [393, 0, 0, 0, 0], [486, 0, 0, 0, 0]],
636-
[[446, 0, 0, 0, -2], [393, 0, 0, 0, -2], [486, 0, 0, 0, -2]],
637-
[[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
638-
[[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
639-
[[446, 0, 0, -2, -2], [393, 0, 0, -2, -2], [486, 0, 0, -2, -2]],
640-
[[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
641-
[[446, 0, 0, -2, -2], [393, 0, 0, -2, -2], [486, 0, 0, -2, -2]],
642-
[[446, 0, 0, 0, -2], [393, 0, 0, 0, -2], [486, 0, 0, 0, -2]],
643-
[[446, 0, 0, 0, 0], [393, 0, 0, 0, 0], [486, 0, 0, 0, 0]],
644-
[[446, 0, 0, 0, -2], [393, 0, 0, 0, -2], [486, 0, 0, 0, -2]],
620+
# def test_call_AD(self, ds):
621+
# call_AD = [
622+
# [[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
623+
# [[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
624+
# [[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
625+
# [[446, 0, 0, -2, -2], [393, 0, 0, -2, -2], [486, 0, 0, -2, -2]],
626+
# [[446, 0, 0, 0, -2], [393, 0, 0, 0, -2], [486, 0, 0, 0, -2]],
627+
# [[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
628+
# [[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
629+
# [[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
630+
# [[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
631+
# [[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
632+
# [[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
633+
# [[446, 0, 0, 0, 0], [393, 0, 0, 0, 0], [486, 0, 0, 0, 0]],
634+
# [[446, 0, 0, 0, 0], [393, 0, 0, 0, 0], [486, 0, 0, 0, 0]],
635+
# [[446, 0, 0, 0, 0], [393, 0, 0, 0, 0], [486, 0, 0, 0, 0]],
636+
# [[446, 0, 0, 0, -2], [393, 0, 0, 0, -2], [486, 0, 0, 0, -2]],
637+
# [[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
638+
# [[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
639+
# [[446, 0, 0, -2, -2], [393, 0, 0, -2, -2], [486, 0, 0, -2, -2]],
640+
# [[446, 0, -2, -2, -2], [393, 0, -2, -2, -2], [486, 0, -2, -2, -2]],
641+
# [[446, 0, 0, -2, -2], [393, 0, 0, -2, -2], [486, 0, 0, -2, -2]],
642+
# [[446, 0, 0, 0, -2], [393, 0, 0, 0, -2], [486, 0, 0, 0, -2]],
643+
# [[446, 0, 0, 0, 0], [393, 0, 0, 0, 0], [486, 0, 0, 0, 0]],
644+
# [[446, 0, 0, 0, -2], [393, 0, 0, 0, -2], [486, 0, 0, 0, -2]],
645+
# ]
646+
# nt.assert_array_equal(ds.call_AD.values, call_AD)
647+
648+
def test_call_LAD(self, ds):
649+
call_LAD = [
650+
[[446, -2], [393, -2], [486, -2]],
651+
[[446, -2], [393, -2], [486, -2]],
652+
[[446, -2], [393, -2], [486, -2]],
653+
[[446, -2], [393, -2], [486, -2]],
654+
[[446, -2], [393, -2], [486, -2]],
655+
[[446, -2], [393, -2], [486, -2]],
656+
[[446, -2], [393, -2], [486, -2]],
657+
[[446, -2], [393, -2], [486, -2]],
658+
[[446, -2], [393, -2], [486, -2]],
659+
[[446, -2], [393, -2], [486, -2]],
660+
[[446, -2], [393, -2], [486, -2]],
661+
[[446, -2], [393, -2], [486, -2]],
662+
[[446, -2], [393, -2], [486, -2]],
663+
[[446, -2], [393, -2], [486, -2]],
664+
[[446, -2], [393, -2], [486, -2]],
665+
[[446, -2], [393, -2], [486, -2]],
666+
[[446, -2], [393, -2], [486, -2]],
667+
[[446, -2], [393, -2], [486, -2]],
668+
[[446, -2], [393, -2], [486, -2]],
669+
[[446, -2], [393, -2], [486, -2]],
670+
[[446, -2], [393, -2], [486, -2]],
671+
[[446, -2], [393, -2], [486, -2]],
672+
[[446, -2], [393, -2], [486, -2]],
645673
]
646-
nt.assert_array_equal(ds.call_AD.values, call_AD)
674+
nt.assert_array_equal(ds.call_LAD.values, call_LAD)
647675

648676
def test_call_LAA(self, ds):
649677
# All the genotypes are 0/0

0 commit comments

Comments
 (0)