@@ -239,6 +239,7 @@ def scan_vcf(path, target_num_partitions, *, local_alleles):
239
239
# Indicates whether vcf2zarr can introduce local alleles
240
240
can_localize = False
241
241
should_add_laa_field = True
242
+ should_add_lpl_field = True
242
243
fields = fixed_vcf_field_definitions ()
243
244
for h in vcf .header_iter ():
244
245
if h ["HeaderType" ] in ["INFO" , "FORMAT" ]:
@@ -252,18 +253,31 @@ def scan_vcf(path, target_num_partitions, *, local_alleles):
252
253
can_localize = True
253
254
if field .name == "LAA" :
254
255
should_add_laa_field = False
255
-
256
- if local_alleles and can_localize and should_add_laa_field :
257
- laa_field = VcfField (
258
- category = "FORMAT" ,
259
- name = "LAA" ,
260
- vcf_type = "Integer" ,
261
- vcf_number = "." ,
262
- description = "1-based indices into ALT, indicating which alleles"
263
- " are relevant (local) for the current sample" ,
264
- summary = VcfFieldSummary (),
265
- )
266
- fields .append (laa_field )
256
+ if field .name == "LPL" :
257
+ should_add_lpl_field = False
258
+
259
+ if local_alleles and can_localize :
260
+ if should_add_laa_field :
261
+ laa_field = VcfField (
262
+ category = "FORMAT" ,
263
+ name = "LAA" ,
264
+ vcf_type = "Integer" ,
265
+ vcf_number = "." ,
266
+ description = "1-based indices into ALT, indicating which alleles"
267
+ " are relevant (local) for the current sample" ,
268
+ summary = VcfFieldSummary (),
269
+ )
270
+ fields .append (laa_field )
271
+ if should_add_lpl_field :
272
+ lpl_field = VcfField (
273
+ category = "FORMAT" ,
274
+ name = "LPL" ,
275
+ vcf_type = "Integer" ,
276
+ vcf_number = "LG" ,
277
+ description = "Local-allele representation of PL" ,
278
+ summary = VcfFieldSummary (),
279
+ )
280
+ fields .append (lpl_field )
267
281
268
282
try :
269
283
contig_lengths = vcf .seqlens
@@ -579,6 +593,56 @@ def nonzero_pad(arr: np.ndarray, *, length: int):
579
593
return alleles
580
594
581
595
596
+ def compute_lpl_field (variant , laa_val : np .ndarray ) -> np .ndarray :
597
+ assert laa_val is not None
598
+
599
+ la_val = np .zeros ((laa_val .shape [0 ], laa_val .shape [1 ] + 1 ), dtype = laa_val .dtype )
600
+ la_val [:, 1 :] = laa_val
601
+ ploidy = variant .ploidy
602
+
603
+ if "PL" not in variant .FORMAT :
604
+ sample_count = variant .num_called + variant .num_unknown
605
+ local_allele_count = la_val .shape [1 ]
606
+
607
+ if ploidy == 1 :
608
+ local_genotype_count = local_allele_count
609
+ elif ploidy == 2 :
610
+ local_genotype_count = local_allele_count * (local_allele_count + 1 ) // 2
611
+ else :
612
+ raise ValueError (f"Cannot handle ploidy = { ploidy } " )
613
+
614
+ return np .full ((sample_count , local_genotype_count ), constants .INT_MISSING )
615
+
616
+ # Compute a and b
617
+ if ploidy == 1 :
618
+ a = la_val
619
+ b = np .zeros_like (la_val )
620
+ elif ploidy == 2 :
621
+ repeats = np .arange (1 , la_val .shape [1 ] + 1 )
622
+ b = np .repeat (la_val , repeats , axis = 1 )
623
+ arange_tile = np .tile (np .arange (la_val .shape [1 ]), (la_val .shape [1 ], 1 ))
624
+ tril_indices = np .tril_indices_from (arange_tile )
625
+ a_index = np .tile (arange_tile [tril_indices ], (b .shape [0 ], 1 ))
626
+ row_index = np .arange (la_val .shape [0 ]).reshape (- 1 , 1 )
627
+ a = la_val [row_index , a_index ]
628
+ else :
629
+ raise ValueError (f"Cannot handle ploidy = { ploidy } " )
630
+
631
+ # Compute n, the local indices of the PL field
632
+ n = (b * (b + 1 ) / 2 + a ).astype (int )
633
+
634
+ pl_val = variant .format ("PL" )
635
+ pl_val [pl_val == constants .VCF_INT_MISSING ] = constants .INT_MISSING
636
+ # When the PL value is missing in all samples, pl_val has shape (sample_count, 1).
637
+ # In that case, we need to broadcast the PL value.
638
+ if pl_val .shape [1 ] < n .shape [1 ]:
639
+ pl_val = np .broadcast_to (pl_val , n .shape )
640
+ row_index = np .arange (pl_val .shape [0 ]).reshape (- 1 , 1 )
641
+ lpl_val = pl_val [row_index , n ]
642
+
643
+ return lpl_val
644
+
645
+
582
646
missing_value_map = {
583
647
"Integer" : constants .INT_MISSING ,
584
648
"Float" : constants .FLOAT32_MISSING ,
@@ -1183,6 +1247,25 @@ def process_partition(self, partition_index):
1183
1247
else :
1184
1248
format_fields .append (field )
1185
1249
1250
+ # We need to determine LAA before LPL
1251
+ try :
1252
+ laa_index = next (
1253
+ index
1254
+ for index , format_field in enumerate (format_fields )
1255
+ if format_field .name == "LAA"
1256
+ )
1257
+ lpl_index = next (
1258
+ index
1259
+ for index , format_field in enumerate (format_fields )
1260
+ if format_field .name == "LPL"
1261
+ )
1262
+
1263
+ if lpl_index < laa_index :
1264
+ format_fields .insert (laa_index + 1 , format_fields [lpl_index ])
1265
+ format_fields .pop (lpl_index )
1266
+ except StopIteration :
1267
+ pass
1268
+
1186
1269
last_position = None
1187
1270
with IcfPartitionWriter (
1188
1271
self .metadata ,
@@ -1209,12 +1292,16 @@ def process_partition(self, partition_index):
1209
1292
else :
1210
1293
val = variant .genotype .array ()
1211
1294
tcw .append ("FORMAT/GT" , val )
1295
+ laa_val = None
1212
1296
for field in format_fields :
1213
- if (
1214
- field .full_name == "FORMAT/LAA"
1215
- and "LAA" not in variant .FORMAT
1216
- ):
1217
- val = compute_laa_field (variant )
1297
+ if field .name == "LAA" :
1298
+ if "LAA" not in variant .FORMAT :
1299
+ laa_val = compute_laa_field (variant )
1300
+ else :
1301
+ laa_val = variant .format ("LAA" )
1302
+ val = laa_val
1303
+ elif field .name == "LPL" and "LPL" not in variant .FORMAT :
1304
+ val = compute_lpl_field (variant , laa_val )
1218
1305
else :
1219
1306
val = variant .format (field .name )
1220
1307
tcw .append (field .full_name , val )
0 commit comments