Skip to content

Commit 7dc38eb

Browse files
authored
Merge pull request #352 from mtreinish/add-float32
Add support for 32 and 16 bit numpy floats
2 parents ea1b771 + 65ceced commit 7dc38eb

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,8 @@ def _sympify(a, raise_error=True):
558558
return Integer(a)
559559
elif isinstance(a, float):
560560
return RealDouble(a)
561+
elif have_numpy and isinstance(a, (np.float16, np.float32)):
562+
return RealDouble(a)
561563
elif isinstance(a, complex):
562564
return ComplexDouble(a)
563565
elif hasattr(a, '_symengine_'):

symengine/tests/test_subs.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import unittest
2+
13
from symengine.utilities import raises
2-
from symengine import Symbol, sin, cos, sqrt, Add, function_symbol
4+
from symengine import Symbol, sin, cos, sqrt, Add, function_symbol, have_numpy
35

46

57
def test_basic():
@@ -56,3 +58,19 @@ def test_xreplace():
5658
y = Symbol("y")
5759
f = sin(cos(x))
5860
assert f.xreplace({x: y}) == sin(cos(y))
61+
62+
63+
@unittest.skipUnless(have_numpy, "Numpy not installed")
64+
def test_float32():
65+
import numpy as np
66+
x = Symbol("x")
67+
expr = x * 2
68+
assert expr.subs({x: np.float32(2)}) == 4.0
69+
70+
71+
@unittest.skipUnless(have_numpy, "Numpy not installed")
72+
def test_float16():
73+
import numpy as np
74+
x = Symbol("x")
75+
expr = x * 2
76+
assert expr.subs({x: np.float16(2)}) == 4.0

0 commit comments

Comments
 (0)