Skip to content

Commit 43c7f99

Browse files
authored
Add zero element to DSL (#96)
* Add zero element to DSL Closes #95 * Add operation overrides * Cleanup division * Finalize zero implementation * Fix zero divided by zero * Update chain.py
1 parent 3c996a6 commit 43c7f99

File tree

2 files changed

+94
-4
lines changed

2 files changed

+94
-4
lines changed

src/y0/dsl.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
"Fraction",
3737
"Expression",
3838
"One",
39+
"Zero",
3940
"Q",
4041
"QFactor",
4142
"A",
@@ -636,7 +637,11 @@ def is_markov_kernel(self) -> bool:
636637
return self.distribution.is_markov_kernel()
637638

638639
def __mul__(self, other: Expression) -> Expression:
639-
if isinstance(other, Product):
640+
if isinstance(other, Zero):
641+
return other
642+
elif isinstance(other, One):
643+
return self
644+
elif isinstance(other, Product):
640645
return Product((self, *other.expressions))
641646
elif isinstance(other, Fraction):
642647
return Fraction(self * other.numerator, other.denominator)
@@ -841,6 +846,8 @@ def to_latex(self):
841846
return " ".join(expression.to_latex() for expression in self.expressions)
842847

843848
def __mul__(self, other: Expression):
849+
if isinstance(other, Zero):
850+
return other
844851
if isinstance(other, Product):
845852
return Product((*self.expressions, *other.expressions))
846853
elif isinstance(other, Fraction):
@@ -929,7 +936,9 @@ def to_y0(self):
929936
return f"Sum[{ranges}]({s})"
930937

931938
def __mul__(self, expression: Expression):
932-
if isinstance(expression, Product):
939+
if isinstance(expression, Zero):
940+
return expression
941+
elif isinstance(expression, Product):
933942
return Product((self, *expression.expressions))
934943
else:
935944
return Product((self, expression))
@@ -969,6 +978,10 @@ class Fraction(Expression):
969978
#: The expression in the denominator of the fraction
970979
denominator: Expression
971980

981+
def __post_init__(self):
982+
if isinstance(self.denominator, Zero):
983+
raise ZeroDivisionError
984+
972985
def to_text(self) -> str:
973986
"""Output this fraction in the internal string format."""
974987
return f"frac_{{{self.numerator.to_text()}}}{{{self.denominator.to_text()}}}"
@@ -982,8 +995,10 @@ def to_y0(self, parens: bool = True) -> str:
982995
s = f"({self.numerator.to_y0()} / {self.denominator.to_y0()})"
983996
return f"({s})" if parens else s
984997

985-
def __mul__(self, expression: Expression) -> Fraction:
986-
if isinstance(expression, Fraction):
998+
def __mul__(self, expression: Expression) -> Expression:
999+
if isinstance(expression, Zero):
1000+
return expression
1001+
elif isinstance(expression, Fraction):
9871002
return Fraction(
9881003
self.numerator * expression.numerator,
9891004
self.denominator * expression.denominator,
@@ -1015,6 +1030,8 @@ def simplify(self) -> Expression:
10151030
"""Simplify this fraction."""
10161031
if isinstance(self.denominator, One):
10171032
return self.numerator
1033+
if isinstance(self.numerator, Zero):
1034+
return self.numerator
10181035
if isinstance(self.numerator, One):
10191036
if isinstance(self.denominator, Fraction):
10201037
return self.denominator.flip().simplify()
@@ -1111,6 +1128,40 @@ def _iter_variables(self) -> Iterable[Variable]:
11111128
return iter([])
11121129

11131130

1131+
class Zero(Expression):
1132+
"""The additive identity (0)."""
1133+
1134+
def to_text(self) -> str:
1135+
"""Output this identity variable in the internal string format."""
1136+
return "0"
1137+
1138+
def to_latex(self) -> str:
1139+
"""Output this identity instance in the LaTeX string format."""
1140+
return "0"
1141+
1142+
def to_y0(self) -> str:
1143+
"""Output this identity instance as y0 internal DSL code."""
1144+
return "Zero()"
1145+
1146+
def __rmul__(self, expression: Expression) -> Expression:
1147+
return self
1148+
1149+
def __mul__(self, expression: Expression) -> Expression:
1150+
return self
1151+
1152+
def __truediv__(self, other: Expression) -> Expression:
1153+
if isinstance(other, Zero):
1154+
raise ZeroDivisionError
1155+
return self
1156+
1157+
def __eq__(self, other):
1158+
return isinstance(other, Zero) # all zeros are equal
1159+
1160+
def _iter_variables(self) -> Iterable[Variable]:
1161+
"""Get the set of variables used in this expression."""
1162+
return iter([])
1163+
1164+
11141165
class QBuilder(Protocol[T_co]):
11151166
"""A protocol for annotating the special class getitem functionality of the :class:`QFactor` class."""
11161167

tests/test_dsl.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
X,
2929
Y,
3030
Z,
31+
Zero,
3132
)
3233
from y0.parser import parse_y0
3334

@@ -345,3 +346,41 @@ def test_product(self):
345346
self.assertEqual(Product((p,)), Product.safe({p}))
346347

347348
self.assertEqual(Product((P(X), P(Y))), Product.safe(P(v) for v in [X, Y]))
349+
350+
351+
zero = Zero()
352+
353+
354+
class TestZero(unittest.TestCase):
355+
"""Tests for zero."""
356+
357+
exprs = [
358+
One(),
359+
Zero(),
360+
P(A),
361+
P(A) * P(B),
362+
P(A) / P(B),
363+
Sum.safe(P(A), [A]),
364+
]
365+
366+
def test_divide_failure(self):
367+
"""Test failure is thron on division by zero."""
368+
for expr in self.exprs:
369+
with self.subTest(expr=expr), self.assertRaises(ZeroDivisionError):
370+
expr / zero
371+
372+
def test_identity(self):
373+
"""Test that zero divided by anything is zero."""
374+
for expr in self.exprs:
375+
if isinstance(expr, Zero):
376+
continue # would raise divides by zero
377+
with self.subTest(expr=expr):
378+
self.assertEqual(zero, zero / expr)
379+
380+
def test_multiply(self):
381+
"""Test other operations."""
382+
for expr in self.exprs:
383+
with self.subTest(expr=expr.to_y0(), direction="right"):
384+
self.assertEqual(zero, zero * expr, msg=f"Got {zero * expr}")
385+
with self.subTest(expr=expr.to_y0(), direction="left"):
386+
self.assertEqual(zero, expr * zero, msg=f"Got {expr * zero}")

0 commit comments

Comments
 (0)