Skip to content

Commit fde9aba

Browse files
tomwhitejeromekelleher
authored andcommitted
Per-sample filtering
1 parent ce78b2f commit fde9aba

File tree

3 files changed

+145
-41
lines changed

3 files changed

+145
-41
lines changed

tests/test_bcftools_validation.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,23 @@ def run_vcztools(args: str, expect_error=False) -> tuple[str, str]:
5252
("view --no-version -i 'INFO/DP > 10'", "sample.vcf.gz"),
5353
# Filters based on FMT values are currently disabled.
5454
# https://github.com/sgkit-dev/vcztools/issues/180
55-
# ("view --no-version -i 'FMT/DP >= 5 && FMT/GQ > 10'", "sample.vcf.gz"),
56-
# ("view --no-version -i 'FMT/DP >= 5 & FMT/GQ>10'", "sample.vcf.gz"),
57-
# (
58-
# "view --no-version -i '(QUAL > 10 || FMT/GQ>10) && POS > 100000'",
59-
# "sample.vcf.gz"
60-
# ),
61-
# (
62-
# "view --no-version -i '(FMT/DP >= 8 | FMT/GQ>40) && POS > 100000'",
63-
# "sample.vcf.gz"
64-
# ),
65-
# (
66-
# "view --no-version -e '(FMT/DP >= 8 | FMT/GQ>40) && POS > 100000'",
67-
# "sample.vcf.gz"
68-
# ),
55+
("view --no-version -i 'FMT/DP >= 5'", "sample.vcf.gz"),
56+
("view --no-version -i 'FMT/DP >= 5 && FMT/GQ > 10'", "sample.vcf.gz"),
57+
("view --no-version -i 'FMT/DP >= 5 & FMT/GQ>10'", "sample.vcf.gz"),
58+
("view --no-version -i 'FMT/DP>5 && FMT/GQ<45'", "sample.vcf.gz"),
59+
("view --no-version -i 'FMT/DP>5 & FMT/GQ<45'", "sample.vcf.gz"),
60+
(
61+
"view --no-version -i '(QUAL > 10 || FMT/GQ>10) && POS > 100000'",
62+
"sample.vcf.gz"
63+
),
64+
(
65+
"view --no-version -i '(FMT/DP >= 8 | FMT/GQ>40) && POS > 100000'",
66+
"sample.vcf.gz"
67+
),
68+
(
69+
"view --no-version -e '(FMT/DP >= 8 | FMT/GQ>40) && POS > 100000'",
70+
"sample.vcf.gz"
71+
),
6972
("view --no-version -G", "sample.vcf.gz"),
7073
(
7174
"view --no-update --no-version --samples-file "
@@ -88,7 +91,7 @@ def run_vcztools(args: str, expect_error=False) -> tuple[str, str]:
8891
],
8992
# This is necessary when trying to run individual tests, as the arguments above
9093
# make for unworkable command lines
91-
# ids=range(26),
94+
# ids=range(28),
9295
)
9396
# fmt: on
9497
def test_vcf_output(tmp_path, args, vcf_file):
@@ -175,6 +178,23 @@ def test_vcf_output_with_output_option(tmp_path, args, vcf_file):
175178
# (r"query -f '%AC{1}\n' -i 'AC[1]>10' ", "sample.vcf.gz"),
176179
# TODO fill-out more of these when supported for more stuff is available
177180
# in filtering
181+
# Per-sample query tests
182+
(
183+
r"query -f '[%CHROM %POS %SAMPLE %GT %DP %GQ\n]' -i 'FMT/DP>3'",
184+
"sample.vcf.gz"
185+
),
186+
(
187+
r"query -f '[%CHROM %POS %SAMPLE %GT %DP %GQ\n]' -i 'FMT/GQ>30'",
188+
"sample.vcf.gz"
189+
),
190+
(
191+
r"query -f '[%CHROM %POS %SAMPLE %GT %DP %GQ\n]' -i 'FMT/DP>3 & FMT/GQ>30'",
192+
"sample.vcf.gz"
193+
),
194+
(
195+
r"query -f '[%CHROM %POS %SAMPLE %GT %DP %GQ\n]' -i 'FMT/DP>3 && FMT/GQ>30'",
196+
"sample.vcf.gz"
197+
),
178198
],
179199
)
180200
def test_output(tmp_path, args, vcf_name):

