Skip to content

Commit 481a0d9

Browse files
committed
Add test utils
1 parent 4c3403d commit 481a0d9

File tree

1 file changed

+105
-0
lines changed

1 file changed

+105
-0
lines changed

tests/utils.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from itertools import zip_longest
2+
from pathlib import Path
3+
4+
import numpy as np
5+
6+
from sgkit.io.vcf.vcf_reader import open_vcf
7+
from sgkit.typing import PathType
8+
9+
10+
def path_for_test(shared_datadir: Path, file: str, is_path: bool = True) -> PathType:
11+
"""Return a test data path whose type is determined by `is_path`.
12+
13+
If `is_path` is True, return a `Path`, otherwise return a `str`.
14+
"""
15+
path: PathType = shared_datadir / file
16+
return path if is_path else str(path)
17+
18+
19+
def assert_vcfs_close(f1, f2, *, rtol=1e-05, atol=1e-03):
20+
"""Like :py:func:`numpy.testing.assert_allclose()`, but for VCF files.
21+
22+
Raises an `AssertionError` if two VCF files are not equal to one another.
23+
Float values in QUAL, INFO, or FORMAT fields are compared up to the
24+
desired tolerance. All other values must match exactly.
25+
26+
Parameters
27+
----------
28+
f1
29+
Path to first VCF to compare.
30+
f2
31+
Path to second VCF to compare.
32+
rtol
33+
Relative tolerance.
34+
atol
35+
Absolute tolerance.
36+
"""
37+
with open_vcf(f1) as vcf1, open_vcf(f2) as vcf2:
38+
assert vcf1.raw_header == vcf2.raw_header
39+
assert vcf1.samples == vcf2.samples
40+
41+
for v1, v2 in zip_longest(vcf1, vcf2):
42+
if v1 is None and v2 is not None:
43+
raise AssertionError(f"Right contains extra variant: {v2}")
44+
if v1 is not None and v2 is None:
45+
raise AssertionError(f"Left contains extra variant: {v1}")
46+
47+
assert v1.CHROM == v2.CHROM, f"CHROM not equal for variants\n{v1}{v2}"
48+
assert v1.POS == v2.POS, f"POS not equal for variants\n{v1}{v2}"
49+
assert v1.ID == v2.ID, f"ID not equal for variants\n{v1}{v2}"
50+
assert v1.REF == v2.REF, f"REF not equal for variants\n{v1}{v2}"
51+
assert v1.ALT == v2.ALT, f"ALT not equal for variants\n{v1}{v2}"
52+
np.testing.assert_allclose(
53+
np.array(v1.QUAL, dtype=np.float32),
54+
np.array(v2.QUAL, dtype=np.float32),
55+
rtol=rtol,
56+
atol=atol,
57+
err_msg=f"QUAL not equal for variants\n{v1}{v2}",
58+
)
59+
assert set(v1.FILTERS) == set(
60+
v2.FILTERS
61+
), f"FILTER not equal for variants\n{v1}{v2}"
62+
63+
assert (
64+
dict(v1.INFO).keys() == dict(v2.INFO).keys()
65+
), f"INFO keys not equal for variants\n{v1}{v2}"
66+
for k in dict(v1.INFO).keys():
67+
# values are python objects (not np arrays)
68+
val1 = v1.INFO[k]
69+
val2 = v2.INFO[k]
70+
if isinstance(val1, float) or (
71+
isinstance(val1, tuple) and any(isinstance(v, float) for v in val1)
72+
):
73+
np.testing.assert_allclose(
74+
np.array(val1, dtype=np.float32),
75+
np.array(val2, dtype=np.float32),
76+
rtol=rtol,
77+
atol=atol,
78+
err_msg=f"INFO {k} not equal for variants\n{v1}{v2}",
79+
)
80+
else:
81+
assert val1 == val2, f"INFO {k} not equal for variants\n{v1}{v2}"
82+
83+
assert v1.FORMAT == v2.FORMAT, f"FORMAT not equal for variants\n{v1}{v2}"
84+
for field in v1.FORMAT:
85+
if field == "GT":
86+
assert (
87+
v1.genotypes == v2.genotypes
88+
), f"GT not equal for variants\n{v1}{v2}"
89+
else:
90+
val1 = v1.format(field)
91+
val2 = v2.format(field)
92+
if val1.dtype.kind == "f":
93+
np.testing.assert_allclose(
94+
val1,
95+
val2,
96+
rtol=rtol,
97+
atol=atol,
98+
err_msg=f"FORMAT {field} not equal for variants\n{v1}{v2}",
99+
)
100+
else:
101+
np.testing.assert_array_equal(
102+
val1,
103+
val2,
104+
err_msg=f"FORMAT {field} not equal for variants\n{v1}{v2}",
105+
)

0 commit comments

Comments
 (0)