Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Enable pickling for LLVMDouble class #213

Merged
merged 8 commits into from
Feb 7, 2018
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions symengine/lib/symengine.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,8 @@ cdef extern from "<symengine/llvm_double.h>" namespace "SymEngine":
LLVMDoubleVisitor() nogil
void init(const vec_basic &x, const vec_basic &b, bool cse) nogil except +
void call(double *r, const double *x) nogil
const string& dumps() nogil
void loads(const string&) nogil

cdef extern from "<symengine/series.h>" namespace "SymEngine":
cdef cppclass SeriesCoeffInterface:
Expand Down
26 changes: 25 additions & 1 deletion symengine/lib/symengine_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4369,7 +4369,7 @@ cdef class _Lambdify(object):
cdef vector[int] accum_out_sizes
cdef object numpy_dtype

def __init__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False):
def __init__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False):
cdef:
Basic e_
size_t ri, ci, nr, nc
Expand All @@ -4378,6 +4378,13 @@ cdef class _Lambdify(object):
symengine.vec_basic args_, outs_
vector[int] out_sizes

if _load:
self.args_size, self.tot_out_size, self.out_shapes, self.real, \
self.n_exprs, self.order, self.accum_out_sizes, self.numpy_dtype, \
llvm_function = args
self._load(llvm_function)
return

args = np.asanyarray(args)
self.args_size = args.size
exprs = tuple(np.asanyarray(expr) for expr in exprs)
Expand Down Expand Up @@ -4414,6 +4421,9 @@ cdef class _Lambdify(object):
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse):
raise ValueError("Not supported")

cdef _load(self, const string &s):
raise ValueError("Not supported")

cpdef unsafe_real(self,
double[::1] inp, double[::1] out,
int inp_offset=0, int out_offset=0):
Expand Down Expand Up @@ -4625,6 +4635,18 @@ IF HAVE_SYMENGINE_LLVM:
self.lambda_double.resize(1)
self.lambda_double[0].init(args_, outs_, cse)

cdef _load(self, const string &s):
self.lambda_double.resize(1)
self.lambda_double[0].loads(s)

def __reduce__(self):
"""
Interface for pickle. Note that the resulting object is platform dependent.
"""
cdef bytes s = self.lambda_double[0].dumps()
return llvm_loading_func, (self.args_size, self.tot_out_size, self.out_shapes, self.real, \
self.n_exprs, self.order, self.accum_out_sizes, self.numpy_dtype, s)

cpdef unsafe_real(self, double[::1] inp, double[::1] out, int inp_offset=0, int out_offset=0):
self.lambda_double[0].call(&out[out_offset], &inp[inp_offset])

Expand All @@ -4639,6 +4661,8 @@ IF HAVE_SYMENGINE_LLVM:
addr2 = cast(<size_t>&self.lambda_double[0], c_void_p)
return create_low_level_callable(self, addr1, addr2)

def llvm_loading_func(*args):
return LLVMDouble(args, _load=True)
Copy link
Contributor

@bjodah bjodah Feb 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think the _load kwarg could go into __cinit__ if we cdef the return value here:

def llvm_loading_func(*args):
    cdef LLVMDouble llvm_double = LLVMDouble(args, _load=True)
    return llvm_double

(untested)
cf. e.g. this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow. Can you elaborate more? (Or a PR. 😃)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not important (actually not sure it matters much), so let's leave it for now (and I can try to make a PR against master after this is merged if I can find a way to hide the _load kwarg).


def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=False, cse=False):
"""
Expand Down
17 changes: 17 additions & 0 deletions symengine/tests/test_pickling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from symengine import symbols, sin, sinh, have_numpy, have_llvm
import pickle
import unittest

@unittest.skipUnless(have_llvm, "No LLVM support")
@unittest.skipUnless(have_numpy, "Numpy not installed")
def test_llvm_double():
import numpy as np
from symengine import Lambdify
args = x, y, z = symbols('x y z')
expr = sin(sinh(x+y) + z)
l = Lambdify(args, expr, cse=True, backend='llvm')
ss = pickle.dumps(l)
ll = pickle.loads(ss)
inp = [1, 2, 3]
assert np.allclose(l(inp), ll(inp))

4 changes: 2 additions & 2 deletions symengine/tests/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def test_ccode():
assert ccode(x**3) == "pow(x, 3)"
assert ccode(x**(y**3)) == "pow(x, pow(y, 3))"
assert ccode(x**-1.0) == "pow(x, -1.0)"
assert ccode(Max(x, x*x)) == "max(x, pow(x, 2))"
assert ccode(Max(x, x*x)) == "fmax(x, pow(x, 2))"
assert ccode(sin(x)) == "sin(x)"
assert ccode(Integer(67)) == "67"
assert ccode(Integer(-1)) == "-1"
Expand All @@ -24,4 +24,4 @@ def test_CCodePrinter():
assert myprinter.doprint(MutableDenseMatrix(1, 2, [x, y]), "larry") == "larry[0] = x;\nlarry[1] = y;"
raises(TypeError, lambda: myprinter.doprint(sin(x), Integer))
raises(RuntimeError, lambda: myprinter.doprint(MutableDenseMatrix(1, 2, [x, y])))


2 changes: 1 addition & 1 deletion symengine_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2f5ff9db9ff511ee243438a85ea8e2da2d05af39
fff6755331226a08f0b14571bfbce2b23001d911