|
5 | 5 | import itertools as itt |
6 | 6 | import unittest |
7 | 7 |
|
8 | | -from y0.algorithm.identify import Unidentifiable, idc, identify |
| 8 | +import y0.examples |
| 9 | +from y0.algorithm.identify import Identification, Query, Unidentifiable, idc, identify |
9 | 10 | from y0.algorithm.identify.id_std import ( |
10 | 11 | line_1, |
11 | 12 | line_2, |
|
16 | 17 | line_7, |
17 | 18 | ) |
18 | 19 | from y0.dsl import ( |
| 20 | + W1, |
| 21 | + W2, |
19 | 22 | Y1, |
| 23 | + Y2, |
20 | 24 | Expression, |
| 25 | + M, |
21 | 26 | P, |
| 27 | + Probability, |
22 | 28 | Product, |
23 | 29 | Sum, |
24 | | - Variable, |
25 | 30 | X, |
26 | 31 | Y, |
27 | 32 | Z, |
|
37 | 42 | line_6_example, |
38 | 43 | line_7_example, |
39 | 44 | ) |
| 45 | +from y0.graph import NxMixedGraph |
40 | 46 | from y0.mutate import canonicalize |
41 | 47 |
|
42 | 48 | P_XY = P(X, Y) |
43 | 49 | P_XYZ = P(X, Y, Z) |
44 | | -M = Variable("M") |
45 | 50 |
|
46 | 51 |
|
47 | 52 | class TestIdentify(unittest.TestCase): |
48 | 53 | """Test cases from https://github.com/COVID-19-Causal-Reasoning/Y0/blob/master/ID_whittemore.ipynb.""" |
49 | 54 |
|
| 55 | + def assert_identify(self, expected: Expression, graph: NxMixedGraph, query: Probability): |
| 56 | + """Assert the ID algorithm returns the expected result.""" |
| 57 | + id_in = Identification(Query.from_expression(query), graph) |
| 58 | + self.assert_expr_equal(expected, identify(id_in)) |
| 59 | + |
50 | 60 | def assert_expr_equal(self, expected: Expression, actual: Expression) -> None: |
51 | 61 | """Assert that two expressions are the same.""" |
52 | 62 | expected_outcomes, expected_treatments = get_outcomes_and_treatments(query=expected) |
@@ -101,7 +111,7 @@ def test_line_2(self): |
101 | 111 | line_2(identification["id_in"][0]), |
102 | 112 | ) |
103 | 113 | self.assert_expr_equal( |
104 | | - Sum.safe(expression=P(Y, Z), ranges=[Z]), |
| 114 | + Sum.safe(expression=Sum.safe(expression=P(Y, X, Z), ranges=[X]), ranges=[Z]), |
105 | 115 | identify(identification["id_in"][0]), |
106 | 116 | ) |
107 | 117 |
|
@@ -149,7 +159,7 @@ def test_line_4(self): |
149 | 159 | [ |
150 | 160 | P(M | (Z, X)), |
151 | 161 | P(Y | (M, Z, X)), |
152 | | - Sum(P(Z)), |
| 162 | + Sum.safe(expression=P(Z, X, M, Y), ranges=[X, M, Y]), |
153 | 163 | ] |
154 | 164 | ), |
155 | 165 | ), |
@@ -202,12 +212,51 @@ def test_line_7(self): |
202 | 212 | """ |
203 | 213 | for identification in line_7_example.identifications: |
204 | 214 | id_out = identification["id_out"][0] |
205 | | - |
206 | | - self.assertEqual( |
207 | | - id_out, |
208 | | - line_7(identification["id_in"][0]), |
209 | | - ) |
| 215 | + id_in = identification["id_in"][0] |
| 216 | + self.assertEqual(id_out, line_7(id_in)) |
210 | 217 | self.assert_expr_equal( |
211 | | - Sum(P(Y1)), |
212 | | - identify(identification["id_in"][0]), |
| 218 | + Sum.safe(expression=P(Y1 | (W1, X)) * P(W1), ranges=[W1]), identify(id_in) |
213 | 219 | ) |
| 220 | + |
| 221 | + def test_figure_2a(self): |
| 222 | + """Test Figure 2A. from Shpitser *et al.*, (2008).""" |
| 223 | + graph = y0.examples.figure_2a_example.graph |
| 224 | + # expr = "[ sum_{} P(Y|X) ]" |
| 225 | + # frac_expr = P_XY / Sum[Y](P_XY) |
| 226 | + cond_expr = P(Y | X) |
| 227 | + self.assert_identify(cond_expr, graph, P(Y @ X)) |
| 228 | + |
| 229 | + def test_figure_2b(self): |
| 230 | + """Test Figure 2B. from Shpitser *et al.*, (2008).""" |
| 231 | + graph = y0.examples.figure_2b_example.graph |
| 232 | + # expr = "[ sum_{Z} P(Z|X) P(Y|X,Z) ]" |
| 233 | + # frac_expr = Sum[Z](Sum[Y](P_XY) / (Sum[Z](Sum[Y](P_XY))) * (P_XY / Sum[Y](P_XY))) |
| 234 | + cond_expr = Sum[Z](P(Z | X) * P(Y | X, Z)) |
| 235 | + self.assert_identify(cond_expr, graph, P(Y @ X)) |
| 236 | + |
| 237 | + def test_figure_2d(self): |
| 238 | + """Test Figure 2D from Shpitser *et al.*, (2008). |
| 239 | +
|
| 240 | + .. note:: frac_expr = Sum[Z](Sum[X, Y](P_XYZ) * P_XYZ / Sum[Y](P_XYZ)) |
| 241 | + """ |
| 242 | + graph = y0.examples.complete_hierarchy_figure_2d_example.graph |
| 243 | + expr = Sum[Z](P(Y | X, Z) * Sum[X, Y](P(X, Y, Z))) |
| 244 | + self.assert_identify(expr, graph, P(Y @ X)) |
| 245 | + |
| 246 | + def test_figure_2e(self): |
| 247 | + """Test Figure 2E from Shpitser *et al.*, (2008).""" |
| 248 | + graph = y0.examples.complete_hierarchy_figure_2e_example.graph |
| 249 | + # expr = "[ sum_{Z} [ sum_{} P(Z|X) ] [ sum_{} [ sum_{X} P(X) P(Y|X,Z) ] ] ]" |
| 250 | + # frac_expr = Sum[Z](Sum[Y](P_XYZ) / Sum[Z](Sum[Y](P_XYZ))) * Sum[X]( |
| 251 | + # P_XYZ * Sum[Y, Z](P_XYZ) / Sum[Y](P_XYZ) / Sum[X](Sum[Y, Z](P_XYZ)) |
| 252 | + # ) |
| 253 | + cond_expr = Sum[Z](Sum[X](P(Y | X, Z) * P(X)) * P(Z | X)) |
| 254 | + self.assert_identify(cond_expr, graph, P(Y @ X)) |
| 255 | + |
| 256 | + def test_figure_3a(self): |
| 257 | + """Test Figure 3A (A graph hedge-less for ``P(y1,y2|do(x))``) from Shpitser *et al.*, (2008).""" |
| 258 | + graph = y0.examples.complete_hierarchy_figure_3a_example.graph |
| 259 | + cond_expr = Sum[W2]( |
| 260 | + Sum[W1, X, Y1, Y2](P(W1, W2, X, Y1, Y2)) * Sum[W1](P(W1) * P(Y1 | W1, X)) * P(Y2 | W2) |
| 261 | + ) |
| 262 | + self.assert_identify(cond_expr, graph, P(Y1 @ X, Y2 @ X)) |
0 commit comments