Skip to content

Commit 8c80f54

Browse files
Pad out tests
1 parent 5c44313 commit 8c80f54

File tree

3 files changed

+180
-132
lines changed

3 files changed

+180
-132
lines changed

bio2zarr/vcf_utils.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,11 @@ def read_tabix(
365365
)
366366

367367

368-
class IndexedVcf:
368+
class IndexedVcf(contextlib.AbstractContextManager):
369369
def __init__(self, vcf_path, index_path=None):
370+
self.vcf = None
370371
vcf_path = pathlib.Path(vcf_path)
372+
# TODO use constants here instead of strings
371373
if index_path is None:
372374
index_path = vcf_path.with_suffix(vcf_path.suffix + ".tbi")
373375
if not index_path.exists():
@@ -379,6 +381,7 @@ def __init__(self, vcf_path, index_path=None):
379381

380382
self.vcf_path = vcf_path
381383
self.index_path = index_path
384+
# TODO use Enums for these
382385
self.file_type = None
383386
self.index_type = None
384387
if index_path.suffix == ".csi":
@@ -387,7 +390,9 @@ def __init__(self, vcf_path, index_path=None):
387390
self.index_type = "tabix"
388391
self.file_type = "vcf"
389392
else:
390-
raise ValueError("TODO")
393+
raise ValueError("Only .tbi or .csi indexes are supported.")
394+
self.vcf = cyvcf2.VCF(vcf_path)
395+
self.vcf.set_index(str(self.index_path))
391396
self.sequence_names = None
392397
if self.index_type == "csi":
393398
# Determine the file-type based on the "aux" field.
@@ -403,12 +408,28 @@ def __init__(self, vcf_path, index_path=None):
403408
self.index = read_tabix(self.index_path)
404409
self.sequence_names = self.index.sequence_names
405410

411+
def __exit__(self, exc_type, exc_val, exc_tb):
412+
if self.vcf is not None:
413+
self.vcf.close()
414+
self.vcf = None
415+
return False
416+
406417
def contig_record_counts(self):
407418
d = dict(zip(self.sequence_names, self.index.record_counts))
408419
if self.file_type == "bcf":
409420
d = {k: v for k, v in d.items() if v > 0}
410421
return d
411422