tests/test_filter.py

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,20 @@ class TestFilterExpressionSample:
7171
("POS < 1000", [1, 1, 0, 0, 0, 0, 0, 0, 1]),
7272
("INFO/DP > 10", [0, 0, 1, 1, 0, 1, 0, 0, 0]),
7373
# Not supporting format fields for now: #180
74-
# ("FMT/GQ > 20", [0, 0, 1, 1, 1, 1, 1, 0, 0]),
74+
(
75+
"FMT/GQ > 20",
76+
[
77+
[0, 0, 0],
78+
[0, 0, 0],
79+
[1, 1, 1],
80+
[1, 0, 1],
81+
[1, 0, 1],
82+
[1, 1, 1],
83+
[0, 0, 1],
84+
[0, 0, 0],
85+
[0, 0, 0],
86+
],
87+
),
7588
# ("FMT/DP >= 5 && FMT/GQ > 10", [0, 0, 1, 1, 1, 0, 0, 0, 0]),
7689
# ("GT > 0", [1, 1, 1, 1, 1, 0, 1, 0, 1]),
7790
# ("GT > 0 & FMT/HQ >= 10", [0, 0, 1, 1, 1, 0, 0, 0, 0]),
@@ -124,18 +137,6 @@ def test_evaluate(self, expression, data, expected):
124137
result = fee.evaluate(numpify_values(data))
125138
nt.assert_array_equal(result, expected)
126139

127-
@pytest.mark.parametrize(
128-
"expression",
129-
[
130-
"FORMAT/AD > 30",
131-
"FMT/AD > 30",
132-
"GT > 30",
133-
],
134-
)
135-
def test_sample_evaluation_unsupported(self, expression):
136-
with pytest.raises(filter_mod.UnsupportedSampleFilteringError):
137-
filter_mod.FilterExpression(include=expression)
138-
139140
@pytest.mark.parametrize(
140141
("expr", "expected"),
141142
[
@@ -300,13 +301,53 @@ def test_boolean_operator_expressions(self, expr, expected):
300301
("a == b", {"a": [0, 1], "b": [1, 1]}, [False, True]),
301302
("a = b", {"a": [0, 1], "b": [1, 1]}, [False, True]),
302303
("a & b", {"a": [0, 1], "b": [1, 1]}, [False, True]),
303-
("a && b", {"a": [0, 1], "b": [1, 1]}, [False, True]),
304304
("a | b", {"a": [0, 1], "b": [1, 1]}, [True, True]),
305-
("a || b", {"a": [0, 1], "b": [1, 1]}, [True, True]),
306305
("(a < 2) & (b > 1)", {"a": [0, 1], "b": [1, 2]}, [False, True]),
307306
# AND has precedence over OR
308307
("t | f & f", {"t": [1], "f": [0]}, [True or False and False]),
309308
("(t | f) & f", {"t": [1], "f": [0]}, [(True or False) and False]),
309+
(
310+
"call_a && call_b",
311+
{
312+
"call_a": [
313+
[0, 0, 0, 0],
314+
[0, 0, 1, 1],
315+
[0, 0, 0, 0],
316+
],
317+
"call_b": [
318+
[0, 0, 0, 0],
319+
[0, 1, 0, 1],
320+
[1, 1, 1, 1],
321+
],
322+
},
323+
[
324+
[False, False, False, False],
325+
[False, True, True, True],
326+
# all False since condition a is not met (all 0)
327+
[False, False, False, False],
328+
],
329+
),
330+
(
331+
"call_a || call_b",
332+
{
333+
"call_a": [
334+
[0, 0, 0, 0],
335+
[0, 0, 1, 1],
336+
[0, 0, 0, 0],
337+
],
338+
"call_b": [
339+
[0, 0, 0, 0],
340+
[0, 1, 0, 1],
341+
[1, 1, 1, 1],
342+
],
343+
},
344+
[
345+
[False, False, False, False],
346+
# all True since variant site is included
347+
[True, True, True, True],
348+
[True, True, True, True],
349+
],
350+
),
310351
],
311352
)
312353
def test_boolean_operator_expressions_data(self, expr, data, expected):

