Skip to content

Commit 24d2de7

Browse files
Remove old localise code and example
The tests depended on complex semantics of mixed local and non local fields, which doesn't seem worth the effort. Fixup du test
1 parent ff056d9 commit 24d2de7

File tree

6 files changed

+10
-269
lines changed

6 files changed

+10
-269
lines changed

bio2zarr/cli.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,6 @@ def show_work_summary(work_summary, json):
221221
@compressor
222222
@progress
223223
@worker_processes
224-
@local_alleles
225224
def explode(
226225
vcfs,
227226
icf_path,
@@ -231,7 +230,6 @@ def explode(
231230
compressor,
232231
progress,
233232
worker_processes,
234-
local_alleles,
235233
):
236234
"""
237235
Convert VCF(s) to intermediate columnar format
@@ -245,7 +243,6 @@ def explode(
245243
column_chunk_size=column_chunk_size,
246244
compressor=get_compressor(compressor),
247245
show_progress=progress,
248-
local_alleles=local_alleles,
249246
)
250247

251248

@@ -260,7 +257,6 @@ def explode(
260257
@verbose
261258
@progress
262259
@worker_processes
263-
@local_alleles
264260
def dexplode_init(
265261
vcfs,
266262
icf_path,
@@ -272,7 +268,6 @@ def dexplode_init(
272268
verbose,
273269
progress,
274270
worker_processes,
275-
local_alleles,
276271
):
277272
"""
278273
Initial step for distributed conversion of VCF(s) to intermediate columnar format
@@ -289,7 +284,6 @@ def dexplode_init(
289284
worker_processes=worker_processes,
290285
compressor=get_compressor(compressor),
291286
show_progress=progress,
292-
local_alleles=local_alleles,
293287
)
294288
show_work_summary(work_summary, json)
295289

bio2zarr/vcf2zarr/icf.py

Lines changed: 3 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def make_field_def(name, vcf_type, vcf_number):
217217
return fields
218218

219219

220-
def scan_vcf(path, target_num_partitions, *, local_alleles):
220+
def scan_vcf(path, target_num_partitions):
221221
with vcf_utils.IndexedVcf(path) as indexed_vcf:
222222
vcf = indexed_vcf.vcf
223223
filters = []
@@ -237,10 +237,6 @@ def scan_vcf(path, target_num_partitions, *, local_alleles):
237237
pass_filter = filters.pop(pass_index)
238238
filters.insert(0, pass_filter)
239239

240-
# Indicates whether vcf2zarr can introduce local alleles
241-
can_localize = False
242-
should_add_laa_field = True
243-
should_add_lpl_field = True
244240
fields = fixed_vcf_field_definitions()
245241
for h in vcf.header_iter():
246242
if h["HeaderType"] in ["INFO", "FORMAT"]:
@@ -249,36 +245,6 @@ def scan_vcf(path, target_num_partitions, *, local_alleles):
249245
field.vcf_type = "Integer"
250246
field.vcf_number = "."
251247
fields.append(field)
252-
if field.category == "FORMAT":
253-
if field.name == "PL":
254-
can_localize = True
255-
if field.name == "LAA":
256-
should_add_laa_field = False
257-
if field.name == "LPL":
258-
should_add_lpl_field = False
259-
260-
if local_alleles and can_localize:
261-
if should_add_laa_field:
262-
laa_field = VcfField(
263-
category="FORMAT",
264-
name="LAA",
265-
vcf_type="Integer",
266-
vcf_number=".",
267-
description="1-based indices into ALT, indicating which alleles"
268-
" are relevant (local) for the current sample",
269-
summary=VcfFieldSummary(),
270-
)
271-
fields.append(laa_field)
272-
if should_add_lpl_field:
273-
lpl_field = VcfField(
274-
category="FORMAT",
275-
name="LPL",
276-
vcf_type="Integer",
277-
vcf_number="LG",
278-
description="Local-allele representation of PL",
279-
summary=VcfFieldSummary(),
280-
)
281-
fields.append(lpl_field)
282248

283249
try:
284250
contig_lengths = vcf.seqlens
@@ -315,14 +281,7 @@ def scan_vcf(path, target_num_partitions, *, local_alleles):
315281
return metadata, vcf.raw_header
316282

317283

318-
def scan_vcfs(
319-
paths,
320-
show_progress,
321-
target_num_partitions,
322-
worker_processes=1,
323-
*,
324-
local_alleles,
325-
):
284+
def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
326285
logger.info(
327286
f"Scanning {len(paths)} VCFs attempting to split into {target_num_partitions}"
328287
f" partitions."
@@ -346,7 +305,6 @@ def scan_vcfs(
346305
scan_vcf,
347306
path,
348307
max(1, target_num_partitions // len(paths)),
349-
local_alleles=local_alleles,
350308
)
351309
results = list(pwm.results_as_completed())
352310

@@ -505,104 +463,6 @@ def sanitise_value_int_2d(buff, j, value):
505463
buff[j, :, : value.shape[1]] = value
506464

507465

508-
def compute_laa_field(variant) -> np.ndarray:
509-
"""
510-
Computes the value of the LAA field for each sample given a variant.
511-
512-
The LAA field is a list of one-based indices into the ALT alleles
513-
that indicates which alternate alleles are observed in the sample.
514-
515-
This method infers which alleles are observed from the GT field.
516-
"""
517-
sample_count = variant.num_called + variant.num_unknown
518-
alt_allele_count = len(variant.ALT)
519-
allele_count = alt_allele_count + 1
520-
allele_counts = np.zeros((sample_count, allele_count), dtype=int)
521-
522-
if "GT" in variant.FORMAT:
523-
# The last element of each sample's genotype indicates the phasing
524-
# and is not an allele.
525-
genotypes = variant.genotype.array()[:, :-1]
526-
genotypes.clip(0, None, out=genotypes)
527-
genotype_allele_counts = np.apply_along_axis(
528-
np.bincount, axis=1, arr=genotypes, minlength=allele_count
529-
)
530-
allele_counts += genotype_allele_counts
531-
532-
allele_counts[:, 0] = 0 # We don't count the reference allele
533-
max_row_length = 1
534-
535-
def nonzero_pad(arr: np.ndarray, *, length: int):
536-
nonlocal max_row_length
537-
alleles = arr.nonzero()[0]
538-
max_row_length = max(max_row_length, len(alleles))
539-
pad_length = length - len(alleles)
540-
return np.pad(
541-
alleles,
542-
(0, pad_length),
543-
mode="constant",
544-
constant_values=constants.INT_FILL,
545-
)
546-
547-
alleles = np.apply_along_axis(
548-
nonzero_pad, axis=1, arr=allele_counts, length=max(1, alt_allele_count)
549-
)
550-
alleles = alleles[:, :max_row_length]
551-
552-
return alleles
553-
554-
555-
def compute_lpl_field(variant, laa_val: np.ndarray) -> np.ndarray:
556-
assert laa_val is not None
557-
558-
la_val = np.zeros((laa_val.shape[0], laa_val.shape[1] + 1), dtype=laa_val.dtype)
559-
la_val[:, 1:] = laa_val
560-
ploidy = variant.ploidy
561-
562-
if "PL" not in variant.FORMAT:
563-
sample_count = variant.num_called + variant.num_unknown
564-
local_allele_count = la_val.shape[1]
565-
566-
if ploidy == 1:
567-
local_genotype_count = local_allele_count
568-
elif ploidy == 2:
569-
local_genotype_count = local_allele_count * (local_allele_count + 1) // 2
570-
else:
571-
raise ValueError(f"Cannot handle ploidy = {ploidy}")
572-
573-
return np.full((sample_count, local_genotype_count), constants.INT_MISSING)
574-
575-
# Compute a and b
576-
if ploidy == 1:
577-
a = la_val
578-
b = np.zeros_like(la_val)
579-
elif ploidy == 2:
580-
repeats = np.arange(1, la_val.shape[1] + 1)
581-
b = np.repeat(la_val, repeats, axis=1)
582-
arange_tile = np.tile(np.arange(la_val.shape[1]), (la_val.shape[1], 1))
583-
tril_indices = np.tril_indices_from(arange_tile)
584-
a_index = np.tile(arange_tile[tril_indices], (b.shape[0], 1))
585-
row_index = np.arange(la_val.shape[0]).reshape(-1, 1)
586-
a = la_val[row_index, a_index]
587-
else:
588-
raise ValueError(f"Cannot handle ploidy = {ploidy}")
589-
590-
# Compute n, the local indices of the PL field
591-
n = (b * (b + 1) / 2 + a).astype(int)
592-
593-
pl_val = variant.format("PL")
594-
pl_val[pl_val == constants.VCF_INT_MISSING] = constants.INT_MISSING
595-
# When the PL value is missing in all samples, pl_val has shape (sample_count, 1).
596-
# In that case, we need to broadcast the PL value.
597-
if pl_val.shape[1] < n.shape[1]:
598-
pl_val = np.broadcast_to(pl_val, n.shape)
599-
row_index = np.arange(pl_val.shape[0]).reshape(-1, 1)
600-
lpl_val = pl_val[row_index, n]
601-
lpl_val[b == constants.INT_FILL] = constants.INT_FILL
602-
603-
return lpl_val
604-
605-
606466
missing_value_map = {
607467
"Integer": constants.INT_MISSING,
608468
"Float": constants.FLOAT32_MISSING,
@@ -1107,14 +967,11 @@ def init(
1107967
target_num_partitions=None,
1108968
show_progress=False,
1109969
compressor=None,
1110-
local_alleles=None,
1111970
):
1112971
if self.path.exists():
1113972
raise ValueError(f"ICF path already exists: {self.path}")
1114973
if compressor is None:
1115974
compressor = ICF_DEFAULT_COMPRESSOR
1116-
if local_alleles is None:
1117-
local_alleles = False
1118975
vcfs = [pathlib.Path(vcf) for vcf in vcfs]
1119976
target_num_partitions = max(target_num_partitions, len(vcfs))
1120977

@@ -1124,7 +981,6 @@ def init(
1124981
worker_processes=worker_processes,
1125982
show_progress=show_progress,
1126983
target_num_partitions=target_num_partitions,
1127-
local_alleles=local_alleles,
1128984
)
1129985
check_field_clobbering(icf_metadata)
1130986
self.metadata = icf_metadata
@@ -1207,17 +1063,6 @@ def process_partition(self, partition_index):
12071063
else:
12081064
format_fields.append(field)
12091065

1210-
format_field_names = [format_field.name for format_field in format_fields]
1211-
if "LAA" in format_field_names and "LPL" in format_field_names:
1212-
laa_index = format_field_names.index("LAA")
1213-
lpl_index = format_field_names.index("LPL")
1214-
# LAA needs to come before LPL
1215-
if lpl_index < laa_index:
1216-
format_fields[laa_index], format_fields[lpl_index] = (
1217-
format_fields[lpl_index],
1218-
format_fields[laa_index],
1219-
)
1220-
12211066
last_position = None
12221067
with IcfPartitionWriter(
12231068
self.metadata,
@@ -1245,18 +1090,8 @@ def process_partition(self, partition_index):
12451090
else:
12461091
val = variant.genotype.array()
12471092
tcw.append("FORMAT/GT", val)
1248-
laa_val = None
12491093
for field in format_fields:
1250-
if field.name == "LAA":
1251-
if "LAA" not in variant.FORMAT:
1252-
laa_val = compute_laa_field(variant)
1253-
else:
1254-
laa_val = variant.format("LAA")
1255-
val = laa_val
1256-
elif field.name == "LPL" and "LPL" not in variant.FORMAT:
1257-
val = compute_lpl_field(variant, laa_val)
1258-
else:
1259-
val = variant.format(field.name)
1094+
val = variant.format(field.name)
12601095
tcw.append(field.full_name, val)
12611096

12621097
# Note: an issue with updating the progress per variant here like
@@ -1352,7 +1187,6 @@ def explode(
13521187
worker_processes=1,
13531188
show_progress=False,
13541189
compressor=None,
1355-
local_alleles=None,
13561190
):
13571191
writer = IntermediateColumnarFormatWriter(icf_path)
13581192
writer.init(
@@ -1363,7 +1197,6 @@ def explode(
13631197
show_progress=show_progress,
13641198
column_chunk_size=column_chunk_size,
13651199
compressor=compressor,
1366-
local_alleles=local_alleles,
13671200
)
13681201
writer.explode(worker_processes=worker_processes, show_progress=show_progress)
13691202
writer.finalise()
@@ -1379,7 +1212,6 @@ def explode_init(
13791212
worker_processes=1,
13801213
show_progress=False,
13811214
compressor=None,
1382-
local_alleles=None,
13831215
):
13841216
writer = IntermediateColumnarFormatWriter(icf_path)
13851217
return writer.init(
@@ -1389,7 +1221,6 @@ def explode_init(
13891221
show_progress=show_progress,
13901222
column_chunk_size=column_chunk_size,
13911223
compressor=compressor,
1392-
local_alleles=local_alleles,
13931224
)
13941225

13951226

-944 Bytes
Binary file not shown.
-118 Bytes
Binary file not shown.

tests/test_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,8 @@ def test_examples(self, chunk_size, size, start, stop):
237237
# It works in CI on Linux, but it'll probably break at some point.
238238
# It's also necessary to update these numbers each time a new data
239239
# file gets added
240-
("tests/data", 4977391),
241-
("tests/data/vcf", 4965254),
240+
("tests/data", 4976329),
241+
("tests/data/vcf", 4964192),
242242
("tests/data/vcf/sample.vcf.gz", 1089),
243243
],
244244
)

0 commit comments

Comments
 (0)