Skip to content

Commit 4696196

Browse files
cthoytdjinnome
andauthored
Update marginalize and conditional functions in DSL (#100)
Co-Authored-By: Jeremy Zucker <[email protected]>
1 parent 08a342c commit 4696196

File tree

4 files changed

+21
-7
lines changed

4 files changed

+21
-7
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ It can also be used to manipulate expressions:
6060
```python
6161
from y0.dsl import P, A, B, Sum
6262

63-
P(A, B).marginalize(A) == P(A, B) / Sum[A](P(A, B))
63+
P(A, B).marginalize(A) == Sum[A](P(A, B))
64+
P(A, B).conditional(A) == P(A, B) / Sum[A](P(A, B))
6465
```
6566

6667
DSL objects can be converted into strings with `str()` and parsed back

src/y0/algorithm/identify/id_c.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def idc(identification: Identification) -> Expression:
2424
return idc(identification.exchange_observation_with_action(condition))
2525

2626
# Run ID algorithm
27-
return identify(identification.uncondition()).marginalize(identification.outcomes)
27+
return identify(identification.uncondition()).conditional(identification.outcomes)
2828

2929

3030
def rule_2_of_do_calculus_applies(identification: Identification, condition: Variable) -> bool:

src/y0/dsl.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -585,16 +585,29 @@ def __truediv__(self, expression: Expression) -> Expression:
585585
else:
586586
return Fraction(self, expression)
587587

588-
def marginalize(self, ranges: VariableHint) -> Fraction:
589-
"""Return this expression, marginalized by the given variables.
588+
def conditional(self, ranges: VariableHint) -> Fraction:
589+
"""Return this expression, conditioned by the given variables.
590590
591591
:param ranges: A variable or list of variables over which to marginalize this expression
592592
:returns: A fraction in which the denominator is represents the sum over the given ranges
593593
594594
>>> from y0.dsl import P, A, B
595-
>>> assert P(A, B).marginalize(A) == P(A, B) / Sum[A](P(A, B))
595+
>>> assert P(A, B).conditional(A) == P(A, B) / Sum[A](P(A, B))
596+
>>> assert P(A, B, C).conditional([A, B]) == P(A, B, C) / Sum[A, B](P(A, B, C))
596597
"""
597-
return Fraction(self, Sum(expression=self, ranges=_upgrade_variables(ranges)))
598+
return Fraction(self, self.marginalize(ranges))
599+
600+
def marginalize(self, ranges: VariableHint) -> Sum:
601+
"""Return this expression, marginalizing out the given variables.
602+
603+
:param ranges: A variable or list of variables over which to marginalize this expression
604+
:returns: The expression but summed over the given variables
605+
606+
>>> from y0.dsl import P, A, B, C
607+
>>> assert P(A, B).marginalize(A) == Sum[A](P(A, B))
608+
>>> assert P(A, B, C).marginalize([A, B]) == Sum[A, B](P(A, B, C))
609+
"""
610+
return Sum(expression=self, ranges=_upgrade_variables(ranges))
598611

599612

600613
@dataclass(frozen=True, repr=False)

src/y0/mutate/chain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,4 @@ def bayes_expand(p: Probability) -> Fraction:
102102
103103
.. note:: This expansion will create a different but equal expression to :func:`fraction_expand`.
104104
"""
105-
return p.uncondition().marginalize(p.children)
105+
return p.uncondition().conditional(p.children)

0 commit comments

Comments
 (0)