Skip to content

Commit 8376e3a

Browse files
committed
Remove order as kwarg from Lambdify.__call__
1 parent 93d2682 commit 8376e3a

File tree

2 files changed

+36
-33
lines changed

2 files changed

+36
-33
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4365,7 +4365,7 @@ cdef class _Lambdify(object):
43654365
cdef list out_shapes
43664366
cdef readonly bint real
43674367
cdef readonly int n_exprs
4368-
cdef readonly str order
4368+
cdef public str order
43694369
cdef vector[int] accum_out_sizes
43704370
cdef object numpy_dtype
43714371

@@ -4440,7 +4440,7 @@ cdef class _Lambdify(object):
44404440
raise ValueError("Size of out incompatible with number of exprs.")
44414441
self.unsafe_complex(inp, out)
44424442

4443-
def __call__(self, inp, *, out=None, order=None):
4443+
def __call__(self, inp, *, out=None):
44444444
"""
44454445
Parameters
44464446
----------
@@ -4469,9 +4469,7 @@ cdef class _Lambdify(object):
44694469
tuple inp_shape
44704470
double[::1] real_out, real_inp
44714471
double complex[::1] cmplx_out, cmplx_inp
4472-
if order is None:
4473-
order = self.order
4474-
if order not in ('C', 'F'):
4472+
if self.order not in ('C', 'F'):
44754473
raise NotImplementedError("Only C & F order supported for now.")
44764474

44774475
try:
@@ -4480,22 +4478,22 @@ cdef class _Lambdify(object):
44804478
inp = np.fromiter(inp, dtype=self.numpy_dtype)
44814479

44824480
if self.real:
4483-
real_inp = np.ascontiguousarray(inp.ravel(order=order))
4481+
real_inp = np.ascontiguousarray(inp.ravel(order=self.order))
44844482
else:
4485-
cmplx_inp = np.ascontiguousarray(inp.ravel(order=order))
4483+
cmplx_inp = np.ascontiguousarray(inp.ravel(order=self.order))
44864484

44874485
if inp.size < self.args_size or inp.size % self.args_size != 0:
44884486
raise ValueError("Broadcasting failed (input/arg size mismatch)")
44894487
nbroadcast = inp.size // self.args_size
44904488

44914489
if inp.ndim > 1:
44924490
if self.args_size > 1:
4493-
if order == 'C':
4491+
if self.order == 'C':
44944492
if inp.shape[inp.ndim-1] != self.args_size:
44954493
raise ValueError(("C order implies last dim (%d) == len(args)"
44964494
" (%d)") % (inp.shape[inp.ndim-1], self.args_size))
44974495
extra_dim = inp.shape[:inp.ndim-1]
4498-
elif order == 'F':
4496+
elif self.order == 'F':
44994497
if inp.shape[0] != self.args_size:
45004498
raise ValueError("F order implies first dim (%d) == len(args) (%d)"
45014499
% (inp.shape[0], self.args_size))
@@ -4507,35 +4505,37 @@ cdef class _Lambdify(object):
45074505
extra_dim = (nbroadcast,) # special case
45084506
else:
45094507
extra_dim = ()
4510-
extra_left = extra_dim if order == 'C' else ()
4511-
extra_right = () if order == 'C' else extra_dim
4508+
extra_left = extra_dim if self.order == 'C' else ()
4509+
extra_right = () if self.order == 'C' else extra_dim
45124510
new_out_shapes = [extra_left + out_shape + extra_right
45134511
for out_shape in self.out_shapes]
45144512

45154513
new_tot_out_size = nbroadcast * self.tot_out_size
45164514
if out is None:
4517-
out = np.empty(new_tot_out_size, dtype=self.numpy_dtype, order=order)
4515+
out = np.empty(new_tot_out_size, dtype=self.numpy_dtype, order=self.order)
45184516
else:
45194517
if out.size < new_tot_out_size:
45204518
raise ValueError("Incompatible size of output argument")
45214519
if out.ndim > 1:
4522-
if order == 'C' and not out.flags['C_CONTIGUOUS']:
4523-
raise ValueError("Output argument needs to be C-contiguous")
4524-
elif order == 'F' and not out.flags['F_CONTIGUOUS']:
4525-
raise ValueError("Output argument needs to be F-contiguous")
45264520
if len(self.out_shapes) > 1:
45274521
raise ValueError("output array with ndim > 1 assumes one output")
45284522
out_shape, = self.out_shapes
4529-
if order == 'C' and out.shape[-len(out_shape):] != tuple(out_shape):
4530-
raise ValueError("shape mismatch for output array")
4531-
elif order == 'F' and out.shape[:len(out_shape)] != tuple(out_shape):
4532-
raise ValueError("shape mismatch for output array")
4523+
if self.order == 'C':
4524+
if not out.flags['C_CONTIGUOUS']:
4525+
raise ValueError("Output argument needs to be C-contiguous")
4526+
if out.shape[-len(out_shape):] != tuple(out_shape):
4527+
raise ValueError("shape mismatch for output array")
4528+
elif self.order == 'F':
4529+
if not out.flags['F_CONTIGUOUS']:
4530+
raise ValueError("Output argument needs to be F-contiguous")
4531+
if out.shape[:len(out_shape)] != tuple(out_shape):
4532+
raise ValueError("shape mismatch for output array")
45334533
else:
45344534
if not out.flags['F_CONTIGUOUS']: # or C_CONTIGUOUS (ndim <= 1)
45354535
raise ValueError("Output array need to be contiguous")
45364536
if not out.flags['WRITEABLE']:
45374537
raise ValueError("Output argument needs to be writeable")
4538-
out = out.ravel(order=order)
4538+
out = out.ravel(order=self.order)
45394539

45404540
if self.real:
45414541
real_out = out
@@ -4551,13 +4551,13 @@ cdef class _Lambdify(object):
45514551
self.unsafe_complex(cmplx_inp, cmplx_out,
45524552
idx*self.args_size, idx*self.tot_out_size)
45534553

4554-
if order == 'C':
4554+
if self.order == 'C':
45554555
out = out.reshape((nbroadcast, self.tot_out_size), order='C')
45564556
result = [
45574557
out[:, self.accum_out_sizes[idx]:self.accum_out_sizes[idx+1]].reshape(
45584558
new_out_shapes[idx], order='C') for idx in range(self.n_exprs)
45594559
]
4560-
elif order == 'F':
4560+
elif self.order == 'F':
45614561
out = out.reshape((self.tot_out_size, nbroadcast), order='F')
45624562
result = [
45634563
out[self.accum_out_sizes[idx]:self.accum_out_sizes[idx+1], :].reshape(
@@ -4657,16 +4657,15 @@ def LambdifyCSE(args, *exprs, cse=None, order='C', **kwargs):
46574657
new_lmb = Lambdify(tuple(_args) + cse_symbs, *new_exprs, order=order, **kwargs)
46584658
cse_lambda = Lambdify(_args, [ce.xreplace(explicit_subs) for ce in cse_exprs], **kwargs)
46594659
def cb(inp, *, out=None, **kw):
4660-
_order = kw.pop('order', order)
46614660
_inp = np.asanyarray(inp)
4662-
cse_vals = cse_lambda(_inp, order=_order, **kw)
4661+
cse_vals = cse_lambda(_inp, **kw)
46634662
if order == 'C':
46644663
new_inp = np.concatenate((_inp[(Ellipsis,) + (np.newaxis,)*(cse_vals.ndim - _inp.ndim)],
46654664
cse_vals), axis=-1)
46664665
else:
46674666
new_inp = np.concatenate((_inp[(np.newaxis,)*(cse_vals.ndim - _inp.ndim) + (Ellipsis,)],
46684667
cse_vals), axis=0)
4669-
return new_lmb(new_inp, out=out, order=_order, **kw)
4668+
return new_lmb(new_inp, out=out, **kw)
46704669
return cb
46714670
else:
46724671
return Lambdify(args, *exprs, **kwargs)

symengine/tests/test_lambdify.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,14 +147,16 @@ def test_numpy_array_out_exceptions():
147147

148148
all_right_broadcast_C = np.empty((4, len(exprs)), order='C')
149149
inp_bcast = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
150-
lmb(np.array(inp_bcast), out=all_right_broadcast_C, order='C')
151-
152-
all_right_broadcast_F = np.empty((len(exprs), 4), order='F')
153-
lmb(np.array(np.array(inp_bcast).T), out=all_right_broadcast_F, order='F')
150+
lmb(np.array(inp_bcast), out=all_right_broadcast_C)
154151

155152
noncontig_broadcast = np.empty((4, len(exprs), 3)).transpose((1, 2, 0))
156153
raises(ValueError, lambda: (lmb(inp_bcast, out=noncontig_broadcast)))
157154

155+
all_right_broadcast_F = np.empty((len(exprs), 4), order='F')
156+
lmb.order = 'F'
157+
lmb(np.array(np.array(inp_bcast).T), out=all_right_broadcast_F)
158+
159+
158160

159161
@unittest.skipUnless(have_numpy, "Numpy not installed")
160162
def test_broadcast():
@@ -613,7 +615,9 @@ def test_Lambdify_gh174():
613615
out1 = lmb1(3)
614616
assert out1.shape == (3, 1)
615617
assert np.all(out1 == [[3], [9], [27]])
616-
out1a = lmb1([2, 3], order='F') # another dimension
618+
assert lmb1([2, 3]).shape == (2, 3, 1)
619+
lmb1.order = 'F' # change order
620+
out1a = lmb1([2, 3])
617621
assert out1a.shape == (3, 1, 2)
618622
ref1a_squeeze = [[2, 3],
619623
[4, 9],
@@ -661,7 +665,7 @@ def _mtx3(_x, _y):
661665
assert out3c[0].shape == (5,)
662666
assert out3c[1].shape == (5, 4, 3)
663667
assert out3c[2].shape == (5, 3, 1) # user can apply numpy.squeeze if they want to.
664-
for a, b in zip(out3c, lmb3c(np.ravel(inp3c, order='C'))):
668+
for a, b in zip(out3c, lmb3c(np.ravel(inp3c))):
665669
assert np.all(a == b)
666670

667671
out3f = lmb3f(inp3f)
@@ -779,5 +783,5 @@ def _mtx(_x, _y):
779783
assert out3b.shape == (3, 2, 4)
780784
for i in range(4):
781785
assert np.all(out3b[..., i] == _mtx(*inp3b[2*i:2*(i+1)]))
782-
raises(ValueError, lambda: lmb3(inp3b.reshape((4, 2), order='F')))
786+
raises(ValueError, lambda: lmb3(inp3b.reshape((4, 2))))
783787
raises(ValueError, lambda: lmb3(inp3b.reshape((2, 4)).T))

0 commit comments

Comments
 (0)