diff --git a/quantecon/optimize/__init__.py b/quantecon/optimize/__init__.py index 4937fb738..ad3267faf 100644 --- a/quantecon/optimize/__init__.py +++ b/quantecon/optimize/__init__.py @@ -3,4 +3,4 @@ """ from .scalar_maximization import brent_max - +from .root_finding import newton, newton_halley, newton_secant diff --git a/quantecon/optimize/root_finding.py b/quantecon/optimize/root_finding.py new file mode 100644 index 000000000..e88e7de51 --- /dev/null +++ b/quantecon/optimize/root_finding.py @@ -0,0 +1,257 @@ +import numpy as np +from numba import jit, njit +from collections import namedtuple + +__all__ = ['newton', 'newton_halley', 'newton_secant'] + +_ECONVERGED = 0 +_ECONVERR = -1 + +results = namedtuple('results', + ('root function_calls iterations converged')) + +@njit +def _results(r): + r"""Select from a tuple of(root, funccalls, iterations, flag)""" + x, funcalls, iterations, flag = r + return results(x, funcalls, iterations, flag == 0) + +@njit +def newton(func, x0, fprime, args=(), tol=1.48e-8, maxiter=50, + disp=True): + """ + Find a zero from the Newton-Raphson method using the jitted version of + Scipy's newton for scalars. Note that this does not provide an alternative + method such as secant. Thus, it is important that `fprime` can be provided. + + Note that `func` and `fprime` must be jitted via Numba. + They are recommended to be `njit` for performance. + + Parameters + ---------- + func : callable and jitted + The function whose zero is wanted. It must be a function of a + single variable of the form f(x,a,b,c...), where a,b,c... are extra + arguments that can be passed in the `args` parameter. + x0 : float + An initial estimate of the zero that should be somewhere near the + actual zero. + fprime : callable and jitted + The derivative of the function (when available and convenient). + args : tuple, optional + Extra arguments to be used in the function call. + tol : float, optional + The allowable error of the zero value. + maxiter : int, optional + Maximum number of iterations. + disp : bool, optional + If True, raise a RuntimeError if the algorithm didn't converge + + Returns + ------- + results : namedtuple + root - Estimated location where function is zero. + function_calls - Number of times the function was called. + iterations - Number of iterations needed to find the root. + converged - True if the routine converged + """ + + if tol <= 0: + raise ValueError("tol is too small <= 0") + if maxiter < 1: + raise ValueError("maxiter must be greater than 0") + + # Convert to float (don't use float(x0); this works also for complex x0) + p0 = 1.0 * x0 + funcalls = 0 + status = _ECONVERR + + # Newton-Raphson method + for itr in range(maxiter): + # first evaluate fval + fval = func(p0, *args) + funcalls += 1 + # If fval is 0, a root has been found, then terminate + if fval == 0: + status = _ECONVERGED + p = p0 + itr -= 1 + break + fder = fprime(p0, *args) + funcalls += 1 + # derivative is zero, not converged + if fder == 0: + p = p0 + break + newton_step = fval / fder + # Newton step + p = p0 - newton_step + if abs(p - p0) < tol: + status = _ECONVERGED + break + p0 = p + + if disp and status == _ECONVERR: + msg = "Failed to converge" + raise RuntimeError(msg) + + return _results((p, funcalls, itr + 1, status)) + +@njit +def newton_halley(func, x0, fprime, fprime2, args=(), tol=1.48e-8, + maxiter=50, disp=True): + """ + Find a zero from Halley's method using the jitted version of + Scipy's. + + `func`, `fprime`, `fprime2` must be jitted via Numba. + + Parameters + ---------- + func : callable and jitted + The function whose zero is wanted. It must be a function of a + single variable of the form f(x,a,b,c...), where a,b,c... are extra + arguments that can be passed in the `args` parameter. + x0 : float + An initial estimate of the zero that should be somewhere near the + actual zero. + fprime : callable and jitted + The derivative of the function (when available and convenient). + fprime2 : callable and jitted + The second order derivative of the function + args : tuple, optional + Extra arguments to be used in the function call. + tol : float, optional + The allowable error of the zero value. + maxiter : int, optional + Maximum number of iterations. + disp : bool, optional + If True, raise a RuntimeError if the algorithm didn't converge + + Returns + ------- + results : namedtuple + root - Estimated location where function is zero. + function_calls - Number of times the function was called. + iterations - Number of iterations needed to find the root. + converged - True if the routine converged + """ + + if tol <= 0: + raise ValueError("tol is too small <= 0") + if maxiter < 1: + raise ValueError("maxiter must be greater than 0") + + # Convert to float (don't use float(x0); this works also for complex x0) + p0 = 1.0 * x0 + funcalls = 0 + status = _ECONVERR + + # Halley Method + for itr in range(maxiter): + # first evaluate fval + fval = func(p0, *args) + funcalls += 1 + # If fval is 0, a root has been found, then terminate + if fval == 0: + status = _ECONVERGED + p = p0 + itr -= 1 + break + fder = fprime(p0, *args) + funcalls += 1 + # derivative is zero, not converged + if fder == 0: + p = p0 + break + newton_step = fval / fder + # Halley's variant + fder2 = fprime2(p0, *args) + p = p0 - newton_step / (1.0 - 0.5 * newton_step * fder2 / fder) + if abs(p - p0) < tol: + status = _ECONVERGED + break + p0 = p + + if disp and status == _ECONVERR: + msg = "Failed to converge" + raise RuntimeError(msg) + + return _results((p, funcalls, itr + 1, status)) + +@njit +def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50, + disp=True): + """ + Find a zero from the secant method using the jitted version of + Scipy's secant method. + + Note that `func` must be jitted via Numba. + + Parameters + ---------- + func : callable and jitted + The function whose zero is wanted. It must be a function of a + single variable of the form f(x,a,b,c...), where a,b,c... are extra + arguments that can be passed in the `args` parameter. + x0 : float + An initial estimate of the zero that should be somewhere near the + actual zero. + args : tuple, optional + Extra arguments to be used in the function call. + tol : float, optional + The allowable error of the zero value. + maxiter : int, optional + Maximum number of iterations. + disp : bool, optional + If True, raise a RuntimeError if the algorithm didn't converge. + + Returns + ------- + results : namedtuple + root - Estimated location where function is zero. + function_calls - Number of times the function was called. + iterations - Number of iterations needed to find the root. + converged - True if the routine converged + """ + + if tol <= 0: + raise ValueError("tol is too small <= 0") + if maxiter < 1: + raise ValueError("maxiter must be greater than 0") + + # Convert to float (don't use float(x0); this works also for complex x0) + p0 = 1.0 * x0 + funcalls = 0 + status = _ECONVERR + + # Secant method + if x0 >= 0: + p1 = x0 * (1 + 1e-4) + 1e-4 + else: + p1 = x0 * (1 + 1e-4) - 1e-4 + q0 = func(p0, *args) + funcalls += 1 + q1 = func(p1, *args) + funcalls += 1 + for itr in range(maxiter): + if q1 == q0: + p = (p1 + p0) / 2.0 + status = _ECONVERGED + break + else: + p = p1 - q1 * (p1 - p0) / (q1 - q0) + if np.abs(p - p1) < tol: + status = _ECONVERGED + break + p0 = p1 + q0 = q1 + p1 = p + q1 = func(p1, *args) + funcalls += 1 + + if disp and status == _ECONVERR: + msg = "Failed to converge" + raise RuntimeError(msg) + + return _results((p, funcalls, itr + 1, status)) \ No newline at end of file diff --git a/quantecon/optimize/tests/test_root_finding.py b/quantecon/optimize/tests/test_root_finding.py new file mode 100644 index 000000000..ad0015ce4 --- /dev/null +++ b/quantecon/optimize/tests/test_root_finding.py @@ -0,0 +1,123 @@ +import numpy as np +from numpy.testing import assert_almost_equal, assert_allclose +from numba import njit + +from quantecon.optimize import newton, newton_halley, newton_secant + +@njit +def func(x): + """ + Function for testing on. + """ + return (x**3 - 1) + + +@njit +def func_prime(x): + """ + Derivative for func. + """ + return (3*x**2) + +@njit +def func_prime2(x): + """ + Second order derivative for func. + """ + return 6*x + +@njit +def func_two(x): + """ + Harder function for testing on. + """ + return np.sin(4 * (x - 1/4)) + x + x**20 - 1 + + +@njit +def func_two_prime(x): + """ + Derivative for func_two. + """ + return 4*np.cos(4*(x - 1/4)) + 20*x**19 + 1 + +@njit +def func_two_prime2(x): + """ + Second order derivative for func_two + """ + return 380*x**18 - 16*np.sin(4*(x - 1/4)) + + +def test_newton_basic(): + """ + Uses the function f defined above to test the scalar maximization + routine. + """ + true_fval = 1.0 + fval = newton(func, 5, func_prime) + assert_almost_equal(true_fval, fval.root, decimal=4) + + +def test_newton_basic_two(): + """ + Uses the function f defined above to test the scalar maximization + routine. + """ + true_fval = 1.0 + fval = newton(func, 5, func_prime) + assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0) + + +def test_newton_hard(): + """ + Harder test for convergence. + """ + true_fval = 0.408 + fval = newton(func_two, 0.4, func_two_prime) + assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.01) + +def test_halley_basic(): + """ + Basic test for halley method + """ + true_fval = 1.0 + fval = newton_halley(func, 5, func_prime, func_prime2) + assert_almost_equal(true_fval, fval.root, decimal=4) + +def test_halley_hard(): + """ + Harder test for halley method + """ + true_fval = 0.408 + fval = newton_halley(func_two, 0.4, func_two_prime, func_two_prime2) + assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.01) + +def test_secant_basic(): + """ + Basic test for secant option. + """ + true_fval = 1.0 + fval = newton_secant(func, 5) + assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.001) + + +def test_secant_hard(): + """ + Harder test for convergence for secant function. + """ + true_fval = 0.408 + fval = newton_secant(func_two, 0.4) + assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.01) + + +# executing testcases. + +if __name__ == '__main__': + import sys + import nose + + argv = sys.argv[:] + argv.append('--verbose') + argv.append('--nocapture') + nose.main(argv=argv, defaultTest=__file__)