55import unittest
66
77from y0 .controls import is_bad_control , is_good_control
8- from y0 .dsl import U1 , U2 , M , P , Variable , X , Y , Z
8+ from y0 .dsl import U1 , U2 , M , P , U , W , X , Y , Z
99from y0 .graph import NxMixedGraph
1010
11- U = Variable ("U" )
12- query = P (Y @ X )
13-
1411model_1 = NxMixedGraph .from_edges (directed = [(Z , X ), (Z , Y ), (X , Y )])
1512model_2 = NxMixedGraph .from_edges (directed = [(U , Z ), (Z , X ), (X , Y ), (U , Y )])
1613model_3 = NxMixedGraph .from_edges (directed = [(U , X ), (U , Z ), (Z , Y ), (X , Y )])
2724 model_6 ,
2825]
2926
30- # M-bias
27+ # bad control, M-bias
3128model_7 = NxMixedGraph .from_edges (directed = [(U1 , Z ), (U2 , Z ), (U1 , X ), (U2 , Y ), (X , Y )])
32- # Bias amplification
29+ # bad control, Bias amplification
3330model_10 = NxMixedGraph .from_edges (directed = [(Z , X ), (U , X ), (U , Y ), (X , Y )])
34- #
31+ # bad control
3532model_11 = NxMixedGraph .from_edges (directed = [(X , Z ), (Z , Y )])
3633model_11_variation = NxMixedGraph .from_edges (directed = [(X , Z ), (U , Z ), (Z , Y ), (U , Y )])
3734model_12 = NxMixedGraph .from_edges (directed = [(X , M ), (M , Y ), (M , Z )])
38- # Selection bias
35+ # bad control, Selection bias
3936model_16 = NxMixedGraph .from_edges (directed = [(X , Z ), (U , Z ), (U , Y ), (X , Y )])
4037model_17 = NxMixedGraph .from_edges (directed = [(X , Z ), (Y , Z ), (X , Y )])
41- # case-control bias
38+ # bad control, case-control bias
4239model_18 = NxMixedGraph .from_edges (directed = [(X , Y ), (Y , Z )])
4340
4441bad_test_models = [
5249 model_18 ,
5350]
5451
52+ # neutral control, possibly good for precision
53+ model_8 = NxMixedGraph .from_edges (directed = [(X , Y ), (Z , Y )])
54+ # neutral control, possibly bad for precision
55+ model_9 = NxMixedGraph .from_edges (directed = [(Z , X ), (X , Y )])
56+ # neutral control, possibly good for precision
57+ model_13 = NxMixedGraph .from_edges (directed = [(X , W ), (Z , W ), (W , Y )])
58+ # neutral control, possibly helpful in the case of selection bias
59+ model_14 = NxMixedGraph .from_edges (directed = [(X , Y ), (X , Z )])
60+ model_15 = NxMixedGraph .from_edges (directed = [(X , Z ), (Z , W ), (X , Y ), (U , W ), (U , Y )])
61+
62+ neutral_test_models = [
63+ model_8 ,
64+ model_9 ,
65+ model_13 ,
66+ model_14 ,
67+ model_15 ,
68+ ]
69+
5570
5671class TestControls (unittest .TestCase ):
5772 """Test case for good, bad, and neutral controls."""
@@ -60,20 +75,16 @@ def test_good_controls(self):
6075 """Test good controls."""
6176 for model in good_test_models :
6277 with self .subTest ():
63- self .assertTrue (is_good_control (model , query , Z ))
64- for model in bad_test_models :
78+ self .assertTrue (is_good_control (model , P ( Y @ X ) , Z ))
79+ for model in bad_test_models + neutral_test_models :
6580 with self .subTest ():
66- self .assertFalse (is_good_control (model , query , Z ))
67-
68- # TODO need alternative negative examples
81+ self .assertFalse (is_good_control (model , P (Y @ X ), Z ))
6982
7083 def test_bad_controls (self ):
7184 """Test bad controls."""
72- for model in good_test_models :
85+ for model in good_test_models + neutral_test_models :
7386 with self .subTest ():
74- self .assertFalse (is_bad_control (model , query , Z ))
87+ self .assertFalse (is_bad_control (model , P ( Y @ X ) , Z ))
7588 for model in bad_test_models :
7689 with self .subTest ():
77- self .assertTrue (is_bad_control (model , query , Z ))
78-
79- # TODO need alternative negative examples
90+ self .assertTrue (is_bad_control (model , P (Y @ X ), Z ))
0 commit comments