@@ -492,78 +492,81 @@ def sanitise_value_int_2d(buff, j, value):
492
492
493
493
def compute_laa_field (variant ) -> np .ndarray :
494
494
sample_count = variant .num_called + variant .num_unknown
495
- alleles = np .empty ((sample_count , 0 ), dtype = int )
495
+ alt_allele_count = len (variant .ALT )
496
+ allele_count = alt_allele_count + 1
497
+ allele_counts = np .zeros ((sample_count , allele_count ), dtype = int )
496
498
497
499
if "GT" in variant .FORMAT :
498
500
# The last element of each sample's genotype indicates the phasing
499
501
# and is not an allele.
500
502
genotypes = variant .genotype .array ()[:, :- 1 ]
501
- genotypes [genotypes < 0 ] = 0
502
- alleles = np .concatenate ((alleles , genotypes ), axis = 1 )
503
+ genotypes .clip (0 , None , out = genotypes )
504
+ genotype_allele_counts = np .apply_along_axis (
505
+ np .bincount , axis = 1 , arr = genotypes , minlength = allele_count
506
+ )
507
+ allele_counts += genotype_allele_counts
503
508
if "AD" in variant .FORMAT :
509
+ depths = variant .format ("AD" )
510
+ depths .clip (0 , None , out = depths )
504
511
505
- def positive_pad (arr , * , length ):
506
- positive = np .where (arr > 0 )[0 ]
507
- pad_length = length - len (positive )
508
- return np .pad (positive , (0 , pad_length ), mode = "constant" , constant_values = 0 )
512
+ def bincount_nonzero (arr , * , minlength ):
513
+ # nonzero returns the indices of the nonzero elements for each axis
514
+ return np .bincount (arr .nonzero ()[0 ], minlength = minlength )
509
515
510
- depths = variant .format ("AD" )
511
- depths = np .apply_along_axis (
512
- positive_pad , axis = 1 , arr = depths , length = depths .shape [1 ]
516
+ depths_allele_counts = np .apply_along_axis (
517
+ bincount_nonzero , axis = 1 , arr = depths , minlength = allele_count
513
518
)
514
- alleles = np . concatenate (( alleles , depths ), axis = 1 )
519
+ allele_counts += depths_allele_counts
515
520
if "PL" in variant .FORMAT :
516
-
517
- def infer_and_pad (arr : np .ndarray , * , ploidy : int , length : int ):
518
- assert ploidy in {1 , 2 }
519
- indices = arr .nonzero ()[0 ]
520
-
521
- if ploidy == 2 :
522
- b = np .ceil (np .sqrt (2 * indices + 9 / 4 ) - 3 / 2 )
523
- a = indices - b * (b + 1 ) / 2
524
- pad_length = length - len (a ) - len (b )
525
- else :
526
- a = indices
527
- b = np .empty (0 )
528
- pad_length = length - len (a )
529
- return np .pad (
530
- np .concatenate ((a , b )),
531
- (0 , pad_length ),
532
- mode = "constant" ,
533
- constant_values = 0 ,
534
- )
535
-
536
521
likelihoods = variant .format ("PL" )
522
+ likelihoods .clip (0 , None , out = likelihoods )
523
+ # n is the indices of the nonzero likelihoods
524
+ n = np .tile (np .arange (likelihoods .shape [1 ]), (likelihoods .shape [0 ], 1 ))
525
+ assert n .shape == likelihoods .shape
526
+ n [likelihoods <= 0 ] = 0
537
527
ploidy = variant .ploidy
538
- likelihoods = np .apply_along_axis (
539
- infer_and_pad ,
540
- axis = 1 ,
541
- arr = likelihoods ,
542
- length = ploidy * likelihoods .shape [1 ],
543
- ploidy = ploidy ,
544
- )
545
- alleles = np .concatenate ((alleles , likelihoods ), axis = 1 )
546
-
547
- max_unique_size = 1
548
-
549
- def unique_pad (arr : np .ndarray , * , length : int ):
550
- unique = np .unique (arr )
551
528
552
- if unique [0 ] == 0 :
553
- # We don't include the reference allele
554
- unique = unique [1 :]
529
+ if ploidy == 1 :
530
+ a = n
531
+ b = np .zeros_like (a )
532
+ elif ploidy == 2 :
533
+ # We have n = b(b+1) / 2 + a
534
+ # We need to compute a and b
535
+ b = np .ceil (np .sqrt (2 * n + 9 / 4 ) - 3 / 2 ).astype (int )
536
+ a = (n - b * (b + 1 ) / 2 ).astype (int )
537
+ else :
538
+ # TODO: Handle all possible ploidy
539
+ raise ValueError (f"Cannot handle ploidy = { ploidy } " )
555
540
556
- nonlocal max_unique_size
557
- max_unique_size = max (max_unique_size , len (unique ))
558
- pad_length = length - len (unique )
541
+ a_counts = np .apply_along_axis (
542
+ np .bincount , axis = 1 , arr = a , minlength = allele_count
543
+ )
544
+ b_counts = np .apply_along_axis (
545
+ np .bincount , axis = 1 , arr = b , minlength = allele_count
546
+ )
547
+ assert a_counts .shape == b_counts .shape == allele_counts .shape
548
+ allele_counts += a_counts
549
+ allele_counts += b_counts
550
+
551
+ allele_counts [:, 0 ] = 0 # We don't count the reference allele
552
+ max_row_length = 1
553
+
554
+ def nonzero_pad (arr : np .ndarray , * , length : int ):
555
+ nonlocal max_row_length
556
+ alleles = arr .nonzero ()[0 ]
557
+ max_row_length = max (max_row_length , len (alleles ))
558
+ pad_length = length - len (alleles )
559
559
return np .pad (
560
- unique , (0 , pad_length ), mode = "constant" , constant_values = constants .INT_FILL
560
+ alleles ,
561
+ (0 , pad_length ),
562
+ mode = "constant" ,
563
+ constant_values = constants .INT_FILL ,
561
564
)
562
565
563
566
alleles = np .apply_along_axis (
564
- unique_pad , axis = 1 , arr = alleles , length = alleles . shape [ 1 ]
567
+ nonzero_pad , axis = 1 , arr = allele_counts , length = max ( 1 , alt_allele_count )
565
568
)
566
- alleles = alleles [:, :max_unique_size ]
569
+ alleles = alleles [:, :max_row_length ]
567
570
568
571
return alleles
569
572
0 commit comments