Skip to content

Remove DiffEqBase Dependency #9

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 1 commit into from
Oct 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ authors = ["Kanav Gupta <kanav0610@gmail.com>"]
version = "0.1.0"

[deps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["BenchmarkTools", "Test", "ForwardDiff"]
test = ["BenchmarkTools", "Test", "ForwardDiff"]
9 changes: 7 additions & 2 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
module NonlinearSolve

using Reexport
@reexport using DiffEqBase
using UnPack: @unpack
using FiniteDiff, ForwardDiff
using Setfield
using StaticArrays
using RecursiveArrayTools

abstract type AbstractNonlinearProblem{uType,isinplace} end
abstract type AbstractNonlinearSolveAlgorithm end
abstract type AbstractBracketingAlgorithm <: AbstractNonlinearSolveAlgorithm end
abstract type AbstractNewtonAlgorithm{CS,AD} <: AbstractNonlinearSolveAlgorithm end
abstract type AbstractNonlinearSolver end
abstract type AbstractImmutableNonlinearSolver <: AbstractNonlinearSolver end

include("utils.jl")
include("jacobian.jl")
include("types.jl")
include("utils.jl")
include("solve.jl")
include("bisection.jl")
include("falsi.jl")
Expand All @@ -28,5 +29,9 @@ module NonlinearSolve
# DiffEq styled algorithms
export Bisection, Falsi, NewtonRaphson

export NonlinearProblem

export solve, init, solve!

export reinit!
end # module
4 changes: 2 additions & 2 deletions src/scalar.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function DiffEqBase.solve(prob::NonlinearProblem{<:Number}, ::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...)
function solve(prob::NonlinearProblem{<:Number}, ::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
x = float(prob.u0)
T = typeof(x)
Expand All @@ -19,7 +19,7 @@ function DiffEqBase.solve(prob::NonlinearProblem{<:Number}, ::NewtonRaphson, arg
return NewtonSolution(x, MAXITERS_EXCEED)
end

function DiffEqBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...)
function solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...)
f = Base.Fix2(prob.f, prob.p)
left, right = prob.u0
fl, fr = f(left), f(right)
Expand Down
12 changes: 6 additions & 6 deletions src/solve.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
function DiffEqBase.solve(prob::NonlinearProblem,
function solve(prob::NonlinearProblem,
alg::AbstractNonlinearSolveAlgorithm, args...;
kwargs...)
solver = DiffEqBase.init(prob, alg, args...; kwargs...)
solver = init(prob, alg, args...; kwargs...)
sol = solve!(solver)
return sol
end

function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...;
function init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...;
alias_u0 = false,
maxiters = 1000,
kwargs...
Expand All @@ -33,11 +33,11 @@ function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracke
return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters, DEFAULT, cache, iip)
end

function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, args...;
function init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, args...;
alias_u0 = false,
maxiters = 1000,
tol = 1e-6,
internalnorm = Base.Fix2(DiffEqBase.ODE_DEFAULT_NORM, nothing),
internalnorm = DEFAULT_NORM,
kwargs...
) where {uType, iip}

Expand All @@ -58,7 +58,7 @@ function DiffEqBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewton
return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm, DEFAULT, tol, cache, iip)
end

function DiffEqBase.solve!(solver::AbstractImmutableNonlinearSolver)
function solve!(solver::AbstractImmutableNonlinearSolver)
solver = mic_check(solver)
while !solver.force_stop && solver.iter < solver.maxiters
solver = perform_step(solver, solver.alg, Val(solver.iip))
Expand Down
14 changes: 14 additions & 0 deletions src/types.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
struct NullParameters end

struct NonlinearProblem{uType,isinplace,P,F,K} <: AbstractNonlinearProblem{uType,isinplace}
f::F
u0::uType
p::P
kwargs::K
@add_kwonly function NonlinearProblem{iip}(f,u0,p=NullParameters();kwargs...) where iip
new{typeof(u0),iip,typeof(p),typeof(f),typeof(kwargs)}(f,u0,p,kwargs)
end
end

