Skip to content

Commit 221ccac

Browse files
committed
Enable pickling for LLVMDouble class
1 parent fefb5f4 commit 221ccac

File tree

3 files changed

+42
-1
lines changed

3 files changed

+42
-1
lines changed

symengine/lib/symengine.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,8 @@ cdef extern from "<symengine/llvm_double.h>" namespace "SymEngine":
972972
LLVMDoubleVisitor() nogil
973973
void init(const vec_basic &x, const vec_basic &b, bool cse) nogil except +
974974
void call(double *r, const double *x) nogil
975+
const string& dumps() nogil
976+
void loads(const string&) nogil
975977

976978
cdef extern from "<symengine/series.h>" namespace "SymEngine":
977979
cdef cppclass SeriesCoeffInterface:

symengine/lib/symengine_wrapper.pyx

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4369,7 +4369,7 @@ cdef class _Lambdify(object):
43694369
cdef vector[int] accum_out_sizes
43704370
cdef object numpy_dtype
43714371

4372-
def __init__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False):
4372+
def __init__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool load=False):
43734373
cdef:
43744374
Basic e_
43754375
size_t ri, ci, nr, nc
@@ -4378,6 +4378,13 @@ cdef class _Lambdify(object):
43784378
symengine.vec_basic args_, outs_
43794379
vector[int] out_sizes
43804380

4381+
if load:
4382+
self.args_size, self.tot_out_size, self.out_shapes, self.real, \
4383+
self.n_exprs, self.order, self.accum_out_sizes, self.numpy_dtype, \
4384+
llvm_function = args
4385+
self._load(llvm_function)
4386+
return
4387+
43814388
args = np.asanyarray(args)
43824389
self.args_size = args.size
43834390
exprs = tuple(np.asanyarray(expr) for expr in exprs)
@@ -4414,6 +4421,9 @@ cdef class _Lambdify(object):
44144421
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse):
44154422
raise ValueError("Not supported")
44164423

4424+
cdef _load(self, const string &s):
4425+
raise ValueError("Not supported")
4426+
44174427
cpdef unsafe_real(self,
44184428
double[::1] inp, double[::1] out,
44194429
int inp_offset=0, int out_offset=0):
@@ -4625,6 +4635,18 @@ IF HAVE_SYMENGINE_LLVM:
46254635
self.lambda_double.resize(1)
46264636
self.lambda_double[0].init(args_, outs_, cse)
46274637

4638+
cdef _load(self, const string &s):
4639+
self.lambda_double.resize(1)
4640+
self.lambda_double[0].loads(s)
4641+
4642+
def __reduce__(self):
4643+
"""
4644+
Interface for pickle. Note that the resulting object is platform dependent.
4645+
"""
4646+
cdef bytes s = self.lambda_double[0].dumps()
4647+
return llvm_loading_func, (self.args_size, self.tot_out_size, self.out_shapes, self.real, \
4648+
self.n_exprs, self.order, self.accum_out_sizes, self.numpy_dtype, s)
4649+
46284650
cpdef unsafe_real(self, double[::1] inp, double[::1] out, int inp_offset=0, int out_offset=0):
46294651
self.lambda_double[0].call(&out[out_offset], &inp[inp_offset])
46304652

@@ -4639,6 +4661,8 @@ IF HAVE_SYMENGINE_LLVM:
46394661
addr2 = cast(<size_t>&self.lambda_double[0], c_void_p)
46404662
return create_low_level_callable(self, addr1, addr2)
46414663

4664+
def llvm_loading_func(*args):
4665+
return LLVMDouble(args, load=True)
46424666

46434667
def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=False, cse=False):
46444668
"""

symengine/tests/test_pickling.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from symengine import symbols, sin, sinh, Lambdify, have_numpy
2+
import pickle
3+
import unittest
4+
5+
@unittest.skipUnless(have_numpy, "Numpy not installed")
6+
def test_llvm_double():
7+
import numpy as np
8+
args = x, y, z = symbols('x y z')
9+
expr = sin(sinh(x+y) + z)
10+
l = Lambdify(args, expr, cse=True, backend='llvm')
11+
ss = pickle.dumps(l)
12+
ll = pickle.loads(ss)
13+
inp = [1, 2, 3]
14+
assert np.allclose(l(inp), ll(inp))
15+

0 commit comments

Comments
 (0)