Skip to content

Commit c818187

Browse files
committed
Add signature and preconditions
@vartikatewari heads up, it's happening!
1 parent c038713 commit c818187

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

src/y0/predicates.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22

33
"""Predicates for expressions."""
44

5-
from .dsl import Expression, Fraction, Probability, Product, Sum
5+
from .dsl import Expression, Fraction, Probability, Product, Sum, Variable
6+
from .graph import NxMixedGraph
67

78
__all__ = [
89
"has_markov_postcondition",
10+
"is_good_control",
11+
"is_bad_control",
912
]
1013

1114

@@ -28,3 +31,26 @@ def has_markov_postcondition(expression: Expression) -> bool:
2831
)
2932
else:
3033
raise TypeError
34+
35+
36+
def _control_precondition(graph: NxMixedGraph, query: Probability, variable: Variable):
37+
if missing := query.get_variables().difference(graph.nodes()):
38+
raise ValueError(f"Query variables missing: {missing}")
39+
if variable not in graph.nodes():
40+
raise ValueError(f"Test variable missing: {variable}")
41+
# TODO does this need to be extended to check that the
42+
# query and variable aren't counterfactual?
43+
44+
45+
def is_good_control(graph: NxMixedGraph, query: Probability, variable: Variable) -> bool:
46+
"""Return if the variable is a good control."""
47+
_control_precondition(graph, query, variable)
48+
49+
raise NotImplementedError
50+
51+
52+
def is_bad_control(graph: NxMixedGraph, query: Probability, variable: Variable) -> bool:
53+
"""Return if the variable is a good control."""
54+
_control_precondition(graph, query, variable)
55+
56+
raise NotImplementedError

0 commit comments

Comments
 (0)