Skip to content

Commit 6320d6c

Browse files
Refactor variant filtering to subset of bcftools
Previous filtering logic was not comprehensive. This is the starting point for fixing various issues with it. Closes #177 Closes #173
1 parent a750ded commit 6320d6c

File tree

8 files changed

+492
-429
lines changed

8 files changed

+492
-429
lines changed

tests/test_bcftools_validation.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,25 +46,25 @@ def run_vcztools(args: str, expect_error=False) -> tuple[str, str]:
4646
("view --no-version", "sample.vcf.gz"),
4747
("view --no-version", "chr22.vcf.gz"),
4848
("view --no-version", "msprime_diploid.vcf.gz"),
49+
("view --no-version -i 'ID == \"rs6054257\"'", "sample.vcf.gz"),
4950
("view --no-version -i 'INFO/DP > 10'", "sample.vcf.gz"),
50-
("view --no-version -i 'FMT/DP >= 5 && FMT/GQ > 10'", "sample.vcf.gz"),
51-
("view --no-version -i 'FMT/DP >= 5 & FMT/GQ>10'", "sample.vcf.gz"),
52-
(
53-
"view --no-version -i '(QUAL > 10 || FMT/GQ>10) && POS > 100000'",
54-
"sample.vcf.gz"
55-
),
56-
(
57-
"view --no-version -i '(FMT/DP >= 8 | FMT/GQ>40) && POS > 100000'",
58-
"sample.vcf.gz"
59-
),
60-
(
61-
"view --no-version -e '(FMT/DP >= 8 | FMT/GQ>40) && POS > 100000'",
62-
"sample.vcf.gz"
63-
),
64-
(
65-
"view --no-version -G",
66-
"sample.vcf.gz"
67-
),
51+
# Filters based on FMT values are currently disabled.
52+
# https://github.com/sgkit-dev/vcztools/issues/180
53+
# ("view --no-version -i 'FMT/DP >= 5 && FMT/GQ > 10'", "sample.vcf.gz"),
54+
# ("view --no-version -i 'FMT/DP >= 5 & FMT/GQ>10'", "sample.vcf.gz"),
55+
# (
56+
# "view --no-version -i '(QUAL > 10 || FMT/GQ>10) && POS > 100000'",
57+
# "sample.vcf.gz"
58+
# ),
59+
# (
60+
# "view --no-version -i '(FMT/DP >= 8 | FMT/GQ>40) && POS > 100000'",
61+
# "sample.vcf.gz"
62+
# ),
63+
# (
64+
# "view --no-version -e '(FMT/DP >= 8 | FMT/GQ>40) && POS > 100000'",
65+
# "sample.vcf.gz"
66+
# ),
67+
("view --no-version -G", "sample.vcf.gz"),
6868
(
6969
"view --no-update --no-version --samples-file "
7070
"tests/data/txt/samples.txt",
@@ -83,10 +83,14 @@ def run_vcztools(args: str, expect_error=False) -> tuple[str, str]:
8383
("view --no-version -s ^NA00003,NA00002", "sample.vcf.gz"),
8484
("view --no-version -s ^NA00003,NA00002,NA00003", "sample.vcf.gz"),
8585
("view --no-version -S ^tests/data/txt/samples.txt", "sample.vcf.gz"),
86-
]
86+
],
87+
# This is necessary when trying to run individual tests, as the arguments above
88+
# make for unworkable command lines
89+
# ids=range(26),
8790
)
8891
# fmt: on
8992
def test_vcf_output(tmp_path, args, vcf_file):
93+
# print("args:", args)
9094
original = pathlib.Path("tests/data/vcf") / vcf_file
9195
vcz = vcz_path_cache(original)
9296

@@ -102,11 +106,10 @@ def test_vcf_output(tmp_path, args, vcf_file):
102106

103107
assert_vcfs_close(bcftools_out_file, vcztools_out_file)
104108

