Skip to content

Commit 5b26752

Browse files
committed
Add neutral control examples and tests
1 parent 03b67ba commit 5b26752

File tree

1 file changed

+30
-19
lines changed

1 file changed

+30
-19
lines changed

tests/test_controls.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,9 @@
55
import unittest
66

77
from y0.controls import is_bad_control, is_good_control
8-
from y0.dsl import U1, U2, M, P, Variable, X, Y, Z
8+
from y0.dsl import U1, U2, M, P, U, W, X, Y, Z
99
from y0.graph import NxMixedGraph
1010

11-
U = Variable("U")
12-
query = P(Y @ X)
13-
1411
model_1 = NxMixedGraph.from_edges(directed=[(Z, X), (Z, Y), (X, Y)])
1512
model_2 = NxMixedGraph.from_edges(directed=[(U, Z), (Z, X), (X, Y), (U, Y)])
1613
model_3 = NxMixedGraph.from_edges(directed=[(U, X), (U, Z), (Z, Y), (X, Y)])
@@ -27,18 +24,18 @@
2724
model_6,
2825
]
2926

30-
# M-bias
27+
# bad control, M-bias
3128
model_7 = NxMixedGraph.from_edges(directed=[(U1, Z), (U2, Z), (U1, X), (U2, Y), (X, Y)])
32-
# Bias amplification
29+
# bad control, Bias amplification
3330
model_10 = NxMixedGraph.from_edges(directed=[(Z, X), (U, X), (U, Y), (X, Y)])
34-
#
31+
# bad control
3532
model_11 = NxMixedGraph.from_edges(directed=[(X, Z), (Z, Y)])
3633
model_11_variation = NxMixedGraph.from_edges(directed=[(X, Z), (U, Z), (Z, Y), (U, Y)])
3734
model_12 = NxMixedGraph.from_edges(directed=[(X, M), (M, Y), (M, Z)])
38-
# Selection bias
35+
# bad control, Selection bias
3936
model_16 = NxMixedGraph.from_edges(directed=[(X, Z), (U, Z), (U, Y), (X, Y)])
4037
model_17 = NxMixedGraph.from_edges(directed=[(X, Z), (Y, Z), (X, Y)])
41-
# case-control bias
38+
# bad control, case-control bias
4239
model_18 = NxMixedGraph.from_edges(directed=[(X, Y), (Y, Z)])
4340

4441
bad_test_models = [
@@ -52,6 +49,24 @@
5249
model_18,
5350
]
5451

52+
# neutral control, possibly good for precision
53+
model_8 = NxMixedGraph.from_edges(directed=[(X, Y), (Z, Y)])
54+
# neutral control, possibly bad for precision
55+
model_9 = NxMixedGraph.from_edges(directed=[(Z, X), (X, Y)])
56+
# neutral control, possibly good for precision
57+
model_13 = NxMixedGraph.from_edges(directed=[(X, W), (Z, W), (W, Y)])
58+
# neutral control, possibly helpful in the case of selection bias
59+
model_14 = NxMixedGraph.from_edges(directed=[(X, Y), (X, Z)])
60+
model_15 = NxMixedGraph.from_edges(directed=[(X, Z), (Z, W), (X, Y), (U, W), (U, Y)])
61+
62+
neutral_test_models = [
63+
model_8,
64+
model_9,
65+
model_13,
66+
model_14,
67+
model_15,
68+
]
69+
5570

5671
class TestControls(unittest.TestCase):
5772
"""Test case for good, bad, and neutral controls."""
@@ -60,20 +75,16 @@ def test_good_controls(self):
6075
"""Test good controls."""
6176
for model in good_test_models:
6277
with self.subTest():
63-
self.assertTrue(is_good_control(model, query, Z))
64-
for model in bad_test_models:
78+
self.assertTrue(is_good_control(model, P(Y @ X), Z))
79+
for model in bad_test_models + neutral_test_models:
6580
with self.subTest():
66-
self.assertFalse(is_good_control(model, query, Z))
67-
68-
# TODO need alternative negative examples
81+
self.assertFalse(is_good_control(model, P(Y @ X), Z))
6982

7083
def test_bad_controls(self):
7184
"""Test bad controls."""
72-
for model in good_test_models:
85+
for model in good_test_models + neutral_test_models:
7386
with self.subTest():
74-
self.assertFalse(is_bad_control(model, query, Z))
87+
self.assertFalse(is_bad_control(model, P(Y @ X), Z))
7588
for model in bad_test_models:
7689
with self.subTest():
77-
self.assertTrue(is_bad_control(model, query, Z))
78-
79-
# TODO need alternative negative examples
90+
self.assertTrue(is_bad_control(model, P(Y @ X), Z))

0 commit comments

Comments
 (0)