Skip to content

Commit 659646d

Browse files
Will-Tylerjeromekelleher
authored andcommitted
Add LPL field during explode step
1 parent 4366ec1 commit 659646d

File tree

3 files changed

+107
-17
lines changed

3 files changed

+107
-17
lines changed

bio2zarr/vcf2zarr/icf.py

Lines changed: 104 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def scan_vcf(path, target_num_partitions, *, local_alleles):
239239
# Indicates whether vcf2zarr can introduce local alleles
240240
can_localize = False
241241
should_add_laa_field = True
242+
should_add_lpl_field = True
242243
fields = fixed_vcf_field_definitions()
243244
for h in vcf.header_iter():
244245
if h["HeaderType"] in ["INFO", "FORMAT"]:
@@ -252,18 +253,31 @@ def scan_vcf(path, target_num_partitions, *, local_alleles):
252253
can_localize = True
253254
if field.name == "LAA":
254255
should_add_laa_field = False
255-
256-
if local_alleles and can_localize and should_add_laa_field:
257-
laa_field = VcfField(
258-
category="FORMAT",
259-
name="LAA",
260-
vcf_type="Integer",
261-
vcf_number=".",
262-
description="1-based indices into ALT, indicating which alleles"
263-
" are relevant (local) for the current sample",
264-
summary=VcfFieldSummary(),
265-
)
266-
fields.append(laa_field)
256+
if field.name == "LPL":
257+
should_add_lpl_field = False
258+
259+
if local_alleles and can_localize:
260+
if should_add_laa_field:
261+
laa_field = VcfField(
262+
category="FORMAT",
263+
name="LAA",
264+
vcf_type="Integer",
265+
vcf_number=".",
266+
description="1-based indices into ALT, indicating which alleles"
267+
" are relevant (local) for the current sample",
268+
summary=VcfFieldSummary(),
269+
)
270+
fields.append(laa_field)
271+
if should_add_lpl_field:
272+
lpl_field = VcfField(
273+
category="FORMAT",
274+
name="LPL",
275+
vcf_type="Integer",
276+
vcf_number="LG",
277+
description="Local-allele representation of PL",
278+
summary=VcfFieldSummary(),
279+
)
280+
fields.append(lpl_field)
267281

268282
try:
269283
contig_lengths = vcf.seqlens
@@ -579,6 +593,56 @@ def nonzero_pad(arr: np.ndarray, *, length: int):
579593
return alleles
580594

581595

