Skip to content

Commit e831e2d

Browse files
committed
Add tests for LLVM float and longdouble
1 parent dc4cc7c commit e831e2d

File tree

3 files changed

+52
-17
lines changed

3 files changed

+52
-17
lines changed

symengine/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .lib.symengine_wrapper import (
2-
have_mpfr, have_mpc, have_flint, have_piranha, have_llvm,
2+
have_mpfr, have_mpc, have_flint, have_piranha, have_llvm, have_llvm_long_double,
33
I, E, pi, oo, zoo, nan, Symbol, Dummy, S, sympify, SympifyError,
44
Integer, Rational, Float, Number, RealNumber, RealDouble, ComplexDouble,
55
add, Add, Mul, Pow, function_symbol,

symengine/lib/symengine_wrapper.pyx

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4058,6 +4058,7 @@ have_mpc = False
40584058
have_piranha = False
40594059
have_flint = False
40604060
have_llvm = False
4061+
have_llvm_long_double = False
40614062

40624063
IF HAVE_SYMENGINE_MPFR:
40634064
have_mpfr = True
@@ -4080,6 +4081,9 @@ IF HAVE_SYMENGINE_FLINT:
40804081
IF HAVE_SYMENGINE_LLVM:
40814082
have_llvm = True
40824083

4084+
IF HAVE_SYMENGINE_LLVM_LONG_DOUBLE:
4085+
have_llvm_long_double = True
4086+
40834087
def require(obj, t):
40844088
if not isinstance(obj, t):
40854089
raise TypeError("{} required. {} is of type {}".format(t, obj, type(obj)))
@@ -4675,7 +4679,7 @@ def create_low_level_callable(lambdify, *args):
46754679

46764680

46774681
cdef class LambdaDouble(_Lambdify):
4678-
def __cinit__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False):
4682+
def __cinit__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False, dtype=None):
46794683
# reject additional arguments
46804684
pass
46814685

@@ -4689,7 +4693,7 @@ cdef class LambdaDouble(_Lambdify):
46894693
cpdef unsafe_eval(self, inp, out, unsigned nbroadcast=1):
46904694
cdef double[::1] c_inp, c_out
46914695
cdef unsigned idx
4692-
c_inp = np.ascontiguousarray(inp.ravel(order=self.order))
4696+
c_inp = np.ascontiguousarray(inp.ravel(order=self.order), dtype=self.numpy_dtype)
46934697
c_out = out
46944698
for idx in range(nbroadcast):
46954699
self.lambda_double[0].call(&c_out[idx*self.tot_out_size], &c_inp[idx*self.args_size])
@@ -4720,7 +4724,7 @@ cdef class LambdaDouble(_Lambdify):
47204724

47214725

47224726
cdef class LambdaComplexDouble(_Lambdify):
4723-
def __cinit__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False):
4727+
def __cinit__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False, dtype=None):
47244728
# reject additional arguments
47254729
pass
47264730

@@ -4734,15 +4738,15 @@ cdef class LambdaComplexDouble(_Lambdify):
47344738
cpdef unsafe_eval(self, inp, out, unsigned nbroadcast=1):
47354739
cdef double complex[::1] c_inp, c_out
47364740
cdef unsigned idx
4737-
c_inp = np.ascontiguousarray(inp.ravel(order=self.order))
4741+
c_inp = np.ascontiguousarray(inp.ravel(order=self.order), dtype=self.numpy_dtype)
47384742
c_out = out
47394743
for idx in range(nbroadcast):
47404744
self.lambda_double[0].call(&c_out[idx*self.tot_out_size], &c_inp[idx*self.args_size])
47414745

47424746

47434747
IF HAVE_SYMENGINE_LLVM:
47444748
cdef class LLVMDouble(_LLVMLambdify):
4745-
def __cinit__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False, opt_level=3):
4749+
def __cinit__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False, opt_level=3, dtype=None):
47464750
self.opt_level = opt_level
47474751

47484752
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse):
@@ -4767,7 +4771,7 @@ IF HAVE_SYMENGINE_LLVM:
47674771
cpdef unsafe_eval(self, inp, out, unsigned nbroadcast=1):
47684772
cdef double[::1] c_inp, c_out
47694773
cdef unsigned idx
4770-
c_inp = np.ascontiguousarray(inp.ravel(order=self.order))
4774+
c_inp = np.ascontiguousarray(inp.ravel(order=self.order), dtype=self.numpy_dtype)
47714775
c_out = out
47724776
for idx in range(nbroadcast):
47734777
self.lambda_double[0].call(&c_out[idx*self.tot_out_size], &c_inp[idx*self.args_size])
@@ -4801,7 +4805,7 @@ IF HAVE_SYMENGINE_LLVM:
48014805
return addr1, addr2
48024806

48034807
cdef class LLVMFloat(_LLVMLambdify):
4804-
def __cinit__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False, opt_level=3):
4808+
def __cinit__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False, opt_level=3, dtype=None):
48054809
self.opt_level = opt_level
48064810