109+
105110
@pytest.mark.parametrize(
106111
("args", "vcf_file"),
107-
[
108-
("view --no-version", "sample.vcf.gz")
109-
],
112+
[("view --no-version", "sample.vcf.gz")],
110113
)
111114
def test_vcf_output_with_output_option(tmp_path, args, vcf_file):
112115
vcf_path = pathlib.Path("tests/data/vcf") / vcf_file
@@ -151,6 +154,14 @@ def test_vcf_output_with_output_option(tmp_path, args, vcf_file):
151154
(r"query -f 'GQ:[ %GQ] \t GT:[ %GT]\n'", "sample.vcf.gz"),
152155
(r"query -f '[%CHROM:%POS %SAMPLE %GT\n]'", "sample.vcf.gz"),
153156
(r"query -f '[%SAMPLE %GT %DP\n]'", "sample.vcf.gz"),
157+
(
158+
r"query -f '[%POS %SAMPLE %GT %DP %GQ\n]' -i 'INFO/DP >= 5'",
159+
"sample.vcf.gz",
160+
),
161+
(
162+
r"query -f '[%POS %QUAL\n]' -i'(QUAL > 10 && POS > 100000)'",
163+
"sample.vcf.gz",
164+
),
154165
],
155166
)
156167
def test_output(tmp_path, args, vcf_name):

tests/test_filter.py

Lines changed: 167 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,136 +1,203 @@
11
import pathlib
22

33
import numpy as np
4+
import numpy.testing as nt
45
import pyparsing as pp
56
import pytest
67
import zarr
7-
from numpy.testing import assert_array_equal
88

99
from tests.utils import vcz_path_cache
10-
from vcztools.filter import FilterExpressionEvaluator, FilterExpressionParser
10+
from vcztools import filter as filter_mod
1111

1212

1313
class TestFilterExpressionParser:
14-
@pytest.fixture()
15-
def identifier_parser(self, parser):
16-
return parser._identifier_parser
14+
@pytest.mark.parametrize(
15+
"expression",
16+
[
17+
"",
18+
"| |",
19+
"a +",
20+
'"stri + 2',
21+
],
22+
)
23+
def test_invalid_expressions(self, expression):
24+
parser = filter_mod.make_bcftools_filter_parser(map_vcf_identifiers=False)
25+
with pytest.raises(pp.ParseException):
26+
parser.parse_string(expression, parse_all=True)
1727

18-
@pytest.fixture()
19-
def parser(self):
20-
return FilterExpressionParser()
2128

29+
class TestFilterExpressionSample:
2230
@pytest.mark.parametrize(
2331
("expression", "expected_result"),
2432
[
25-
("1", [1]),
26-
("1.0", [1.0]),
27-
("1e-4", [0.0001]),
28-
('"String"', ["String"]),
29-
("POS", ["POS"]),
30-
("INFO/DP", ["INFO/DP"]),
31-
("FORMAT/GT", ["FORMAT/GT"]),
32-
("FMT/GT", ["FMT/GT"]),
33-
("GT", ["GT"]),
33+
("POS < 1000", [1, 1, 0, 0, 0, 0, 0, 0, 1]),
34+
("INFO/DP > 10", [0, 0, 1, 1, 0, 1, 0, 0, 0]),
35+
# Not supporting format fields for now: #180
36+
# ("FMT/GQ > 20", [0, 0, 1, 1, 1, 1, 1, 0, 0]),
37+
# ("FMT/DP >= 5 && FMT/GQ > 10", [0, 0, 1, 1, 1, 0, 0, 0, 0]),
38+
# ("GT > 0", [1, 1, 1, 1, 1, 0, 1, 0, 1]),
39+
# ("GT > 0 & FMT/HQ >= 10", [0, 0, 1, 1, 1, 0, 0, 0, 0]),
40+
# ("FMT/DP >= 5 & FMT/GQ>10", [0, 0, 1, 0, 1, 0, 0, 0, 0]),
41+
# ("QUAL > 10 || FMT/GQ>10", [0, 0, 1, 1, 1, 1, 1, 0, 0]),
42+
# ("(QUAL > 10 || FMT/GQ>10) && POS > 100000", [0, 0, 0, 0, 1, 1, 1, 0, 0]),
43+
# ("(FMT/DP >= 8 | FMT/GQ>40) && POS > 100000",
44+
# [0, 0, 0, 0, 0, 1, 0, 0, 0]),
3445
],
3546
)
36-
def test_valid_identifiers(self, identifier_parser, expression, expected_result):
37-
assert identifier_parser(expression).as_list() == expected_result
47+
def test(self, expression, expected_result):
48+
original = pathlib.Path("tests/data/vcf") / "sample.vcf.gz"
49+
vcz = vcz_path_cache(original)
50+
root = zarr.open(vcz, mode="r")
51+
data = {field: root[field][:] for field in root.keys()}
52+
filter_expr = filter_mod.FilterExpression(
53+
field_names=set(root), include=expression
54+
)
55+
result = filter_expr.evaluate(data)
56+
nt.assert_array_equal(result, expected_result)
57+
58+
filter_expr = filter_mod.FilterExpression(
59+
field_names=set(root), exclude=expression
60+
)
61+
result = filter_expr.evaluate(data)
62+
nt.assert_array_equal(result, np.logical_not(expected_result))
3863

