@@ -2580,8 +2580,10 @@ cdef class _Lambdify(object):
2580
2580
Parameters
2581
2581
----------
2582
2582
args: iterable of Symbols
2583
- exprs: array_like of expressions
2583
+ \* exprs: array_like of expressions
2584
2584
the shape of exprs is preserved
2585
+ real : bool
2586
+ Whether datatype is ``double`` (``double complex`` otherwise).
2585
2587
2586
2588
Returns
2587
2589
-------
@@ -2599,49 +2601,45 @@ cdef class _Lambdify(object):
2599
2601
[ 9., 24.]
2600
2602
2601
2603
"""
2602
- cdef size_t args_size, out_size
2603
- cdef tuple out_shape
2604
+ cdef size_t args_size, tot_out_size
2605
+ cdef list out_shapes
2606
+ cdef vector[int ] out_sizes, accum_out_sizes
2604
2607
cdef readonly bool real
2608
+ cdef readonly int n_exprs
2605
2609
2606
- def __cinit__ (self , args , exprs , bool real = True ):
2610
+ def __cinit__ (self , args , * exprs , bool real = True ):
2607
2611
self .real = real
2608
- self .out_shape = get_shape(exprs)
2612
+ self .out_shapes = [get_shape(expr) for expr in exprs]
2613
+ self .n_exprs = len (exprs)
2609
2614
self .args_size = _size(args)
2610
- self .out_size = reduce (mul, self .out_shape)
2615
+ self .out_sizes = [reduce (mul, shape) for shape in self .out_shapes]
2616
+ self .accum_out_sizes = [sum (self .out_sizes[:i]) for i in range (self .n_exprs + 1 )]
2617
+ self .tot_out_size = sum (self .out_sizes)
2611
2618
2612
-
2613
- def __init__ (self , args , exprs , bool real = True ):
2619
+ def __init__ (self , args , *exprs , bool real = True ):
2614
2620
cdef:
2615
2621
Basic e_
2616
2622
size_t ri, ci, nr, nc
2617
2623
symengine.MatrixBase * mtx
2618
2624
RCP[const symengine.Basic] b_
2619
2625
symengine.vec_basic args_, outs_
2620
2626
2621
- if isinstance (args, DenseMatrix):
2622
- nr = args.nrows()
2623
- nc = args.ncols()
2624
- mtx = (< DenseMatrix> args).thisptr
2625
- for ri in range (nr):
2626
- for ci in range (nc):
2627
- args_.push_back(deref(mtx).get(ri, ci))
2628
- else :
2629
- for e in args:
2630
- e_ = sympify(e)
2631
- args_.push_back(e_.thisptr)
2632
-
2633
- if isinstance (exprs, DenseMatrix):
2634
- nr = exprs.nrows()
2635
- nc = exprs.ncols()
2636
- mtx = (< DenseMatrix> exprs).thisptr
2637
- for ri in range (nr):
2638
- for ci in range (nc):
2639
- b_ = deref(mtx).get(ri, ci)
2640
- outs_.push_back(b_)
2641
- else :
2642
- for e in ravel(exprs):
2643
- e_ = sympify(e)
2644
- outs_.push_back(e_.thisptr)
2627
+ for e in args:
2628
+ e_ = sympify(e)
2629
+ args_.push_back(e_.thisptr)
2630
+ for curr_expr in exprs:
2631
+ if isinstance (curr_expr, DenseMatrix):
2632
+ nr = curr_expr.nrows()
2633
+ nc = curr_expr.ncols()
2634
+ mtx = (< DenseMatrix> curr_expr).thisptr
2635
+ for ri in range (nr):
2636
+ for ci in range (nc):
2637
+ b_ = deref(mtx).get(ri, ci)
2638
+ outs_.push_back(b_)
2639
+ else :
2640
+ for e in ravel(curr_expr):
2641
+ e_ = sympify(e)
2642
+ outs_.push_back(e_.thisptr)
2645
2643
2646
2644
self ._init(args_, outs_)
2647
2645
@@ -2659,14 +2657,14 @@ cdef class _Lambdify(object):
2659
2657
cpdef eval_real(self , double [::1 ] inp, double [::1 ] out):
2660
2658
if inp.size != self .args_size:
2661
2659
raise ValueError (" Size of inp incompatible with number of args." )
2662
- if out.size != self .out_size :
2660
+ if out.size != self .tot_out_size :
2663
2661
raise ValueError (" Size of out incompatible with number of exprs." )
2664
2662
self .unsafe_real(inp, out)
2665
2663
2666
2664
cpdef eval_complex(self , double complex [::1 ] inp, double complex [::1 ] out):
2667
2665
if inp.size != self .args_size:
2668
2666
raise ValueError (" Size of inp incompatible with number of args." )
2669
- if out.size != self .out_size :
2667
+ if out.size != self .tot_out_size :
2670
2668
raise ValueError (" Size of out incompatible with number of exprs." )
2671
2669
self .unsafe_complex(inp, out)
2672
2670
@@ -2677,11 +2675,18 @@ cdef class _Lambdify(object):
2677
2675
inp: array_like
2678
2676
last dimension must be equal to number of arguments.
2679
2677
out: array_like or None (default)
2680
- Allows for for low-overhead use (output argument), if None:
2681
- an output container will be allocated (NumPy ndarray or
2682
- cython.view.array)
2678
+ Allows for for low-overhead use (output argument, must be contiguous).
2679
+ If ``None``: an output container will be allocated (NumPy ndarray or
2680
+ cython.view.array). If ``len(exprs) > 0`` output is found in the corresponding
2681
+ order. Note that ``out`` is not reshaped.
2683
2682
use_numpy: bool (default: None)
2684
2683
None -> use numpy if available
2684
+
2685
+ Returns
2686
+ -------
2687
+ If ``len(exprs) == 1``: ``numpy.ndarray`` or ``cython.view.array``, otherwise
2688
+ a tuple of such.
2689
+
2685
2690
"""
2686
2691
cdef cython.view.array tmp
2687
2692
cdef double [::1 ] real_out_view, real_inp_view
@@ -2693,14 +2698,15 @@ cdef class _Lambdify(object):
2693
2698
except TypeError :
2694
2699
inp = tuple (inp)
2695
2700
inp_shape = (len (inp),)
2696
- inp_size = reduce (mul, inp_shape)
2701
+ inp_size = long ( reduce (mul, inp_shape) )
2697
2702
if inp_size % self .args_size != 0 :
2698
2703
raise ValueError (" Broadcasting failed" )
2699
2704
nbroadcast = inp_size // self .args_size
2700
2705
if nbroadcast > 1 and self .args_size == 1 and inp_shape[- 1 ] != 1 : # Implicit reshape
2701
2706
inp_shape = inp_shape + (1 ,)
2702
- new_out_shape = inp_shape[:- 1 ] + self .out_shape
2703
- new_out_size = nbroadcast * self .out_size
2707
+ new_out_shapes = [inp_shape[:- 1 ] + out_shape for out_shape in self .out_shapes]
2708
+ new_out_sizes = [nbroadcast* out_size for out_size in self .out_sizes]
2709
+ new_tot_out_size = nbroadcast * self .tot_out_size
2704
2710
if use_numpy is None :
2705
2711
try :
2706
2712
import numpy as np
@@ -2730,16 +2736,17 @@ cdef class _Lambdify(object):
2730
2736
if out is None :
2731
2737
# allocate output container
2732
2738
if use_numpy:
2733
- out = np.empty(new_out_size , dtype = numpy_dtype)
2739
+ out = np.empty(new_tot_out_size , dtype = numpy_dtype)
2734
2740
else :
2735
2741
if self .real:
2736
- out = cython.view.array((new_out_size ,),
2742
+ out = cython.view.array((new_tot_out_size ,),
2737
2743
sizeof(double ), format = ' d' )
2738
2744
else :
2739
- out = cython.view.array((new_out_size ,),
2745
+ out = cython.view.array((new_tot_out_size ,),
2740
2746
sizeof(double complex ), format = ' Zd' )
2741
- reshape_out = len (new_out_shape ) > 1
2747
+ reshape_outs = len (new_out_shapes[ 0 ] ) > 1
2742
2748
else :
2749
+ reshape_outs = False
2743
2750
if use_numpy:
2744
2751
try :
2745
2752
out_dtype = out.dtype
@@ -2748,55 +2755,63 @@ cdef class _Lambdify(object):
2748
2755
out_dtype = out.dtype
2749
2756
if out_dtype != numpy_dtype:
2750
2757
raise TypeError (" Output array is of incorrect type" )
2751
- if out.size < new_out_size :
2758
+ if out.size < new_tot_out_size :
2752
2759
raise ValueError (" Incompatible size of output argument" )
2753
2760
if not out.flags[' C_CONTIGUOUS' ]:
2754
2761
raise ValueError (" Output argument needs to be C-contiguous" )
2755
- for idx, length in enumerate (out.shape[- len (self .out_shape)::- 1 ]):
2756
- if length < self .out_shape[- idx]:
2757
- raise ValueError (" Incompatible shape of output argument" )
2762
+ if self .n_exprs == 1 :
2763
+ for idx, length in enumerate (out.shape[- len (self .out_shapes[0 ])::- 1 ]):
2764
+ if length < self .out_shapes[0 ][- idx]:
2765
+ raise ValueError (" Incompatible shape of output argument" )
2758
2766
if not out.flags[' WRITEABLE' ]:
2759
2767
raise ValueError (" Output argument needs to be writeable" )
2760
2768
if out.ndim > 1 :
2761
2769
out = out.ravel()
2762
- reshape_out = True
2763
- else :
2764
- # The user passed a 1-dimensional output argument,
2765
- # we trust the user to do the right thing.
2766
- reshape_out = False
2767
2770
else :
2768
2771
out = with_buffer(out, self .real)
2769
- reshape_out = False # only reshape if we allocated.
2770
2772
for idx in range (nbroadcast):
2771
2773
if self .real:
2772
2774
real_inp_view = inp # slicing cython.view.array does not give a memview
2773
2775
real_out_view = out
2774
2776
self .unsafe_real(real_inp_view[idx* self .args_size:(idx+ 1 )* self .args_size],
2775
- real_out_view[idx* self .out_size :(idx+ 1 )* self .out_size ])
2777
+ real_out_view[idx* self .tot_out_size :(idx+ 1 )* self .tot_out_size ])
2776
2778
else :
2777
2779
complex_inp_view = inp
2778
2780
complex_out_view = out
2779
2781
self .unsafe_complex(complex_inp_view[idx* self .args_size:(idx+ 1 )* self .args_size],
2780
- complex_out_view[idx* self .out_size:(idx+ 1 )* self .out_size])
2782
+ complex_out_view[idx* self .tot_out_size:(idx+ 1 )* self .tot_out_size])
2783
+
2784
+ if use_numpy and reshape_outs:
2785
+ out = out.reshape((nbroadcast, self .tot_out_size))
2786
+ result = [out[:, self .accum_out_sizes[idx]:self .accum_out_sizes[idx+ 1 ]].reshape(new_out_shapes[idx])
2787
+ for idx in range (self .n_exprs)]
2788
+ elif reshape_outs:
2789
+ result = []
2790
+ for idx in range (self .n_exprs):
2791
+ if self .real:
2792
+ tmp = cython.view.array(new_out_shapes[idx],
2793
+ sizeof(double ), format = ' d' )
2794
+ real_out_view = out
2795
+ memcpy(< double * > tmp.data, & real_out_view[self .accum_out_sizes[idx]],
2796
+ sizeof(double )* new_out_sizes[idx])
2797
+ result.append(tmp)
2798
+ else :
2799
+ tmp = cython.view.array(new_out_shapes[idx],
2800
+ sizeof(double complex ), format = ' Zd' )
2801
+ cmplx_out_view = out
2802
+ memcpy(< double complex * > tmp.data, & cmplx_out_view[self .accum_out_sizes[idx]],
2803
+ sizeof(double complex )* new_out_sizes[idx])
2804
+ result.append(tmp)
2805
+ else :
2806
+ result = [out]
2807
+
2808
+ if self .n_exprs == 1 :
2809
+ result = result[0 ]
2810
+ else :
2811
+ result = tuple (result)
2812
+
2813
+ return result
2781
2814
2782
- if use_numpy and reshape_out:
2783
- out = out.reshape(new_out_shape)
2784
- elif reshape_out:
2785
- if self .real:
2786
- tmp = cython.view.array(new_out_shape,
2787
- sizeof(double ), format = ' d' )
2788
- real_out_view = out
2789
- memcpy(< double * > tmp.data, & real_out_view[0 ],
2790
- sizeof(double )* new_out_size)
2791
- out = tmp
2792
- else :
2793
- tmp = cython.view.array(new_out_shape,
2794
- sizeof(double complex ), format = ' Zd' )
2795
- cmplx_out_view = tmp
2796
- memcpy(< double complex * > tmp.data, & cmplx_out_view[0 ],
2797
- sizeof(double complex )* new_out_size)
2798
- out = tmp
2799
- return out
2800
2815
2801
2816
cdef class LambdaDouble(_Lambdify):
2802
2817
@@ -2831,17 +2846,18 @@ IF HAVE_SYMENGINE_LLVM:
2831
2846
self .lambda_double[0 ].call(& out[0 ], & inp[0 ])
2832
2847
2833
2848
2834
- def Lambdify (args , exprs , bool real = True , backend = " lambda" ):
2849
+ def Lambdify (args , * exprs , bool real = True , backend = " lambda" ):
2835
2850
if backend == " llvm" :
2836
2851
IF HAVE_SYMENGINE_LLVM:
2837
- return LLVMDouble(args, exprs, real)
2852
+ return LLVMDouble(args, * exprs, real = real)
2838
2853
ELSE :
2839
2854
raise ValueError (""" llvm backend is chosen, but symengine is not compiled
2840
2855
with llvm support.""" )
2841
2856
2842
- return LambdaDouble(args, exprs, real)
2857
+ return LambdaDouble(args, * exprs, real = real)
2843
2858
2844
- def LambdifyCSE (args , exprs , real = True , cse = None , concatenate = None ):
2859
+
2860
+ def LambdifyCSE (args , *exprs , real = True , cse = None , concatenate = None ):
2845
2861
"""
2846
2862
Analogous with Lambdify but performs common subexpression elimination
2847
2863
internally. See docstring of Lambdify.
@@ -2863,10 +2879,23 @@ def LambdifyCSE(args, exprs, real=True, cse=None, concatenate=None):
2863
2879
if concatenate is None :
2864
2880
from numpy import concatenate
2865
2881
from sympy import sympify as ssympify
2866
- subs, new_exprs = cse([ssympify(expr) for expr in exprs])
2882
+ flat_exprs = list (itertools.chain(* map (ravel, exprs)))
2883
+ subs, flat_new_exprs = cse([ssympify(expr) for expr in flat_exprs])
2867
2884
if subs:
2868
2885
cse_symbs, cse_exprs = zip (* subs)
2869
- lmb = Lambdify(tuple (args) + cse_symbs, new_exprs, real = real)
2886
+ new_exprs = []
2887
+ n_taken = 0
2888
+ for expr in exprs:
2889
+ shape = get_shape(exprs)
2890
+ size = long (reduce (mul, shape))
2891
+ if len (shape) == 1 :
2892
+ new_exprs.append(flat_new_exprs[n_taken:n_taken+ size])
2893
+ elif len (shape) == 2 :
2894
+ new_exprs.append(DenseMatrix(shape[0 ], shape[1 ], flat_new_exprs[n_taken:n_taken+ size]))
2895
+ else :
2896
+ raise NotImplementedError (" n-dimensional output not yet supported." )
2897
+ n_taken += size
2898
+ lmb = Lambdify(tuple (args) + cse_symbs, * new_exprs, real = real)
2870
2899
cse_lambda = Lambdify(args, cse_exprs, real = real)
2871
2900
2872
2901
def cb (inp , out = None , **kwargs ):
@@ -2876,7 +2905,7 @@ def LambdifyCSE(args, exprs, real=True, cse=None, concatenate=None):
2876
2905
2877
2906
return cb
2878
2907
else :
2879
- return Lambdify(args, exprs, real = real)
2908
+ return Lambdify(args, * exprs, real = real)
2880
2909
2881
2910
2882
2911
def has_symbol (obj , symbol = None ):
0 commit comments