@@ -486,25 +486,54 @@ def sanitise_value_int_2d(buff, j, value):
486
486
buff [j , :, : value .shape [1 ]] = value
487
487
488
488
489
- def compute_laa_field (variant ) -> list [ list [ int ]] :
489
+ def compute_laa_field (variant ) -> np . ndarray :
490
490
sample_count = variant .num_called + variant .num_unknown
491
- laa_val = [ set () for _ in range (sample_count )]
491
+ alleles = np . empty ( (sample_count , 0 ), dtype = int )
492
492
493
493
if "GT" in variant .FORMAT :
494
- for sample_index , genotype in enumerate ( variant . genotypes ):
495
- # The last element in the genotype is not an allele.
496
- for allele in genotype [: - 1 ]:
497
- if allele > 0 :
498
- laa_val [ sample_index ]. add ( allele )
494
+ # The last element of each sample's genotype indicates the phasing
495
+ # and is not an allele.
496
+ genotypes = variant . genotype . array ()[:, : - 1 ]
497
+ genotypes [ genotypes < 0 ] = 0
498
+ alleles = np . concatenate (( alleles , genotypes ), axis = 1 )
499
499
if "AD" in variant .FORMAT :
500
- for sample_index , ad in enumerate (variant .format ("AD" )):
501
- # The first depth in AD is for the reference allele.
502
- alt_alleles = set (
503
- allele_index for allele_index , depth in enumerate (ad [1 :]) if depth > 0
504
- )
505
- laa_val [sample_index ] |= alt_alleles
506
500
507
- return [sorted (laa ) for laa in laa_val ]
501
+ def positive_pad (arr , * , length ):
502
+ positive = np .where (arr > 0 )[0 ]
503
+ pad_length = length - len (positive )
504
+ return np .pad (positive , (0 , pad_length ), mode = "constant" , constant_values = 0 )
505
+
506
+ depths = variant .format ("AD" )
507
+ depths = np .apply_along_axis (
508
+ positive_pad , axis = 1 , arr = depths , length = depths .shape [1 ]
509
+ )
510
+ alleles = np .concatenate ((alleles , depths ), axis = 1 )
511
+ if "PL" in variant .FORMAT :
512
+ # TODO
513
+ pass
514
+
515
+ max_unique_size = 1
516
+
517
+ def unique_pad (arr : np .ndarray , * , length : int ):
518
+ unique = np .unique (arr )
519
+
520
+ if unique [0 ] == 0 :
521
+ # We don't include the reference allele
522
+ unique = unique [1 :]
523
+
524
+ nonlocal max_unique_size
525
+ max_unique_size = max (max_unique_size , len (unique ))
526
+ pad_length = length - len (unique )
527
+ return np .pad (
528
+ unique , (0 , pad_length ), mode = "constant" , constant_values = constants .INT_FILL
529
+ )
530
+
531
+ alleles = np .apply_along_axis (
532
+ unique_pad , axis = 1 , arr = alleles , length = alleles .shape [0 ]
533
+ )
534
+ alleles = alleles [:, :max_unique_size ]
535
+
536
+ return alleles
508
537
509
538
510
539
missing_value_map = {
@@ -1138,21 +1167,10 @@ def process_partition(self, partition_index):
1138
1167
for field in format_fields :
1139
1168
if field .full_name == "FORMAT/LAA" :
1140
1169
laa_val = compute_laa_field (variant )
1141
- # Convert laa_val to a NumPy array
1142
- max_laa_len = max (len (laa ) for laa in laa_val )
1143
- # At minimum, we want to have at least one value per sample
1144
- # so that the field is present.
1145
- max_laa_len = max (1 , max_laa_len )
1146
- laa_val = [
1147
- sorted (laa )
1148
- + [constants .INT_FILL ] * (max_laa_len - len (laa ))
1149
- for laa in laa_val
1150
- ]
1151
- laa_val = np .array (laa_val )
1152
1170
tcw .append ("FORMAT/LAA" , laa_val )
1153
- continue
1154
- val = variant .format (field .name )
1155
- tcw .append (field .full_name , val )
1171
+ else :
1172
+ val = variant .format (field .name )
1173
+ tcw .append (field .full_name , val )
1156
1174
1157
1175
# Note: an issue with updating the progress per variant here like
1158
1176
# this is that we get a significant pause at the end of the counter
0 commit comments