64+
65+
def numpify_values(data):
66+
return {k: np.array(v) for k, v in data.items()}
67+
68+
69+
class TestFilterExpression:
3970
@pytest.mark.parametrize(
40-
"expression",
71+
("expression", "data", "expected"),
4172
[
42-
"",
43-
"FORMAT/ GT",
44-
"format / GT",
45-
"fmt / GT",
46-
"info / DP",
47-
"'String'",
73+
("POS<5", {"variant_position": [1, 5, 6, 10]}, [1, 0, 0, 0]),
74+
("INFO/XX>=10", {"variant_XX": [1, 5, 6, 10]}, [0, 0, 0, 1]),
75+
("INFO/XX / 2 >=5", {"variant_XX": [1, 5, 6, 10]}, [0, 0, 0, 1]),
76+
("POS<5 | POS>8", {"variant_position": [1, 5, 6, 10]}, [1, 0, 0, 1]),
77+
(
78+
"POS<0 & POS<1 & POS<2 & POS<3 & POS<4",
79+
{"variant_position": range(10)},
80+
np.zeros(10, dtype=bool),
81+
),
4882
],
4983
)
50-
def test_invalid_identifiers(self, identifier_parser, expression):
51-
with pytest.raises(pp.ParseException):
52-
identifier_parser(expression)
84+
def test_evaluate(self, expression, data, expected):
85+
fee = filter_mod.FilterExpression(include=expression)
86+
result = fee.evaluate(numpify_values(data))
87+
nt.assert_array_equal(result, expected)
88+
5389

