@@ -513,44 +513,30 @@ def fromdict(d):
513513 return ret
514514
515515
516- def compute_laa_field (genotypes , alleles ) -> np .ndarray :
516+ def compute_laa_field (genotypes ) -> np .ndarray :
517517 """
518518 Computes the value of the LAA field for each sample given the genotypes
519519 for a variant.
520520
521521 The LAA field is a list of one-based indices into the ALT alleles
522522 that indicates which alternate alleles are observed in the sample.
523523 """
524- alt_allele_count = len (alleles ) - 1
525- allele_counts = np .zeros ((genotypes .shape [0 ], len (alleles )), dtype = int )
526-
527- genotypes = genotypes .clip (0 , None )
528- genotype_allele_counts = np .apply_along_axis (
529- np .bincount , axis = 1 , arr = genotypes , minlength = len (alleles )
530- )
531- allele_counts += genotype_allele_counts
532-
533- allele_counts [:, 0 ] = 0 # We don't count the reference allele
534- max_row_length = 1
535-
536- def nonzero_pad (arr : np .ndarray , * , length : int ):
537- nonlocal max_row_length
538- alleles = arr .nonzero ()[0 ]
539- max_row_length = max (max_row_length , len (alleles ))
540- pad_length = length - len (alleles )
541- return np .pad (
542- alleles ,
543- (0 , pad_length ),
544- mode = "constant" ,
545- constant_values = constants .INT_FILL ,
546- )
547-
548- alleles = np .apply_along_axis (
549- nonzero_pad , axis = 1 , arr = allele_counts , length = max (1 , alt_allele_count )
550- )
551- alleles = alleles [:, :max_row_length ]
552-
553- return alleles
524+ v = 2 ** 31 - 1
525+ if np .any (genotypes >= v ):
526+ raise ValueError ("Extreme allele value not supported" )
527+ G = genotypes .astype (np .int32 )
528+ if len (G ) > 0 :
529+ # Anything <=0 gets mapped to -2 (pad) in the output, which comes last.
530+ # So, to get this sorting correctly, we remap to the largest value for
531+ # sorting, then map back. We promote the genotypes up to 32 bit for convenience
532+ # here, assuming that we'll never have a allele of 2**31 - 1.
533+ assert np .all (G != v )
534+ G [G <= 0 ] = v
535+ G .sort (axis = 1 )
536+ # Equal non-zero values result in padding also
537+ G [G [:, 0 ] == G [:, 1 ], 1 ] = - 2
538+ G [G == v ] = - 2
539+ return G .astype (genotypes .dtype )
554540
555541
556542@dataclasses .dataclass
@@ -865,17 +851,10 @@ def encode_local_alleles_partition(self, partition_index):
865851 store = self .wip_partition_array_path (partition_index , "call_genotype" ),
866852 mode = "r" ,
867853 )
868- alleles_array = zarr .open_array (
869- store = self .wip_partition_array_path (partition_index , "variant_allele" ),
870- mode = "r" ,
871- )
872854 for chunk_index in range (gt_array .cdata_shape [0 ]):
873- A = alleles_array .blocks [chunk_index ]
874- G = gt_array .blocks [chunk_index ]
875- for alleles , var in zip (A , G ):
855+ for genotypes in gt_array .blocks [chunk_index ]:
876856 j = call_LAA .next_buffer_row ()
877- # TODO we should probably compute LAAs by chunk for efficiency
878- call_LAA .buff [j ] = compute_laa_field (var , alleles )
857+ call_LAA .buff [j ] = compute_laa_field (genotypes )
879858
880859 call_LAA .flush ()
881860 self .finalise_partition_array (partition_index , "call_LAA" )
0 commit comments