-
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #417 from chrishyland/root-finding
ENH: Root finding
- Loading branch information
Showing
3 changed files
with
381 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,4 @@ | |
""" | ||
|
||
from .scalar_maximization import brent_max | ||
|
||
from .root_finding import newton, newton_halley, newton_secant |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,257 @@ | ||
import numpy as np | ||
from numba import jit, njit | ||
from collections import namedtuple | ||
|
||
__all__ = ['newton', 'newton_halley', 'newton_secant'] | ||
|
||
_ECONVERGED = 0 | ||
_ECONVERR = -1 | ||
|
||
results = namedtuple('results', | ||
('root function_calls iterations converged')) | ||
|
||
@njit | ||
def _results(r): | ||
r"""Select from a tuple of(root, funccalls, iterations, flag)""" | ||
x, funcalls, iterations, flag = r | ||
return results(x, funcalls, iterations, flag == 0) | ||
|
||
@njit | ||
def newton(func, x0, fprime, args=(), tol=1.48e-8, maxiter=50, | ||
disp=True): | ||
""" | ||
Find a zero from the Newton-Raphson method using the jitted version of | ||
Scipy's newton for scalars. Note that this does not provide an alternative | ||
method such as secant. Thus, it is important that `fprime` can be provided. | ||
Note that `func` and `fprime` must be jitted via Numba. | ||
They are recommended to be `njit` for performance. | ||
Parameters | ||
---------- | ||
func : callable and jitted | ||
The function whose zero is wanted. It must be a function of a | ||
single variable of the form f(x,a,b,c...), where a,b,c... are extra | ||
arguments that can be passed in the `args` parameter. | ||
x0 : float | ||
An initial estimate of the zero that should be somewhere near the | ||
actual zero. | ||
fprime : callable and jitted | ||
The derivative of the function (when available and convenient). | ||
args : tuple, optional | ||
Extra arguments to be used in the function call. | ||
tol : float, optional | ||
The allowable error of the zero value. | ||
maxiter : int, optional | ||
Maximum number of iterations. | ||
disp : bool, optional | ||
If True, raise a RuntimeError if the algorithm didn't converge | ||
Returns | ||
------- | ||
results : namedtuple | ||
root - Estimated location where function is zero. | ||
function_calls - Number of times the function was called. | ||
iterations - Number of iterations needed to find the root. | ||
converged - True if the routine converged | ||
""" | ||
|
||
if tol <= 0: | ||
raise ValueError("tol is too small <= 0") | ||
if maxiter < 1: | ||
raise ValueError("maxiter must be greater than 0") | ||
|
||
# Convert to float (don't use float(x0); this works also for complex x0) | ||
p0 = 1.0 * x0 | ||
funcalls = 0 | ||
status = _ECONVERR | ||
|
||
# Newton-Raphson method | ||
for itr in range(maxiter): | ||
# first evaluate fval | ||
fval = func(p0, *args) | ||
funcalls += 1 | ||
# If fval is 0, a root has been found, then terminate | ||
if fval == 0: | ||
status = _ECONVERGED | ||
p = p0 | ||
itr -= 1 | ||
break | ||
fder = fprime(p0, *args) | ||
funcalls += 1 | ||
# derivative is zero, not converged | ||
if fder == 0: | ||
p = p0 | ||
break | ||
newton_step = fval / fder | ||
# Newton step | ||
p = p0 - newton_step | ||
if abs(p - p0) < tol: | ||
status = _ECONVERGED | ||
break | ||
p0 = p | ||
|
||
if disp and status == _ECONVERR: | ||
msg = "Failed to converge" | ||
raise RuntimeError(msg) | ||
|
||
return _results((p, funcalls, itr + 1, status)) | ||
|
||
@njit | ||
def newton_halley(func, x0, fprime, fprime2, args=(), tol=1.48e-8, | ||
maxiter=50, disp=True): | ||
""" | ||
Find a zero from Halley's method using the jitted version of | ||
Scipy's. | ||
`func`, `fprime`, `fprime2` must be jitted via Numba. | ||
Parameters | ||
---------- | ||
func : callable and jitted | ||
The function whose zero is wanted. It must be a function of a | ||
single variable of the form f(x,a,b,c...), where a,b,c... are extra | ||
arguments that can be passed in the `args` parameter. | ||
x0 : float | ||
An initial estimate of the zero that should be somewhere near the | ||
actual zero. | ||
fprime : callable and jitted | ||
The derivative of the function (when available and convenient). | ||
fprime2 : callable and jitted | ||
The second order derivative of the function | ||
args : tuple, optional | ||
Extra arguments to be used in the function call. | ||
tol : float, optional | ||
The allowable error of the zero value. | ||
maxiter : int, optional | ||
Maximum number of iterations. | ||
disp : bool, optional | ||
If True, raise a RuntimeError if the algorithm didn't converge | ||
Returns | ||
------- | ||
results : namedtuple | ||
root - Estimated location where function is zero. | ||
function_calls - Number of times the function was called. | ||
iterations - Number of iterations needed to find the root. | ||
converged - True if the routine converged | ||
""" | ||
|
||
if tol <= 0: | ||
raise ValueError("tol is too small <= 0") | ||
if maxiter < 1: | ||
raise ValueError("maxiter must be greater than 0") | ||
|
||
# Convert to float (don't use float(x0); this works also for complex x0) | ||
p0 = 1.0 * x0 | ||
funcalls = 0 | ||
status = _ECONVERR | ||
|
||
# Halley Method | ||
for itr in range(maxiter): | ||
# first evaluate fval | ||
fval = func(p0, *args) | ||
funcalls += 1 | ||
# If fval is 0, a root has been found, then terminate | ||
if fval == 0: | ||
status = _ECONVERGED | ||
p = p0 | ||
itr -= 1 | ||
break | ||
fder = fprime(p0, *args) | ||
funcalls += 1 | ||
# derivative is zero, not converged | ||
if fder == 0: | ||
p = p0 | ||
break | ||
newton_step = fval / fder | ||
# Halley's variant | ||
fder2 = fprime2(p0, *args) | ||
p = p0 - newton_step / (1.0 - 0.5 * newton_step * fder2 / fder) | ||
if abs(p - p0) < tol: | ||
status = _ECONVERGED | ||
break | ||
p0 = p | ||
|
||
if disp and status == _ECONVERR: | ||
msg = "Failed to converge" | ||
raise RuntimeError(msg) | ||
|
||
return _results((p, funcalls, itr + 1, status)) | ||
|
||
@njit | ||
def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50, | ||
disp=True): | ||
""" | ||
Find a zero from the secant method using the jitted version of | ||
Scipy's secant method. | ||
Note that `func` must be jitted via Numba. | ||
Parameters | ||
---------- | ||
func : callable and jitted | ||
The function whose zero is wanted. It must be a function of a | ||
single variable of the form f(x,a,b,c...), where a,b,c... are extra | ||
arguments that can be passed in the `args` parameter. | ||
x0 : float | ||
An initial estimate of the zero that should be somewhere near the | ||
actual zero. | ||
args : tuple, optional | ||
Extra arguments to be used in the function call. | ||
tol : float, optional | ||
The allowable error of the zero value. | ||
maxiter : int, optional | ||
Maximum number of iterations. | ||
disp : bool, optional | ||
If True, raise a RuntimeError if the algorithm didn't converge. | ||
Returns | ||
------- | ||
results : namedtuple | ||
root - Estimated location where function is zero. | ||
function_calls - Number of times the function was called. | ||
iterations - Number of iterations needed to find the root. | ||
converged - True if the routine converged | ||
""" | ||
|
||
if tol <= 0: | ||
raise ValueError("tol is too small <= 0") | ||
if maxiter < 1: | ||
raise ValueError("maxiter must be greater than 0") | ||
|
||
# Convert to float (don't use float(x0); this works also for complex x0) | ||
p0 = 1.0 * x0 | ||
funcalls = 0 | ||
status = _ECONVERR | ||
|
||
# Secant method | ||
if x0 >= 0: | ||
p1 = x0 * (1 + 1e-4) + 1e-4 | ||
else: | ||
p1 = x0 * (1 + 1e-4) - 1e-4 | ||
q0 = func(p0, *args) | ||
funcalls += 1 | ||
q1 = func(p1, *args) | ||
funcalls += 1 | ||
for itr in range(maxiter): | ||
if q1 == q0: | ||
p = (p1 + p0) / 2.0 | ||
status = _ECONVERGED | ||
break | ||
else: | ||
p = p1 - q1 * (p1 - p0) / (q1 - q0) | ||
if np.abs(p - p1) < tol: | ||
status = _ECONVERGED | ||
break | ||
p0 = p1 | ||
q0 = q1 | ||
p1 = p | ||
q1 = func(p1, *args) | ||
funcalls += 1 | ||
|
||
if disp and status == _ECONVERR: | ||
msg = "Failed to converge" | ||
raise RuntimeError(msg) | ||
|
||
return _results((p, funcalls, itr + 1, status)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import numpy as np | ||
from numpy.testing import assert_almost_equal, assert_allclose | ||
from numba import njit | ||
|
||
from quantecon.optimize import newton, newton_halley, newton_secant | ||
|
||
@njit | ||
def func(x): | ||
""" | ||
Function for testing on. | ||
""" | ||
return (x**3 - 1) | ||
|
||
|
||
@njit | ||
def func_prime(x): | ||
""" | ||
Derivative for func. | ||
""" | ||
return (3*x**2) | ||
|
||
@njit | ||
def func_prime2(x): | ||
""" | ||
Second order derivative for func. | ||
""" | ||
return 6*x | ||
|
||
@njit | ||
def func_two(x): | ||
""" | ||
Harder function for testing on. | ||
""" | ||
return np.sin(4 * (x - 1/4)) + x + x**20 - 1 | ||
|
||
|
||
@njit | ||
def func_two_prime(x): | ||
""" | ||
Derivative for func_two. | ||
""" | ||
return 4*np.cos(4*(x - 1/4)) + 20*x**19 + 1 | ||
|
||
@njit | ||
def func_two_prime2(x): | ||
""" | ||
Second order derivative for func_two | ||
""" | ||
return 380*x**18 - 16*np.sin(4*(x - 1/4)) | ||
|
||
|
||
def test_newton_basic(): | ||
""" | ||
Uses the function f defined above to test the scalar maximization | ||
routine. | ||
""" | ||
true_fval = 1.0 | ||
fval = newton(func, 5, func_prime) | ||
assert_almost_equal(true_fval, fval.root, decimal=4) | ||
|
||
|
||
def test_newton_basic_two(): | ||
""" | ||
Uses the function f defined above to test the scalar maximization | ||
routine. | ||
""" | ||
true_fval = 1.0 | ||
fval = newton(func, 5, func_prime) | ||
assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0) | ||
|
||
|
||
def test_newton_hard(): | ||
""" | ||
Harder test for convergence. | ||
""" | ||
true_fval = 0.408 | ||
fval = newton(func_two, 0.4, func_two_prime) | ||
assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.01) | ||
|
||
def test_halley_basic(): | ||
""" | ||
Basic test for halley method | ||
""" | ||
true_fval = 1.0 | ||
fval = newton_halley(func, 5, func_prime, func_prime2) | ||
assert_almost_equal(true_fval, fval.root, decimal=4) | ||
|
||
def test_halley_hard(): | ||
""" | ||
Harder test for halley method | ||
""" | ||
true_fval = 0.408 | ||
fval = newton_halley(func_two, 0.4, func_two_prime, func_two_prime2) | ||
assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.01) | ||
|
||
def test_secant_basic(): | ||
""" | ||
Basic test for secant option. | ||
""" | ||
true_fval = 1.0 | ||
fval = newton_secant(func, 5) | ||
assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.001) | ||
|
||
|
||
def test_secant_hard(): | ||
""" | ||
Harder test for convergence for secant function. | ||
""" | ||
true_fval = 0.408 | ||
fval = newton_secant(func_two, 0.4) | ||
assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.01) | ||
|
||
|
||
# executing testcases. | ||
|
||
if __name__ == '__main__': | ||
import sys | ||
import nose | ||
|
||
argv = sys.argv[:] | ||
argv.append('--verbose') | ||
argv.append('--nocapture') | ||
nose.main(argv=argv, defaultTest=__file__) |