diff --git a/CMakeLists.txt b/CMakeLists.txt index 45d0a63ae..6cabd8ed2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 2.8) project(python_wrapper) set(CMAKE_PREFIX_PATH ${SymEngine_DIR} ${CMAKE_PREFIX_PATH}) -find_package(SymEngine 0.5.0 REQUIRED CONFIG +find_package(SymEngine 0.6.0 REQUIRED CONFIG PATH_SUFFIXES lib/cmake/symengine cmake/symengine CMake/) message("SymEngine_DIR : " ${SymEngine_DIR}) message("SymEngine Version : " ${SymEngine_VERSION}) @@ -30,6 +30,15 @@ if (MINGW AND (CMAKE_BUILD_TYPE STREQUAL "Release")) endif() endif() +include(CheckTypeSize) +check_type_size("long double" SYMENGINE_SIZEOF_LONG_DOUBLE) + +if (HAVE_SYMENGINE_LLVM AND SYMENGINE_SIZEOF_LONG_DOUBLE GREATER "8" AND CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)") + set (HAVE_SYMENGINE_LLVM_LONG_DOUBLE True) +else () + set (HAVE_SYMENGINE_LLVM_LONG_DOUBLE False) +endif () + foreach (PKG MPC MPFR PIRANHA FLINT LLVM) if ("${HAVE_SYMENGINE_${PKG}}" STREQUAL "") set(HAVE_SYMENGINE_${PKG} False) diff --git a/symengine/__init__.py b/symengine/__init__.py index 1b0ac5a2e..5b5e1062a 100644 --- a/symengine/__init__.py +++ b/symengine/__init__.py @@ -1,5 +1,5 @@ from .lib.symengine_wrapper import ( - have_mpfr, have_mpc, have_flint, have_piranha, have_llvm, + have_mpfr, have_mpc, have_flint, have_piranha, have_llvm, have_llvm_long_double, I, E, pi, oo, zoo, nan, Symbol, Dummy, S, sympify, SympifyError, Integer, Rational, Float, Number, RealNumber, RealDouble, ComplexDouble, add, Add, Mul, Pow, function_symbol, diff --git a/symengine/lib/config.pxi.in b/symengine/lib/config.pxi.in index 03d3c06b9..26a33012e 100644 --- a/symengine/lib/config.pxi.in +++ b/symengine/lib/config.pxi.in @@ -3,3 +3,4 @@ DEF HAVE_SYMENGINE_MPC = ${HAVE_SYMENGINE_MPC} DEF HAVE_SYMENGINE_PIRANHA = ${HAVE_SYMENGINE_PIRANHA} DEF HAVE_SYMENGINE_FLINT = ${HAVE_SYMENGINE_FLINT} DEF HAVE_SYMENGINE_LLVM = ${HAVE_SYMENGINE_LLVM} +DEF HAVE_SYMENGINE_LLVM_LONG_DOUBLE = ${HAVE_SYMENGINE_LLVM_LONG_DOUBLE} diff --git a/symengine/lib/symengine.pxd b/symengine/lib/symengine.pxd index 1013cd1e1..348f6fb21 100644 --- a/symengine/lib/symengine.pxd +++ b/symengine/lib/symengine.pxd @@ -985,13 +985,22 @@ cdef extern from "" namespace "SymEngine": void call(double complex *r, const double complex *x) nogil cdef extern from "" namespace "SymEngine": - cdef cppclass LLVMDoubleVisitor: - LLVMDoubleVisitor() nogil + cdef cppclass LLVMVisitor: + LLVMVisitor() nogil void init(const vec_basic &x, const vec_basic &b, bool cse, int opt_level) nogil except + - void call(double *r, const double *x) nogil const string& dumps() nogil void loads(const string&) nogil + cdef cppclass LLVMFloatVisitor(LLVMVisitor): + void call(float *r, const float *x) nogil + + cdef cppclass LLVMDoubleVisitor(LLVMVisitor): + void call(double *r, const double *x) nogil + + cdef cppclass LLVMLongDoubleVisitor(LLVMVisitor): + void call(long double *r, const long double *x) nogil + + cdef extern from "" namespace "SymEngine": cdef cppclass SeriesCoeffInterface: rcp_const_basic as_basic() nogil except + diff --git a/symengine/lib/symengine_wrapper.pxd b/symengine/lib/symengine_wrapper.pxd index d1c6215f3..653854dc0 100644 --- a/symengine/lib/symengine_wrapper.pxd +++ b/symengine/lib/symengine_wrapper.pxd @@ -39,29 +39,46 @@ cdef class _Lambdify(object): cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse) cdef _load(self, const string &s) - cpdef unsafe_real(self, - double[::1] inp, double[::1] out, - int inp_offset=*, int out_offset=*) - cpdef unsafe_complex(self, double complex[::1] inp, double complex[::1] out, - int inp_offset=*, int out_offset=*) cpdef eval_real(self, inp, out) cpdef eval_complex(self, inp, out) + cpdef unsafe_eval(sef, inp, out, unsigned nbroadcast=*) cdef class LambdaDouble(_Lambdify): cdef vector[symengine.LambdaRealDoubleVisitor] lambda_double - cdef vector[symengine.LambdaComplexDoubleVisitor] lambda_double_complex cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse) cpdef unsafe_real(self, double[::1] inp, double[::1] out, int inp_offset=*, int out_offset=*) - cpdef unsafe_complex(self, double complex[::1] inp, double complex[::1] out, int inp_offset=*, int out_offset=*) cpdef as_scipy_low_level_callable(self) cpdef as_ctypes(self) + cpdef unsafe_real(self, + double[::1] inp, double[::1] out, + int inp_offset=*, int out_offset=*) + +cdef class LambdaComplexDouble(_Lambdify): + cdef vector[symengine.LambdaComplexDoubleVisitor] lambda_double + cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse) + cpdef unsafe_complex(self, double complex[::1] inp, double complex[::1] out, int inp_offset=*, int out_offset=*) IF HAVE_SYMENGINE_LLVM: - cdef class LLVMDouble(_Lambdify): + cdef class _LLVMLambdify(_Lambdify): cdef int opt_level + + cdef class LLVMDouble(_LLVMLambdify): cdef vector[symengine.LLVMDoubleVisitor] lambda_double cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse) cdef _load(self, const string &s) cpdef unsafe_real(self, double[::1] inp, double[::1] out, int inp_offset=*, int out_offset=*) cpdef as_scipy_low_level_callable(self) cpdef as_ctypes(self) + + cdef class LLVMFloat(_LLVMLambdify): + cdef vector[symengine.LLVMFloatVisitor] lambda_double + cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse) + cdef _load(self, const string &s) + cpdef unsafe_real(self, float[::1] inp, float[::1] out, int inp_offset=*, int out_offset=*) + + IF HAVE_SYMENGINE_LLVM_LONG_DOUBLE: + cdef class LLVMLongDouble(_LLVMLambdify): + cdef vector[symengine.LLVMLongDoubleVisitor] lambda_double + cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse) + cdef _load(self, const string &s) + cpdef unsafe_real(self, long double[::1] inp, long double[::1] out, int inp_offset=*, int out_offset=*) diff --git a/symengine/lib/symengine_wrapper.pyx b/symengine/lib/symengine_wrapper.pyx index 6cbb51010..3772dcb1e 100644 --- a/symengine/lib/symengine_wrapper.pyx +++ b/symengine/lib/symengine_wrapper.pyx @@ -4058,6 +4058,7 @@ have_mpc = False have_piranha = False have_flint = False have_llvm = False +have_llvm_long_double = False IF HAVE_SYMENGINE_MPFR: have_mpfr = True @@ -4080,6 +4081,9 @@ IF HAVE_SYMENGINE_FLINT: IF HAVE_SYMENGINE_LLVM: have_llvm = True +IF HAVE_SYMENGINE_LLVM_LONG_DOUBLE: + have_llvm_long_double = True + def require(obj, t): if not isinstance(obj, t): raise TypeError("{} required. {} is of type {}".format(t, obj, type(obj))) @@ -4465,7 +4469,7 @@ def has_symbol(obj, symbol=None): cdef class _Lambdify(object): - def __init__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False, **kwargs): + def __init__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False, dtype=None, **kwargs): cdef: Basic e_ size_t ri, ci, nr, nc @@ -4488,7 +4492,7 @@ cdef class _Lambdify(object): self.n_exprs = len(exprs) self.real = real self.order = order - self.numpy_dtype = np.float64 if self.real else np.complex128 + self.numpy_dtype = dtype if dtype else (np.float64 if self.real else np.complex128) if self.args_size == 0: raise NotImplementedError("Support for zero arguments not yet supported") self.tot_out_size = 0 @@ -4520,15 +4524,6 @@ cdef class _Lambdify(object): cdef _load(self, const string &s): raise ValueError("Not supported") - cpdef unsafe_real(self, - double[::1] inp, double[::1] out, - int inp_offset=0, int out_offset=0): - raise ValueError("Not supported") - - cpdef unsafe_complex(self, double complex[::1] inp, double complex[::1] out, - int inp_offset=0, int out_offset=0): - raise ValueError("Not supported") - cpdef eval_real(self, inp, out): if inp.size != self.args_size: raise ValueError("Size of inp incompatible with number of args.") @@ -4543,6 +4538,9 @@ cdef class _Lambdify(object): raise ValueError("Size of out incompatible with number of exprs.") self.unsafe_complex(inp, out) + cpdef unsafe_eval(self, inp, out, unsigned nbroadcast=1): + raise ValueError("Not supported") + def __call__(self, *args, out=None): """ Parameters @@ -4562,10 +4560,6 @@ cdef class _Lambdify(object): """ cdef: size_t idx, new_tot_out_size, nbroadcast = 1 - long inp_size - tuple inp_shape - double[::1] real_out, real_inp - double complex[::1] cmplx_out, cmplx_inp if self.order not in ('C', 'F'): raise NotImplementedError("Only C & F order supported for now.") @@ -4577,11 +4571,6 @@ cdef class _Lambdify(object): except TypeError: inp = np.fromiter(args, dtype=self.numpy_dtype) - if self.real: - real_inp = np.ascontiguousarray(inp.ravel(order=self.order)) - else: - cmplx_inp = np.ascontiguousarray(inp.ravel(order=self.order)) - if inp.size < self.args_size or inp.size % self.args_size != 0: raise ValueError("Broadcasting failed (input/arg size mismatch)") nbroadcast = inp.size // self.args_size @@ -4637,19 +4626,7 @@ cdef class _Lambdify(object): raise ValueError("Output argument needs to be writeable") out = out.ravel(order=self.order) - if self.real: - real_out = out - else: - cmplx_out = out - - if self.real: - for idx in range(nbroadcast): - self.unsafe_real(real_inp, real_out, - idx*self.args_size, idx*self.tot_out_size) - else: - for idx in range(nbroadcast): - self.unsafe_complex(cmplx_inp, cmplx_out, - idx*self.args_size, idx*self.tot_out_size) + self.unsafe_eval(inp, out, nbroadcast) if self.order == 'C': out = out.reshape((nbroadcast, self.tot_out_size), order='C') @@ -4702,28 +4679,27 @@ def create_low_level_callable(lambdify, *args): cdef class LambdaDouble(_Lambdify): - def __cinit__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False): + def __cinit__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False, dtype=None): # reject additional arguments pass cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse): - if self.real: - self.lambda_double.resize(1) - self.lambda_double[0].init(args_, outs_, cse) - else: - self.lambda_double_complex.resize(1) - self.lambda_double_complex[0].init(args_, outs_, cse) + self.lambda_double.resize(1) + self.lambda_double[0].init(args_, outs_, cse) cpdef unsafe_real(self, double[::1] inp, double[::1] out, int inp_offset=0, int out_offset=0): self.lambda_double[0].call(&out[out_offset], &inp[inp_offset]) - cpdef unsafe_complex(self, double complex[::1] inp, double complex[::1] out, int inp_offset=0, int out_offset=0): - self.lambda_double_complex[0].call(&out[out_offset], &inp[inp_offset]) + cpdef unsafe_eval(self, inp, out, unsigned nbroadcast=1): + cdef double[::1] c_inp, c_out + cdef unsigned idx + c_inp = np.ascontiguousarray(inp.ravel(order=self.order), dtype=self.numpy_dtype) + c_out = out + for idx in range(nbroadcast): + self.lambda_double[0].call(&c_out[idx*self.tot_out_size], &c_inp[idx*self.args_size]) cpdef as_scipy_low_level_callable(self): from ctypes import c_double, c_void_p, c_int, cast, POINTER, CFUNCTYPE - if not self.real: - raise RuntimeError("Lambda function has to be real") if self.tot_out_size > 1: raise RuntimeError("SciPy LowLevelCallable supports only functions with 1 output") addr1 = cast(&_scipy_callback_lambda_real, @@ -4741,17 +4717,36 @@ cdef class LambdaDouble(_Lambdify): passed as input to the function as the third argument `user_data`. """ from ctypes import c_double, c_void_p, c_int, cast, POINTER, CFUNCTYPE - if not self.real: - raise RuntimeError("Lambda function has to be real") addr1 = cast(&_ctypes_callback_lambda_real, CFUNCTYPE(c_void_p, POINTER(c_double), POINTER(c_double), c_void_p)) addr2 = cast(&self.lambda_double[0], c_void_p) return addr1, addr2 +cdef class LambdaComplexDouble(_Lambdify): + def __cinit__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False, dtype=None): + # reject additional arguments + pass + + cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse): + self.lambda_double.resize(1) + self.lambda_double[0].init(args_, outs_, cse) + + cpdef unsafe_complex(self, double complex[::1] inp, double complex[::1] out, int inp_offset=0, int out_offset=0): + self.lambda_double[0].call(&out[out_offset], &inp[inp_offset]) + + cpdef unsafe_eval(self, inp, out, unsigned nbroadcast=1): + cdef double complex[::1] c_inp, c_out + cdef unsigned idx + c_inp = np.ascontiguousarray(inp.ravel(order=self.order), dtype=self.numpy_dtype) + c_out = out + for idx in range(nbroadcast): + self.lambda_double[0].call(&c_out[idx*self.tot_out_size], &c_inp[idx*self.args_size]) + + IF HAVE_SYMENGINE_LLVM: - cdef class LLVMDouble(_Lambdify): - def __cinit__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False, opt_level=3): + cdef class LLVMDouble(_LLVMLambdify): + def __cinit__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False, opt_level=3, dtype=None): self.opt_level = opt_level cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse): @@ -4773,6 +4768,14 @@ IF HAVE_SYMENGINE_LLVM: cpdef unsafe_real(self, double[::1] inp, double[::1] out, int inp_offset=0, int out_offset=0): self.lambda_double[0].call(&out[out_offset], &inp[inp_offset]) + cpdef unsafe_eval(self, inp, out, unsigned nbroadcast=1): + cdef double[::1] c_inp, c_out + cdef unsigned idx + c_inp = np.ascontiguousarray(inp.ravel(order=self.order), dtype=self.numpy_dtype) + c_out = out + for idx in range(nbroadcast): + self.lambda_double[0].call(&c_out[idx*self.tot_out_size], &c_inp[idx*self.args_size]) + cpdef as_scipy_low_level_callable(self): from ctypes import c_double, c_void_p, c_int, cast, POINTER, CFUNCTYPE if not self.real: @@ -4801,10 +4804,81 @@ IF HAVE_SYMENGINE_LLVM: addr2 = cast(&self.lambda_double[0], c_void_p) return addr1, addr2 + cdef class LLVMFloat(_LLVMLambdify): + def __cinit__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False, opt_level=3, dtype=None): + self.opt_level = opt_level + + cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse): + self.lambda_double.resize(1) + self.lambda_double[0].init(args_, outs_, cse, self.opt_level) + + cdef _load(self, const string &s): + self.lambda_double.resize(1) + self.lambda_double[0].loads(s) + + def __reduce__(self): + """ + Interface for pickle. Note that the resulting object is platform dependent. + """ + cdef bytes s = self.lambda_double[0].dumps() + return llvm_float_loading_func, (self.args_size, self.tot_out_size, self.out_shapes, self.real, \ + self.n_exprs, self.order, self.accum_out_sizes, self.numpy_dtype, s) + + cpdef unsafe_real(self, float[::1] inp, float[::1] out, int inp_offset=0, int out_offset=0): + self.lambda_double[0].call(&out[out_offset], &inp[inp_offset]) + + cpdef unsafe_eval(self, inp, out, unsigned nbroadcast=1): + cdef float[::1] c_inp, c_out + cdef unsigned idx + c_inp = np.ascontiguousarray(inp.ravel(order=self.order), dtype=self.numpy_dtype) + c_out = out + for idx in range(nbroadcast): + self.lambda_double[0].call(&c_out[idx*self.tot_out_size], &c_inp[idx*self.args_size]) + + IF HAVE_SYMENGINE_LLVM_LONG_DOUBLE: + cdef class LLVMLongDouble(_LLVMLambdify): + def __cinit__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False, opt_level=3, dtype=None): + self.opt_level = opt_level + + cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse): + self.lambda_double.resize(1) + self.lambda_double[0].init(args_, outs_, cse, self.opt_level) + + cdef _load(self, const string &s): + self.lambda_double.resize(1) + self.lambda_double[0].loads(s) + + def __reduce__(self): + """ + Interface for pickle. Note that the resulting object is platform dependent. + """ + cdef bytes s = self.lambda_double[0].dumps() + return llvm_long_double_loading_func, (self.args_size, self.tot_out_size, self.out_shapes, self.real, \ + self.n_exprs, self.order, self.accum_out_sizes, self.numpy_dtype, s) + + cpdef unsafe_real(self, long double[::1] inp, long double[::1] out, int inp_offset=0, int out_offset=0): + self.lambda_double[0].call(&out[out_offset], &inp[inp_offset]) + + cpdef unsafe_eval(self, inp, out, unsigned nbroadcast=1): + cdef long double[::1] c_inp, c_out + cdef unsigned idx + c_inp = np.ascontiguousarray(inp.ravel(order=self.order), dtype=self.numpy_dtype) + c_out = out + for idx in range(nbroadcast): + self.lambda_double[0].call(&c_out[idx*self.tot_out_size], &c_inp[idx*self.args_size]) + def llvm_loading_func(*args): return LLVMDouble(args, _load=True) -def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=False, cse=False, **kwargs): + def llvm_float_loading_func(*args): + return LLVMFloat(args, _load=True) + + IF HAVE_SYMENGINE_LLVM_LONG_DOUBLE: + def llvm_long_double_loading_func(*args): + return LLVMLongDouble(args, _load=True) + +def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', + as_scipy=False, cse=False, dtype=None, **kwargs): """ Lambdify instances are callbacks that numerically evaluate their symbolic expressions from user provided input (real or complex) into (possibly user @@ -4833,6 +4907,7 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy= cse : bool Run Common Subexpression Elimination on the output before generating the callback. + dtype : numpy.dtype type Returns ------- @@ -4854,7 +4929,20 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy= backend = os.getenv('SYMENGINE_LAMBDIFY_BACKEND', "lambda") if backend == "llvm": IF HAVE_SYMENGINE_LLVM: - ret = LLVMDouble(args, *exprs, real=real, order=order, cse=cse, **kwargs) + if dtype == None: + dtype = np.float64 + if dtype == np.float64: + ret = LLVMDouble(args, *exprs, real=real, order=order, cse=cse, dtype=np.float64, **kwargs) + elif dtype == np.float32: + ret = LLVMFloat(args, *exprs, real=real, order=order, cse=cse, dtype=np.float32, **kwargs) + elif dtype == np.longdouble: + IF HAVE_SYMENGINE_LLVM_LONG_DOUBLE: + ret = LLVMLongDouble(args, *exprs, real=real, order=order, cse=cse, dtype=np.longdouble, **kwargs) + ELSE: + raise ValueError("Long double not supported on this platform") + else: + raise ValueError("Unknown numpy dtype.") + if as_scipy: return ret.as_scipy_low_level_callable() return ret @@ -4865,7 +4953,10 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy= pass else: warnings.warn("Unknown SymEngine backend: %s\nUsing backend='lambda'" % backend) - ret = LambdaDouble(args, *exprs, real=real, order=order, cse=cse, **kwargs) + if real: + ret = LambdaDouble(args, *exprs, real=real, order=order, cse=cse, **kwargs) + else: + ret = LambdaComplexDouble(args, *exprs, real=real, order=order, cse=cse, **kwargs) if as_scipy: return ret.as_scipy_low_level_callable() return ret diff --git a/symengine/tests/test_lambdify.py b/symengine/tests/test_lambdify.py index 6ec90722a..b2ad57c0a 100644 --- a/symengine/tests/test_lambdify.py +++ b/symengine/tests/test_lambdify.py @@ -835,3 +835,34 @@ def test_as_ctypes(): out = np.array([0, 0], dtype=np.double) addr1(out.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), inp.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), addr2) assert np.all(out == [6, 7]) + +@unittest.skipUnless(have_numpy, "Numpy not installed") +@unittest.skipUnless(se.have_llvm, "No LLVM support") +def test_llvm_float(): + import numpy as np + import ctypes + from symengine.lib.symengine_wrapper import LLVMFloat + x, y, z = se.symbols('x, y, z') + l = se.Lambdify([x, y, z], [se.Min(x, y), se.Max(y, z)], dtype=np.float32, backend='llvm') + inp = np.array([1,2,3], dtype=np.float32) + exp_out = np.array([1, 3], dtype=np.float32) + out = l(inp) + assert type(l) == LLVMFloat + assert out.dtype == np.float32 + assert np.allclose(out, exp_out) + +@unittest.skipUnless(have_numpy, "Numpy not installed") +@unittest.skipUnless(se.have_llvm, "No LLVM support") +@unittest.skipUnless(se.have_llvm_long_double, "No LLVM IEEE-80 bit support") +def test_llvm_long_double(): + import numpy as np + import ctypes + from symengine.lib.symengine_wrapper import LLVMLongDouble + x, y, z = se.symbols('x, y, z') + l = se.Lambdify([x, y, z], [2*x, y/z], dtype=np.longdouble, backend='llvm') + inp = np.array([1,2,3], dtype=np.longdouble) + exp_out = np.array([2, 2.0/3.0], dtype=np.longdouble) + out = l(inp) + assert type(l) == LLVMLongDouble + assert out.dtype == np.longdouble + assert np.allclose(out, exp_out)