Skip to content

Commit abed893

Browse files
Fixups for PL
1 parent 0f89efb commit abed893

File tree

3 files changed

+154
-78
lines changed

3 files changed

+154
-78
lines changed

bio2zarr/vcf2zarr/vcz.py

Lines changed: 79 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -197,20 +197,21 @@ def convert_local_allele_field_types(fields):
197197
gt = fields_by_name["call_genotype"]
198198
if gt.shape[-1] != 2:
199199
raise ValueError("Local alleles only supported on diploid data")
200-
# TODO check if LAA is already in here
200+
201+
# TODO check if LA is already in here
201202

202203
shape = gt.shape[:-1]
203204
chunks = gt.chunks[:-1]
204205

205-
laa = ZarrArraySpec.new(
206+
la = ZarrArraySpec.new(
206207
vcf_field=None,
207-
name="call_LAA",
208+
name="call_LA",
208209
dtype="i1",
209210
shape=gt.shape,
210211
chunks=gt.chunks,
211212
dimensions=gt.dimensions, # FIXME
212213
description=(
213-
"1-based indices into ALT, indicating which alleles"
214+
"0-based indices into REF+ALT, indicating which alleles"
214215
" are relevant (local) for the current sample"
215216
),
216217
)
@@ -224,16 +225,16 @@ def convert_local_allele_field_types(fields):
224225
ad.description += " (local-alleles)"
225226
# TODO fix dimensions
226227

227-
# pl = fields_by_name.get("call_PL", None)
228-
# if pl is not None:
229-
# # TODO check if call_LPL is in the list already
230-
# pl.name = "call_LPL"
231-
# pl.vcf_field = None
232-
# pl.shape = (*shape, 3)
233-
# pl.chunks = (*chunks, 3)
234-
# pl.description += " (local-alleles)"
235-
# # TODO fix dimensions
236-
return [*fields, laa]
228+
pl = fields_by_name.get("call_PL", None)
229+
if pl is not None:
230+
# TODO check if call_LPL is in the list already
231+
pl.name = "call_LPL"
232+
pl.vcf_field = None
233+
pl.shape = (*shape, 3)
234+
pl.chunks = (*chunks, 3)
235+
pl.description += " (local-alleles)"
236+
# TODO fix dimensions
237+
return [*fields, la]
237238

238239

239240
@dataclasses.dataclass
@@ -523,50 +524,66 @@ def fromdict(d):
523524
return ret
524525

525526

526-
def compute_laa_field(genotypes):
527+
def compute_la_field(genotypes):
527528
"""
528-
Computes the value of the LAA field for each sample given the genotypes
529-
for a variant.
530-
531-
The LAA field is a list of one-based indices into the ALT alleles
532-
that indicates which alternate alleles are observed in the sample.
529+
Computes the value of the LA field for each sample given the genotypes
530+
for a variant. The LA field lists the unique alleles observed for
531+
each sample, including the REF.
533532
"""
534533
v = 2**31 - 1
535534
if np.any(genotypes >= v):
536535
raise ValueError("Extreme allele value not supported")
537536
G = genotypes.astype(np.int32)
538537
if len(G) > 0:
539-
# Anything <=0 gets mapped to -2 (pad) in the output, which comes last.
538+
# Anything < 0 gets mapped to -2 (pad) in the output, which comes last.
540539
# So, to get this sorting correctly, we remap to the largest value for
541540
# sorting, then map back. We promote the genotypes up to 32 bit for convenience
542541
# here, assuming that we'll never have a allele of 2**31 - 1.
543542
assert np.all(G != v)
544-
G[G <= 0] = v
543+
G[G < 0] = v
545544
G.sort(axis=1)
546-
# Equal non-zero values result in padding also
547545
G[G[:, 0] == G[:, 1], 1] = -2
546+
# Equal values result in padding also
548547
G[G == v] = -2
549548
return G.astype(genotypes.dtype)
550549

551550

552-
def compute_lad_field(ad, laa):
553-
try:
554-
lad = np.full((ad.shape[0], 2), -2, dtype=ad.dtype)
555-
ref_ref = np.where((laa[:, 0] == -2) & (laa[:, 1] == -2))[0]
556-
lad[ref_ref, 0] = ad[ref_ref, 0]
557-
ref_alt = np.where((laa[:, 0] != -2) & (laa[:, 1] == -2))[0]
558-
lad[ref_alt, 0] = ad[ref_alt, 0]
559-
lad[ref_alt, 1] = ad[ref_alt, laa[ref_alt, 0]]
560-
alt_alt = np.where((laa[:, 0] != -2) & (laa[:, 1] != -2))[0]
561-
lad[alt_alt, 0] = ad[alt_alt, laa[alt_alt, 0]]
562-
lad[alt_alt, 1] = ad[alt_alt, laa[alt_alt, 1]]
563-
except Exception as e:
564-
print("ad = ", ad)
565-
print("laa = ", laa)
566-
raise e
551+
def compute_lad_field(ad, la):
552+
assert ad.shape[0] == la.shape[0]
553+
assert la.shape[1] == 2
554+
lad = np.full((ad.shape[0], 2), -2, dtype=ad.dtype)
555+
homs = np.where((la[:, 0] != -2) & (la[:, 1] == -2))
556+
lad[homs, 0] = ad[homs, la[homs, 0]]
557+
hets = np.where(la[:, 1] != -2)
558+
lad[hets, 0] = ad[hets, la[hets, 0]]
559+
lad[hets, 1] = ad[hets, la[hets, 1]]
567560
return lad
568561

569562

563+
def pl_index(a, b):
564+
"""
565+
Returns the PL index for alleles a and b.
566+
"""
567+
return b * (b + 1) // 2 + a
568+
569+
570+
def compute_lpl_field(pl, la):
571+
lpl = np.full((pl.shape[0], 3), -2, dtype=pl.dtype)
572+
573+
homs = np.where((la[:, 0] != -2) & (la[:, 1] == -2))
574+
a = la[homs, 0]
575+
lpl[homs, 0] = pl[homs, pl_index(a, a)]
576+
577+
hets = np.where(la[:, 1] != -2)[0]
578+
a = la[hets, 0]
579+
b = la[hets, 1]
580+
lpl[hets, 0] = pl[hets, pl_index(a, a)]
581+
lpl[hets, 1] = pl[hets, pl_index(a, b)]
582+
lpl[hets, 2] = pl[hets, pl_index(b, b)]
583+
584+
return lpl
585+
586+
570587
@dataclasses.dataclass
571588
class VcfZarrWriteSummary(core.JsonDataclass):
572589
num_partitions: int
@@ -601,7 +618,7 @@ def has_genotypes(self):
601618

602619
def has_local_alleles(self):
603620
for field in self.schema.fields:
604-
if field.name == "call_LAA" and field.vcf_field is None:
621+
if field.name == "call_LA" and field.vcf_field is None:
605622
return True
606623
return False
607624

@@ -872,35 +889,49 @@ def encode_genotypes_partition(self, partition_index):
872889

873890
def encode_local_alleles_partition(self, partition_index):
874891
partition = self.metadata.partitions[partition_index]
875-
call_LAA_array = self.init_partition_array(partition_index, "call_LAA")
876-
call_LAA = core.BufferedArray(call_LAA_array, partition.start)
892+
call_LA_array = self.init_partition_array(partition_index, "call_LA")
893+
call_LA = core.BufferedArray(call_LA_array, partition.start)
877894

878895
call_LAD_array = self.init_partition_array(partition_index, "call_LAD")
879896
call_LAD = core.BufferedArray(call_LAD_array, partition.start)
880897
call_AD_source = self.icf.fields["FORMAT/AD"].iter_values(
881898
partition.start, partition.stop
882899
)
900+
call_LPL_array = self.init_partition_array(partition_index, "call_LPL")
901+
call_LPL = core.BufferedArray(call_LPL_array, partition.start)
902+
call_PL_source = self.icf.fields["FORMAT/PL"].iter_values(
903+
partition.start, partition.stop
904+
)
905+
883906
gt_array = zarr.open_array(
884907
store=self.wip_partition_array_path(partition_index, "call_genotype"),
885908
mode="r",
886909
)
887910
for genotypes in core.first_dim_slice_iter(
888911
gt_array, partition.start, partition.stop
889912
):
890-
laa = compute_laa_field(genotypes)
891-
j = call_LAA.next_buffer_row()
892-
call_LAA.buff[j] = laa
913+
la = compute_la_field(genotypes)
914+
j = call_LA.next_buffer_row()
915+
call_LA.buff[j] = la
893916

894917
ad = next(call_AD_source)
918+
ad = icf.sanitise_int_array(ad, 2, ad.dtype)
895919
k = call_LAD.next_buffer_row()
896920
assert j == k
897-
lad = compute_lad_field(ad, laa)
898-
call_LAD.buff[j] = lad
921+
call_LAD.buff[j] = compute_lad_field(ad, la)
922+
923+
pl = next(call_PL_source)
924+
pl = icf.sanitise_int_array(pl, 2, pl.dtype)
925+
k = call_LPL.next_buffer_row()
926+
assert j == k
927+
call_LPL.buff[j] = compute_lpl_field(pl, la)
899928

900-
call_LAA.flush()
901-
self.finalise_partition_array(partition_index, "call_LAA")
929+
call_LA.flush()
930+
self.finalise_partition_array(partition_index, "call_LA")
902931
call_LAD.flush()
903932
self.finalise_partition_array(partition_index, "call_LAD")
933+
call_LPL.flush()
934+
self.finalise_partition_array(partition_index, "call_LPL")
904935

905936
def encode_alleles_partition(self, partition_index):
906937
array_name = "variant_allele"

tests/test_local_alleles.py

Lines changed: 68 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,63 +2,107 @@
22
import numpy.testing as nt
33
import pytest
44

5-
from bio2zarr.vcf2zarr.vcz import compute_laa_field, compute_lad_field
5+
from bio2zarr.vcf2zarr import vcz
66

77

8-
class TestComputeLAA:
8+
class TestComputeLA:
99
@pytest.mark.parametrize(
1010
("genotypes", "expected"),
1111
[
1212
([], []),
13-
([[0, 0]], [[-2, -2]]),
14-
([[0, 0], [0, 0]], [[-2, -2], [-2, -2]]),
15-
([[1, 1], [0, 0]], [[1, -2], [-2, -2]]),
16-
([[0, 1], [3, 2], [3, 0]], [[1, -2], [2, 3], [3, -2]]),
17-
([[0, 0], [2, 3]], [[-2, -2], [2, 3]]),
18-
([[2, 3], [0, 0]], [[2, 3], [-2, -2]]),
19-
([[128, 0], [6, 5]], [[128, -2], [5, 6]]),
20-
([[0, -1], [-1, 5]], [[-2, -2], [5, -2]]),
13+
([[0, 0]], [[0, -2]]),
14+
([[0, 0], [0, 0]], [[0, -2], [0, -2]]),
15+
([[1, 1], [0, 0]], [[1, -2], [0, -2]]),
16+
([[0, 1], [3, 2], [3, 0]], [[0, 1], [2, 3], [0, 3]]),
17+
([[0, 0], [2, 3]], [[0, -2], [2, 3]]),
18+
([[2, 3], [0, 0]], [[2, 3], [0, -2]]),
19+
([[128, 0], [6, 5]], [[0, 128], [5, 6]]),
20+
([[0, -1], [-1, 5]], [[0, -2], [5, -2]]),
21+
([[-1, -1], [-1, 5]], [[-2, -2], [5, -2]]),
2122
],
2223
)
2324
def test_simple_examples(self, genotypes, expected):
2425
G = np.array(genotypes)
25-
result = compute_laa_field(G)
26+
result = vcz.compute_la_field(G)
2627
nt.assert_array_equal(result, expected)
2728

2829
def test_extreme_value(self):
2930
G = np.array([[0, 2**32 - 1]])
3031
with pytest.raises(ValueError, match="Extreme"):
31-
compute_laa_field(G)
32+
vcz.compute_la_field(G)
3233

3334

3435
class TestComputeLAD:
3536
@pytest.mark.parametrize(
36-
("ad", "laa", "expected"),
37+
("ad", "la", "expected"),
3738
[
39+
# Missing data
40+
([[0, 0]], [[-2, -2]], [[-2, -2]]),
3841
# 0/0 calls
39-
([[10, 0]], [[-2, -2]], [[10, -2]]),
40-
([[10, 0, 0], [11, 0, 0]], [[-2, -2], [-2, -2]], [[10, -2], [11, -2]]),
42+
([[10, 0]], [[0, -2]], [[10, -2]]),
43+
([[10, 0, 0]], [[0, -2]], [[10, -2]]),
44+
([[10, 0, 0], [11, 0, 0]], [[0, -2], [0, -2]], [[10, -2], [11, -2]]),
4145
# 0/1 calls
42-
([[10, 11]], [[1, -2]], [[10, 11]]),
43-
([[10, 11], [12, 0]], [[1, -2], [-2, -2]], [[10, 11], [12, -2]]),
46+
([[10, 11]], [[0, 1]], [[10, 11]]),
47+
([[10, 11], [12, 0]], [[0, 1], [0, -2]], [[10, 11], [12, -2]]),
4448
# 0/2 calls
45-
([[10, 0, 11]], [[2, -2]], [[10, 11]]),
46-
([[10, 0, 11], [10, 11, 0]], [[2, -2], [1, -2]], [[10, 11], [10, 11]]),
49+
([[10, 0, 11]], [[0, 2]], [[10, 11]]),
50+
([[10, 0, 11], [10, 11, 0]], [[0, 2], [0, 1]], [[10, 11], [10, 11]]),
4751
(
4852
[[10, 0, 11], [10, 11, 0], [12, 0, 0]],
49-
[[2, -2], [1, -2], [-2, -2]],
53+
[[0, 2], [0, 1], [0, -2]],
5054
[[10, 11], [10, 11], [12, -2]],
5155
),
5256
# 1/2 calls
5357
([[0, 10, 11]], [[1, 2]], [[10, 11]]),
54-
([[0, 10, 11], [12, 0, 13]], [[1, 2], [2, -2]], [[10, 11], [12, 13]]),
58+
([[0, 10, 11], [12, 0, 13]], [[1, 2], [0, 2]], [[10, 11], [12, 13]]),
5559
(
5660
[[0, 10, 11], [12, 0, 13], [14, 0, 0]],
57-
[[1, 2], [2, -2], [-2, -2]],
61+
[[1, 2], [0, 2], [0, -2]],
5862
[[10, 11], [12, 13], [14, -2]],
5963
),
6064
],
6165
)
62-
def test_simple_examples(self, ad, laa, expected):
63-
result = compute_lad_field(np.array(ad), np.array(laa))
66+
def test_simple_examples(self, ad, la, expected):
67+
result = vcz.compute_lad_field(np.array(ad), np.array(la))
68+
nt.assert_array_equal(result, expected)
69+
70+
71+
# PL translation indexes:
72+
# a b i
73+
# 0 0 0
74+
# 0 1 1
75+
# 0 2 3
76+
# 0 3 6
77+
# 1 1 2
78+
# 1 2 4
79+
# 1 3 7
80+
# 2 2 5
81+
# 2 3 8
82+
# 3 3 9
83+
84+
85+
class TestComputeLPL:
86+
@pytest.mark.parametrize(
87+
("pl", "la", "expected"),
88+
[
89+
# Missing
90+
([range(3)], [[-2, -2]], [[-2, -2, -2]]),
91+
# 0/0 calls
92+
([range(3)], [[0, -2]], [[0, -2, -2]]),
93+
# 0/0 calls
94+
([[-1, -1, -1]], [[0, -2]], [[-1, -2, -2]]),
95+
# 1/1 calls
96+
([range(3)], [[1, -2]], [[2, -2, -2]]),
97+
([range(3), range(3)], [[0, -2], [1, -2]], [[0, -2, -2], [2, -2, -2]]),
98+
# 2/2 calls
99+
([range(6)], [[2, -2]], [[5, -2, -2]]),
100+
# 0/1 calls
101+
([range(3)], [[0, 1]], [[0, 1, 2]]),
102+
# 0/2 calls
103+
([range(6)], [[0, 2]], [[0, 3, 5]]),
104+
],
105+
)
106+
def test_simple_examples(self, pl, la, expected):
107+
result = vcz.compute_lpl_field(np.array(pl), np.array(la))
64108
nt.assert_array_equal(result, expected)

tests/test_vcf_examples.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -725,14 +725,15 @@ def test_call_LAD(self, ds):
725725
]
726726
nt.assert_array_equal(ds.call_LAD.values, call_LAD)
727727

728-
def test_call_LAA(self, ds):
728+
def test_call_LA(self, ds):
729729
# All the genotypes are 0/0
730-
call_LAA = np.full((23, 3, 2), -2)
731-
nt.assert_array_equal(ds.call_LAA.values, call_LAA)
730+
call_LA = np.full((23, 3, 2), -2)
731+
call_LA[:, :, 0] = 0
732+
nt.assert_array_equal(ds.call_LA.values, call_LA)
732733

733-
# def test_call_LPL(self, ds):
734-
# call_LPL = np.tile([0, -2, -2], (23, 3, 1))
735-
# nt.assert_array_equal(ds.call_LPL.values, call_LPL)
734+
def test_call_LPL(self, ds):
735+
call_LPL = np.tile([0, -2, -2], (23, 3, 1))
736+
nt.assert_array_equal(ds.call_LPL.values, call_LPL)
736737

737738

738739
class Test1000G2020AnnotationsExample:

0 commit comments

Comments
 (0)