NonlinearProblem(f,u0,args...;kwargs...) = NonlinearProblem{isinplace(f, 3)}(f,u0,args...;kwargs...)

@enum Retcode::Int begin
DEFAULT
EXACT_SOLUTION_LEFT
Expand Down
213 changes: 210 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,213 @@
"""
@add_kwonly function_definition

Define keyword-only version of the `function_definition`.

@add_kwonly function f(x; y=1)
...
end

expands to:

function f(x; y=1)
...
end
function f(; x = error("No argument x"), y=1)
...
end
"""
macro add_kwonly(ex)
esc(add_kwonly(ex))
end

add_kwonly(ex::Expr) = add_kwonly(Val{ex.head}, ex)

function add_kwonly(::Type{<: Val}, ex)
error("add_only does not work with expression $(ex.head)")
end

function add_kwonly(::Union{Type{Val{:function}},
Type{Val{:(=)}}}, ex::Expr)
body = ex.args[2:end] # function body
default_call = ex.args[1] # e.g., :(f(a, b=2; c=3))
kwonly_call = add_kwonly(default_call)
if kwonly_call === nothing
return ex
end

return quote
begin
$ex
$(Expr(ex.head, kwonly_call, body...))
end
end
end

function add_kwonly(::Type{Val{:where}}, ex::Expr)
default_call = ex.args[1]
rest = ex.args[2:end]
kwonly_call = add_kwonly(default_call)
if kwonly_call === nothing
return nothing
end
return Expr(:where, kwonly_call, rest...)
end

function add_kwonly(::Type{Val{:call}}, default_call::Expr)
# default_call is, e.g., :(f(a, b=2; c=3))
funcname = default_call.args[1] # e.g., :f
required = [] # required positional arguments; e.g., [:a]
optional = [] # optional positional arguments; e.g., [:(b=2)]
default_kwargs = []
for arg in default_call.args[2:end]
if isa(arg, Symbol)
push!(required, arg)
elseif arg.head == :(::)
push!(required, arg)
elseif arg.head == :kw
push!(optional, arg)
elseif arg.head == :parameters
@assert default_kwargs == [] # can I have :parameters twice?
default_kwargs = arg.args
else
error("Not expecting to see: $arg")
end
end
if isempty(required) && isempty(optional)
# If the function is already keyword-only, do nothing:
return nothing
end
if isempty(required)
# It's not clear what should be done. Let's not support it at
# the moment:
error("At least one positional mandatory argument is required.")
end

kwonly_kwargs = Expr(:parameters, [
Expr(:kw, pa, :(error($("No argument $pa"))))
for pa in required
]..., optional..., default_kwargs...)
kwonly_call = Expr(:call, funcname, kwonly_kwargs)
# e.g., :(f(; a=error(...), b=error(...), c=1, d=2))

return kwonly_call
end

function num_types_in_tuple(sig)
length(sig.parameters)
end

function num_types_in_tuple(sig::UnionAll)
length(Base.unwrap_unionall(sig).parameters)
end

function numargs(f)
typ = Tuple{Any, Val{:analytic}, Vararg}
typ2 = Tuple{Any, Type{Val{:analytic}}, Vararg} # This one is required for overloaded types
typ3 = Tuple{Any, Val{:jac}, Vararg}
typ4 = Tuple{Any, Type{Val{:jac}}, Vararg} # This one is required for overloaded types
typ5 = Tuple{Any, Val{:tgrad}, Vararg}
typ6 = Tuple{Any, Type{Val{:tgrad}}, Vararg} # This one is required for overloaded types
numparam = maximum([(m.sig<:typ || m.sig<:typ2 || m.sig<:typ3 || m.sig<:typ4 || m.sig<:typ5 || m.sig<:typ6) ? 0 : num_types_in_tuple(m.sig) for m in methods(f)])
return (numparam-1) #-1 in v0.5 since it adds f as the first parameter
end

function isinplace(f,inplace_param_number)
numargs(f)>=inplace_param_number
end

### Default Linsolve

# Try to be as smart as possible
# lu! if Matrix
# lu if sparse
# gmres if operator

mutable struct DefaultLinSolve
A
iterable
end
DefaultLinSolve() = DefaultLinSolve(nothing, nothing)

