@@ -513,44 +513,30 @@ def fromdict(d):
513
513
return ret
514
514
515
515
516
- def compute_laa_field (genotypes , alleles ) -> np .ndarray :
516
+ def compute_laa_field (genotypes ) -> np .ndarray :
517
517
"""
518
518
Computes the value of the LAA field for each sample given the genotypes
519
519
for a variant.
520
520
521
521
The LAA field is a list of one-based indices into the ALT alleles
522
522
that indicates which alternate alleles are observed in the sample.
523
523
"""
524
- alt_allele_count = len (alleles ) - 1
525
- allele_counts = np .zeros ((genotypes .shape [0 ], len (alleles )), dtype = int )
526
-
527
- genotypes = genotypes .clip (0 , None )
528
- genotype_allele_counts = np .apply_along_axis (
529
- np .bincount , axis = 1 , arr = genotypes , minlength = len (alleles )
530
- )
531
- allele_counts += genotype_allele_counts
532
-
533
- allele_counts [:, 0 ] = 0 # We don't count the reference allele
534
- max_row_length = 1
535
-
536
- def nonzero_pad (arr : np .ndarray , * , length : int ):
537
- nonlocal max_row_length
538
- alleles = arr .nonzero ()[0 ]
539
- max_row_length = max (max_row_length , len (alleles ))
540
- pad_length = length - len (alleles )
541
- return np .pad (
542
- alleles ,
543
- (0 , pad_length ),
544
- mode = "constant" ,
545
- constant_values = constants .INT_FILL ,
546
- )
547
-
548
- alleles = np .apply_along_axis (
549
- nonzero_pad , axis = 1 , arr = allele_counts , length = max (1 , alt_allele_count )
550
- )
551
- alleles = alleles [:, :max_row_length ]
552
-
553
- return alleles
524
+ v = 2 ** 31 - 1
525
+ if np .any (genotypes >= v ):
526
+ raise ValueError ("Extreme allele value not supported" )
527
+ G = genotypes .astype (np .int32 )
528
+ if len (G ) > 0 :
529
+ # Anything <=0 gets mapped to -2 (pad) in the output, which comes last.
530
+ # So, to get this sorting correctly, we remap to the largest value for
531
+ # sorting, then map back. We promote the genotypes up to 32 bit for convenience
532
+ # here, assuming that we'll never have a allele of 2**31 - 1.
533
+ assert np .all (G != v )
534
+ G [G <= 0 ] = v
535
+ G .sort (axis = 1 )
536
+ # Equal non-zero values result in padding also
537
+ G [G [:, 0 ] == G [:, 1 ], 1 ] = - 2
538
+ G [G == v ] = - 2
539
+ return G .astype (genotypes .dtype )
554
540
555
541
556
542
@dataclasses .dataclass
@@ -865,17 +851,10 @@ def encode_local_alleles_partition(self, partition_index):
865
851
store = self .wip_partition_array_path (partition_index , "call_genotype" ),
866
852
mode = "r" ,
867
853
)
868
- alleles_array = zarr .open_array (
869
- store = self .wip_partition_array_path (partition_index , "variant_allele" ),
870
- mode = "r" ,
871
- )
872
854
for chunk_index in range (gt_array .cdata_shape [0 ]):
873
- A = alleles_array .blocks [chunk_index ]
874
- G = gt_array .blocks [chunk_index ]
875
- for alleles , var in zip (A , G ):
855
+ for genotypes in gt_array .blocks [chunk_index ]:
876
856
j = call_LAA .next_buffer_row ()
877
- # TODO we should probably compute LAAs by chunk for efficiency
878
- call_LAA .buff [j ] = compute_laa_field (var , alleles )
857
+ call_LAA .buff [j ] = compute_laa_field (genotypes )
879
858
880
859
call_LAA .flush ()
881
860
self .finalise_partition_array (partition_index , "call_LAA" )
0 commit comments