Skip to content

Commit d35b2a6

Browse files
committed
Add symengine.have_numpy, handle defines in setup.py
1 parent 537c860 commit d35b2a6

File tree

5 files changed

+88
-76
lines changed

5 files changed

+88
-76
lines changed

setup.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,21 @@ def get_build_dir(dist):
6161
'options to cmake <var>:<type>=<value>'),
6262
]
6363

64+
def _process_define(arg):
65+
(defs, one), = getattr(arg, 'define', ['', '1'])
66+
assert one == '1'
67+
defs = defs.split(';')
68+
if not any(define.startswith('WITH_NUMPY') for define in defs):
69+
try:
70+
import numpy as np
71+
except ImportError:
72+
defs.append('WITH_NUMPY=False')
73+
else:
74+
defs.append('WITH_NUMPY=True')
75+
return [(s.strip(), None) if '=' not in s else
76+
tuple(ss.strip() for ss in s.split('='))
77+
for s in defs]
78+
6479

6580
class BuildWithCmake(_build):
6681
sub_commands = [('build_ext', None)]
@@ -83,12 +98,8 @@ def finalize_options(self):
8398
# The argument parsing will result in self.define being a string, but
8499
# it has to be a list of 2-tuples.
85100
# Multiple symbols can be separated with semi-colons.
86-
if self.define:
87-
defines = self.define.split(';')
88-
self.define = [(s.strip(), None) if '=' not in s else
89-
tuple(ss.strip() for ss in s.split('='))
90-
for s in defines]
91-
cmake_opts.extend(self.define)
101+
self.define = _process_define(self)
102+
cmake_opts.extend(self.define)
92103
if self.symengine_dir:
93104
cmake_opts.extend([('SymEngine_DIR', self.symengine_dir)])
94105

@@ -160,21 +171,8 @@ def finalize_options(self):
160171
# The argument parsing will result in self.define being a string, but
161172
# it has to be a list of 2-tuples.
162173
# Multiple symbols can be separated with semi-colons.
163-
if self.define:
164-
defines = self.define.split(';')
165-
if not any(define.startswith('WITH_NUMPY') for define in defines):
166-
try:
167-
import numpy as np
168-
except ImportError:
169-
defines.append('WITH_NUMPY=False')
170-
else:
171-
defines.append('WITH_NUMPY=True')
172-
defines.append()
173-
self.define = [(s.strip(), None) if '=' not in s else
174-
tuple(ss.strip() for ss in s.split('='))
175-
for s in defines]
176-
cmake_opts.extend(self.define)
177-
174+
self.define = _process_define(self)
175+
cmake_opts.extend(self.define)
178176
cmake_build_type[0] = self.build_type
179177
cmake_opts.extend([('PYTHON_INSTALL_PATH', self.install_platlib)])
180178
cmake_opts.extend([('PYTHON_INSTALL_HEADER_PATH',

symengine/__init__.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
have_llvm, Integer, Rational, Float, Number, RealNumber,
55
RealDouble, ComplexDouble, Max, Min, DenseMatrix, Matrix,
66
ImmutableMatrix, ImmutableDenseMatrix, MutableDenseMatrix,
7-
MatrixBase, Basic, Lambdify, LambdifyCSE, Lambdify as lambdify,
8-
DictBasic, symarray, series, diff, zeros, eye, diag, ones,
9-
Derivative, Subs, add, expand, has_symbol, UndefFunction,
10-
Function, FunctionSymbol as AppliedUndef)
7+
MatrixBase, Basic, DictBasic, symarray, series, diff, zeros,
8+
eye, diag, ones, Derivative, Subs, add, expand, has_symbol,
9+
UndefFunction, Function, FunctionSymbol as AppliedUndef,
10+
have_numpy)
1111
from .utilities import var, symbols
1212
from .functions import *
1313

@@ -17,6 +17,26 @@
1717
if have_mpc:
1818
from .lib.symengine_wrapper import ComplexMPC
1919

20+
if have_numpy:
21+
from .lib.symengine_wrapper import Lambdify, LambdifyCSE
22+
23+
def lambdify(args, exprs):
24+
try:
25+
len(args)
26+
except TypeError:
27+
args = [args]
28+
try:
29+
len(exprs)
30+
except TypeError:
31+
exprs = [exprs]
32+
lmb = Lambdify(args, exprs)
33+
def f(*inner_args):
34+
if len(inner_args) != len(args):
35+
raise TypeError("Incorrect number of arguments")
36+
return lmb(inner_args)
37+
return f
38+
39+
2040
__version__ = "0.2.1.dev"
2141

2242

symengine/lib/symengine_wrapper.pyx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2599,6 +2599,7 @@ have_mpc = False
25992599
have_piranha = False
26002600
have_flint = False
26012601
have_llvm = False
2602+
have_numpy = False
26022603

26032604
IF HAVE_SYMENGINE_MPFR:
26042605
have_mpfr = True
@@ -3015,6 +3016,7 @@ IF HAVE_NUMPY:
30153016
# Lambdify requires NumPy (since b713a61, see gh-112)
30163017
cimport numpy as cnp
30173018
import numpy as np
3019+
have_numpy = True
30183020

30193021
cdef size_t _size(n):
30203022
try:

symengine/sympy_compat.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,3 @@
22
warnings.warn("sympy_compat module is deprecated. Use `import symengine` instead", DeprecationWarning,
33
stacklevel=2)
44
from symengine import *
5-

0 commit comments

Comments
 (0)