Skip to content

Commit 3c996a6

Browse files
committed
Follow-up division improvements
1 parent a348691 commit 3c996a6

File tree

3 files changed

+29
-15
lines changed

3 files changed

+29
-15
lines changed

src/y0/dsl.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -555,9 +555,14 @@ class Expression(Element, ABC):
555555
def __mul__(self, other):
556556
pass
557557

558-
def __truediv__(self, expression: Expression) -> Fraction:
558+
def __truediv__(self, expression: Expression) -> Expression:
559559
"""Divide this expression by another and create a fraction."""
560-
return Fraction(self, expression)
560+
if isinstance(expression, One):
561+
return self
562+
elif isinstance(expression, Fraction):
563+
return Fraction(self * expression.denominator, expression.numerator)
564+
else:
565+
return Fraction(self, expression)
561566

562567
def marginalize(self, ranges: VariableHint) -> Fraction:
563568
"""Return this expression, marginalized by the given variables.
@@ -987,13 +992,13 @@ def __mul__(self, expression: Expression) -> Fraction:
987992
return Fraction(self.numerator * expression, self.denominator)
988993

989994
def __truediv__(self, expression: Expression) -> Fraction:
990-
if isinstance(expression, Fraction):
995+
if isinstance(expression, One):
996+
return self
997+
elif isinstance(expression, Fraction):
991998
return Fraction(
992999
self.numerator * expression.denominator,
9931000
self.denominator * expression.numerator,
9941001
)
995-
elif isinstance(expression, One):
996-
return self
9971002
else:
9981003
return Fraction(self.numerator, self.denominator * expression)
9991004

@@ -1002,12 +1007,19 @@ def _iter_variables(self) -> Iterable[Variable]:
10021007
yield from self.numerator._iter_variables()
10031008
yield from self.denominator._iter_variables()
10041009

1010+
def flip(self) -> Fraction:
1011+
"""Exchange the numerator and denominator."""
1012+
return Fraction(self.denominator, self.numerator)
1013+
10051014
def simplify(self) -> Expression:
10061015
"""Simplify this fraction."""
10071016
if isinstance(self.denominator, One):
10081017
return self.numerator
10091018
if isinstance(self.numerator, One):
1010-
return self
1019+
if isinstance(self.denominator, Fraction):
1020+
return self.denominator.flip().simplify()
1021+
else:
1022+
return self
10111023
if self.numerator == self.denominator:
10121024
return One()
10131025
if isinstance(self.numerator, Product) and isinstance(self.denominator, Product):
@@ -1183,12 +1195,6 @@ def __mul__(self, other: Expression):
11831195
else:
11841196
return Product((self, other))
11851197

1186-
def __truediv__(self, expression: Expression) -> Fraction:
1187-
if isinstance(expression, Fraction):
1188-
return Fraction(self * expression.denominator, expression.numerator)
1189-
else:
1190-
return super().__truediv__(expression)
1191-
11921198
def _iter_variables(self) -> Iterable[Variable]:
11931199
yield from self.codomain
11941200
yield from self.domain

src/y0/mutate/chain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def fraction_expand(p: Probability) -> Fraction:
8787
.. math::
8888
P(Y_1,\dots,Y_n | X_1, \dots, X_m) = \frac{P(Y_1,\dots,Y_n,X_1,\dots,X_m)}{P(X_1,\dots,X_m)}
8989
"""
90-
return p.uncondition() / P(p.parents)
90+
return Fraction(p.uncondition(), P(p.parents))
9191

9292

9393
def bayes_expand(p: Probability) -> Fraction:

tests/test_mutate/test_simplify.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import unittest
66

7-
from y0.dsl import A, B, One, P, Sum
7+
from y0.dsl import A, B, Fraction, One, P, Sum
88

99
one = One()
1010

@@ -15,12 +15,13 @@ class TestCancel(unittest.TestCase):
1515
def test_simple_identity(self):
1616
"""Test cancelling when the numerator and denominator are the same."""
1717
for label, frac in [
18-
("one", one / one),
18+
("one", Fraction(one, one)),
1919
("prob", P(A) / P(A)),
2020
("sum", Sum(P(A)) / Sum(P(A))),
2121
("product", (P(A) * P(B)) / (P(A) * P(B))),
2222
]:
2323
with self.subTest(type=label):
24+
self.assertIsInstance(frac, Fraction)
2425
self.assertEqual(one, frac.simplify(), msg=f"\n\nActual:{frac}")
2526

2627
def test_fraction_simplify(self):
@@ -29,6 +30,13 @@ def test_fraction_simplify(self):
2930
("leave num.", P(B), (P(A) * P(B)) / P(A)),
3031
("leave den.", one / P(B), P(A) / (P(A) * P(B))),
3132
("unordered", one, (P(A) * P(B)) / (P(B) * P(A))),
33+
("canonical", one / P(A), one / P(A)),
34+
("flipper", P(A), Fraction(one, Fraction(one, P(A)))),
35+
("prob-redundant-one", P(A), Fraction(P(A), one)),
36+
("sum-redundant-one", Sum(P(A)), Fraction(Sum(P(A)), one)),
37+
("frac-redundant-one", P(A) / P(B), Fraction(Fraction(P(A), P(B)), one)),
38+
("prod-redundant-one", P(A) * P(B), Fraction(P(A) * P(B), one)),
3239
]:
3340
with self.subTest(type=label):
41+
self.assertIsInstance(frac, Fraction)
3442
self.assertEqual(expected, frac.simplify(), msg=f"\n\nActual:{frac}")

0 commit comments

Comments
 (0)