Skip to content

Commit a51345a

Browse files
committed
Update precondition tests
1 parent 5b26752 commit a51345a

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

src/y0/controls.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212

1313

1414
def _control_precondition(graph: NxMixedGraph, query: Probability, variable: Variable):
15-
if missing := query.get_variables().difference(graph.nodes()):
16-
raise ValueError(f"Query variables missing: {missing}")
1715
if variable not in graph.nodes():
1816
raise ValueError(f"Test variable missing: {variable}")
1917
# TODO does this need to be extended to check that the
@@ -29,7 +27,6 @@ def is_good_control(graph: NxMixedGraph, query: Probability, variable: Variable)
2927
:return: If the variable is a good control
3028
"""
3129
_control_precondition(graph, query, variable)
32-
3330
raise NotImplementedError
3431

3532

@@ -46,5 +43,4 @@ def is_bad_control(graph: NxMixedGraph, query: Probability, variable: Variable)
4643
:return: If the variable is a bad control
4744
"""
4845
_control_precondition(graph, query, variable)
49-
5046
raise NotImplementedError

tests/test_controls.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import unittest
66

77
from y0.controls import is_bad_control, is_good_control
8-
from y0.dsl import U1, U2, M, P, U, W, X, Y, Z
8+
from y0.dsl import U1, U2, A, M, P, U, W, X, Y, Z
99
from y0.graph import NxMixedGraph
1010

1111
model_1 = NxMixedGraph.from_edges(directed=[(Z, X), (Z, Y), (X, Y)])
@@ -71,6 +71,13 @@
7171
class TestControls(unittest.TestCase):
7272
"""Test case for good, bad, and neutral controls."""
7373

74+
def test_preconditions(self):
75+
"""Test the preconditions are checked properly for good controls."""
76+
for func in is_good_control, is_bad_control:
77+
with self.subTest(name=func.__name__):
78+
with self.assertRaises(ValueError):
79+
func(model_1, P(Y @ X), A)
80+
7481
def test_good_controls(self):
7582
"""Test good controls."""
7683
for model in good_test_models:

0 commit comments

Comments
 (0)