Skip to content

Commit a1987d8

Browse files
committed
Stub for heterogeneous output in Lambdify
1 parent eb13330 commit a1987d8

File tree

2 files changed

+140
-81
lines changed

2 files changed

+140
-81
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 110 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -2580,8 +2580,10 @@ cdef class _Lambdify(object):
25802580
Parameters
25812581
----------
25822582
args: iterable of Symbols
2583-
exprs: array_like of expressions
2583+
\*exprs: array_like of expressions
25842584
the shape of exprs is preserved
2585+
real : bool
2586+
Whether datatype is ``double`` (``double complex`` otherwise).
25852587
25862588
Returns
25872589
-------
@@ -2599,49 +2601,45 @@ cdef class _Lambdify(object):
25992601
[ 9., 24.]
26002602
26012603
"""
2602-
cdef size_t args_size, out_size
2603-
cdef tuple out_shape
2604+
cdef size_t args_size, tot_out_size
2605+
cdef list out_shapes
2606+
cdef vector[int] out_sizes, accum_out_sizes
26042607
cdef readonly bool real
2608+
cdef readonly int n_exprs
26052609

2606-
def __cinit__(self, args, exprs, bool real=True):
2610+
def __cinit__(self, args, *exprs, bool real=True):
26072611
self.real = real
2608-
self.out_shape = get_shape(exprs)
2612+
self.out_shapes = [get_shape(expr) for expr in exprs]
2613+
self.n_exprs = len(exprs)
26092614
self.args_size = _size(args)
2610-
self.out_size = reduce(mul, self.out_shape)
2615+
self.out_sizes = [reduce(mul, shape) for shape in self.out_shapes]
2616+
self.accum_out_sizes = [sum(self.out_sizes[:i]) for i in range(self.n_exprs + 1)]
2617+
self.tot_out_size = sum(self.out_sizes)
26112618

2612-
2613-
def __init__(self, args, exprs, bool real=True):
2619+
def __init__(self, args, *exprs, bool real=True):
26142620
cdef:
26152621
Basic e_
26162622
size_t ri, ci, nr, nc
26172623
symengine.MatrixBase *mtx
26182624
RCP[const symengine.Basic] b_
26192625
symengine.vec_basic args_, outs_
26202626

2621-
if isinstance(args, DenseMatrix):
2622-
nr = args.nrows()
2623-
nc = args.ncols()
2624-
mtx = (<DenseMatrix>args).thisptr
2625-
for ri in range(nr):
2626-
for ci in range(nc):
2627-
args_.push_back(deref(mtx).get(ri, ci))
2628-
else:
2629-
for e in args:
2630-
e_ = sympify(e)
2631-
args_.push_back(e_.thisptr)
2632-
2633-
if isinstance(exprs, DenseMatrix):
2634-
nr = exprs.nrows()
2635-
nc = exprs.ncols()
2636-
mtx = (<DenseMatrix>exprs).thisptr
2637-
for ri in range(nr):
2638-
for ci in range(nc):
2639-
b_ = deref(mtx).get(ri, ci)
2640-
outs_.push_back(b_)
2641-
else:
2642-
for e in ravel(exprs):
2643-
e_ = sympify(e)
2644-
outs_.push_back(e_.thisptr)
2627+
for e in args:
2628+
e_ = sympify(e)
2629+
args_.push_back(e_.thisptr)
2630+
for curr_expr in exprs:
2631+
if isinstance(curr_expr, DenseMatrix):
2632+
nr = curr_expr.nrows()
2633+
nc = curr_expr.ncols()
2634+
mtx = (<DenseMatrix>curr_expr).thisptr
2635+
for ri in range(nr):
2636+
for ci in range(nc):
2637+
b_ = deref(mtx).get(ri, ci)
2638+
outs_.push_back(b_)
2639+
else:
2640+
for e in ravel(curr_expr):
2641+
e_ = sympify(e)
2642+
outs_.push_back(e_.thisptr)
26452643

26462644
self._init(args_, outs_)
26472645

@@ -2659,14 +2657,14 @@ cdef class _Lambdify(object):
26592657
cpdef eval_real(self, double[::1] inp, double[::1] out):
26602658
if inp.size != self.args_size:
26612659
raise ValueError("Size of inp incompatible with number of args.")
2662-
if out.size != self.out_size:
2660+
if out.size != self.tot_out_size:
26632661
raise ValueError("Size of out incompatible with number of exprs.")
26642662
self.unsafe_real(inp, out)
26652663

26662664
cpdef eval_complex(self, double complex[::1] inp, double complex[::1] out):
26672665
if inp.size != self.args_size:
26682666
raise ValueError("Size of inp incompatible with number of args.")
2669-
if out.size != self.out_size:
2667+
if out.size != self.tot_out_size:
26702668
raise ValueError("Size of out incompatible with number of exprs.")
26712669
self.unsafe_complex(inp, out)
26722670

@@ -2677,11 +2675,18 @@ cdef class _Lambdify(object):
26772675
inp: array_like
26782676
last dimension must be equal to number of arguments.
26792677
out: array_like or None (default)
2680-
Allows for for low-overhead use (output argument), if None:
2681-
an output container will be allocated (NumPy ndarray or
2682-
cython.view.array)
2678+
Allows for for low-overhead use (output argument, must be contiguous).
2679+
If ``None``: an output container will be allocated (NumPy ndarray or
2680+
cython.view.array). If ``len(exprs) > 0`` output is found in the corresponding
2681+
order. Note that ``out`` is not reshaped.
26832682
use_numpy: bool (default: None)
26842683
None -> use numpy if available
2684+
2685+
Returns
2686+
-------
2687+
If ``len(exprs) == 1``: ``numpy.ndarray`` or ``cython.view.array``, otherwise
2688+
a tuple of such.
2689+
26852690
"""
26862691
cdef cython.view.array tmp
26872692
cdef double[::1] real_out_view, real_inp_view
@@ -2693,14 +2698,15 @@ cdef class _Lambdify(object):
26932698
except TypeError:
26942699
inp = tuple(inp)
26952700
inp_shape = (len(inp),)
2696-
inp_size = reduce(mul, inp_shape)
2701+
inp_size = long(reduce(mul, inp_shape))
26972702
if inp_size % self.args_size != 0:
26982703
raise ValueError("Broadcasting failed")
26992704
nbroadcast = inp_size // self.args_size
27002705
if nbroadcast > 1 and self.args_size == 1 and inp_shape[-1] != 1: # Implicit reshape
27012706
inp_shape = inp_shape + (1,)
2702-
new_out_shape = inp_shape[:-1] + self.out_shape
2703-
new_out_size = nbroadcast * self.out_size
2707+
new_out_shapes = [inp_shape[:-1] + out_shape for out_shape in self.out_shapes]
2708+
new_out_sizes = [nbroadcast*out_size for out_size in self.out_sizes]
2709+
new_tot_out_size = nbroadcast * self.tot_out_size
27042710
if use_numpy is None:
27052711
try:
27062712
import numpy as np
@@ -2730,16 +2736,17 @@ cdef class _Lambdify(object):
27302736
if out is None:
27312737
# allocate output container
27322738
if use_numpy:
2733-
out = np.empty(new_out_size, dtype=numpy_dtype)
2739+
out = np.empty(new_tot_out_size, dtype=numpy_dtype)
27342740
else:
27352741
if self.real:
2736-
out = cython.view.array((new_out_size,),
2742+
out = cython.view.array((new_tot_out_size,),
27372743
sizeof(double), format='d')
27382744
else:
2739-
out = cython.view.array((new_out_size,),
2745+
out = cython.view.array((new_tot_out_size,),
27402746
sizeof(double complex), format='Zd')
2741-
reshape_out = len(new_out_shape) > 1
2747+
reshape_outs = len(new_out_shapes[0]) > 1
27422748
else:
2749+
reshape_outs = False
27432750
if use_numpy:
27442751
try:
27452752
out_dtype = out.dtype
@@ -2748,55 +2755,63 @@ cdef class _Lambdify(object):
27482755
out_dtype = out.dtype
27492756
if out_dtype != numpy_dtype:
27502757
raise TypeError("Output array is of incorrect type")
2751-
if out.size < new_out_size:
2758+
if out.size < new_tot_out_size:
27522759
raise ValueError("Incompatible size of output argument")
27532760
if not out.flags['C_CONTIGUOUS']:
27542761
raise ValueError("Output argument needs to be C-contiguous")
2755-
for idx, length in enumerate(out.shape[-len(self.out_shape)::-1]):
2756-
if length < self.out_shape[-idx]:
2757-
raise ValueError("Incompatible shape of output argument")
2762+
if self.n_exprs == 1:
2763+
for idx, length in enumerate(out.shape[-len(self.out_shapes[0])::-1]):
2764+
if length < self.out_shapes[0][-idx]:
2765+
raise ValueError("Incompatible shape of output argument")
27582766
if not out.flags['WRITEABLE']:
27592767
raise ValueError("Output argument needs to be writeable")
27602768
if out.ndim > 1:
27612769
out = out.ravel()
2762-
reshape_out = True
2763-
else:
2764-
# The user passed a 1-dimensional output argument,
2765-
# we trust the user to do the right thing.
2766-
reshape_out = False
27672770
else:
27682771
out = with_buffer(out, self.real)
2769-
reshape_out = False # only reshape if we allocated.
27702772
for idx in range(nbroadcast):
27712773
if self.real:
27722774
real_inp_view = inp # slicing cython.view.array does not give a memview
27732775
real_out_view = out
27742776
self.unsafe_real(real_inp_view[idx*self.args_size:(idx+1)*self.args_size],
2775-
real_out_view[idx*self.out_size:(idx+1)*self.out_size])
2777+
real_out_view[idx*self.tot_out_size:(idx+1)*self.tot_out_size])
27762778
else:
27772779
complex_inp_view = inp
27782780
complex_out_view = out
27792781
self.unsafe_complex(complex_inp_view[idx*self.args_size:(idx+1)*self.args_size],
2780-
complex_out_view[idx*self.out_size:(idx+1)*self.out_size])
2782+
complex_out_view[idx*self.tot_out_size:(idx+1)*self.tot_out_size])
2783+
2784+
if use_numpy and reshape_outs:
2785+
out = out.reshape((nbroadcast, self.tot_out_size))
2786+
result = [out[:, self.accum_out_sizes[idx]:self.accum_out_sizes[idx+1]].reshape(new_out_shapes[idx])
2787+
for idx in range(self.n_exprs)]
2788+
elif reshape_outs:
2789+
result = []
2790+
for idx in range(self.n_exprs):
2791+
if self.real:
2792+
tmp = cython.view.array(new_out_shapes[idx],
2793+
sizeof(double), format='d')
2794+
real_out_view = out
2795+
memcpy(<double *>tmp.data, &real_out_view[self.accum_out_sizes[idx]],
2796+
sizeof(double)*new_out_sizes[idx])
2797+
result.append(tmp)
2798+
else:
2799+
tmp = cython.view.array(new_out_shapes[idx],
2800+
sizeof(double complex), format='Zd')
2801+
cmplx_out_view = out
2802+
memcpy(<double complex*>tmp.data, &cmplx_out_view[self.accum_out_sizes[idx]],
2803+
sizeof(double complex)*new_out_sizes[idx])
2804+
result.append(tmp)
2805+
else:
2806+
result = [out]
2807+
2808+
if self.n_exprs == 1:
2809+
result = result[0]
2810+
else:
2811+
result = tuple(result)
2812+
2813+
return result
27812814