423+
def count_variants(self, region):
424+
return sum(1 for _ in self.variants(region))
425+
426+
def variants(self, region):
427+
# Need to filter because of indels overlapping the region
428+
start = 1 if region.start is None else region.start
429+
for var in self.vcf(str(region)):
430+
if var.POS >= start:
431+
yield var
432+
412433
def partition_into_regions(
413434
self,
414435
num_parts: Optional[int] = None,

tests/test_vcf_utils.py

Lines changed: 156 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -11,121 +11,138 @@
1111
data_path = pathlib.Path("tests/data/vcf/")
1212

1313

14-
# values computed using bcftools index -s
15-
@pytest.mark.parametrize(
16-
["index_file", "expected"],
17-
[
18-
("sample.vcf.gz.tbi", {"19": 2, "20": 6, "X": 1}),
19-
("sample.bcf.csi", {"19": 2, "20": 6, "X": 1}),
20-
("sample_no_genotypes.vcf.gz.csi", {"19": 2, "20": 6, "X": 1}),
21-
("CEUTrio.20.21.gatk3.4.g.vcf.bgz.tbi", {"20": 3450, "21": 16460}),
22-
("CEUTrio.20.21.gatk3.4.g.bcf.csi", {"20": 3450, "21": 16460}),
23-
("1kg_2020_chrM.vcf.gz.tbi", {"chrM": 23}),
24-
("1kg_2020_chrM.vcf.gz.csi", {"chrM": 23}),
25-
("1kg_2020_chrM.bcf.csi", {"chrM": 23}),
26-
("1kg_2020_chr20_annotations.bcf.csi", {"chr20": 21}),
27-
("NA12878.prod.chr20snippet.g.vcf.gz.tbi", {"20": 301778}),
28-
("multi_contig.vcf.gz.tbi", {str(j): 933 for j in range(5)}),
29-
],
30-
)
31-
def test_index_record_count(index_file, expected):
32-
vcf_path = data_path / (".".join(list(index_file.split("."))[:-1]))
33-
indexed_vcf = vcf_utils.IndexedVcf(vcf_path, data_path / index_file)
34-
assert indexed_vcf.contig_record_counts() == expected
35-
36-
37-
@pytest.mark.parametrize(
38-
["index_file", "expected"],
39-
[
40-
("sample.vcf.gz.tbi", ["19:1-", "20", "X"]),
41-
("sample.bcf.csi", ["19:1-", "20", "X"]),
42-
("sample_no_genotypes.vcf.gz.csi", ["19:1-", "20", "X"]),
43-
("CEUTrio.20.21.gatk3.4.g.vcf.bgz.tbi", ["20:1-", "21"]),
44-
("CEUTrio.20.21.gatk3.4.g.bcf.csi", ["20:1-", "21"]),
45-
("1kg_2020_chrM.vcf.gz.tbi", ["chrM:1-"]),
46-
("1kg_2020_chrM.vcf.gz.csi", ["chrM:1-"]),
47-
("1kg_2020_chrM.bcf.csi", ["chrM:1-"]),
48-
("1kg_2020_chr20_annotations.bcf.csi", ["chr20:49153-"]),
49-
("NA12878.prod.chr20snippet.g.vcf.gz.tbi", ["20:1-"]),
50-
("multi_contig.vcf.gz.tbi", ["0:1-"] + [str(j) for j in range(1, 5)]),
51-
],
52-
)
53-
def test_partition_into_one_part(index_file, expected):
54-
vcf_path = data_path / (".".join(list(index_file.split("."))[:-1]))
55-
indexed_vcf = vcf_utils.IndexedVcf(vcf_path, data_path / index_file)
56-
regions = indexed_vcf.partition_into_regions(num_parts=1)
57-
assert all(isinstance(r, vcf_utils.Region) for r in regions)
58-
assert [str(r) for r in regions] == expected
59-
60-
61-
def test_tabix_multi_chrom_bug():
62-
index_file = "multi_contig.vcf.gz.tbi"
63-
vcf_path = data_path / (".".join(list(index_file.split("."))[:-1]))
64-
indexed_vcf = vcf_utils.IndexedVcf(vcf_path, data_path / index_file)
65-
regions = indexed_vcf.partition_into_regions(num_parts=10)
66-
# An earlier version of the code returned this, i.e. with a duplicate
67-
# for 4 with end coord of 0
68-
# ["0:1-", "1", "2", "3", "4:1-0", "4:1-"]
69-
expected = ["0:1-", "1", "2", "3", "4:1-"]
70-
assert [str(r) for r in regions] == expected
71-
72-
73-
@pytest.mark.skip("TODO")
74-
class TestCsiIndex:
14+
def assert_part_counts_non_zero(part_counts, index_file):
15+
# We may have one zero count value at the end in Tabix indexes.
16+
# Should probably try to get rid of it, but probably no harm
17+
# https://github.com/jeromekelleher/bio2zarr/issues/45
18+
if index_file.endswith(".tbi"):
19+
assert np.all(part_counts[:-1] > 0)
20+
else:
21+
assert np.all(part_counts > 0)
22+
23+
24+
class TestIndexedVcf:
25+
def get_instance(self, index_file):
26+
vcf_path = data_path / (".".join(list(index_file.split("."))[:-1]))
27+
return vcf_utils.IndexedVcf(vcf_path, data_path / index_file)
28+
29+
def test_context_manager_success(self):
30+
# Nominal case
31+
with vcf_utils.IndexedVcf(data_path / "sample.bcf") as iv:
32+
assert iv.vcf is not None
33+
assert iv.vcf is None
34+
35+
def test_context_manager_error(self):
36+
with pytest.raises(ValueError, match="Cannot find"):
37+
with vcf_utils.IndexedVcf(data_path / "no-such-file.bcf"):
38+
pass
39+
40+
# values computed using bcftools index -s
7541
@pytest.mark.parametrize(
76-
"filename",
77-
["CEUTrio.20.21.gatk3.4.g.vcf.bgz", "CEUTrio.20.21.gatk3.4.g.vcf.bgz.tbi"],
42+
["index_file", "expected"],
43+
[
44+
("sample.vcf.gz.tbi", {"19": 2, "20": 6, "X": 1}),
45+
("sample.bcf.csi", {"19": 2, "20": 6, "X": 1}),
46+
("sample_no_genotypes.vcf.gz.csi", {"19": 2, "20": 6, "X": 1}),
47+
("CEUTrio.20.21.gatk3.4.g.vcf.bgz.tbi", {"20": 3450, "21": 16460}),
48+
("CEUTrio.20.21.gatk3.4.g.bcf.csi", {"20": 3450, "21": 16460}),
49+
("1kg_2020_chrM.vcf.gz.tbi", {"chrM": 23}),
50+
("1kg_2020_chrM.vcf.gz.csi", {"chrM": 23}),
51+
("1kg_2020_chrM.bcf.csi", {"chrM": 23}),
52+
("1kg_2020_chr20_annotations.bcf.csi", {"chr20": 21}),
53+
("NA12878.prod.chr20snippet.g.vcf.gz.tbi", {"20": 301778}),
54+
("multi_contig.vcf.gz.tbi", {str(j): 933 for j in range(5)}),
55+
],
7856
)
79-
def test_invalid_csi(self, filename):
80-
with pytest.raises(ValueError, match=r"File not in CSI format."):
81-
read_csi(data_path / filename)
57+
def test_contig_record_counts(self, index_file, expected):
58+
indexed_vcf = self.get_instance(index_file)
59+
assert indexed_vcf.contig_record_counts() == expected
8260

83-
84-
@pytest.mark.skip("TODO")
85-
class TestTabixIndex:
8661
@pytest.mark.parametrize(
87-
"filename",
62+
["index_file", "expected"],
8863
[
89-
"CEUTrio.20.21.gatk3.4.g.vcf.bgz",
90-
"CEUTrio.20.21.gatk3.4.g.bcf.csi",
64+
("sample.vcf.gz.tbi", ["19:1-", "20", "X"]),
65+
("sample.bcf.csi", ["19:1-", "20", "X"]),
66+
("sample_no_genotypes.vcf.gz.csi", ["19:1-", "20", "X"]),
67+
("CEUTrio.20.21.gatk3.4.g.vcf.bgz.tbi", ["20:1-", "21"]),
68+
("CEUTrio.20.21.gatk3.4.g.bcf.csi", ["20:1-", "21"]),
69+
("1kg_2020_chrM.vcf.gz.tbi", ["chrM:1-"]),
70+
("1kg_2020_chrM.vcf.gz.csi", ["chrM:1-"]),
71+
("1kg_2020_chrM.bcf.csi", ["chrM:1-"]),
72+
("1kg_2020_chr20_annotations.bcf.csi", ["chr20:49153-"]),
73+
("NA12878.prod.chr20snippet.g.vcf.gz.tbi", ["20:1-"]),
74+
("multi_contig.vcf.gz.tbi", ["0:1-"] + [str(j) for j in range(1, 5)]),
9175
],
9276
)
93-
def test_invalid_tbi(self, filename):
94-
with pytest.raises(ValueError, match=r"File not in Tabix format."):
95-
read_tabix(data_path / filename)
77+
def test_partition_into_one_part(self, index_file, expected):
78+
indexed_vcf = self.get_instance(index_file)
79+
regions = indexed_vcf.partition_into_regions(num_parts=1)
80+
assert all(isinstance(r, vcf_utils.Region) for r in regions)
81+
assert [str(r) for r in regions] == expected
9682

97-
98-
@pytest.mark.skip("TODO")
99-
class TestPartitionIntoRegions:
10083
@pytest.mark.parametrize(
101-
"vcf_file",
84+
["index_file", "num_expected", "total_records"],
10285
[
103-
"CEUTrio.20.21.gatk3.4.g.bcf",
104-
"CEUTrio.20.21.gatk3.4.g.vcf.bgz",
105-
"NA12878.prod.chr20snippet.g.vcf.gz",
86+
("sample.vcf.gz.tbi", 3, 9),
87+
("sample.bcf.csi", 3, 9),
88+
("sample_no_genotypes.vcf.gz.csi", 3, 9),
89+
("CEUTrio.20.21.gatk3.4.g.vcf.bgz.tbi", 18, 19910),
90+
("CEUTrio.20.21.gatk3.4.g.bcf.csi", 3, 19910),
91+
("1kg_2020_chrM.vcf.gz.tbi", 1, 23),
92+
("1kg_2020_chrM.vcf.gz.csi", 1, 23),
93+
("1kg_2020_chrM.bcf.csi", 1, 23),
94+
("1kg_2020_chr20_annotations.bcf.csi", 1, 21),
95+
("NA12878.prod.chr20snippet.g.vcf.gz.tbi", 59, 301778),
96+
("multi_contig.vcf.gz.tbi", 5, 5 * 933),
10697
],
10798
)
108-
def test_num_parts(self, vcf_file):
109-
vcf_path = data_path / vcf_file
110-
regions = partition_into_regions(vcf_path, num_parts=4)
111-
112-
assert regions is not None
113-
part_variant_counts = [count_variants(vcf_path, region) for region in regions]
114-
total_variants = count_variants(vcf_path)
115-
116-
assert sum(part_variant_counts) == total_variants
117-
118-
def test_num_parts_large(self):
119-
vcf_path = data_path / "CEUTrio.20.21.gatk3.4.g.vcf.bgz"
120-
121-
regions = partition_into_regions(vcf_path, num_parts=100)
122-
assert regions is not None
123-
assert len(regions) == 18
124-
125-
part_variant_counts = [count_variants(vcf_path, region) for region in regions]
126-
total_variants = count_variants(vcf_path)
99+
def test_partition_into_max_parts(self, index_file, num_expected, total_records):
100+
indexed_vcf = self.get_instance(index_file)
101+
regions = indexed_vcf.partition_into_regions(num_parts=1000)
102+
assert all(isinstance(r, vcf_utils.Region) for r in regions)
103+
# print(regions)
104+
assert len(regions) == num_expected
105+
part_variant_counts = np.array(
106+
[indexed_vcf.count_variants(region) for region in regions]
107+
)
108+
assert np.sum(part_variant_counts) == total_records
109+
assert_part_counts_non_zero(part_variant_counts, index_file)
127110

128-
assert sum(part_variant_counts) == total_variants
111+
@pytest.mark.parametrize(
112+
["index_file", "total_records"],
113+
[
114+
("sample.vcf.gz.tbi", 9),
115+
("sample.bcf.csi", 9),
116+
("sample_no_genotypes.vcf.gz.csi", 9),
117+
("CEUTrio.20.21.gatk3.4.g.vcf.bgz.tbi", 19910),
118+
("CEUTrio.20.21.gatk3.4.g.bcf.csi", 19910),
119+
("1kg_2020_chrM.vcf.gz.tbi", 23),
120+
("1kg_2020_chrM.vcf.gz.csi", 23),
121+
("1kg_2020_chrM.bcf.csi", 23),
122+
("1kg_2020_chr20_annotations.bcf.csi", 21),
123+
("NA12878.prod.chr20snippet.g.vcf.gz.tbi", 301778),
124+
("multi_contig.vcf.gz.tbi", 5 * 933),
125+
],
126+
)
127+
@pytest.mark.parametrize("num_parts", [2, 3, 4, 5, 16, 33])
128+
def test_partition_into_n_parts(self, index_file, total_records, num_parts):
129+
indexed_vcf = self.get_instance(index_file)
130+
regions = indexed_vcf.partition_into_regions(num_parts=num_parts)
131+
assert all(isinstance(r, vcf_utils.Region) for r in regions)
132+
part_variant_counts = np.array(
133+
[indexed_vcf.count_variants(region) for region in regions]
134+
)
135+
assert np.sum(part_variant_counts) == total_records
136+
assert_part_counts_non_zero(part_variant_counts, index_file)
137+
138+
def test_tabix_multi_chrom_bug(self):
139+
indexed_vcf = self.get_instance("multi_contig.vcf.gz.tbi")
140+
regions = indexed_vcf.partition_into_regions(num_parts=10)
141+
# An earlier version of the code returned this, i.e. with a duplicate
142+
# for 4 with end coord of 0
143+
# ["0:1-", "1", "2", "3", "4:1-0", "4:1-"]
144+
expected = ["0:1-", "1", "2", "3", "4:1-"]
145+
assert [str(r) for r in regions] == expected
129146

130147
@pytest.mark.parametrize(
131148
"target_part_size",
@@ -136,48 +153,60 @@ def test_num_parts_large(self):
136153
],
137154
)
138155
def test_target_part_size(self, target_part_size):
139-
vcf_path = data_path / "CEUTrio.20.21.gatk3.4.g.vcf.bgz"
140-
141-
regions = partition_into_regions(vcf_path, target_part_size=target_part_size)
142-
assert regions is not None
156+
indexed_vcf = self.get_instance("CEUTrio.20.21.gatk3.4.g.vcf.bgz.tbi")
157+
regions = indexed_vcf.partition_into_regions(target_part_size=target_part_size)
143158
assert len(regions) == 5
144-
145-
part_variant_counts = [count_variants(vcf_path, region) for region in regions]
159+
part_variant_counts = [indexed_vcf.count_variants(region) for region in regions]
146160
assert part_variant_counts == [3450, 3869, 4525, 7041, 1025]
147-
total_variants = count_variants(vcf_path)
161+
assert sum(part_variant_counts) == 19910
148162