90+
class TestBcftoolsParser:
5491
@pytest.mark.parametrize(
55-
("expression", "expected_result"),
92+
"expr",
5693
[
57-
("POS>=100", [["POS", ">=", 100]]),
58-
(
59-
"FMT/DP>10 && FMT/GQ>10",
60-
[[["FMT/DP", ">", 10], "&&", ["FMT/GQ", ">", 10]]],
61-
),
62-
("QUAL>10 || FMT/GQ>10", [[["QUAL", ">", 10], "||", ["FMT/GQ", ">", 10]]]),
63-
(
64-
"FMT/DP>10 && FMT/GQ>10 || QUAL > 10",
65-
[
66-
[
67-
[["FMT/DP", ">", 10], "&&", ["FMT/GQ", ">", 10]],
68-
"||",
69-
["QUAL", ">", 10],
70-
]
71-
],
72-
),
73-
(
74-
"QUAL>10 || FMT/DP>10 && FMT/GQ>10",
75-
[
76-
[
77-
["QUAL", ">", 10],
78-
"||",
79-
[["FMT/DP", ">", 10], "&&", ["FMT/GQ", ">", 10]],
80-
]
81-
],
82-
),
83-
(
84-
"QUAL>10 | FMT/DP>10 & FMT/GQ>10",
85-
[
86-
[
87-
["QUAL", ">", 10],
88-
"|",
89-
[["FMT/DP", ">", 10], "&", ["FMT/GQ", ">", 10]],
90-
],
91-
],
92-
),
93-
(
94-
"(QUAL>10 || FMT/DP>10) && FMT/GQ>10",
95-
[
96-
[
97-
[["QUAL", ">", 10], "||", ["FMT/DP", ">", 10]],
98-
"&&",
99-
["FMT/GQ", ">", 10],
100-
]
101-
],
102-
),
94+
"2",
95+
'"x"',
96+
'"INFO/STRING"',
97+
"2 + 2",
98+
"(2 + 3) / 2",
99+
"2 / (2 + 3)",
100+
"1 + 1 + 1 + 1 + 1",
101+
"5 * (2 / 3)",
102+
"5 * 2 / 3",
103+
"1 + 2 - 3 / 4 * 5 + 6 * 7 / 8",
104+
"5 / (1 + 2 - 4) / (4 * 5 + 6 * 7 / 8)",
105+
"5 < 2",
106+
"5 > 2",
107+
"0 == 0",
108+
"0 != 0",
109+
"(1 + 2) == 0",
110+
"1 + 2 == 0",
111+
"1 + 2 == 1 + 2 + 3",
112+
"(1 + 2) == (1 + 2 + 3)",
113+
"(1 == 1) != (2 == 2)",
114+
'("x" == "x")',
103115
],
104116
)
105-
def test_valid_expressions(self, parser, expression, expected_result):
106-
assert parser(expression=expression).as_list() == expected_result
117+
def test_python_arithmetic_expressions(self, expr):
118+
parser = filter_mod.make_bcftools_filter_parser()
119+
parsed = parser.parse_string(expr, parse_all=True)
120+
result = parsed[0].eval({})
121+
assert result == eval(expr)
107122

123+
@pytest.mark.parametrize(
124+
("expr", "data"),
125+
[
126+
("a", {"a": 1}),
127+
("a + a", {"a": 1}),
128+
("a + 2 * a - 1", {"a": 7}),
129+
("a - b < a + b", {"a": 7, "b": 6}),
130+
("(a - b) < (a + b)", {"a": 7, "b": 6}),
131+
("(a - b) < (a + b)", {"a": 7.0, "b": 6.666}),
132+
("a == a", {"a": 1}),
133+
('a == "string"', {"a": "string"}),
134+
],
135+
)
136+
def test_python_arithmetic_expressions_data(self, expr, data):
137+
parser = filter_mod.make_bcftools_filter_parser(map_vcf_identifiers=False)
138+
parsed = parser.parse_string(expr, parse_all=True)
139+
result = parsed[0].eval(data)
140+
assert result == eval(expr, data)
108141

