@@ -555,9 +555,14 @@ class Expression(Element, ABC):
555555 def __mul__ (self , other ):
556556 pass
557557
558- def __truediv__ (self , expression : Expression ) -> Fraction :
558+ def __truediv__ (self , expression : Expression ) -> Expression :
559559 """Divide this expression by another and create a fraction."""
560- return Fraction (self , expression )
560+ if isinstance (expression , One ):
561+ return self
562+ elif isinstance (expression , Fraction ):
563+ return Fraction (self * expression .denominator , expression .numerator )
564+ else :
565+ return Fraction (self , expression )
561566
562567 def marginalize (self , ranges : VariableHint ) -> Fraction :
563568 """Return this expression, marginalized by the given variables.
@@ -987,13 +992,13 @@ def __mul__(self, expression: Expression) -> Fraction:
987992 return Fraction (self .numerator * expression , self .denominator )
988993
989994 def __truediv__ (self , expression : Expression ) -> Fraction :
990- if isinstance (expression , Fraction ):
995+ if isinstance (expression , One ):
996+ return self
997+ elif isinstance (expression , Fraction ):
991998 return Fraction (
992999 self .numerator * expression .denominator ,
9931000 self .denominator * expression .numerator ,
9941001 )
995- elif isinstance (expression , One ):
996- return self
9971002 else :
9981003 return Fraction (self .numerator , self .denominator * expression )
9991004
@@ -1002,12 +1007,19 @@ def _iter_variables(self) -> Iterable[Variable]:
10021007 yield from self .numerator ._iter_variables ()
10031008 yield from self .denominator ._iter_variables ()
10041009
1010+ def flip (self ) -> Fraction :
1011+ """Exchange the numerator and denominator."""
1012+ return Fraction (self .denominator , self .numerator )
1013+
10051014 def simplify (self ) -> Expression :
10061015 """Simplify this fraction."""
10071016 if isinstance (self .denominator , One ):
10081017 return self .numerator
10091018 if isinstance (self .numerator , One ):
1010- return self
1019+ if isinstance (self .denominator , Fraction ):
1020+ return self .denominator .flip ().simplify ()
1021+ else :
1022+ return self
10111023 if self .numerator == self .denominator :
10121024 return One ()
10131025 if isinstance (self .numerator , Product ) and isinstance (self .denominator , Product ):
@@ -1183,12 +1195,6 @@ def __mul__(self, other: Expression):
11831195 else :
11841196 return Product ((self , other ))
11851197
1186- def __truediv__ (self , expression : Expression ) -> Fraction :
1187- if isinstance (expression , Fraction ):
1188- return Fraction (self * expression .denominator , expression .numerator )
1189- else :
1190- return super ().__truediv__ (expression )
1191-
11921198 def _iter_variables (self ) -> Iterable [Variable ]:
11931199 yield from self .codomain
11941200 yield from self .domain
0 commit comments