Skip to content

Commit 50697c0

Browse files
committed
Improvements
1 parent cc7d63f commit 50697c0

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

symengine/printing.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
from symengine.lib.symengine_wrapper import ccode, _sympify
1+
from symengine.lib.symengine_wrapper import ccode, _sympify, Basic
22

33
class CCodePrinter:
44

55
def doprint(self, expr, assign_to=None):
6+
if not isinstance(assign_to, (Basic, type(None), str)):
7+
raise TypeError("{0} cannot assign to object of type {1}".format(
8+
type(self).__name__, type(assign_to)))
9+
610
expr = _sympify(expr)
711
if not assign_to:
812
if expr.is_Matrix:

symengine/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ install(FILES __init__.py
88
test_number.py
99
test_matrices.py
1010
test_ntheory.py
11+
test_printing.py
1112
test_sage.py
1213
test_series_expansion.py
1314
test_subs.py

symengine/tests/test_printing.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from symengine.utilities import raises
2+
from symengine.lib.symengine_wrapper import (ccode, Symbol, sqrt, Pow, Max, sin, Integer, MutableDenseMatrix)
3+
from symengine.printing import CCodePrinter
4+
5+
def test_ccode():
6+
x = Symbol("x")
7+
y = Symbol("y")
8+
assert ccode(x) == "x"
9+
assert ccode(x**3) == "pow(x, 3)"
10+
assert ccode(x**(y**3)) == "pow(x, pow(y, 3))"
11+
assert ccode(x**-1.0) == "pow(x, -1.0)"
12+
assert ccode(Max(x, x*x)) == "max(x, pow(x, 2))"
13+
assert ccode(sin(x)) == "sin(x)"
14+
assert ccode(Integer(67)) == "67"
15+
assert ccode(Integer(-1)) == "-1"
16+
17+
def test_CCodePrinter():
18+
x = Symbol("x")
19+
y = Symbol("y")
20+
myprinter = CCodePrinter()
21+
22+
assert myprinter.doprint(1+x, "bork") == "bork = 1 + x;"
23+
assert myprinter.doprint(1*x) == "x"
24+
assert myprinter.doprint(MutableDenseMatrix(1, 2, [x, y]), "larry") == "larry[0] = x;\nlarry[1] = y;"
25+
raises(TypeError, lambda: myprinter.doprint(sin(x), Integer))
26+
raises(RuntimeError, lambda: myprinter.doprint(MutableDenseMatrix(1, 2, [x, y])))
27+

0 commit comments

Comments
 (0)