@@ -4058,6 +4058,7 @@ have_mpc = False
4058
4058
have_piranha = False
4059
4059
have_flint = False
4060
4060
have_llvm = False
4061
+ have_llvm_long_double = False
4061
4062
4062
4063
IF HAVE_SYMENGINE_MPFR:
4063
4064
have_mpfr = True
@@ -4080,6 +4081,9 @@ IF HAVE_SYMENGINE_FLINT:
4080
4081
IF HAVE_SYMENGINE_LLVM:
4081
4082
have_llvm = True
4082
4083
4084
+ IF HAVE_SYMENGINE_LLVM_LONG_DOUBLE:
4085
+ have_llvm_long_double = True
4086
+
4083
4087
def require (obj , t ):
4084
4088
if not isinstance (obj, t):
4085
4089
raise TypeError (" {} required. {} is of type {}" .format(t, obj, type (obj)))
@@ -4675,7 +4679,7 @@ def create_low_level_callable(lambdify, *args):
4675
4679
4676
4680
4677
4681
cdef class LambdaDouble(_Lambdify):
4678
- def __cinit__ (self , args , *exprs , cppbool real = True , order = ' C' , cppbool cse = False , cppbool _load = False ):
4682
+ def __cinit__ (self , args , *exprs , cppbool real = True , order = ' C' , cppbool cse = False , cppbool _load = False , dtype = None ):
4679
4683
# reject additional arguments
4680
4684
pass
4681
4685
@@ -4689,7 +4693,7 @@ cdef class LambdaDouble(_Lambdify):
4689
4693
cpdef unsafe_eval(self , inp, out, unsigned nbroadcast = 1 ):
4690
4694
cdef double [::1 ] c_inp, c_out
4691
4695
cdef unsigned idx
4692
- c_inp = np.ascontiguousarray(inp.ravel(order = self .order))
4696
+ c_inp = np.ascontiguousarray(inp.ravel(order = self .order), dtype = self .numpy_dtype )
4693
4697
c_out = out
4694
4698
for idx in range (nbroadcast):
4695
4699
self .lambda_double[0 ].call(& c_out[idx* self .tot_out_size], & c_inp[idx* self .args_size])
@@ -4720,7 +4724,7 @@ cdef class LambdaDouble(_Lambdify):
4720
4724
4721
4725
4722
4726
cdef class LambdaComplexDouble(_Lambdify):
4723
- def __cinit__ (self , args , *exprs , cppbool real = True , order = ' C' , cppbool cse = False , cppbool _load = False ):
4727
+ def __cinit__ (self , args , *exprs , cppbool real = True , order = ' C' , cppbool cse = False , cppbool _load = False , dtype = None ):
4724
4728
# reject additional arguments
4725
4729
pass
4726
4730
@@ -4734,15 +4738,15 @@ cdef class LambdaComplexDouble(_Lambdify):
4734
4738
cpdef unsafe_eval(self , inp, out, unsigned nbroadcast = 1 ):
4735
4739
cdef double complex [::1 ] c_inp, c_out
4736
4740
cdef unsigned idx
4737
- c_inp = np.ascontiguousarray(inp.ravel(order = self .order))
4741
+ c_inp = np.ascontiguousarray(inp.ravel(order = self .order), dtype = self .numpy_dtype )
4738
4742
c_out = out
4739
4743
for idx in range (nbroadcast):
4740
4744
self .lambda_double[0 ].call(& c_out[idx* self .tot_out_size], & c_inp[idx* self .args_size])
4741
4745
4742
4746
4743
4747
IF HAVE_SYMENGINE_LLVM:
4744
4748
cdef class LLVMDouble(_LLVMLambdify):
4745
- def __cinit__ (self , args , *exprs , cppbool real = True , order = ' C' , cppbool cse = False , cppbool _load = False , opt_level = 3 ):
4749
+ def __cinit__ (self , args , *exprs , cppbool real = True , order = ' C' , cppbool cse = False , cppbool _load = False , opt_level = 3 , dtype = None ):
4746
4750
self .opt_level = opt_level
4747
4751
4748
4752
cdef _init(self , symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse):
@@ -4767,7 +4771,7 @@ IF HAVE_SYMENGINE_LLVM:
4767
4771
cpdef unsafe_eval(self , inp, out, unsigned nbroadcast = 1 ):
4768
4772
cdef double [::1 ] c_inp, c_out
4769
4773
cdef unsigned idx
4770
- c_inp = np.ascontiguousarray(inp.ravel(order = self .order))
4774
+ c_inp = np.ascontiguousarray(inp.ravel(order = self .order), dtype = self .numpy_dtype )
4771
4775
c_out = out
4772
4776
for idx in range (nbroadcast):
4773
4777
self .lambda_double[0 ].call(& c_out[idx* self .tot_out_size], & c_inp[idx* self .args_size])
@@ -4801,7 +4805,7 @@ IF HAVE_SYMENGINE_LLVM:
4801
4805
return addr1, addr2
4802
4806
4803
4807
cdef class LLVMFloat(_LLVMLambdify):
4804
- def __cinit__ (self , args , *exprs , cppbool real = True , order = ' C' , cppbool cse = False , cppbool _load = False , opt_level = 3 ):
4808
+ def __cinit__ (self , args , *exprs , cppbool real = True , order = ' C' , cppbool cse = False , cppbool _load = False , opt_level = 3 , dtype = None ):
4805
4809
self .opt_level = opt_level
4806
4810
4807
4811
cdef _init(self , symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse):
@@ -4826,14 +4830,14 @@ IF HAVE_SYMENGINE_LLVM:
4826
4830
cpdef unsafe_eval(self , inp, out, unsigned nbroadcast = 1 ):
4827
4831
cdef float [::1 ] c_inp, c_out
4828
4832
cdef unsigned idx
4829
- c_inp = np.ascontiguousarray(inp.ravel(order = self .order))
4833
+ c_inp = np.ascontiguousarray(inp.ravel(order = self .order), dtype = self .numpy_dtype )
4830
4834
c_out = out
4831
4835
for idx in range (nbroadcast):
4832
4836
self .lambda_double[0 ].call(& c_out[idx* self .tot_out_size], & c_inp[idx* self .args_size])
4833
4837
4834
4838
IF HAVE_SYMENGINE_LLVM_LONG_DOUBLE:
4835
4839
cdef class LLVMLongDouble(_LLVMLambdify):
4836
- def __cinit__ (self , args , *exprs , cppbool real = True , order = ' C' , cppbool cse = False , cppbool _load = False , opt_level = 3 ):
4840
+ def __cinit__ (self , args , *exprs , cppbool real = True , order = ' C' , cppbool cse = False , cppbool _load = False , opt_level = 3 , dtype = None ):
4837
4841
self .opt_level = opt_level
4838
4842
4839
4843
cdef _init(self , symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse):
@@ -4858,7 +4862,7 @@ IF HAVE_SYMENGINE_LLVM:
4858
4862
cpdef unsafe_eval(self , inp, out, unsigned nbroadcast = 1 ):
4859
4863
cdef long double [::1 ] c_inp, c_out
4860
4864
cdef unsigned idx
4861
- c_inp = np.ascontiguousarray(inp.ravel(order = self .order))
4865
+ c_inp = np.ascontiguousarray(inp.ravel(order = self .order), dtype = self .numpy_dtype )
4862
4866
c_out = out
4863
4867
for idx in range (nbroadcast):
4864
4868
self .lambda_double[0 ].call(& c_out[idx* self .tot_out_size], & c_inp[idx* self .args_size])
@@ -4926,14 +4930,14 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C',
4926
4930
if backend == " llvm" :
4927
4931
IF HAVE_SYMENGINE_LLVM:
4928
4932
if dtype == None :
4929
- dtype = np.double
4930
- if dtype == np.double :
4931
- ret = LLVMDouble(args, * exprs, real = real, order = order, cse = cse, ** kwargs)
4932
- elif dtype == np.float :
4933
- ret = LLVMFloat(args, * exprs, real = real, order = order, cse = cse, ** kwargs)
4933
+ dtype = np.float64
4934
+ if dtype == np.float64 :
4935
+ ret = LLVMDouble(args, * exprs, real = real, order = order, cse = cse, dtype = np.float64, ** kwargs)
4936
+ elif dtype == np.float32 :
4937
+ ret = LLVMFloat(args, * exprs, real = real, order = order, cse = cse, dtype = np.float32, ** kwargs)
4934
4938
elif dtype == np.longdouble:
4935
4939
IF HAVE_SYMENGINE_LLVM_LONG_DOUBLE:
4936
- ret = LLVMLongDouble(args, * exprs, real = real, order = order, cse = cse, ** kwargs)
4940
+ ret = LLVMLongDouble(args, * exprs, real = real, order = order, cse = cse, dtype = np.longdouble, ** kwargs)
4937
4941
ELSE :
4938
4942
raise ValueError (" Long double not supported on this platform" )
4939
4943
else :
0 commit comments