@@ -197,6 +197,8 @@ def convert_local_allele_field_types(fields):
197
197
gt = fields_by_name ["call_genotype" ]
198
198
if gt .shape [- 1 ] != 2 :
199
199
raise ValueError ("Local alleles only supported on diploid data" )
200
+ # TODO check if LAA is already in here
201
+
200
202
shape = gt .shape [:- 1 ]
201
203
chunks = gt .chunks [:- 1 ]
202
204
@@ -214,6 +216,7 @@ def convert_local_allele_field_types(fields):
214
216
)
215
217
pl = fields_by_name .get ("call_PL" , None )
216
218
if pl is not None :
219
+ # TODO check if call_LPL is in the list already
217
220
pl .name = "call_LPL"
218
221
pl .vcf_field = None
219
222
pl .shape = (* shape , 3 )
@@ -511,6 +514,46 @@ def fromdict(d):
511
514
return ret
512
515
513
516
517
+ def compute_laa_field (genotypes , alleles ) -> np .ndarray :
518
+ """
519
+ Computes the value of the LAA field for each sample given the genotypes
520
+ for a variant.
521
+
522
+ The LAA field is a list of one-based indices into the ALT alleles
523
+ that indicates which alternate alleles are observed in the sample.
524
+ """
525
+ alt_allele_count = len (alleles ) - 1
526
+ allele_counts = np .zeros ((genotypes .shape [0 ], len (alleles )), dtype = int )
527
+
528
+ genotypes = genotypes .clip (0 , None )
529
+ genotype_allele_counts = np .apply_along_axis (
530
+ np .bincount , axis = 1 , arr = genotypes , minlength = len (alleles )
531
+ )
532
+ allele_counts += genotype_allele_counts
533
+
534
+ allele_counts [:, 0 ] = 0 # We don't count the reference allele
535
+ max_row_length = 1
536
+
537
+ def nonzero_pad (arr : np .ndarray , * , length : int ):
538
+ nonlocal max_row_length
539
+ alleles = arr .nonzero ()[0 ]
540
+ max_row_length = max (max_row_length , len (alleles ))
541
+ pad_length = length - len (alleles )
542
+ return np .pad (
543
+ alleles ,
544
+ (0 , pad_length ),
545
+ mode = "constant" ,
546
+ constant_values = constants .INT_FILL ,
547
+ )
548
+
549
+ alleles = np .apply_along_axis (
550
+ nonzero_pad , axis = 1 , arr = allele_counts , length = max (1 , alt_allele_count )
551
+ )
552
+ alleles = alleles [:, :max_row_length ]
553
+
554
+ return alleles
555
+
556
+
514
557
@dataclasses .dataclass
515
558
class VcfZarrWriteSummary (core .JsonDataclass ):
516
559
num_partitions : int
@@ -543,6 +586,12 @@ def has_genotypes(self):
543
586
return True
544
587
return False
545
588
589
+ def has_local_alleles (self ):
590
+ for field in self .schema .fields :
591
+ if field .name == "call_LAA" and field .vcf_field is None :
592
+ return True
593
+ return False
594
+
546
595
#######################
547
596
# init
548
597
#######################
@@ -729,6 +778,8 @@ def encode_partition(self, partition_index):
729
778
self .encode_array_partition (array_spec , partition_index )
730
779
if self .has_genotypes ():
731
780
self .encode_genotypes_partition (partition_index )
781
+ if self .has_local_alleles ():
782
+ self .encode_local_alleles_partition (partition_index )
732
783
733
784
final_path = self .partition_path (partition_index )
734
785
logger .info (f"Finalising { partition_index } at { final_path } " )
@@ -800,6 +851,30 @@ def encode_genotypes_partition(self, partition_index):
800
851
self .finalise_partition_array (partition_index , "call_genotype_mask" )
801
852
self .finalise_partition_array (partition_index , "call_genotype_phased" )
802
853
854
+ def encode_local_alleles_partition (self , partition_index ):
855
+ call_LAA_array = self .init_partition_array (partition_index , "call_LAA" )
856
+ partition = self .metadata .partitions [partition_index ]
857
+ call_LAA = core .BufferedArray (call_LAA_array , partition .start )
858
+
859
+ gt_array = zarr .open_array (
860
+ store = self .wip_partition_array_path (partition_index , "call_genotype" ),
861
+ mode = "r" ,
862
+ )
863
+ alleles_array = zarr .open_array (
864
+ store = self .wip_partition_array_path (partition_index , "variant_allele" ),
865
+ mode = "r" ,
866
+ )
867
+ for chunk_index in range (gt_array .cdata_shape [0 ]):
868
+ A = alleles_array .blocks [chunk_index ]
869
+ G = gt_array .blocks [chunk_index ]
870
+ for alleles , var in zip (A , G ):
871
+ j = call_LAA .next_buffer_row ()
872
+ # TODO we should probably compute LAAs by chunk for efficiency
873
+ call_LAA .buff [j ] = compute_laa_field (var , alleles )
874
+
875
+ call_LAA .flush ()
876
+ self .finalise_partition_array (partition_index , "call_LAA" )
877
+
803
878
def encode_alleles_partition (self , partition_index ):
804
879
array_name = "variant_allele"
805
880
alleles_array = self .init_partition_array (partition_index , array_name )
0 commit comments