diff --git a/src/y0/dsl.py b/src/y0/dsl.py index 44241e8bc..e9908ff8c 100644 --- a/src/y0/dsl.py +++ b/src/y0/dsl.py @@ -1019,9 +1019,11 @@ def __getitem__(self, interventions: VariableHint) -> ProbabilityMetaBuilder: class Product(Expression): """Represent the product of several probability expressions.""" - expressions: tuple[Expression, ...] + expressions: frozenset[Expression] def __post_init__(self) -> None: + if not isinstance(self.expressions, frozenset): + raise TypeError("Products must be given with a frozenset") if len(self.expressions) < 2: raise ValueError("Product() must two or more expressions") @@ -1060,36 +1062,40 @@ def safe(cls, expressions: Expression | Iterable[Expression]) -> Expression: return One() if len(expressions) == 1: return expressions[0] - return cls(expressions=tuple(sorted(expressions))) + return cls(expressions=frozenset(expressions)) + + @property + def _sorted_expressions(self) -> list[Expression]: + return sorted(self.expressions) def _get_key(self): # type:ignore - inner_keys = (sexpr._get_key() for sexpr in self.expressions) + inner_keys = (sexpr._get_key() for sexpr in self._sorted_expressions) return 2, *inner_keys def to_text(self) -> str: """Output this product in the internal string format.""" - return " ".join(expression.to_text() for expression in self.expressions) + return " ".join(expression.to_text() for expression in self._sorted_expressions) def to_y0(self) -> str: """Output this product instance as y0 internal DSL code.""" - return " * ".join(expr.to_y0() for expr in self.expressions) + return " * ".join(expr.to_y0() for expr in self._sorted_expressions) def to_latex(self) -> str: """Output this product in the LaTeX string format.""" - return " ".join(expression.to_latex() for expression in self.expressions) + return " ".join(expression.to_latex() for expression in self._sorted_expressions) def __mul__(self, other: Expression) -> Expression: if isinstance(other, Zero): return other if isinstance(other, Product): - return Product.safe((*self.expressions, *other.expressions)) + return Product.safe(self.expressions | other.expressions) elif isinstance(other, Fraction): return Fraction(self * other.numerator, other.denominator) else: - return Product.safe((*self.expressions, other)) + return Product.safe(self.expressions | {other}) def _iter_variables(self) -> Iterable[Variable]: - """Get the union of the variables used in each expresison in this product.""" + """Get the union of the variables used in each expression in this product.""" for expression in self.expressions: yield from expression._iter_variables() @@ -1347,11 +1353,13 @@ def simplify(self) -> Expression: if self.numerator == self.denominator: return One() if isinstance(self.numerator, Product) and isinstance(self.denominator, Product): - return self._simplify_parts(self.numerator.expressions, self.denominator.expressions) + return self._simplify_parts( + self.numerator._sorted_expressions, self.denominator._sorted_expressions + ) elif isinstance(self.numerator, Product): - return self._simplify_parts(self.numerator.expressions, [self.denominator]) + return self._simplify_parts(self.numerator._sorted_expressions, [self.denominator]) elif isinstance(self.denominator, Product): - return self._simplify_parts([self.numerator], self.denominator.expressions) + return self._simplify_parts([self.numerator], self.denominator._sorted_expressions) return self @classmethod @@ -1398,6 +1406,7 @@ def _simplify_parts_helper( ) +@dataclass(frozen=True) class One(Expression): """The multiplicative identity (1).""" @@ -1430,6 +1439,7 @@ def _iter_variables(self) -> Iterable[Variable]: return iter([]) +@dataclass(frozen=True) class Zero(Expression): """The additive identity (0).""" diff --git a/tests/test_dsl.py b/tests/test_dsl.py index 26009274c..a39c005dd 100644 --- a/tests/test_dsl.py +++ b/tests/test_dsl.py @@ -517,19 +517,21 @@ def test_product(self): p2 = P(Z) self.assertEqual(p1, Product.safe(p1)) self.assertEqual(p1, Product.safe([p1])) - self.assertEqual(Product((p1, p2)), Product.safe((p1, p2))) - self.assertEqual(Product((p1, p2)), Product.safe([p1, p2])) - self.assertEqual(Product((P(X), P(Y))), Product.safe(P(v) for v in [X, Y])) + self.assertEqual(Product(frozenset([p1, p2])), Product.safe((p1, p2))) + self.assertEqual(Product(frozenset([p1, p2])), Product.safe([p1, p2])) + self.assertEqual(Product(frozenset([P(X), P(Y)])), Product.safe(P(v) for v in [X, Y])) self.assertEqual(One(), Product.safe([])) self.assertEqual(One(), Product.safe([One()])) self.assertEqual(One(), Product.safe([One(), One()])) - self.assertEqual(Product((P(X), P(Y))), Product.safe((One(), P(X), P(Y)))) - self.assertEqual(Product((P(X), P(Y))), Product.safe((P(X), P(Y)))) - self.assertEqual(Product((P(X), P(Y))), Product.safe((P(X), One(), P(Y)))) - self.assertEqual(Product((P(X), P(Y))), Product.safe((P(X), P(Y), One(), One()))) - self.assertEqual(Product((P(X), P(Y))), Product.safe((P(X), One(), P(Y), One(), One()))) + self.assertEqual(Product(frozenset((P(X), P(Y)))), Product.safe((One(), P(X), P(Y)))) + self.assertEqual(Product(frozenset((P(X), P(Y)))), Product.safe((P(X), P(Y)))) + self.assertEqual(Product(frozenset((P(X), P(Y)))), Product.safe((P(X), One(), P(Y)))) + self.assertEqual(Product(frozenset((P(X), P(Y)))), Product.safe((P(X), P(Y), One(), One()))) + self.assertEqual( + Product(frozenset((P(X), P(Y)))), Product.safe((P(X), One(), P(Y), One(), One())) + ) self.assertEqual(P(X), Product.safe((One(), P(X)))) self.assertEqual(P(X), Product.safe((P(X),))) diff --git a/tests/test_mutate/test_canonicalize.py b/tests/test_mutate/test_canonicalize.py index ce28caf3d..954432f8f 100644 --- a/tests/test_mutate/test_canonicalize.py +++ b/tests/test_mutate/test_canonicalize.py @@ -86,14 +86,18 @@ def test_derived_atomic(self): # self.assert_canonicalize(One(), Sum(One(), ()), ()) self.assert_canonicalize(One(), Product.safe(One()), ()) self.assert_canonicalize(One(), Product.safe([One()]), ()) - self.assert_canonicalize(One(), Product((One(), One())), ()) + with self.assertRaises(ValueError): + Product(frozenset((One(), One()))) + with self.assertRaises(ValueError): + Product(frozenset((Zero(), Zero()))) self.assert_canonicalize(Zero(), Sum.safe(Zero(), (A,)), [A]) self.assert_canonicalize(Zero(), Product.safe(Zero()), ()) self.assert_canonicalize(Zero(), Product.safe([Zero()]), ()) - self.assert_canonicalize(Zero(), Product((P(A), Product((P(B), Zero())))), [A, B]) - self.assert_canonicalize(Zero(), Product((Zero(), Zero())), ()) - self.assert_canonicalize(P(A), Product((One(), P(A))), [A]) - self.assert_canonicalize(Zero(), Product((Zero(), One(), P(A))), [A]) + self.assert_canonicalize( + Zero(), Product(frozenset((P(A), Product(frozenset((P(B), Zero())))))), [A, B] + ) + self.assert_canonicalize(P(A), Product(frozenset((One(), P(A)))), [A]) + self.assert_canonicalize(Zero(), Product(frozenset((Zero(), One(), P(A)))), [A]) # Sum expected = expression = Sum[R](P(A)) @@ -112,10 +116,10 @@ def test_derived_atomic(self): # Nested product expected = P(A) * P(B) * P(C) for b, c in itt.permutations((P(B), P(C))): - expression = Product((P(A), Product((b, c)))) + expression = Product(frozenset((P(A), Product(frozenset((b, c)))))) self.assert_canonicalize(expected, expression, [A, B, C]) - expression = Product((Product((P(A), b)), c)) + expression = Product(frozenset((Product(frozenset((P(A), b))), c))) self.assert_canonicalize(expected, expression, [A, B, C]) # Sum with simple product (only atomic)