596+
def compute_lpl_field(variant, laa_val: np.ndarray) -> np.ndarray:
597+
assert laa_val is not None
598+
599+
la_val = np.zeros((laa_val.shape[0], laa_val.shape[1] + 1), dtype=laa_val.dtype)
600+
la_val[:, 1:] = laa_val
601+
ploidy = variant.ploidy
602+
603+
if "PL" not in variant.FORMAT:
604+
sample_count = variant.num_called + variant.num_unknown
605+
local_allele_count = la_val.shape[1]
606+
607+
if ploidy == 1:
608+
local_genotype_count = local_allele_count
609+
elif ploidy == 2:
610+
local_genotype_count = local_allele_count * (local_allele_count + 1) // 2
611+
else:
612+
raise ValueError(f"Cannot handle ploidy = {ploidy}")
613+
614+
return np.full((sample_count, local_genotype_count), constants.INT_MISSING)
615+
616+
# Compute a and b
617+
if ploidy == 1:
618+
a = la_val
619+
b = np.zeros_like(la_val)
620+
elif ploidy == 2:
621+
repeats = np.arange(1, la_val.shape[1] + 1)
622+
b = np.repeat(la_val, repeats, axis=1)
623+
arange_tile = np.tile(np.arange(la_val.shape[1]), (la_val.shape[1], 1))
624+
tril_indices = np.tril_indices_from(arange_tile)
625+
a_index = np.tile(arange_tile[tril_indices], (b.shape[0], 1))
626+
row_index = np.arange(la_val.shape[0]).reshape(-1, 1)
627+
a = la_val[row_index, a_index]
628+
else:
629+
raise ValueError(f"Cannot handle ploidy = {ploidy}")
630+
631+
# Compute n, the local indices of the PL field
632+
n = (b * (b + 1) / 2 + a).astype(int)
633+
634+
pl_val = variant.format("PL")
635+
pl_val[pl_val == constants.VCF_INT_MISSING] = constants.INT_MISSING
636+
# When the PL value is missing in all samples, pl_val has shape (sample_count, 1).
637+
# In that case, we need to broadcast the PL value.
638+
if pl_val.shape[1] < n.shape[1]:
639+
pl_val = np.broadcast_to(pl_val, n.shape)
640+
row_index = np.arange(pl_val.shape[0]).reshape(-1, 1)
641+
lpl_val = pl_val[row_index, n]
642+
643+
return lpl_val
644+
645+
582646
missing_value_map = {
583647
"Integer": constants.INT_MISSING,
584648
"Float": constants.FLOAT32_MISSING,
@@ -1183,6 +1247,25 @@ def process_partition(self, partition_index):
11831247
else:
11841248
format_fields.append(field)
11851249

1250+
# We need to determine LAA before LPL
1251+
try:
1252+
laa_index = next(
1253+
index
1254+
for index, format_field in enumerate(format_fields)
1255+
if format_field.name == "LAA"
1256+
)
1257+
lpl_index = next(
1258+
index
1259+
for index, format_field in enumerate(format_fields)
1260+
if format_field.name == "LPL"
1261+
)
1262+
1263+
if lpl_index < laa_index:
1264+
format_fields.insert(laa_index + 1, format_fields[lpl_index])
1265+
format_fields.pop(lpl_index)
1266+
except StopIteration:
1267+
pass
1268+
11861269
last_position = None
11871270
with IcfPartitionWriter(
11881271
self.metadata,
@@ -1209,12 +1292,16 @@ def process_partition(self, partition_index):
12091292
else:
12101293
val = variant.genotype.array()
12111294
tcw.append("FORMAT/GT", val)
1295+
laa_val = None
12121296
for field in format_fields:
1213-
if (
1214-
field.full_name == "FORMAT/LAA"
1215-
and "LAA" not in variant.FORMAT
1216-
):
1217-
val = compute_laa_field(variant)
1297+
if field.name == "LAA":
1298+
if "LAA" not in variant.FORMAT:
1299+
laa_val = compute_laa_field(variant)
1300+
else:
1301+
laa_val = variant.format("LAA")
1302+
val = laa_val
1303+
elif field.name == "LPL" and "LPL" not in variant.FORMAT:
1304+
val = compute_lpl_field(variant, laa_val)
12181305
else:
12191306
val = variant.format(field.name)
12201307
tcw.append(field.full_name, val)

bio2zarr/vcf2zarr/verification.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def verify(vcf_path, zarr_path, show_progress=False):
172172
vcf_name = colname.split("_", 1)[1]
173173
if vcf_name == "LAA" and vcf_name not in format_headers:
174174
continue # LAA could have been computed during the explode step.
175+
if vcf_name == "LPL" and vcf_name not in format_headers:
176+
continue # LPL could have been computed during the explode step.
175177
vcf_type = format_headers[vcf_name]["Type"]
176178
vcf_number = format_headers[vcf_name]["Number"]
177179
format_fields[vcf_name] = vcf_type, vcf_number, iter(root[colname])

tests/test_icf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ class TestLocalAllelesExample:
103103
"FORMAT/GQ",
104104
"FORMAT/GT",
105105
"FORMAT/LAA",
106+
"FORMAT/LPL",
106107
"FORMAT/PL",
107108
"ID",
108109
"INFO/AA",

0 commit comments

Comments
 (0)