Skip to content

Commit 59f7094

Browse files
cthoytdjinnome
andauthored
Update expression in line 1 of identification (#89)
* Update expression in line 1 of identification Co-Authored-By: Jeremy Zucker <[email protected]> * Update tests * Cleanup Co-authored-by: Jeremy Zucker <[email protected]>
1 parent eca49e1 commit 59f7094

File tree

2 files changed

+62
-13
lines changed

2 files changed

+62
-13
lines changed

src/y0/algorithm/identify/id_std.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def line_1(identification: Identification) -> Expression:
8383
outcomes = identification.outcomes
8484
vertices = set(identification.graph.nodes())
8585
return Sum.safe(
86-
expression=P(vertices),
86+
expression=identification.estimand,
8787
ranges=vertices.difference(outcomes),
8888
)
8989

tests/test_algorithm/test_id_alg.py

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import itertools as itt
66
import unittest
77

8-
from y0.algorithm.identify import Unidentifiable, idc, identify
8+
import y0.examples
9+
from y0.algorithm.identify import Identification, Query, Unidentifiable, idc, identify
910
from y0.algorithm.identify.id_std import (
1011
line_1,
1112
line_2,
@@ -16,12 +17,16 @@
1617
line_7,
1718
)
1819
from y0.dsl import (
20+
W1,
21+
W2,
1922
Y1,
23+
Y2,
2024
Expression,
25+
M,
2126
P,
27+
Probability,
2228
Product,
2329
Sum,
24-
Variable,
2530
X,
2631
Y,
2732
Z,
@@ -37,16 +42,21 @@
3742
line_6_example,
3843
line_7_example,
3944
)
45+
from y0.graph import NxMixedGraph
4046
from y0.mutate import canonicalize
4147

4248
P_XY = P(X, Y)
4349
P_XYZ = P(X, Y, Z)
44-
M = Variable("M")
4550

4651

4752
class TestIdentify(unittest.TestCase):
4853
"""Test cases from https://github.com/COVID-19-Causal-Reasoning/Y0/blob/master/ID_whittemore.ipynb."""
4954

55+
def assert_identify(self, expected: Expression, graph: NxMixedGraph, query: Probability):
56+
"""Assert the ID algorithm returns the expected result."""
57+
id_in = Identification(Query.from_expression(query), graph)
58+
self.assert_expr_equal(expected, identify(id_in))
59+
5060
def assert_expr_equal(self, expected: Expression, actual: Expression) -> None:
5161
"""Assert that two expressions are the same."""
5262
expected_outcomes, expected_treatments = get_outcomes_and_treatments(query=expected)
@@ -101,7 +111,7 @@ def test_line_2(self):
101111
line_2(identification["id_in"][0]),
102112
)
103113
self.assert_expr_equal(
104-
Sum.safe(expression=P(Y, Z), ranges=[Z]),
114+
Sum.safe(expression=Sum.safe(expression=P(Y, X, Z), ranges=[X]), ranges=[Z]),
105115
identify(identification["id_in"][0]),
106116
)
107117

@@ -149,7 +159,7 @@ def test_line_4(self):
149159
[
150160
P(M | (Z, X)),
151161
P(Y | (M, Z, X)),
152-
Sum(P(Z)),
162+
Sum.safe(expression=P(Z, X, M, Y), ranges=[X, M, Y]),
153163
]
154164
),
155165
),
@@ -202,12 +212,51 @@ def test_line_7(self):
202212
"""
203213
for identification in line_7_example.identifications:
204214
id_out = identification["id_out"][0]
205-
206-
self.assertEqual(
207-
id_out,
208-
line_7(identification["id_in"][0]),
209-
)
215+
id_in = identification["id_in"][0]
216+
self.assertEqual(id_out, line_7(id_in))
210217
self.assert_expr_equal(
211-
Sum(P(Y1)),
212-
identify(identification["id_in"][0]),
218+
Sum.safe(expression=P(Y1 | (W1, X)) * P(W1), ranges=[W1]), identify(id_in)
213219
)
220+
221+
def test_figure_2a(self):
222+
"""Test Figure 2A. from Shpitser *et al.*, (2008)."""
223+
graph = y0.examples.figure_2a_example.graph
224+
# expr = "[ sum_{} P(Y|X) ]"
225+
# frac_expr = P_XY / Sum[Y](P_XY)
226+
cond_expr = P(Y | X)
227+
self.assert_identify(cond_expr, graph, P(Y @ X))
228+
229+
def test_figure_2b(self):
230+
"""Test Figure 2B. from Shpitser *et al.*, (2008)."""
231+
graph = y0.examples.figure_2b_example.graph
232+
# expr = "[ sum_{Z} P(Z|X) P(Y|X,Z) ]"
233+
# frac_expr = Sum[Z](Sum[Y](P_XY) / (Sum[Z](Sum[Y](P_XY))) * (P_XY / Sum[Y](P_XY)))
234+
cond_expr = Sum[Z](P(Z | X) * P(Y | X, Z))
235+
self.assert_identify(cond_expr, graph, P(Y @ X))
236+
237+
def test_figure_2d(self):
238+
"""Test Figure 2D from Shpitser *et al.*, (2008).
239+
240+
.. note:: frac_expr = Sum[Z](Sum[X, Y](P_XYZ) * P_XYZ / Sum[Y](P_XYZ))
241+
"""
242+
graph = y0.examples.complete_hierarchy_figure_2d_example.graph
243+
expr = Sum[Z](P(Y | X, Z) * Sum[X, Y](P(X, Y, Z)))
244+
self.assert_identify(expr, graph, P(Y @ X))
245+
246+
def test_figure_2e(self):
247+
"""Test Figure 2E from Shpitser *et al.*, (2008)."""
248+
graph = y0.examples.complete_hierarchy_figure_2e_example.graph
249+
# expr = "[ sum_{Z} [ sum_{} P(Z|X) ] [ sum_{} [ sum_{X} P(X) P(Y|X,Z) ] ] ]"
250+
# frac_expr = Sum[Z](Sum[Y](P_XYZ) / Sum[Z](Sum[Y](P_XYZ))) * Sum[X](
251+
# P_XYZ * Sum[Y, Z](P_XYZ) / Sum[Y](P_XYZ) / Sum[X](Sum[Y, Z](P_XYZ))
252+
# )
253+
cond_expr = Sum[Z](Sum[X](P(Y | X, Z) * P(X)) * P(Z | X))
254+
self.assert_identify(cond_expr, graph, P(Y @ X))
255+
256+
def test_figure_3a(self):
257+
"""Test Figure 3A (A graph hedge-less for ``P(y1,y2|do(x))``) from Shpitser *et al.*, (2008)."""
258+
graph = y0.examples.complete_hierarchy_figure_3a_example.graph
259+
cond_expr = Sum[W2](
260+
Sum[W1, X, Y1, Y2](P(W1, W2, X, Y1, Y2)) * Sum[W1](P(W1) * P(Y1 | W1, X)) * P(Y2 | W2)
261+
)
262+
self.assert_identify(cond_expr, graph, P(Y1 @ X, Y2 @ X))

0 commit comments

Comments
 (0)