3636 "Fraction" ,
3737 "Expression" ,
3838 "One" ,
39+ "Zero" ,
3940 "Q" ,
4041 "QFactor" ,
4142 "A" ,
@@ -636,7 +637,11 @@ def is_markov_kernel(self) -> bool:
636637 return self .distribution .is_markov_kernel ()
637638
638639 def __mul__ (self , other : Expression ) -> Expression :
639- if isinstance (other , Product ):
640+ if isinstance (other , Zero ):
641+ return other
642+ elif isinstance (other , One ):
643+ return self
644+ elif isinstance (other , Product ):
640645 return Product ((self , * other .expressions ))
641646 elif isinstance (other , Fraction ):
642647 return Fraction (self * other .numerator , other .denominator )
@@ -841,6 +846,8 @@ def to_latex(self):
841846 return " " .join (expression .to_latex () for expression in self .expressions )
842847
843848 def __mul__ (self , other : Expression ):
849+ if isinstance (other , Zero ):
850+ return other
844851 if isinstance (other , Product ):
845852 return Product ((* self .expressions , * other .expressions ))
846853 elif isinstance (other , Fraction ):
@@ -929,7 +936,9 @@ def to_y0(self):
929936 return f"Sum[{ ranges } ]({ s } )"
930937
931938 def __mul__ (self , expression : Expression ):
932- if isinstance (expression , Product ):
939+ if isinstance (expression , Zero ):
940+ return expression
941+ elif isinstance (expression , Product ):
933942 return Product ((self , * expression .expressions ))
934943 else :
935944 return Product ((self , expression ))
@@ -969,6 +978,10 @@ class Fraction(Expression):
969978 #: The expression in the denominator of the fraction
970979 denominator : Expression
971980
981+ def __post_init__ (self ):
982+ if isinstance (self .denominator , Zero ):
983+ raise ZeroDivisionError
984+
972985 def to_text (self ) -> str :
973986 """Output this fraction in the internal string format."""
974987 return f"frac_{{{ self .numerator .to_text ()} }}{{{ self .denominator .to_text ()} }}"
@@ -982,8 +995,10 @@ def to_y0(self, parens: bool = True) -> str:
982995 s = f"({ self .numerator .to_y0 ()} / { self .denominator .to_y0 ()} )"
983996 return f"({ s } )" if parens else s
984997
985- def __mul__ (self , expression : Expression ) -> Fraction :
986- if isinstance (expression , Fraction ):
998+ def __mul__ (self , expression : Expression ) -> Expression :
999+ if isinstance (expression , Zero ):
1000+ return expression
1001+ elif isinstance (expression , Fraction ):
9871002 return Fraction (
9881003 self .numerator * expression .numerator ,
9891004 self .denominator * expression .denominator ,
@@ -1015,6 +1030,8 @@ def simplify(self) -> Expression:
10151030 """Simplify this fraction."""
10161031 if isinstance (self .denominator , One ):
10171032 return self .numerator
1033+ if isinstance (self .numerator , Zero ):
1034+ return self .numerator
10181035 if isinstance (self .numerator , One ):
10191036 if isinstance (self .denominator , Fraction ):
10201037 return self .denominator .flip ().simplify ()
@@ -1111,6 +1128,40 @@ def _iter_variables(self) -> Iterable[Variable]:
11111128 return iter ([])
11121129
11131130
1131+ class Zero (Expression ):
1132+ """The additive identity (0)."""
1133+
1134+ def to_text (self ) -> str :
1135+ """Output this identity variable in the internal string format."""
1136+ return "0"
1137+
1138+ def to_latex (self ) -> str :
1139+ """Output this identity instance in the LaTeX string format."""
1140+ return "0"
1141+
1142+ def to_y0 (self ) -> str :
1143+ """Output this identity instance as y0 internal DSL code."""
1144+ return "Zero()"
1145+
1146+ def __rmul__ (self , expression : Expression ) -> Expression :
1147+ return self
1148+
1149+ def __mul__ (self , expression : Expression ) -> Expression :
1150+ return self
1151+
1152+ def __truediv__ (self , other : Expression ) -> Expression :
1153+ if isinstance (other , Zero ):
1154+ raise ZeroDivisionError
1155+ return self
1156+
1157+ def __eq__ (self , other ):
1158+ return isinstance (other , Zero ) # all zeros are equal
1159+
1160+ def _iter_variables (self ) -> Iterable [Variable ]:
1161+ """Get the set of variables used in this expression."""
1162+ return iter ([])
1163+
1164+
11141165class QBuilder (Protocol [T_co ]):
11151166 """A protocol for annotating the special class getitem functionality of the :class:`QFactor` class."""
11161167
0 commit comments