Skip to content

Commit 2d6d683

Browse files
committed
Add unit tests
@vartikatewari note we still need some negative tests that correspond to neutral scenarios
1 parent c818187 commit 2d6d683

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

tests/test_controls.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""Tests for good, bad, and neutral controls."""
4+
5+
import unittest
6+
7+
from y0.dsl import M, P, U1, U2, Variable, X, Y, Z
8+
from y0.graph import NxMixedGraph
9+
from y0.predicates import is_bad_control, is_good_control
10+
11+
U = Variable("U")
12+
query = P(Y @ X)
13+
14+
model_1 = NxMixedGraph.from_edges(directed=[(Z, X), (Z, Y), (X, Y)])
15+
model_2 = NxMixedGraph.from_edges(directed=[(U, Z), (Z, X), (X, Y), (U, Y)])
16+
model_3 = NxMixedGraph.from_edges(directed=[(U, X), (U, Z), (Z, Y), (X, Y)])
17+
model_4 = NxMixedGraph.from_edges(directed=[(Z, X), (Z, M), (X, M), (M, Y)])
18+
model_5 = NxMixedGraph.from_edges(directed=[(U, Z), (Z, X), (U, M), (X, M), (M, Y)])
19+
model_6 = NxMixedGraph.from_edges(directed=[(U, X), (U, Z), (Z, M), (X, M), (M, Y)])
20+
21+
good_test_models = [
22+
model_1,
23+
model_2,
24+
model_3,
25+
model_4,
26+
model_5,
27+
model_6,
28+
]
29+
30+
# M-bias
31+
model_7 = NxMixedGraph.from_edges(directed=[(U1, Z), (U2, Z), (U1, X), (U2, Y), (X, Y)])
32+
# Bias amplification
33+
model_10 = NxMixedGraph.from_edges(directed=[(Z, X), (U, X), (U, Y), (X, Y)])
34+
#
35+
model_11 = NxMixedGraph.from_edges(directed=[(X, Z), (Z, Y)])
36+
model_11_variation = NxMixedGraph.from_edges(directed=[(X, Z), (U, Z), (Z, Y), (U, Y)])
37+
model_12 = NxMixedGraph.from_edges(directed=[(X, M), (M, Y), (M, Z)])
38+
# Selection bias
39+
model_16 = NxMixedGraph.from_edges(directed=[(X, Z), (U, Z), (U, Y), (X, Y)])
40+
model_17 = NxMixedGraph.from_edges(directed=[(X, Z), (Y, Z), (X, Y)])
41+
# case-control bias
42+
model_18 = NxMixedGraph.from_edges(directed=[(X, Y), (Y, Z)])
43+
44+
bad_test_models = [
45+
model_7,
46+
model_10,
47+
model_11,
48+
model_11_variation,
49+
model_12,
50+
model_16,
51+
model_17,
52+
model_18,
53+
]
54+
55+
56+
class TestControls(unittest.TestCase):
57+
"""Test case for good, bad, and neutral controls."""
58+
59+
def test_good_controls(self):
60+
"""Test good controls."""
61+
for model in good_test_models:
62+
with self.subTest():
63+
self.assertTrue(is_good_control(model, query, Z))
64+
for model in bad_test_models:
65+
with self.subTest():
66+
self.assertFalse(is_good_control(model, query, Z))
67+
68+
# TODO need alternative negative examples
69+
70+
def test_bad_controls(self):
71+
"""Test bad controls."""
72+
for model in good_test_models:
73+
with self.subTest():
74+
self.assertFalse(is_bad_control(model, query, Z))
75+
for model in bad_test_models:
76+
with self.subTest():
77+
self.assertTrue(is_bad_control(model, query, Z))
78+
79+
# TODO need alternative negative examples

0 commit comments

Comments
 (0)