48074811
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse):
@@ -4826,14 +4830,14 @@ IF HAVE_SYMENGINE_LLVM:
48264830
cpdef unsafe_eval(self, inp, out, unsigned nbroadcast=1):
48274831
cdef float[::1] c_inp, c_out
48284832
cdef unsigned idx
4829-
c_inp = np.ascontiguousarray(inp.ravel(order=self.order))
4833+
c_inp = np.ascontiguousarray(inp.ravel(order=self.order), dtype=self.numpy_dtype)
48304834
c_out = out
48314835
for idx in range(nbroadcast):
48324836
self.lambda_double[0].call(&c_out[idx*self.tot_out_size], &c_inp[idx*self.args_size])
48334837

48344838
IF HAVE_SYMENGINE_LLVM_LONG_DOUBLE:
48354839
cdef class LLVMLongDouble(_LLVMLambdify):
4836-
def __cinit__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False, opt_level=3):
4840+
def __cinit__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False, opt_level=3, dtype=None):
48374841
self.opt_level = opt_level
48384842

48394843
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse):
@@ -4858,7 +4862,7 @@ IF HAVE_SYMENGINE_LLVM:
48584862
cpdef unsafe_eval(self, inp, out, unsigned nbroadcast=1):
48594863
cdef long double[::1] c_inp, c_out
48604864
cdef unsigned idx
4861-
c_inp = np.ascontiguousarray(inp.ravel(order=self.order))
4865+
c_inp = np.ascontiguousarray(inp.ravel(order=self.order), dtype=self.numpy_dtype)
48624866
c_out = out
48634867
for idx in range(nbroadcast):
48644868
self.lambda_double[0].call(&c_out[idx*self.tot_out_size], &c_inp[idx*self.args_size])
@@ -4926,14 +4930,14 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C',
49264930
if backend == "llvm":
49274931
IF HAVE_SYMENGINE_LLVM:
49284932
if dtype == None:
4929-
dtype = np.double
4930-
if dtype == np.double:
4931-
ret = LLVMDouble(args, *exprs, real=real, order=order, cse=cse, **kwargs)
4932-
elif dtype == np.float:
4933-
ret = LLVMFloat(args, *exprs, real=real, order=order, cse=cse, **kwargs)
4933+
dtype = np.float64
4934+
if dtype == np.float64:
4935+
ret = LLVMDouble(args, *exprs, real=real, order=order, cse=cse, dtype=np.float64, **kwargs)
4936+
elif dtype == np.float32:
4937+
ret = LLVMFloat(args, *exprs, real=real, order=order, cse=cse, dtype=np.float32, **kwargs)
49344938
elif dtype == np.longdouble:
49354939
IF HAVE_SYMENGINE_LLVM_LONG_DOUBLE:
4936-
ret = LLVMLongDouble(args, *exprs, real=real, order=order, cse=cse, **kwargs)
4940+
ret = LLVMLongDouble(args, *exprs, real=real, order=order, cse=cse, dtype=np.longdouble, **kwargs)
49374941
ELSE:
49384942
raise ValueError("Long double not supported on this platform")
49394943
else:

symengine/tests/test_lambdify.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,3 +835,34 @@ def test_as_ctypes():
835835
out = np.array([0, 0], dtype=np.double)
836836
addr1(out.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), inp.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), addr2)
837837
assert np.all(out == [6, 7])
838+
839+
@unittest.skipUnless(have_numpy, "Numpy not installed")
840+
@unittest.skipUnless(se.have_llvm, "No LLVM support")
841+
def test_llvm_float():
842+
import numpy as np
843+
import ctypes
844+
from symengine.lib.symengine_wrapper import LLVMFloat
845+
x, y, z = se.symbols('x, y, z')
846+
l = se.Lambdify([x, y, z], [se.Min(x, y), se.Max(y, z)], dtype=np.float32, backend='llvm')
847+
inp = np.array([1,2,3], dtype=np.float32)
848+
exp_out = np.array([1, 3], dtype=np.float32)
849+
out = l(inp)
850+
assert type(l) == LLVMFloat
851+
assert out.dtype == np.float32
852+
assert np.allclose(out, exp_out)
853+
854+
@unittest.skipUnless(have_numpy, "Numpy not installed")
855+
@unittest.skipUnless(se.have_llvm, "No LLVM support")
856+
@unittest.skipUnless(se.have_llvm_long_double, "No LLVM IEEE-80 bit support")
857+
def test_llvm_long_double():
858+
import numpy as np
859+
import ctypes
860+
from symengine.lib.symengine_wrapper import LLVMLongDouble
861+
x, y, z = se.symbols('x, y, z')
862+
l = se.Lambdify([x, y, z], [2*x, y/z], dtype=np.longdouble, backend='llvm')
863+
inp = np.array([1,2,3], dtype=np.longdouble)
864+
exp_out = np.array([2, 2.0/3.0], dtype=np.longdouble)
865+
out = l(inp)
866+
assert type(l) == LLVMLongDouble
867+
assert out.dtype == np.longdouble
868+
assert np.allclose(out, exp_out)

0 commit comments

Comments
 (0)