Skip to content

Commit fa0323f

Browse files
committed
Update serialization of cf variables
1 parent 4696196 commit fa0323f

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

src/y0/dsl.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,12 @@ def to_latex(self) -> str:
345345

346346
def to_y0(self) -> str:
347347
"""Output this counterfactual variable instance as y0 internal DSL code."""
348-
prefix = "~" if self.star else ""
348+
if self.star is None:
349+
prefix = ""
350+
elif self.star:
351+
prefix = "+"
352+
else:
353+
prefix = "-"
349354
if len(self.interventions) == 1:
350355
return f"{prefix}{self.name} @ {self.interventions[0].to_y0()}"
351356
else:

tests/test_dsl.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,34 @@ def test_counterfactual_variable(self):
119119
def test_star_counterfactual(self):
120120
"""Tests for generalized counterfactual variables."""
121121
for expr, expected in [
122+
# Single variable
122123
(P(Y @ X), "P(Y @ X)"),
123-
(P(~Y @ X), "P(~Y @ X)"),
124+
(P(Y @ -X), "P(Y @ X)"),
124125
(P(Y @ ~X), "P(Y @ ~X)"),
125-
(P(~Y @ ~X), "P(~Y @ ~X)"),
126+
(P(Y @ +X), "P(Y @ ~X)"),
127+
#
128+
(P(-Y @ X), "P(-Y @ X)"),
129+
(P(-Y @ -X), "P(-Y @ X)"),
130+
(P(-Y @ ~X), "P(-Y @ ~X)"),
131+
(P(-Y @ +X), "P(-Y @ ~X)"),
132+
#
133+
(P(~Y @ X), "P(+Y @ X)"),
134+
(P(~Y @ -X), "P(+Y @ X)"),
135+
(P(~Y @ ~X), "P(+Y @ ~X)"),
136+
(P(~Y @ +X), "P(+Y @ ~X)"),
137+
#
138+
(P(+Y @ X), "P(+Y @ X)"),
139+
(P(+Y @ -X), "P(+Y @ X)"),
140+
(P(+Y @ ~X), "P(+Y @ ~X)"),
141+
(P(+Y @ +X), "P(+Y @ ~X)"),
142+
#
126143
(P(Y @ X | ~X, ~Y), "P(Y @ X | ~X, ~Y)"),
127-
(P(~(Y @ ~X) | X, Y), "P(~Y @ ~X | X, Y)"),
128-
(P(~Y @ ~X | X, Y), "P(~Y @ ~X | X, Y)"), # should be same as above
144+
(P(Y @ -X | ~X, ~Y), "P(Y @ X | ~X, ~Y)"),
145+
(P(Y @ +X | ~X, ~Y), "P(Y @ ~X | ~X, ~Y)"),
146+
(P(Y @ ~X | ~X, ~Y), "P(Y @ ~X | ~X, ~Y)"),
147+
#
148+
(P(~(Y @ ~X) | X, Y), "P(+Y @ ~X | X, Y)"),
149+
(P(~Y @ ~X | X, Y), "P(+Y @ ~X | X, Y)"), # should be same as above
129150
]:
130151
with self.subTest(expr=expected):
131152
self.assert_exp(expr, expected)

0 commit comments

Comments
 (0)