Skip to content

Commit c222e8b

Browse files
Will-Tylerjeromekelleher
authored andcommitted
Simplify compute_laa_field
1 parent ad46178 commit c222e8b

File tree

1 file changed

+56
-53
lines changed

1 file changed

+56
-53
lines changed

bio2zarr/vcf2zarr/icf.py

Lines changed: 56 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -492,78 +492,81 @@ def sanitise_value_int_2d(buff, j, value):
492492

493493
def compute_laa_field(variant) -> np.ndarray:
494494
sample_count = variant.num_called + variant.num_unknown
495-
alleles = np.empty((sample_count, 0), dtype=int)
495+
alt_allele_count = len(variant.ALT)
496+
allele_count = alt_allele_count + 1
497+
allele_counts = np.zeros((sample_count, allele_count), dtype=int)
496498

497499
if "GT" in variant.FORMAT:
498500
# The last element of each sample's genotype indicates the phasing
499501
# and is not an allele.
500502
genotypes = variant.genotype.array()[:, :-1]
501-
genotypes[genotypes < 0] = 0
502-
alleles = np.concatenate((alleles, genotypes), axis=1)
503+
genotypes.clip(0, None, out=genotypes)
504+
genotype_allele_counts = np.apply_along_axis(
505+
np.bincount, axis=1, arr=genotypes, minlength=allele_count
506+
)
507+
allele_counts += genotype_allele_counts
503508
if "AD" in variant.FORMAT:
509+
depths = variant.format("AD")
510+
depths.clip(0, None, out=depths)
504511

505-
def positive_pad(arr, *, length):
506-
positive = np.where(arr > 0)[0]
507-
pad_length = length - len(positive)
508-
return np.pad(positive, (0, pad_length), mode="constant", constant_values=0)
512+
def bincount_nonzero(arr, *, minlength):
513+
# nonzero returns the indices of the nonzero elements for each axis
514+
return np.bincount(arr.nonzero()[0], minlength=minlength)
509515

510-
depths = variant.format("AD")
511-
depths = np.apply_along_axis(
512-
positive_pad, axis=1, arr=depths, length=depths.shape[1]
516+
depths_allele_counts = np.apply_along_axis(
517+
bincount_nonzero, axis=1, arr=depths, minlength=allele_count
513518
)
514-
alleles = np.concatenate((alleles, depths), axis=1)
519+
allele_counts += depths_allele_counts
515520
if "PL" in variant.FORMAT:
516-
517-
def infer_and_pad(arr: np.ndarray, *, ploidy: int, length: int):
518-
assert ploidy in {1, 2}
519-
indices = arr.nonzero()[0]
520-
521-
if ploidy == 2:
522-
b = np.ceil(np.sqrt(2 * indices + 9 / 4) - 3 / 2)
523-
a = indices - b * (b + 1) / 2
524-
pad_length = length - len(a) - len(b)
525-
else:
526-
a = indices
527-
b = np.empty(0)
528-
pad_length = length - len(a)
529-
return np.pad(
530-
np.concatenate((a, b)),
531-
(0, pad_length),
532-
mode="constant",
533-
constant_values=0,
534-
)
535-
536521
likelihoods = variant.format("PL")
522+
likelihoods.clip(0, None, out=likelihoods)
523+
# n is the indices of the nonzero likelihoods
524+
n = np.tile(np.arange(likelihoods.shape[1]), (likelihoods.shape[0], 1))
525+
assert n.shape == likelihoods.shape
526+
n[likelihoods <= 0] = 0
537527
ploidy = variant.ploidy
538-
likelihoods = np.apply_along_axis(
539-
infer_and_pad,
540-
axis=1,
541-
arr=likelihoods,
542-
length=ploidy * likelihoods.shape[1],
543-
ploidy=ploidy,
544-
)
545-
alleles = np.concatenate((alleles, likelihoods), axis=1)
546-
547-
max_unique_size = 1
548-
549-
def unique_pad(arr: np.ndarray, *, length: int):
550-
unique = np.unique(arr)
551528

552-
if unique[0] == 0:
553-
# We don't include the reference allele
554-
unique = unique[1:]
529+
if ploidy == 1:
530+
a = n
531+
b = np.zeros_like(a)
532+
elif ploidy == 2:
533+
# We have n = b(b+1) / 2 + a
534+
# We need to compute a and b
535+
b = np.ceil(np.sqrt(2 * n + 9 / 4) - 3 / 2).astype(int)
536+
a = (n - b * (b + 1) / 2).astype(int)
537+
else:
538+
# TODO: Handle all possible ploidy
539+
raise ValueError(f"Cannot handle ploidy = {ploidy}")
555540

556-
nonlocal max_unique_size
557-
max_unique_size = max(max_unique_size, len(unique))
558-
pad_length = length - len(unique)
541+
a_counts = np.apply_along_axis(
542+
np.bincount, axis=1, arr=a, minlength=allele_count
543+
)
544+
b_counts = np.apply_along_axis(
545+
np.bincount, axis=1, arr=b, minlength=allele_count
546+
)
547+
assert a_counts.shape == b_counts.shape == allele_counts.shape
548+
allele_counts += a_counts
549+
allele_counts += b_counts
550+
551+
allele_counts[:, 0] = 0 # We don't count the reference allele
552+
max_row_length = 1
553+
554+
def nonzero_pad(arr: np.ndarray, *, length: int):
555+
nonlocal max_row_length
556+
alleles = arr.nonzero()[0]
557+
max_row_length = max(max_row_length, len(alleles))
558+
pad_length = length - len(alleles)
559559
return np.pad(
560-
unique, (0, pad_length), mode="constant", constant_values=constants.INT_FILL
560+
alleles,
561+
(0, pad_length),
562+
mode="constant",
563+
constant_values=constants.INT_FILL,
561564
)
562565

563566
alleles = np.apply_along_axis(
564-
unique_pad, axis=1, arr=alleles, length=alleles.shape[1]
567+
nonzero_pad, axis=1, arr=allele_counts, length=max(1, alt_allele_count)
565568
)
566-
alleles = alleles[:, :max_unique_size]
569+
alleles = alleles[:, :max_row_length]
567570

568571
return alleles
569572

0 commit comments

Comments
 (0)