@@ -197,6 +197,8 @@ def convert_local_allele_field_types(fields):
197197 gt = fields_by_name ["call_genotype" ]
198198 if gt .shape [- 1 ] != 2 :
199199 raise ValueError ("Local alleles only supported on diploid data" )
200+ # TODO check if LAA is already in here
201+
200202 shape = gt .shape [:- 1 ]
201203 chunks = gt .chunks [:- 1 ]
202204
@@ -214,6 +216,7 @@ def convert_local_allele_field_types(fields):
214216 )
215217 pl = fields_by_name .get ("call_PL" , None )
216218 if pl is not None :
219+ # TODO check if call_LPL is in the list already
217220 pl .name = "call_LPL"
218221 pl .vcf_field = None
219222 pl .shape = (* shape , 3 )
@@ -510,6 +513,46 @@ def fromdict(d):
510513 return ret
511514
512515
516+ def compute_laa_field (genotypes , alleles ) -> np .ndarray :
517+ """
518+ Computes the value of the LAA field for each sample given the genotypes
519+ for a variant.
520+
521+ The LAA field is a list of one-based indices into the ALT alleles
522+ that indicates which alternate alleles are observed in the sample.
523+ """
524+ alt_allele_count = len (alleles ) - 1
525+ allele_counts = np .zeros ((genotypes .shape [0 ], len (alleles )), dtype = int )
526+
527+ genotypes = genotypes .clip (0 , None )
528+ genotype_allele_counts = np .apply_along_axis (
529+ np .bincount , axis = 1 , arr = genotypes , minlength = len (alleles )
530+ )
531+ allele_counts += genotype_allele_counts
532+
533+ allele_counts [:, 0 ] = 0 # We don't count the reference allele
534+ max_row_length = 1
535+
536+ def nonzero_pad (arr : np .ndarray , * , length : int ):
537+ nonlocal max_row_length
538+ alleles = arr .nonzero ()[0 ]
539+ max_row_length = max (max_row_length , len (alleles ))
540+ pad_length = length - len (alleles )
541+ return np .pad (
542+ alleles ,
543+ (0 , pad_length ),
544+ mode = "constant" ,
545+ constant_values = constants .INT_FILL ,
546+ )
547+
548+ alleles = np .apply_along_axis (
549+ nonzero_pad , axis = 1 , arr = allele_counts , length = max (1 , alt_allele_count )
550+ )
551+ alleles = alleles [:, :max_row_length ]
552+
553+ return alleles
554+
555+
513556@dataclasses .dataclass
514557class VcfZarrWriteSummary (core .JsonDataclass ):
515558 num_partitions : int
@@ -542,6 +585,12 @@ def has_genotypes(self):
542585 return True
543586 return False
544587
588+ def has_local_alleles (self ):
589+ for field in self .schema .fields :
590+ if field .name == "call_LAA" and field .vcf_field is None :
591+ return True
592+ return False
593+
545594 #######################
546595 # init
547596 #######################
@@ -734,6 +783,8 @@ def encode_partition(self, partition_index):
734783 self .encode_array_partition (array_spec , partition_index )
735784 if self .has_genotypes ():
736785 self .encode_genotypes_partition (partition_index )
786+ if self .has_local_alleles ():
787+ self .encode_local_alleles_partition (partition_index )
737788
738789 final_path = self .partition_path (partition_index )
739790 logger .info (f"Finalising { partition_index } at { final_path } " )
@@ -805,6 +856,30 @@ def encode_genotypes_partition(self, partition_index):
805856 self .finalise_partition_array (partition_index , "call_genotype_mask" )
806857 self .finalise_partition_array (partition_index , "call_genotype_phased" )
807858
859+ def encode_local_alleles_partition (self , partition_index ):
860+ call_LAA_array = self .init_partition_array (partition_index , "call_LAA" )
861+ partition = self .metadata .partitions [partition_index ]
862+ call_LAA = core .BufferedArray (call_LAA_array , partition .start )
863+
864+ gt_array = zarr .open_array (
865+ store = self .wip_partition_array_path (partition_index , "call_genotype" ),
866+ mode = "r" ,
867+ )
868+ alleles_array = zarr .open_array (
869+ store = self .wip_partition_array_path (partition_index , "variant_allele" ),
870+ mode = "r" ,
871+ )
872+ for chunk_index in range (gt_array .cdata_shape [0 ]):
873+ A = alleles_array .blocks [chunk_index ]
874+ G = gt_array .blocks [chunk_index ]
875+ for alleles , var in zip (A , G ):
876+ j = call_LAA .next_buffer_row ()
877+ # TODO we should probably compute LAAs by chunk for efficiency
878+ call_LAA .buff [j ] = compute_laa_field (var , alleles )
879+
880+ call_LAA .flush ()
881+ self .finalise_partition_array (partition_index , "call_LAA" )
882+
808883 def encode_alleles_partition (self , partition_index ):
809884 array_name = "variant_allele"
810885 alleles_array = self .init_partition_array (partition_index , array_name )
0 commit comments