Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions src/y0/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,9 +1019,11 @@ def __getitem__(self, interventions: VariableHint) -> ProbabilityMetaBuilder:
class Product(Expression):
"""Represent the product of several probability expressions."""

expressions: tuple[Expression, ...]
expressions: frozenset[Expression]

def __post_init__(self) -> None:
if not isinstance(self.expressions, frozenset):
raise TypeError("Products must be given with a frozenset")
if len(self.expressions) < 2:
raise ValueError("Product() must two or more expressions")

Expand Down Expand Up @@ -1060,36 +1062,40 @@ def safe(cls, expressions: Expression | Iterable[Expression]) -> Expression:
return One()
if len(expressions) == 1:
return expressions[0]
return cls(expressions=tuple(sorted(expressions)))
return cls(expressions=frozenset(expressions))

@property
def _sorted_expressions(self) -> list[Expression]:
return sorted(self.expressions)

def _get_key(self): # type:ignore
inner_keys = (sexpr._get_key() for sexpr in self.expressions)
inner_keys = (sexpr._get_key() for sexpr in self._sorted_expressions)
return 2, *inner_keys

def to_text(self) -> str:
"""Output this product in the internal string format."""
return " ".join(expression.to_text() for expression in self.expressions)
return " ".join(expression.to_text() for expression in self._sorted_expressions)

def to_y0(self) -> str:
"""Output this product instance as y0 internal DSL code."""
return " * ".join(expr.to_y0() for expr in self.expressions)
return " * ".join(expr.to_y0() for expr in self._sorted_expressions)

def to_latex(self) -> str:
"""Output this product in the LaTeX string format."""
return " ".join(expression.to_latex() for expression in self.expressions)
return " ".join(expression.to_latex() for expression in self._sorted_expressions)

def __mul__(self, other: Expression) -> Expression:
if isinstance(other, Zero):
return other
if isinstance(other, Product):
return Product.safe((*self.expressions, *other.expressions))
return Product.safe(self.expressions | other.expressions)
elif isinstance(other, Fraction):
return Fraction(self * other.numerator, other.denominator)
else:
return Product.safe((*self.expressions, other))
return Product.safe(self.expressions | {other})

def _iter_variables(self) -> Iterable[Variable]:
"""Get the union of the variables used in each expresison in this product."""
"""Get the union of the variables used in each expression in this product."""
for expression in self.expressions:
yield from expression._iter_variables()

Expand Down Expand Up @@ -1347,11 +1353,13 @@ def simplify(self) -> Expression:
if self.numerator == self.denominator:
return One()
if isinstance(self.numerator, Product) and isinstance(self.denominator, Product):
return self._simplify_parts(self.numerator.expressions, self.denominator.expressions)
return self._simplify_parts(
self.numerator._sorted_expressions, self.denominator._sorted_expressions
)
elif isinstance(self.numerator, Product):
return self._simplify_parts(self.numerator.expressions, [self.denominator])
return self._simplify_parts(self.numerator._sorted_expressions, [self.denominator])
elif isinstance(self.denominator, Product):
return self._simplify_parts([self.numerator], self.denominator.expressions)
return self._simplify_parts([self.numerator], self.denominator._sorted_expressions)
return self

@classmethod
Expand Down Expand Up @@ -1398,6 +1406,7 @@ def _simplify_parts_helper(
)


@dataclass(frozen=True)
class One(Expression):
"""The multiplicative identity (1)."""

Expand Down Expand Up @@ -1430,6 +1439,7 @@ def _iter_variables(self) -> Iterable[Variable]:
return iter([])


@dataclass(frozen=True)
class Zero(Expression):
"""The additive identity (0)."""