2782-
if use_numpy and reshape_out:
2783-
out = out.reshape(new_out_shape)
2784-
elif reshape_out:
2785-
if self.real:
2786-
tmp = cython.view.array(new_out_shape,
2787-
sizeof(double), format='d')
2788-
real_out_view = out
2789-
memcpy(<double *>tmp.data, &real_out_view[0],
2790-
sizeof(double)*new_out_size)
2791-
out = tmp
2792-
else:
2793-
tmp = cython.view.array(new_out_shape,
2794-
sizeof(double complex), format='Zd')
2795-
cmplx_out_view = tmp
2796-
memcpy(<double complex*>tmp.data, &cmplx_out_view[0],
2797-
sizeof(double complex)*new_out_size)
2798-
out = tmp
2799-
return out
28002815

28012816
cdef class LambdaDouble(_Lambdify):
28022817

@@ -2831,17 +2846,18 @@ IF HAVE_SYMENGINE_LLVM:
28312846
self.lambda_double[0].call(&out[0], &inp[0])
28322847

28332848

2834-
def Lambdify(args, exprs, bool real=True, backend="lambda"):
2849+
def Lambdify(args, *exprs, bool real=True, backend="lambda"):
28352850
if backend == "llvm":
28362851
IF HAVE_SYMENGINE_LLVM:
2837-
return LLVMDouble(args, exprs, real)
2852+
return LLVMDouble(args, *exprs, real=real)
28382853
ELSE:
28392854
raise ValueError("""llvm backend is chosen, but symengine is not compiled
28402855
with llvm support.""")
28412856

