Skip to content

Commit 6c58c9e

Browse files
Fix various issues with indexing, partial refactor
1 parent d7a24e2 commit 6c58c9e

File tree

2 files changed

+199
-60
lines changed

2 files changed

+199
-60
lines changed

bio2zarr/vcf_utils.py

Lines changed: 151 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from typing import Any, Dict, Optional, Sequence, Union
2+
import contextlib
3+
import struct
24
import re
35
import pathlib
46
import itertools
@@ -7,6 +9,7 @@
79
import fsspec
810
import numpy as np
911
from cyvcf2 import VCF
12+
import cyvcf2
1013
import humanfriendly
1114

1215
from bio2zarr.typing import PathType
@@ -49,6 +52,21 @@ def region_string(contig: str, start: int, end: Optional[int] = None) -> str:
4952
else:
5053
return f"{contig}:{start}-"
5154

55+
@dataclass
56+
class Region:
57+
contig: str
58+
start: Optional[int] = None
59+
end: Optional[int]=None
60+
61+
def __str__(self):
62+
s = f"{self.contig}"
63+
if self.start is not None:
64+
s += f":{self.start}-"
65+
if self.end is not None:
66+
s += str(self.end)
67+
return s
68+
69+
# TODO add "parse" class method
5270

5371
def get_tabix_path(
5472
vcf_path: PathType, storage_options: Optional[Dict[str, str]] = None
@@ -263,6 +281,19 @@ class CSIIndex:
263281
record_counts: Sequence[int]
264282
n_no_coor: int
265283

284+
def parse_vcf_aux(self):
285+
assert len(self.aux) > 0
286+
# The first 7 values form the Tabix header or something, but I don't
287+
# know how to interpret what's in there. The n_ref value doesn't seem
288+
# to correspond to the number of contigs at all anyway, so just
289+
# ignoring for now.
290+
# values = struct.Struct("<7i").unpack(self.aux[:28])
291+
# tabix_header = Header(*values, 0)
292+
names = self.aux[28:]
293+
# Convert \0-terminated names to strings
294+
sequence_names = [str(name, "utf-8") for name in names.split(b"\x00")[:-1]]
295+
return sequence_names
296+
266297
def offsets(self) -> Any:
267298
pseudo_bin = bin_limit(self.min_shift, self.depth) + 1
268299

@@ -388,12 +419,6 @@ class Header:
388419
l_nm: int
389420

390421

391-
# @dataclass
392-
# class Chunk:
393-
# cnk_beg: int
394-
# cnk_end: int
395-
396-
397422
@dataclass
398423
class TabixBin:
399424
bin: int
@@ -503,3 +528,123 @@ def read_tabix(
503528
return TabixIndex(
504529
header, sequence_names, bins, linear_indexes, record_counts, n_no_coor
505530
)
531+
532+
533+
534+
class IndexedVcf:
535+
def __init__(self, path, index_path=None):
536+
# for h in vcf.header_iter():
537+
# print(h)
538+
# if index_path is None:
539+
# index_path = get_tabix_path(vcf_path, storage_options=storage_options)
540+
# if index_path is None:
541+
# index_path = get_csi_path(vcf_path, storage_options=storage_options)
542+
# if index_path is None:
543+
# raise ValueError("Cannot find .tbi or .csi file.")
544+
self.vcf_path = path
545+
self.index_path = index_path
546+
self.file_type = None
547+
self.index_type = None
548+
if index_path.suffix == ".csi":
549+
self.index_type = "csi"
550+
elif index_path.suffix == ".tbi":
551+
self.index_type = "tabix"
552+
self.file_type = "vcf"
553+
else:
554+
raise ValueError("TODO")
555+
self.index = read_index(self.index_path)
556+
self.sequence_names = None
557+
if self.index_type == "csi":
558+
# Determine the file-type based on the "aux" field.
559+
self.file_type = "bcf"
560+
if len(self.index.aux) > 0:
561+
self.file_type = "vcf"
562+
self.sequence_names = self.index.parse_vcf_aux()
563+
else:
564+
with contextlib.closing(cyvcf2.VCF(path)) as vcf:
565+
self.sequence_names = vcf.seqnames
566+
else:
567+
self.sequence_names = self.index.sequence_names
568+
569+
def contig_record_counts(self):
570+
d = dict(zip(self.sequence_names, self.index.record_counts))
571+
if self.file_type == "bcf":
572+
d = {k: v for k, v in d.items() if v > 0}
573+
return d
574+
575+
def partition_into_regions(
576+
self,
577+
num_parts: Optional[int] = None,
578+
target_part_size: Union[None, int, str] = None,
579+
):
580+
if num_parts is None and target_part_size is None:
581+
raise ValueError("One of num_parts or target_part_size must be specified")
582+
583+
if num_parts is not None and target_part_size is not None:
584+
raise ValueError(
585+
"Only one of num_parts or target_part_size may be specified"
586+
)
587+
588+
if num_parts is not None and num_parts < 1:
589+
raise ValueError("num_parts must be positive")
590+
591+
if target_part_size is not None:
592+
if isinstance(target_part_size, int):
593+
target_part_size_bytes = target_part_size
594+
else:
595+
target_part_size_bytes = humanfriendly.parse_size(target_part_size)
596+
if target_part_size_bytes < 1:
597+
raise ValueError("target_part_size must be positive")
598+
599+
# Calculate the desired part file boundaries
600+
file_length = get_file_length(self.vcf_path)
601+
if num_parts is not None:
602+
target_part_size_bytes = file_length // num_parts
603+
elif target_part_size_bytes is not None:
604+
num_parts = ceildiv(file_length, target_part_size_bytes)
605+
part_lengths = np.array([i * target_part_size_bytes for i in range(num_parts)])
606+
607+
file_offsets, region_contig_indexes, region_positions = self.index.offsets()
608+
609+
# Search the file offsets to find which indexes the part lengths fall at
610+
ind = np.searchsorted(file_offsets, part_lengths)
611+
612+
# Drop any parts that are greater than the file offsets
613+
# (these will be covered by a region with no end)
614+
ind = np.delete(ind, ind >= len(file_offsets))
615+
616+
# Drop any duplicates
617+
ind = np.unique(ind)
618+
619+
# Calculate region contig and start for each index
620+
region_contigs = region_contig_indexes[ind]
621+
region_starts = region_positions[ind]
622+
623+
# Build region query strings
624+
regions = []
625+
for i in range(len(region_starts)):
626+
contig = self.sequence_names[region_contigs[i]]
627+
start = region_starts[i]
628+
629+
if i == len(region_starts) - 1: # final region
630+
regions.append(Region(contig, start))
631+
else:
632+
next_contig = self.sequence_names[region_contigs[i + 1]]
633+
next_start = region_starts[i + 1]
634+
end = next_start - 1 # subtract one since positions are inclusive
635+
if next_contig == contig: # contig doesn't change
636+
regions.append(Region(contig, start, end))
637+
else:
638+
# contig changes, so need two regions (or possibly more if any
639+
# sequences were skipped)
640+
regions.append(Region(contig, start))
641+
for ri in range(region_contigs[i] + 1, region_contigs[i + 1]):
642+
regions.append(self.sequence_names[ri])
643+
regions.append(Region(next_contig, 1, end))
644+
645+
# Add any sequences at the end that were not skipped
646+
for ri in range(region_contigs[-1] + 1, len(self.sequence_names)):
647+
if self.index.record_counts[ri] > 0:
648+
regions.append(Region(self.sequence_names[ri]))
649+
650+
return regions

tests/test_vcf_utils.py

Lines changed: 48 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44
from cyvcf2 import VCF
5+
import numpy as np
56

67
from bio2zarr import vcf_utils
78

@@ -16,57 +17,53 @@
1617

1718
data_path = pathlib.Path("tests/data/vcf/")
1819

19-
# bcftools index -s
20-
@pytest.mark.parametrize(["index_file", "expected"], [
21-
("sample.vcf.gz.tbi", [2, 6, 1]),
22-
("sample.bcf.csi", [2, 6, 1]),
23-
("sample_no_genotypes.vcf.gz.csi", [2, 6, 1]),
24-
("CEUTrio.20.21.gatk3.4.g.vcf.bgz.tbi", [3450, 16460]),
25-
("CEUTrio.20.21.gatk3.4.g.bcf.csi", [3450, 16460]),
26-
("1kg_2020_chrM.vcf.gz.tbi", [23]),
27-
("1kg_2020_chrM.vcf.gz.csi", [23]),
28-
# ("1kg_2020_chrM.bcf.csi", [23]),
29-
# ("1kg_2020_chr20_annotations.bcf.csi", [21]),
30-
])
31-
def test_index_record_count(index_file, expected):
32-
index = vcf_utils.read_index(data_path / index_file)
33-
assert index.record_counts == expected
34-
35-
36-
37-
# class TestCEUTrio2021VcfExample:
38-
# data_path = "tests/data/vcf/CEUTrio.20.21.gatk3.4.g.vcf.bgz"
39-
40-
# @pytest.fixture(scope="class")
41-
# def index(self):
42-
# tabix_path = get_tabix_path(self.data_path)
43-
# return read_tabix(tabix_path)
44-
45-
# def test_record_counts(self, index):
46-
# assert index.record_counts == [3450, 16460]
47-
# # print(index)
48-
# # # print(index.sequence_names)
49-
# # print(index.record_counts)
50-
# # for i, contig in enumerate(tabix.sequence_names):
51-
# # assert tabix.record_counts[i] == count_variants(vcf_path, contig)
52-
53-
# # def test_one_region(self, index):
54-
# # parts = partition_into_regions(self.data_path, num_parts=1)
55-
# # assert parts == ["20:1-", "21"]
56-
57-
58-
# class TestCEUTrio2021BcfExample(TestCEUTrio2021VcfExample):
59-
# data_path = "tests/data/vcf/CEUTrio.20.21.gatk3.4.g.bcf"
6020

61-
# @pytest.fixture(scope="class")
62-
# def index(self):
63-
# csi_path = get_csi_path(self.data_path)
64-
# return read_csi(csi_path)
21+
# values computed using bcftools index -s
22+
@pytest.mark.parametrize(
23+
["index_file", "expected"],
24+
[
25+
("sample.vcf.gz.tbi", {"19": 2, "20": 6, "X": 1}),
26+
("sample.bcf.csi", {"19": 2, "20": 6, "X": 1}),
27+
("sample_no_genotypes.vcf.gz.csi", {"19": 2, "20": 6, "X": 1}),
28+
("CEUTrio.20.21.gatk3.4.g.vcf.bgz.tbi", {"20": 3450, "21": 16460}),
29+
("CEUTrio.20.21.gatk3.4.g.bcf.csi", {"20": 3450, "21": 16460}),
30+
("1kg_2020_chrM.vcf.gz.tbi", {"chrM": 23}),
31+
("1kg_2020_chrM.vcf.gz.csi", {"chrM": 23}),
32+
("1kg_2020_chrM.bcf.csi", {"chrM": 23}),
33+
("1kg_2020_chr20_annotations.bcf.csi", {"chr20": 21}),
34+
("NA12878.prod.chr20snippet.g.vcf.gz.tbi", {"20": 301778}),
35+
],
36+
)
37+
def test_index_record_count(index_file, expected):
38+
vcf_path = data_path / (".".join(list(index_file.split("."))[:-1]))
39+
indexed_vcf = vcf_utils.IndexedVcf(vcf_path, data_path / index_file)
40+
assert indexed_vcf.contig_record_counts() == expected
41+
42+
43+
@pytest.mark.parametrize(
44+
["index_file", "expected"],
45+
[
46+
("sample.vcf.gz.tbi", ["19:1-", "20", "X"]),
47+
("sample.bcf.csi", ["19:1-", "20", "X"]),
48+
("sample_no_genotypes.vcf.gz.csi", ["19:1-", "20", "X"]),
49+
("CEUTrio.20.21.gatk3.4.g.vcf.bgz.tbi", ["20:1-", "21"]),
50+
("CEUTrio.20.21.gatk3.4.g.bcf.csi", ["20:1-", "21"]),
51+
("1kg_2020_chrM.vcf.gz.tbi", ["chrM:1-"]),
52+
("1kg_2020_chrM.vcf.gz.csi", ["chrM:1-"]),
53+
("1kg_2020_chrM.bcf.csi", ["chrM:1-"]),
54+
("1kg_2020_chr20_annotations.bcf.csi", ["chr20:49153-"]),
55+
("NA12878.prod.chr20snippet.g.vcf.gz.tbi", ["20:1-"]),
56+
],
57+
)
58+
def test_partition_into_one_part(index_file, expected):
59+
vcf_path = data_path / (".".join(list(index_file.split("."))[:-1]))
60+
indexed_vcf = vcf_utils.IndexedVcf(vcf_path, data_path / index_file)
61+
regions = indexed_vcf.partition_into_regions(num_parts=1)
62+
assert all(isinstance(r, vcf_utils.Region) for r in regions)
63+
assert [str(r) for r in regions] == expected
6564

6665

6766
class TestCsiIndex:
68-
69-
7067
@pytest.mark.parametrize(
7168
"filename",
7269
["CEUTrio.20.21.gatk3.4.g.vcf.bgz", "CEUTrio.20.21.gatk3.4.g.vcf.bgz.tbi"],
@@ -77,10 +74,12 @@ def test_invalid_csi(self, filename):
7774

7875

7976
class TestTabixIndex:
80-
8177
@pytest.mark.parametrize(
8278
"filename",
83-
["CEUTrio.20.21.gatk3.4.g.vcf.bgz", "CEUTrio.20.21.gatk3.4.g.bcf.csi", ],
79+
[
80+
"CEUTrio.20.21.gatk3.4.g.vcf.bgz",
81+
"CEUTrio.20.21.gatk3.4.g.bcf.csi",
82+
],
8483
)
8584
def test_invalid_tbi(self, filename):
8685
with pytest.raises(ValueError, match=r"File not in Tabix format."):
@@ -159,11 +158,6 @@ def test_invalid_arguments(self):
159158
with pytest.raises(ValueError, match=r"target_part_size must be positive"):
160159
partition_into_regions(vcf_path, target_part_size=0)
161160

162-
def test_one_part(self):
163-
vcf_path = data_path / "CEUTrio.20.21.gatk3.4.g.vcf.bgz"
164-
parts = partition_into_regions(vcf_path, num_parts=1)
165-
assert parts == ["20:1-", "21"]
166-
167161
@pytest.mark.skip("TODO")
168162
def test_missing_index(self, temp_path):
169163
vcf_path = data_path / "CEUTrio.20.21.gatk3.4.g.vcf.bgz"

0 commit comments

Comments
 (0)