Skip to content

Commit 4cafd8b

Browse files
committed
Add support for 32 and 16 bit numpy floats
This commit adds support for 32 and 16 bit floats numpy as input types. Since it doesn't look like there is a native symengine type for working with single (or half) precision floats outside of using MPFR (which seemed like the wrong approach here, it might make sense for np.float128 in a follow up though) this implicitly casts the np.float16 and np.float32 types to doubles and just behaves as float/np.float64. Fixes #351
1 parent ea1b771 commit 4cafd8b

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,10 @@ def sympy2symengine(a, raise_error=False):
295295
return RealDouble(float(str(a)))
296296
ELSE:
297297
return RealDouble(float(str(a)))
298+
elif isinstance(a, np.float16):
299+
return RealDouble(a)
300+
elif isinstance(a, np.float32):
301+
return RealDouble(a)
298302
elif a is sympy.I:
299303
return I
300304
elif a is sympy.E:
@@ -558,6 +562,10 @@ def _sympify(a, raise_error=True):
558562
return Integer(a)
559563
elif isinstance(a, float):
560564
return RealDouble(a)
565+
elif isinstance(a, np.float16):
566+
return RealDouble(a)
567+
elif isinstance(a, np.float32):
568+
return RealDouble(a)
561569
elif isinstance(a, complex):
562570
return ComplexDouble(a)
563571
elif hasattr(a, '_symengine_'):

symengine/tests/test_subs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import numpy as np
2+
13
from symengine.utilities import raises
24
from symengine import Symbol, sin, cos, sqrt, Add, function_symbol
35

@@ -56,3 +58,15 @@ def test_xreplace():
5658
y = Symbol("y")
5759
f = sin(cos(x))
5860
assert f.xreplace({x: y}) == sin(cos(y))
61+
62+
63+
def test_float32():
64+
x = Symbol("x")
65+
expr = x * 2
66+
assert expr.subs({x: np.float32(2)}) == 4.0
67+
68+
69+
def test_float16():
70+
x = Symbol("x")
71+
expr = x * 2
72+
assert expr.subs({x: np.float16(2)}) == 4.0

0 commit comments

Comments
 (0)