@@ -585,16 +585,29 @@ def __truediv__(self, expression: Expression) -> Expression:
585585 else :
586586 return Fraction (self , expression )
587587
588- def marginalize (self , ranges : VariableHint ) -> Fraction :
589- """Return this expression, marginalized by the given variables.
588+ def conditional (self , ranges : VariableHint ) -> Fraction :
589+ """Return this expression, conditioned by the given variables.
590590
591591 :param ranges: A variable or list of variables over which to marginalize this expression
592592 :returns: A fraction in which the denominator is represents the sum over the given ranges
593593
594594 >>> from y0.dsl import P, A, B
595- >>> assert P(A, B).marginalize(A) == P(A, B) / Sum[A](P(A, B))
595+ >>> assert P(A, B).conditional(A) == P(A, B) / Sum[A](P(A, B))
596+ >>> assert P(A, B, C).conditional([A, B]) == P(A, B, C) / Sum[A, B](P(A, B, C))
596597 """
597- return Fraction (self , Sum (expression = self , ranges = _upgrade_variables (ranges )))
598+ return Fraction (self , self .marginalize (ranges ))
599+
600+ def marginalize (self , ranges : VariableHint ) -> Sum :
601+ """Return this expression, marginalizing out the given variables.
602+
603+ :param ranges: A variable or list of variables over which to marginalize this expression
604+ :returns: The expression but summed over the given variables
605+
606+ >>> from y0.dsl import P, A, B, C
607+ >>> assert P(A, B).marginalize(A) == Sum[A](P(A, B))
608+ >>> assert P(A, B, C).marginalize([A, B]) == Sum[A, B](P(A, B, C))
609+ """
610+ return Sum (expression = self , ranges = _upgrade_variables (ranges ))
598611
599612
600613@dataclass (frozen = True , repr = False )
0 commit comments