function (p::DefaultLinSolve)(x,A,b,update_matrix=false;tol=nothing, kwargs...)
if p.iterable isa Vector && eltype(p.iterable) <: LinearAlgebra.BlasInt # `iterable` here is the pivoting vector
F = LU{eltype(A)}(A, p.iterable, zero(LinearAlgebra.BlasInt))
ldiv!(x, F, b)
return nothing
end
if update_matrix
if typeof(A) <: Matrix
blasvendor = BLAS.vendor()
# if the user doesn't use OpenBLAS, we assume that is a better BLAS
# implementation like MKL
#
# RecursiveFactorization seems to be consistantly winning below 100
# https://discourse.julialang.org/t/ann-recursivefactorization-jl/39213
if ArrayInterface.can_setindex(x) && (size(A,1) <= 100 || ((blasvendor === :openblas || blasvendor === :openblas64) && size(A,1) <= 500))
p.A = RecursiveFactorization.lu!(A)
else
p.A = lu!(A)
end
elseif typeof(A) <: Tridiagonal
p.A = lu!(A)
elseif typeof(A) <: Union{SymTridiagonal}
p.A = ldlt!(A)
elseif typeof(A) <: Union{Symmetric,Hermitian}
p.A = bunchkaufman!(A)
elseif typeof(A) <: SparseMatrixCSC
p.A = lu(A)
elseif ArrayInterface.isstructured(A)
p.A = factorize(A)
elseif !(typeof(A) <: AbstractDiffEqOperator)
# Most likely QR is the one that is overloaded
# Works on things like CuArrays
p.A = qr(A)
end
end

if typeof(A) <: Union{Matrix,SymTridiagonal,Tridiagonal,Symmetric,Hermitian} # No 2-arg form for SparseArrays!
x .= b
ldiv!(p.A,x)
# Missing a little bit of efficiency in a rare case
#elseif typeof(A) <: DiffEqArrayOperator
# ldiv!(x,p.A,b)
elseif ArrayInterface.isstructured(A) || A isa SparseMatrixCSC
ldiv!(x,p.A,b)
elseif typeof(A) <: AbstractDiffEqOperator
# No good starting guess, so guess zero
if p.iterable === nothing
p.iterable = IterativeSolvers.gmres_iterable!(x,A,b;initially_zero=true,restart=5,maxiter=5,tol=1e-16,kwargs...)
p.iterable.reltol = tol
end
x .= false
iter = p.iterable
purge_history!(iter, x, b)

for residual in iter
end
else
ldiv!(x,p.A,b)
end
return nothing
end

function (p::DefaultLinSolve)(::Type{Val{:init}},f,u0_prototype)
DefaultLinSolve()
end

const DEFAULT_LINSOLVE = DefaultLinSolve()

@inline UNITLESS_ABS2(x) = real(abs2(x))
@inline DEFAULT_NORM(u::Union{AbstractFloat,Complex}) = @fastmath abs(u)
@inline DEFAULT_NORM(u::Array{T}) where T<:Union{AbstractFloat,Complex} =
sqrt(real(sum(abs2,u)) / length(u))
@inline DEFAULT_NORM(u::StaticArray{T}) where T<:Union{AbstractFloat,Complex} =
sqrt(real(sum(abs2,u)) / length(u))
@inline DEFAULT_NORM(u::RecursiveArrayTools.AbstractVectorOfArray) =
sum(sqrt(real(sum(UNITLESS_ABS2,_u)) / length(_u)) for _u in u.u)
@inline DEFAULT_NORM(u::AbstractArray) = sqrt(real(sum(UNITLESS_ABS2,u)) / length(u))
@inline DEFAULT_NORM(u) = norm(u)

"""
prevfloat_tdir(x, x0, x1)

Expand All @@ -24,6 +234,3 @@ function value_derivative(f::F, x::R) where {F,R}
out = f(ForwardDiff.Dual{T}(x, one(x)))
ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
end

DiffEqBase.has_Wfact(f::Function) = false
DiffEqBase.has_Wfact_t(f::Function) = false