22
33"""Predicates for expressions."""
44
5- from .dsl import Expression , Fraction , Probability , Product , Sum
5+ from .dsl import Expression , Fraction , Probability , Product , Sum , Variable
6+ from .graph import NxMixedGraph
67
78__all__ = [
89 "has_markov_postcondition" ,
10+ "is_good_control" ,
11+ "is_bad_control" ,
912]
1013
1114
@@ -28,3 +31,26 @@ def has_markov_postcondition(expression: Expression) -> bool:
2831 )
2932 else :
3033 raise TypeError
34+
35+
36+ def _control_precondition (graph : NxMixedGraph , query : Probability , variable : Variable ):
37+ if missing := query .get_variables ().difference (graph .nodes ()):
38+ raise ValueError (f"Query variables missing: { missing } " )
39+ if variable not in graph .nodes ():
40+ raise ValueError (f"Test variable missing: { variable } " )
41+ # TODO does this need to be extended to check that the
42+ # query and variable aren't counterfactual?
43+
44+
45+ def is_good_control (graph : NxMixedGraph , query : Probability , variable : Variable ) -> bool :
46+ """Return if the variable is a good control."""
47+ _control_precondition (graph , query , variable )
48+
49+ raise NotImplementedError
50+
51+
52+ def is_bad_control (graph : NxMixedGraph , query : Probability , variable : Variable ) -> bool :
53+ """Return if the variable is a good control."""
54+ _control_precondition (graph , query , variable )
55+
56+ raise NotImplementedError
0 commit comments