2842-
return LambdaDouble(args, exprs, real)
2857+
return LambdaDouble(args, *exprs, real=real)
28432858

2844-
def LambdifyCSE(args, exprs, real=True, cse=None, concatenate=None):
2859+
2860+
def LambdifyCSE(args, *exprs, real=True, cse=None, concatenate=None):
28452861
"""
28462862
Analogous with Lambdify but performs common subexpression elimination
28472863
internally. See docstring of Lambdify.
@@ -2863,10 +2879,23 @@ def LambdifyCSE(args, exprs, real=True, cse=None, concatenate=None):
28632879
if concatenate is None:
28642880
from numpy import concatenate
28652881
from sympy import sympify as ssympify
2866-
subs, new_exprs = cse([ssympify(expr) for expr in exprs])
2882+
flat_exprs = list(itertools.chain(*map(ravel, exprs)))
2883+
subs, flat_new_exprs = cse([ssympify(expr) for expr in flat_exprs])
28672884
if subs:
28682885
cse_symbs, cse_exprs = zip(*subs)
2869-
lmb = Lambdify(tuple(args) + cse_symbs, new_exprs, real=real)
2886+
new_exprs = []
2887+
n_taken = 0
2888+
for expr in exprs:
2889+
shape = get_shape(exprs)
2890+
size = long(reduce(mul, shape))
2891+
if len(shape) == 1:
2892+
new_exprs.append(flat_new_exprs[n_taken:n_taken+size])
2893+
elif len(shape) == 2:
2894+
new_exprs.append(DenseMatrix(shape[0], shape[1], flat_new_exprs[n_taken:n_taken+size]))
2895+
else:
2896+
raise NotImplementedError("n-dimensional output not yet supported.")
2897+
n_taken += size
2898+
lmb = Lambdify(tuple(args) + cse_symbs, *new_exprs, real=real)
28702899
cse_lambda = Lambdify(args, cse_exprs, real=real)
28712900

28722901
def cb(inp, out=None, **kwargs):
@@ -2876,7 +2905,7 @@ def LambdifyCSE(args, exprs, real=True, cse=None, concatenate=None):
28762905

28772906
return cb
28782907
else:
2879-
return Lambdify(args, exprs, real=real)
2908+
return Lambdify(args, *exprs, real=real)
28802909

28812910

28822911
def has_symbol(obj, symbol=None):

symengine/tests/test_lambdify.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,3 +527,33 @@ def test_more_than_255_args():
527527
ref[i, 1] = q + n*i + r
528528
ref[:, 2] = -99
529529
assert np.allclose(out, ref)
530+
531+
532+
def _Lambdify_heterogeneous_output(Lambdify):
533+
if not HAVE_NUMPY: # nosetests work-around
534+
return
535+
x, y = se.symbols('x, y')
536+
args = se.DenseMatrix(2, 1, [x, y])
537+
v = se.DenseMatrix(2, 1, [x**3 * y, (x+1)*(y+1)])
538+
jac = v.jacobian(args)
539+
exprs = [jac, x+y, v, x*y]
540+
lmb = se.Lambdify(args, exprs)
541+
inp0 = 7, 11
542+
inp1 = 8, 13
543+
inp2 = 5, 9
544+
inp = np.array([inp0, inp1, inp2])
545+
o_j, o_xpy, o_v, o_xty = lmb(inp, out)
546+
for idx, (X, Y) in enumerate([inp0, inp1, inp2]):
547+
assert np.allclose(o_j[idx, ...], [[3 * X**2 * Y, X**3],
548+
[Y + 1, X + 1]])
549+
assert np.allclose(o_xpy[idx, ...], [X+Y])
550+
assert np.allclose(o_v[idx, ...], [X**3 * Y, (X+1)*(Y+1)])
551+
assert np.allclose(o_xty[idx, ...], [X*Y])
552+
553+
554+
def test_Lambdify_heterogeneous_output():
555+
_Lambdify_heterogeneous_output(se.Lambdify)
556+
557+
558+
def test_LambdifyCSE_heterogeneous_output():
559+
_Lambdify_heterogeneous_output(se.LambdifyCSE)

0 commit comments

Comments
 (0)