149-
assert sum(part_variant_counts) == total_variants
150-
151-
def test_invalid_arguments(self):
152-
vcf_path = data_path / "CEUTrio.20.21.gatk3.4.g.vcf.bgz"
163+
def test_partition_invalid_arguments(self):
164+
indexed_vcf = self.get_instance("CEUTrio.20.21.gatk3.4.g.vcf.bgz.tbi")
153165

154166
with pytest.raises(
155167
ValueError, match=r"One of num_parts or target_part_size must be specified"
156168
):
157-
partition_into_regions(vcf_path)
169+
indexed_vcf.partition_into_regions()
158170

159171
with pytest.raises(
160172
ValueError,
161173
match=r"Only one of num_parts or target_part_size may be specified",
162174
):
163-
partition_into_regions(vcf_path, num_parts=4, target_part_size=100_000)
175+
indexed_vcf.partition_into_regions(num_parts=4, target_part_size=100_000)
164176

165177
with pytest.raises(ValueError, match=r"num_parts must be positive"):
166-
partition_into_regions(vcf_path, num_parts=0)
178+
indexed_vcf.partition_into_regions(num_parts=0)
167179

168180
with pytest.raises(ValueError, match=r"target_part_size must be positive"):
169-
partition_into_regions(vcf_path, target_part_size=0)
170-
171-
@pytest.mark.skip("TODO")
172-
def test_missing_index(self, temp_path):
173-
vcf_path = data_path / "CEUTrio.20.21.gatk3.4.g.vcf.bgz"
174-
with pytest.raises(ValueError, match=r"Cannot find .tbi or .csi file."):
175-
partition_into_regions(vcf_path, num_parts=2)
181+
indexed_vcf.partition_into_regions(target_part_size=0)
176182