Expand Down
18 changes: 10 additions & 8 deletions tests/test_dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,19 +517,21 @@ def test_product(self):
p2 = P(Z)
self.assertEqual(p1, Product.safe(p1))
self.assertEqual(p1, Product.safe([p1]))
self.assertEqual(Product((p1, p2)), Product.safe((p1, p2)))
self.assertEqual(Product((p1, p2)), Product.safe([p1, p2]))
self.assertEqual(Product((P(X), P(Y))), Product.safe(P(v) for v in [X, Y]))
self.assertEqual(Product(frozenset([p1, p2])), Product.safe((p1, p2)))
self.assertEqual(Product(frozenset([p1, p2])), Product.safe([p1, p2]))
self.assertEqual(Product(frozenset([P(X), P(Y)])), Product.safe(P(v) for v in [X, Y]))

self.assertEqual(One(), Product.safe([]))
self.assertEqual(One(), Product.safe([One()]))
self.assertEqual(One(), Product.safe([One(), One()]))

self.assertEqual(Product((P(X), P(Y))), Product.safe((One(), P(X), P(Y))))
self.assertEqual(Product((P(X), P(Y))), Product.safe((P(X), P(Y))))
self.assertEqual(Product((P(X), P(Y))), Product.safe((P(X), One(), P(Y))))
self.assertEqual(Product((P(X), P(Y))), Product.safe((P(X), P(Y), One(), One())))
self.assertEqual(Product((P(X), P(Y))), Product.safe((P(X), One(), P(Y), One(), One())))
self.assertEqual(Product(frozenset((P(X), P(Y)))), Product.safe((One(), P(X), P(Y))))
self.assertEqual(Product(frozenset((P(X), P(Y)))), Product.safe((P(X), P(Y))))
self.assertEqual(Product(frozenset((P(X), P(Y)))), Product.safe((P(X), One(), P(Y))))
self.assertEqual(Product(frozenset((P(X), P(Y)))), Product.safe((P(X), P(Y), One(), One())))
self.assertEqual(
Product(frozenset((P(X), P(Y)))), Product.safe((P(X), One(), P(Y), One(), One()))
)

self.assertEqual(P(X), Product.safe((One(), P(X))))
self.assertEqual(P(X), Product.safe((P(X),)))
Expand Down
18 changes: 11 additions & 7 deletions tests/test_mutate/test_canonicalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,18 @@ def test_derived_atomic(self):
# self.assert_canonicalize(One(), Sum(One(), ()), ())
self.assert_canonicalize(One(), Product.safe(One()), ())
self.assert_canonicalize(One(), Product.safe([One()]), ())
self.assert_canonicalize(One(), Product((One(), One())), ())
with self.assertRaises(ValueError):
Product(frozenset((One(), One())))
with self.assertRaises(ValueError):
Product(frozenset((Zero(), Zero())))
self.assert_canonicalize(Zero(), Sum.safe(Zero(), (A,)), [A])
self.assert_canonicalize(Zero(), Product.safe(Zero()), ())
self.assert_canonicalize(Zero(), Product.safe([Zero()]), ())
self.assert_canonicalize(Zero(), Product((P(A), Product((P(B), Zero())))), [A, B])
self.assert_canonicalize(Zero(), Product((Zero(), Zero())), ())
self.assert_canonicalize(P(A), Product((One(), P(A))), [A])
self.assert_canonicalize(Zero(), Product((Zero(), One(), P(A))), [A])
self.assert_canonicalize(
Zero(), Product(frozenset((P(A), Product(frozenset((P(B), Zero())))))), [A, B]
)
self.assert_canonicalize(P(A), Product(frozenset((One(), P(A)))), [A])
self.assert_canonicalize(Zero(), Product(frozenset((Zero(), One(), P(A)))), [A])

# Sum
expected = expression = Sum[R](P(A))
Expand All @@ -112,10 +116,10 @@ def test_derived_atomic(self):
# Nested product
expected = P(A) * P(B) * P(C)
for b, c in itt.permutations((P(B), P(C))):
expression = Product((P(A), Product((b, c))))
expression = Product(frozenset((P(A), Product(frozenset((b, c))))))
self.assert_canonicalize(expected, expression, [A, B, C])

expression = Product((Product((P(A), b)), c))
expression = Product(frozenset((Product(frozenset((P(A), b))), c)))
self.assert_canonicalize(expected, expression, [A, B, C])

# Sum with simple product (only atomic)
Expand Down
Loading