Skip to content

Commit 2069889

Browse files
Add introspection and referenced fields to eval tree
1 parent 6320d6c commit 2069889

File tree

3 files changed

+73
-14
lines changed

3 files changed

+73
-14
lines changed

tests/test_filter.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,33 @@ def test_evaluate(self, expression, data, expected):
8686
result = fee.evaluate(numpify_values(data))
8787
nt.assert_array_equal(result, expected)
8888

89+
@pytest.mark.parametrize(
90+
("expr", "expected"),
91+
[
92+
("a == b", {"variant_a", "variant_b"}),
93+
("a == b + c", {"variant_a", "variant_b", "variant_c"}),
94+
("(a + 1) < (b + c) - d / a", {f"variant_{x}" for x in "abcd"}),
95+
],
96+
)
97+
def test_referenced_fields(self, expr, expected):
98+
fe = filter_mod.FilterExpression(include=expr)
99+
assert fe.referenced_fields == expected
100+
101+
@pytest.mark.parametrize(
102+
("expr", "expected"),
103+
[
104+
("a == b", "(variant_a)==(variant_b)"),
105+
("a + 1", "(variant_a)+(1)"),
106+
("a + 1 + 2", "(variant_a)+(1)+(2)"),
107+
("a + (1 + 2)", "(variant_a)+((1)+(2))"),
108+
("POS<10", "(variant_position)<(10)"),
109+
('CHROM=="chr1"', "(variant_contig)==('chr1')"),
110+
],
111+
)
112+
def test_repr(self, expr, expected):
113+
fe = filter_mod.FilterExpression(include=expr)
114+
assert repr(fe.parse_result[0]) == expected
115+
89116

90117
class TestBcftoolsParser:
91118
@pytest.mark.parametrize(

vcztools/filter.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,28 @@ class Constant(EvaluationNode):
2929
def eval(self, data):
3030
return self.tokens
3131

32+
def __repr__(self):
33+
return repr(self.tokens)
34+
35+
def referenced_fields(self):
36+
return frozenset()
37+
3238

3339
class Identifier(EvaluationNode):
3440
def __init__(self, mapper, tokens):
3541
self.field_name = mapper(tokens[0])
3642
logger.debug(f"Mapped {tokens[0]} to {self.field_name}")
43+
# TODO add errors for unsupported things like call_ fields etc.
3744

3845
def eval(self, data):
3946
return data[self.field_name]
4047

48+
def __repr__(self):
49+
return self.field_name
50+
51+
def referenced_fields(self):
52+
return frozenset([self.field_name])
53+
4154

4255
class BinaryOperator(EvaluationNode):
4356
op_map = {
@@ -55,19 +68,31 @@ class BinaryOperator(EvaluationNode):
5568
}
5669

5770
def eval(self, data):
58-
# start by eval()'ing the first operand
59-
ret = self.tokens[0].eval(data)
60-
61-
# get following operators and operands in pairs
71+
# get the operators and operands in pairs
72+
operands = self.tokens[0::2]
6273
ops = self.tokens[1::2]
63-
operands = self.tokens[2::2]
64-
for op, operand in zip(ops, operands):
65-
# print(f"Eval {op}, {ret}, {operand}")
66-
# update cumulative value by add/subtract/mult/divide the next operand
74+
# start by eval()'ing the first operand
75+
ret = operands[0].eval(data)
76+
for op, operand in zip(ops, operands[1:]):
6777
arith_fn = self.op_map[op]
6878
ret = arith_fn(ret, operand.eval(data))
6979
return ret
7080

81+
def __repr__(self):
82+
ops = self.tokens[1::2]
83+
operands = self.tokens[0::2]
84+
ret = f"({repr(operands[0])})"
85+
for op, operand in zip(ops, operands[1:]):
86+
ret += f"{op}({repr(operand)})"
87+
return ret
88+
89+
def referenced_fields(self):
90+
operands = self.tokens[0::2]
91+
ret = operands[0].referenced_fields()
92+
for operand in operands[1:]:
93+
ret |= operand.referenced_fields()
94+
return ret
95+
7196

7297
class ComparisonOperator(EvaluationNode):
7398
op_map = {
@@ -85,6 +110,14 @@ def eval(self, data):
85110
comparison_fn = self.op_map[op]
86111
return comparison_fn(op1.eval(data), op2.eval(data))
87112

113+
def __repr__(self):
114+
op1, op, op2 = self.tokens
115+
return f"({repr(op1)}){op}({repr(op2)})"
116+
117+
def referenced_fields(self):
118+
op1, _, op2 = self.tokens
119+
return op1.referenced_fields() | op2.referenced_fields()
120+
88121

89122
def _identity(x):
90123
return x
@@ -110,6 +143,7 @@ def make_bcftools_filter_parser(all_fields=None, map_vcf_identifiers=True):
110143
filter_expression = pp.infix_notation(
111144
constant | identifier,
112145
[
146+
# FIXME Does bcftools support unary minus?
113147
# ("-", 1, pp.OpAssoc.RIGHT, ),
114148
(pp.one_of("* /"), 2, pp.OpAssoc.LEFT, BinaryOperator),
115149
(pp.one_of("+ -"), 2, pp.OpAssoc.LEFT, BinaryOperator),
@@ -128,6 +162,7 @@ def __init__(self, *, field_names=None, include=None, exclude=None):
128162
if field_names is None:
129163
field_names = set()
130164
self.parse_result = None
165+
self.referenced_fields = set()
131166
self.invert = False
132167
expr = None
133168
if include is not None and exclude is not None:
@@ -144,9 +179,8 @@ def __init__(self, *, field_names=None, include=None, exclude=None):
144179
if expr is not None:
145180
parser = make_bcftools_filter_parser(field_names)
146181
self.parse_result = parser.parse_string(expr, parse_all=True)
147-
148-
# Setting to None for now so that we retrieve all fields
149-
self.referenced_fields = None
182+
# This isn't a very good pattern, fix
183+
self.referenced_fields = self.parse_result[0].referenced_fields()
150184

151185
def evaluate(self, chunk_data):
152186
if self.parse_result is None:

vcztools/query.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,7 @@ def generate(root):
272272
# NOTE: this should be done at the top-level when we've
273273
# figured out what fields need to be retrieved from both
274274
# the parsed query and filter expressions.
275-
reader = retrieval.VariantChunkReader(
276-
root, fields=filter_expr.referenced_fields
277-
)
275+
reader = retrieval.VariantChunkReader(root)
278276
for v_chunk in range(root["variant_position"].cdata_shape[0]):
279277
# print("Read v_chunk", v_chunk)
280278
chunk_data = reader[v_chunk]

0 commit comments

Comments
 (0)