177-
bogus_index_path = path_for_test(
178-
shared_datadir, "CEUTrio.20.21.gatk3.4.noindex.g.vcf.bgz.index", True
179-
)
183+
def test_bad_index(self):
180184
with pytest.raises(
181185
ValueError, match=r"Only .tbi or .csi indexes are supported."
182186
):
183-
partition_into_regions(vcf_path, index_path=bogus_index_path, num_parts=2)
187+
# We don't actually go out the filesystem before checking so can
188+
# be anything
189+
vcf_utils.IndexedVcf("x", "y")
190+
191+
192+
class TestCsiIndex:
193+
@pytest.mark.parametrize(
194+
"filename",
195+
["CEUTrio.20.21.gatk3.4.g.vcf.bgz", "CEUTrio.20.21.gatk3.4.g.vcf.bgz.tbi"],
196+
)
197+
def test_invalid_csi(self, filename):
198+
with pytest.raises(ValueError, match=r"File not in CSI format."):
199+
vcf_utils.read_csi(data_path / filename)
200+
201+
202+
class TestTabixIndex:
203+
@pytest.mark.parametrize(
204+
"filename",
205+
[
206+
"CEUTrio.20.21.gatk3.4.g.vcf.bgz",
207+
"CEUTrio.20.21.gatk3.4.g.bcf.csi",
208+
],
209+
)
210+
def test_invalid_tbi(self, filename):
211+
with pytest.raises(ValueError, match=r"File not in Tabix format."):
212+
vcf_utils.read_tabix(data_path / filename)

tests/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,4 @@ def get_region_start(region: str) -> int:
5252
def count_variants(path: PathType, region: Optional[str] = None) -> int:
5353
"""Count the number of variants in a VCF file."""
5454
with open_vcf(path) as vcf:
55-
if region is not None:
56-
vcf = vcf(region)
57-
return sum(1 for _ in region_filter(vcf, region))
55+
return sum(1 for _ in region_filter(vcf(str(region)), str(region)))

0 commit comments

Comments
 (0)