vcztools/filter.py

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,6 @@ class UnsupportedFileReferenceError(UnsupportedFilteringFeatureError):
4747
feature = "File references"
4848

4949

50-
class UnsupportedSampleFilteringError(UnsupportedFilteringFeatureError):
51-
issue = "180"
52-
feature = "Per-sample filter expressions"
53-
54-
5550
class UnsupportedFunctionsError(UnsupportedFilteringFeatureError):
5651
issue = "190"
5752
feature = "Function evaluation"
@@ -110,13 +105,11 @@ def __init__(self, tokens):
110105
class Identifier(EvaluationNode):
111106
def __init__(self, mapper, tokens):
112107
self.field_name = mapper(tokens[0])
113-
if self.field_name.startswith("call_"):
114-
raise UnsupportedSampleFilteringError()
115108
logger.debug(f"Mapped {tokens[0]} to {self.field_name}")
116109

117110
def eval(self, data):
118111
value = np.asarray(data[self.field_name])
119-
if len(value.shape) > 1:
112+
if not self.field_name.startswith("call_") and len(value.shape) > 1:
120113
raise Unsupported2DFieldsError()
121114
return value
122115

@@ -160,6 +153,57 @@ def referenced_fields(self):
160153
return operand.referenced_fields()
161154

162155

156+
def double_and(a, b):
157+
# if both operands are 1D, then they are just variant masks
158+
if a.ndim == 1 and b.ndim == 1:
159+
return np.logical_and(a, b)
160+
161+
# if either operand is 1D and the other is 2D, then make both 2D
162+
if a.ndim == 1 and b.ndim == 2:
163+
a = np.expand_dims(a, axis=1)
164+
elif a.ndim == 2 and b.ndim == 1:
165+
b = np.expand_dims(b, axis=1)
166+
167+
if a.ndim == 2 and b.ndim == 2:
168+
# a variant site is included only if both conditions are met
169+
# but not necessarily in the same sample
170+
variant_mask = np.logical_and(np.any(a, axis=1), np.any(b, axis=1))
171+
variant_mask = np.expand_dims(variant_mask, axis=1)
172+
# a sample is included if either condition is met
173+
sample_mask = np.logical_or(a, b)
174+
# but if a variant site is not included then none of its samples should be
175+
return np.logical_and(variant_mask, sample_mask)
176+
else:
177+
raise NotImplementedError(
178+
f"&& not implemented for dimensions {a.ndim} and {b.ndim}"
179+
)
180+
181+
182+
def double_or(a, b):
183+
# if both operands are 1D, then they are just variant masks
184+
if a.ndim == 1 and b.ndim == 1:
185+
return np.logical_or(a, b)
186+
187+
# if either operand is 1D and the other is 2D, then make both 2D
188+
if a.ndim == 1 and b.ndim == 2:
189+
a = np.expand_dims(a, axis=1)
190+
elif a.ndim == 2 and b.ndim == 1:
191+
b = np.expand_dims(b, axis=1)
192+
193+
if a.ndim == 2 and b.ndim == 2:
194+
# a variant site is included if either condition is met in any sample
195+
variant_mask = np.logical_or(np.any(a, axis=1), np.any(b, axis=1))
196+
variant_mask = np.expand_dims(variant_mask, axis=1)
197+
# a sample is included if either condition is met
198+
sample_mask = np.logical_or(a, b)
199+
# but if a variant site is included then all of its samples should be
200+
return np.logical_or(variant_mask, sample_mask)
201+
else:
202+
raise NotImplementedError(
203+
f"|| not implemented for dimensions {a.ndim} and {b.ndim}"
204+
)
205+
206+
163207
class BinaryOperator(EvaluationNode):
164208
op_map = {
165209
"*": operator.mul,
@@ -170,9 +214,8 @@ class BinaryOperator(EvaluationNode):
170214
# circuit optimisations
171215
"&": np.logical_and,
172216
"|": np.logical_or,
173-
# As we're only supporting 1D values for now, these are the same thing
174-
"&&": np.logical_and,
175-
"||": np.logical_or,
217+
"&&": double_and,
218+
"||": double_or,
176219
}
177220

178221
def eval(self, data):

0 commit comments

Comments
 (0)