109-
class TestFilterExpressionEvaluator:
110142
@pytest.mark.parametrize(
111-
("expression", "expected_result"),
143+
("expr", "data"),
112144
[
113-
("POS < 1000", [1, 1, 0, 0, 0, 0, 0, 0, 1]),
114-
("FMT/GQ > 20", [0, 0, 1, 1, 1, 1, 1, 0, 0]),
115-
("FMT/DP >= 5 && FMT/GQ > 10", [0, 0, 1, 1, 1, 0, 0, 0, 0]),
116-
("FMT/DP >= 5 & FMT/GQ>10", [0, 0, 1, 0, 1, 0, 0, 0, 0]),
117-
("QUAL > 10 || FMT/GQ>10", [0, 0, 1, 1, 1, 1, 1, 0, 0]),
118-
("(QUAL > 10 || FMT/GQ>10) && POS > 100000", [0, 0, 0, 0, 1, 1, 1, 0, 0]),
119-
("(FMT/DP >= 8 | FMT/GQ>40) && POS > 100000", [0, 0, 0, 0, 0, 1, 0, 0, 0]),
120-
("INFO/DP > 10", [0, 0, 1, 1, 0, 1, 0, 0, 0]),
121-
("GT > 0", [1, 1, 1, 1, 1, 0, 1, 0, 1]),
122-
("GT > 0 & FMT/HQ >= 10", [0, 0, 1, 1, 1, 0, 0, 0, 0]),
145+
("a", {"a": [1, 2, 3]}),
146+
("a + a", {"a": [1, 2, 3]}),
147+
("1 + a + a", {"a": [1, 2, 3]}),
148+
("a + b", {"a": [1, 2, 3], "b": [5, 6, 7]}),
149+
("(a + b) < c", {"a": [1, 2, 3], "b": [5, 6, 7], "c": [5, 10, 15]}),
123150
],
124151
)
125-
def test(self, expression, expected_result):
126-
original = pathlib.Path("tests/data/vcf") / "sample.vcf.gz"
127-
vcz = vcz_path_cache(original)
128-
root = zarr.open(vcz, mode="r")
152+
def test_numpy_arithmetic_expressions_data(self, expr, data):
153+
parser = filter_mod.make_bcftools_filter_parser(map_vcf_identifiers=False)
154+
parsed = parser.parse_string(expr, parse_all=True)
155+
npdata = numpify_values(data)
156+
result = parsed[0].eval(npdata)
157+
evaled = eval(expr, npdata)
158+
nt.assert_array_equal(result, evaled)
159+
160+
@pytest.mark.parametrize(
161+
("expr", "expected"),
162+
[
163+
("1 & 1", True),
164+
("0 & 1", False),
165+
("1 & 0", False),
166+
("0 & 0", False),
167+
("1 | 1", True),
168+
("0 | 1", True),
169+
("1 | 0", True),
170+
("0 | 0", False),
171+
("(1 < 2) | 0", True),
172+
("(1 < 2) & 0", False),
173+
],
174+
)
175+
def test_boolean_operator_expressions(self, expr, expected):
176+
parser = filter_mod.make_bcftools_filter_parser()
177+
parsed = parser.parse_string(expr, parse_all=True)
178+
result = parsed[0].eval({})
179+
assert result == expected
180+
181+
@pytest.mark.parametrize(
182+
("expr", "data", "expected"),
183+
[
184+
("a == b", {"a": [0, 1], "b": [1, 1]}, [False, True]),
185+
("a = b", {"a": [0, 1], "b": [1, 1]}, [False, True]),
186+
("a & b", {"a": [0, 1], "b": [1, 1]}, [False, True]),
187+
("a && b", {"a": [0, 1], "b": [1, 1]}, [False, True]),
188+
("a | b", {"a": [0, 1], "b": [1, 1]}, [True, True]),
189+
("a || b", {"a": [0, 1], "b": [1, 1]}, [True, True]),
190+
("(a < 2) & (b > 1)", {"a": [0, 1], "b": [1, 2]}, [False, True]),
191+
],
192+
)
193+
def test_boolean_operator_expressions_data(self, expr, data, expected):
194+
parser = filter_mod.make_bcftools_filter_parser(map_vcf_identifiers=False)
195+
parsed = parser.parse_string(expr, parse_all=True)
196+
result = parsed[0].eval(numpify_values(data))
197+
nt.assert_array_equal(result, expected)
129198

130-
parser = FilterExpressionParser()
131-
parse_results = parser(expression)[0]
132-
evaluator = FilterExpressionEvaluator(parse_results)
133-
assert_array_equal(evaluator(root, 0), expected_result)
134199

135-
invert_evaluator = FilterExpressionEvaluator(parse_results, invert=True)
136-
assert_array_equal(invert_evaluator(root, 0), np.logical_not(expected_result))
200+
class TestAPIErrors:
201+
def test_include_and_exclude(self):
202+
with pytest.raises(ValueError, match="Cannot handle both an include "):
203+
filter_mod.FilterExpression(include="x", exclude="y")

0 commit comments

Comments
 (0)