Skip to content

Commit 2b508ad

Browse files
Simple implementation of compute_laa and tests
Fixup Close LAA loophole
1 parent d347e37 commit 2b508ad

File tree

2 files changed

+50
-40
lines changed

2 files changed

+50
-40
lines changed

bio2zarr/vcf2zarr/vcz.py

Lines changed: 19 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -513,44 +513,30 @@ def fromdict(d):
513513
return ret
514514

515515

516-
def compute_laa_field(genotypes, alleles) -> np.ndarray:
516+
def compute_laa_field(genotypes) -> np.ndarray:
517517
"""
518518
Computes the value of the LAA field for each sample given the genotypes
519519
for a variant.
520520
521521
The LAA field is a list of one-based indices into the ALT alleles
522522
that indicates which alternate alleles are observed in the sample.
523523
"""
524-
alt_allele_count = len(alleles) - 1
525-
allele_counts = np.zeros((genotypes.shape[0], len(alleles)), dtype=int)
526-
527-
genotypes = genotypes.clip(0, None)
528-
genotype_allele_counts = np.apply_along_axis(
529-
np.bincount, axis=1, arr=genotypes, minlength=len(alleles)
530-
)
531-
allele_counts += genotype_allele_counts
532-
533-
allele_counts[:, 0] = 0 # We don't count the reference allele
534-
max_row_length = 1
535-
536-
def nonzero_pad(arr: np.ndarray, *, length: int):
537-
nonlocal max_row_length
538-
alleles = arr.nonzero()[0]
539-
max_row_length = max(max_row_length, len(alleles))
540-
pad_length = length - len(alleles)
541-
return np.pad(
542-
alleles,
543-
(0, pad_length),
544-
mode="constant",
545-
constant_values=constants.INT_FILL,
546-
)
547-
548-
alleles = np.apply_along_axis(
549-
nonzero_pad, axis=1, arr=allele_counts, length=max(1, alt_allele_count)
550-
)
551-
alleles = alleles[:, :max_row_length]
552-
553-
return alleles
524+
v = 2**31 - 1
525+
if np.any(genotypes >= v):
526+
raise ValueError("Extreme allele value not supported")
527+
G = genotypes.astype(np.int32)
528+
if len(G) > 0:
529+
# Anything <=0 gets mapped to -2 (pad) in the output, which comes last.
530+
# So, to get this sorting correctly, we remap to the largest value for
531+
# sorting, then map back. We promote the genotypes up to 32 bit for convenience
532+
# here, assuming that we'll never have a allele of 2**31 - 1.
533+
assert np.all(G != v)
534+
G[G <= 0] = v
535+
G.sort(axis=1)
536+
# Equal non-zero values result in padding also
537+
G[G[:, 0] == G[:, 1], 1] = -2
538+
G[G == v] = -2
539+
return G.astype(genotypes.dtype)
554540

555541

556542
@dataclasses.dataclass
@@ -865,17 +851,10 @@ def encode_local_alleles_partition(self, partition_index):
865851
store=self.wip_partition_array_path(partition_index, "call_genotype"),
866852
mode="r",
867853
)
868-
alleles_array = zarr.open_array(
869-
store=self.wip_partition_array_path(partition_index, "variant_allele"),
870-
mode="r",
871-
)
872854
for chunk_index in range(gt_array.cdata_shape[0]):
873-
A = alleles_array.blocks[chunk_index]
874-
G = gt_array.blocks[chunk_index]
875-
for alleles, var in zip(A, G):
855+
for genotypes in gt_array.blocks[chunk_index]:
876856
j = call_LAA.next_buffer_row()
877-
# TODO we should probably compute LAAs by chunk for efficiency
878-
call_LAA.buff[j] = compute_laa_field(var, alleles)
857+
call_LAA.buff[j] = compute_laa_field(genotypes)
879858

880859
call_LAA.flush()
881860
self.finalise_partition_array(partition_index, "call_LAA")

tests/test_local_alleles.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import numpy as np
2+
import numpy.testing as nt
3+
import pytest
4+
5+
from bio2zarr.vcf2zarr.vcz import compute_laa_field
6+
7+
8+
class TestComputeLAA:
9+
@pytest.mark.parametrize(
10+
("genotypes", "expected"),
11+
[
12+
([], []),
13+
([[0, 0]], [[-2, -2]]),
14+
([[0, 0], [0, 0]], [[-2, -2], [-2, -2]]),
15+
([[1, 1], [0, 0]], [[1, -2], [-2, -2]]),
16+
([[0, 1], [3, 2], [3, 0]], [[1, -2], [2, 3], [3, -2]]),
17+
([[0, 0], [2, 3]], [[-2, -2], [2, 3]]),
18+
([[2, 3], [0, 0]], [[2, 3], [-2, -2]]),
19+
([[128, 0], [6, 5]], [[128, -2], [5, 6]]),
20+
([[0, -1], [-1, 5]], [[-2, -2], [5, -2]]),
21+
],
22+
)
23+
def test_simple_examples(self, genotypes, expected):
24+
G = np.array(genotypes)
25+
result = compute_laa_field(G)
26+
nt.assert_array_equal(result, expected)
27+
28+
def test_extreme_value(self):
29+
G = np.array([[0, 2**32 - 1]])
30+
with pytest.raises(ValueError, match="Extreme"):
31+
compute_laa_field(G)

0 commit comments

Comments
 (0)