Skip to content

Commit 4accfc5

Browse files
Simple implementation of compute_laa and tests
1 parent d347e37 commit 4accfc5

File tree

2 files changed

+44
-31
lines changed

2 files changed

+44
-31
lines changed

bio2zarr/vcf2zarr/vcz.py

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -513,44 +513,27 @@ def fromdict(d):
513513
return ret
514514

515515

516-
def compute_laa_field(genotypes, alleles) -> np.ndarray:
516+
def compute_laa_field(genotypes) -> np.ndarray:
517517
"""
518518
Computes the value of the LAA field for each sample given the genotypes
519519
for a variant.
520520
521521
The LAA field is a list of one-based indices into the ALT alleles
522522
that indicates which alternate alleles are observed in the sample.
523523
"""
524-
alt_allele_count = len(alleles) - 1
525-
allele_counts = np.zeros((genotypes.shape[0], len(alleles)), dtype=int)
526-
527-
genotypes = genotypes.clip(0, None)
528-
genotype_allele_counts = np.apply_along_axis(
529-
np.bincount, axis=1, arr=genotypes, minlength=len(alleles)
530-
)
531-
allele_counts += genotype_allele_counts
532-
533-
allele_counts[:, 0] = 0 # We don't count the reference allele
534-
max_row_length = 1
535-
536-
def nonzero_pad(arr: np.ndarray, *, length: int):
537-
nonlocal max_row_length
538-
alleles = arr.nonzero()[0]
539-
max_row_length = max(max_row_length, len(alleles))
540-
pad_length = length - len(alleles)
541-
return np.pad(
542-
alleles,
543-
(0, pad_length),
544-
mode="constant",
545-
constant_values=constants.INT_FILL,
546-
)
547-
548-
alleles = np.apply_along_axis(
549-
nonzero_pad, axis=1, arr=allele_counts, length=max(1, alt_allele_count)
550-
)
551-
alleles = alleles[:, :max_row_length]
552-
553-
return alleles
524+
v = 2**31 - 1
525+
if np.any(genotypes >= v):
526+
raise ValueError("Extreme allele value not supported")
527+
G = genotypes.astype(np.int32)
528+
# Anything <=0 gets mapped to -2 (pad) in the output, which comes last.
529+
# So, to get this sorting correctly, we remap to the largest value for
530+
# sorting, then map back. We promote the genotypes up to 32 bit for convenience
531+
# here, assuming that we'll never have a allele of 2**31 - 1.
532+
assert np.all(G != v)
533+
G[G <= 0] = v
534+
G.sort(axis=1)
535+
G[G == v] = -2
536+
return G.astype(genotypes.dtype)
554537

555538

556539
@dataclasses.dataclass

tests/test_local_alleles.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import numpy as np
2+
import numpy.testing as nt
3+
import pytest
4+
5+
from bio2zarr.vcf2zarr.vcz import compute_laa_field
6+
7+
8+
class TestComputeLAA:
9+
@pytest.mark.parametrize(
10+
("genotypes", "expected"),
11+
[
12+
([[]], [[]]),
13+
([[0, 0]], [[-2, -2]]),
14+
([[0, 0], [0, 0]], [[-2, -2], [-2, -2]]),
15+
([[0, 1], [3, 2], [3, 0]], [[1, -2], [2, 3], [3, -2]]),
16+
([[0, 0], [2, 3]], [[-2, -2], [2, 3]]),
17+
([[2, 3], [0, 0]], [[2, 3], [-2, -2]]),
18+
([[128, 0], [6, 5]], [[128, -2], [5, 6]]),
19+
([[0, -1], [-1, 5]], [[-2, -2], [5, -2]]),
20+
],
21+
)
22+
def test_simple_examples(self, genotypes, expected):
23+
G = np.array(genotypes)
24+
result = compute_laa_field(G)
25+
nt.assert_array_equal(result, expected)
26+
27+
def test_extreme_value(self):
28+
G = np.array([[0, 2**32 - 1]])
29+
with pytest.raises(ValueError, match="Extreme"):
30+
compute_laa_field(G)

0 commit comments

Comments
 (0)