Skip to content

Commit f6cb31f

Browse files
authoredFeb 26, 2024
Merge pull request #131 from avik-pal/ap/nlls_adjoint
Forward Mode overloads for Least Squares Problem
2 parents a3f9a13 + 2b8b5e4 commit f6cb31f

9 files changed

+311
-32
lines changed
 

Diff for: ‎lib/SimpleNonlinearSolve/Project.toml

+21-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1010
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
11+
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
1112
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1213
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1314
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -22,34 +23,51 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2223
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2324
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
2425
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
26+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2527

2628
[extensions]
2729
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
2830
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
2931
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
32+
SimpleNonlinearSolveZygoteExt = "Zygote"
3033

3134
[compat]
3235
ADTypes = "0.2.6"
36+
AllocCheck = "0.1.1"
3337
ArrayInterface = "7.7"
34-
ChainRulesCore = "1.21"
38+
Aqua = "0.8"
39+
CUDA = "5.2"
40+
ChainRulesCore = "1.22"
3541
ConcreteStructs = "0.2.3"
3642
DiffEqBase = "6.146"
43+
DiffResults = "1.1"
3744
FastClosures = "0.3"
3845
FiniteDiff = "2.22"
3946
ForwardDiff = "0.10.36"
4047
LinearAlgebra = "1.10"
48+
LinearSolve = "2.25"
4149
MaybeInplace = "0.1.1"
50+
NonlinearProblemLibrary = "0.1.2"
51+
Pkg = "1.10"
52+
PolyesterForwardDiff = "0.1.1"
4253
PrecompileTools = "1.2"
54+
Random = "1.10"
55+
ReTestItems = "1.23"
4356
Reexport = "1.2"
44-
SciMLBase = "2.23"
57+
SciMLBase = "2.26.3"
58+
SciMLSensitivity = "7.56"
4559
StaticArrays = "1.9"
4660
StaticArraysCore = "1.4.2"
61+
Test = "1.10"
62+
Zygote = "0.6.69"
4763
julia = "1.10"
4864

4965
[extras]
5066
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
67+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
5168
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
5269
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
70+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
5371
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
5472
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5573
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
@@ -65,4 +83,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6583
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6684

6785
[targets]
68-
test = ["AllocCheck", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SciMLSensitivity", "StaticArrays", "Zygote", "CUDA", "PolyesterForwardDiff", "Reexport", "Test"]
86+
test = ["Aqua", "AllocCheck", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SciMLSensitivity", "StaticArrays", "Zygote", "CUDA", "PolyesterForwardDiff", "Reexport", "Test", "FiniteDiff"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module SimpleNonlinearSolveZygoteExt
2+
3+
import SimpleNonlinearSolve, Zygote
4+
5+
SimpleNonlinearSolve.__is_extension_loaded(::Val{:Zygote}) = true
6+
7+
function SimpleNonlinearSolve.__zygote_compute_nlls_vjp(f::F, u, p) where {F}
8+
y, pb = Zygote.pullback(Base.Fix2(f, p), u)
9+
return 2 .* only(pb(y))
10+
end
11+
12+
end

Diff for: ‎lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidat
1111
AbstractSafeBestNonlinearTerminationMode,
1212
NonlinearSafeTerminationReturnCode, get_termination_mode,
1313
NONLINEARSOLVE_DEFAULT_NORM
14+
import DiffResults
1415
import ForwardDiff: Dual
1516
import MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
1617
import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace, _unwrap_val

Diff for: ‎lib/SimpleNonlinearSolve/src/ad.jl

+103-3
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,17 @@ function SciMLBase.solve(
55
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
66
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
77
return SciMLBase.build_solution(
8-
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats,
9-
sol.original)
8+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
9+
end
10+
11+
function SciMLBase.solve(
12+
prob::NonlinearLeastSquaresProblem{<:AbstractArray,
13+
iip, <:Union{<:AbstractArray{<:Dual{T, V, P}}}},
14+
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) where {T, V, P, iip}
15+
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
16+
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
17+
return SciMLBase.build_solution(
18+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
1019
end
1120

1221
for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
@@ -24,7 +33,8 @@ for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
2433
end
2534
end
2635

27-
function __nlsolve_ad(prob, alg, args...; kwargs...)
36+
function __nlsolve_ad(
37+
prob::Union{IntervalNonlinearProblem, NonlinearProblem}, alg, args...; kwargs...)
2838
p = value(prob.p)
2939
if prob isa IntervalNonlinearProblem
3040
tspan = value.(prob.tspan)
@@ -55,6 +65,96 @@ function __nlsolve_ad(prob, alg, args...; kwargs...)
5565
return sol, partials
5666
end
5767

68+
function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...)
69+
p = value(prob.p)
70+
u0 = value(prob.u0)
71+
newprob = NonlinearLeastSquaresProblem(prob.f, u0, p; prob.kwargs...)
72+
73+
sol = solve(newprob, alg, args...; kwargs...)
74+
75+
uu = sol.u
76+
77+
# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
78+
# nested autodiff as the last resort
79+
if SciMLBase.has_vjp(prob.f)
80+
if isinplace(prob)
81+
_F = @closure (du, u, p) -> begin
82+
resid = similar(du, length(sol.resid))
83+
prob.f(resid, u, p)
84+
prob.f.vjp(du, resid, u, p)
85+
du .*= 2
86+
return nothing
87+
end
88+
else
89+
_F = @closure (u, p) -> begin
90+
resid = prob.f(u, p)
91+
return reshape(2 .* prob.f.vjp(resid, u, p), size(u))
92+
end
93+
end
94+
elseif SciMLBase.has_jac(prob.f)
95+
if isinplace(prob)
96+
_F = @closure (du, u, p) -> begin
97+
J = similar(du, length(sol.resid), length(u))
98+
prob.f.jac(J, u, p)
99+
resid = similar(du, length(sol.resid))
100+
prob.f(resid, u, p)
101+
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
102+
return nothing
103+
end
104+
else
105+
_F = @closure (u, p) -> begin
106+
return reshape(2 .* vec(prob.f(u, p))' * prob.f.jac(u, p), size(u))
107+
end
108+
end
109+
else
110+
if isinplace(prob)
111+
_F = @closure (du, u, p) -> begin
112+
resid = similar(du, length(sol.resid))
113+
res = DiffResults.DiffResult(
114+
resid, similar(du, length(sol.resid), length(u)))
115+
_f = @closure (du, u) -> prob.f(du, u, p)
116+
ForwardDiff.jacobian!(res, _f, resid, u)
117+
mul!(reshape(du, 1, :), vec(DiffResults.value(res))',
118+
DiffResults.jacobian(res), 2, false)
119+
return nothing
120+
end
121+
else
122+
# For small problems, nesting ForwardDiff is actually quite fast
123+
if __is_extension_loaded(Val(:Zygote)) && (length(uu) + length(sol.resid) 50)
124+
_F = @closure (u, p) -> __zygote_compute_nlls_vjp(prob.f, u, p)
125+
else
126+
_F = @closure (u, p) -> begin
127+
T = promote_type(eltype(u), eltype(p))
128+
res = DiffResults.DiffResult(
129+
similar(u, T, size(sol.resid)), similar(
130+
u, T, length(sol.resid), length(u)))
131+
ForwardDiff.jacobian!(res, Base.Fix2(prob.f, p), u)
132+
return reshape(
133+
2 .* vec(DiffResults.value(res))' * DiffResults.jacobian(res),
134+
size(u))
135+
end
136+
end
137+
end
138+
end
139+
140+
f_p = __nlsolve_∂f_∂p(prob, _F, uu, p)
141+
f_x = __nlsolve_∂f_∂u(prob, _F, uu, p)
142+
143+
z_arr = -f_x \ f_p
144+
145+
pp = prob.p
146+
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
147+
if uu isa Number
148+
partials = sum(sumfun, zip(z_arr, pp))
149+
elseif p isa Number
150+
partials = sumfun((z_arr, pp))
151+
else
152+
partials = sum(sumfun, zip(eachcol(z_arr), pp))
153+
end
154+
155+
return sol, partials
156+
end
157+
58158
@inline function __nlsolve_∂f_∂p(prob, f::F, u, p) where {F}
59159
if isinplace(prob)
60160
__f = p -> begin

Diff for: ‎lib/SimpleNonlinearSolve/src/utils.jl

+3
Original file line numberDiff line numberDiff line change
@@ -388,3 +388,6 @@ function __get_tolerance(x::Union{SArray, Number}, ::Nothing, ::Type{T}) where {
388388
η = real(oneunit(T)) * (eps(real(one(T))))^(real(T)(0.8))
389389
return T(η)
390390
end
391+
392+
# Extension
393+
function __zygote_compute_nlls_vjp end

Diff for: ‎lib/SimpleNonlinearSolve/test/core/aqua_tests.jl

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
@testitem "Aqua" begin
2+
using Aqua
3+
4+
Aqua.test_all(SimpleNonlinearSolve; piracies = false, ambiguities = false)
5+
Aqua.test_piracies(SimpleNonlinearSolve;
6+
treat_as_own = [
7+
NonlinearProblem, NonlinearLeastSquaresProblem, IntervalNonlinearProblem])
8+
Aqua.test_ambiguities(SimpleNonlinearSolve; recursive = false)
9+
end

Diff for: ‎lib/SimpleNonlinearSolve/test/core/forward_ad_tests.jl

+120-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testsetup module ForwardADTesting
1+
@testsetup module ForwardADRootfindingTesting
22
using Reexport
33
@reexport using ForwardDiff, SimpleNonlinearSolve, StaticArrays, LinearAlgebra
44
import SimpleNonlinearSolve: AbstractSimpleNonlinearSolveAlgorithm
@@ -40,7 +40,7 @@ __compatible(::SimpleHalley, ::Val{:iip}) = false
4040
export test_f, test_f!, jacobian_f, solve_with, __compatible
4141
end
4242

43-
@testitem "ForwardDiff.jl Integration" setup=[ForwardADTesting] begin
43+
@testitem "ForwardDiff.jl Integration: Rootfinding" setup=[ForwardADRootfindingTesting] begin
4444
@testset "$(nameof(typeof(alg)))" for alg in (SimpleNewtonRaphson(),
4545
SimpleTrustRegion(), SimpleTrustRegion(; nlsolve_update_rule = Val(true)),
4646
SimpleHalley(), SimpleBroyden(), SimpleKlement(), SimpleDFSane())
@@ -88,3 +88,121 @@ end
8888
end
8989
end
9090
end
91+
92+
@testsetup module ForwardADNLLSTesting
93+
using Reexport
94+
@reexport using ForwardDiff, FiniteDiff, SimpleNonlinearSolve, StaticArrays, LinearAlgebra,
95+
Zygote
96+
97+
true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])
98+
99+
const θ_true = [1.0, 0.1, 2.0, 0.5]
100+
const x = [-1.0, -0.5, 0.0, 0.5, 1.0]
101+
const y_target = true_function(x, θ_true)
102+
103+
function loss_function(θ, p)
104+
= true_function(p, θ)
105+
return.- y_target
106+
end
107+
108+
function loss_function_jac(θ, p)
109+
return ForwardDiff.jacobian-> loss_function(θ, p), θ)
110+
end
111+
112+
loss_function_vjp(v, θ, p) = reshape(vec(v)' * loss_function_jac(θ, p), size(θ))
113+
114+
function loss_function!(resid, θ, p)
115+
= true_function(p, θ)
116+
@. resid =- y_target
117+
return
118+
end
119+
120+
function loss_function_jac!(J, θ, p)
121+
J .= ForwardDiff.jacobian-> loss_function(θ, p), θ)
122+
return
123+
end
124+
125+
function loss_function_vjp!(vJ, v, θ, p)
126+
vec(vJ) .= reshape(vec(v)' * loss_function_jac(θ, p), size(θ))
127+
return
128+
end
129+
130+
θ_init = θ_true .+ 0.1
131+
132+
export loss_function, loss_function!, loss_function_jac, loss_function_vjp,
133+
loss_function_jac!, loss_function_vjp!, θ_init, x, y_target
134+
end
135+
136+
@testitem "ForwardDiff.jl Integration: NLLS" setup=[ForwardADNLLSTesting] begin
137+
@testset "$(nameof(typeof(alg)))" for alg in (
138+
SimpleNewtonRaphson(), SimpleGaussNewton(),
139+
SimpleNewtonRaphson(AutoFiniteDiff()), SimpleGaussNewton(AutoFiniteDiff()))
140+
function obj_1(p)
141+
prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, p)
142+
sol = solve(prob_oop, alg)
143+
return sum(abs2, sol.u)
144+
end
145+
146+
function obj_2(p)
147+
ff = NonlinearFunction{false}(loss_function; jac = loss_function_jac)
148+
prob_oop = NonlinearLeastSquaresProblem{false}(ff, θ_init, p)
149+
sol = solve(prob_oop, alg)
150+
return sum(abs2, sol.u)
151+
end
152+
153+
function obj_3(p)
154+
ff = NonlinearFunction{false}(loss_function; vjp = loss_function_vjp)
155+
prob_oop = NonlinearLeastSquaresProblem{false}(ff, θ_init, p)
156+
sol = solve(prob_oop, alg)
157+
return sum(abs2, sol.u)
158+
end
159+
160+
finitediff = FiniteDiff.finite_difference_gradient(obj_1, x)
161+
162+
fdiff1 = ForwardDiff.gradient(obj_1, x)
163+
fdiff2 = ForwardDiff.gradient(obj_2, x)
164+
fdiff3 = ForwardDiff.gradient(obj_3, x)
165+
166+
@test finitedifffdiff1 atol=1e-5
167+
@test finitedifffdiff2 atol=1e-5
168+
@test finitedifffdiff3 atol=1e-5
169+
@test fdiff1 fdiff2 fdiff3
170+
171+
function obj_4(p)
172+
prob_iip = NonlinearLeastSquaresProblem(
173+
NonlinearFunction{true}(
174+
loss_function!; resid_prototype = zeros(length(y_target))), θ_init, p)
175+
sol = solve(prob_iip, alg)
176+
return sum(abs2, sol.u)
177+
end
178+
179+
function obj_5(p)
180+
ff = NonlinearFunction{true}(
181+
loss_function!; resid_prototype = zeros(length(y_target)), jac = loss_function_jac!)
182+
prob_iip = NonlinearLeastSquaresProblem(
183+
ff, θ_init, p)
184+
sol = solve(prob_iip, alg)
185+
return sum(abs2, sol.u)
186+
end
187+
188+
function obj_6(p)
189+
ff = NonlinearFunction{true}(
190+
loss_function!; resid_prototype = zeros(length(y_target)), vjp = loss_function_vjp!)
191+
prob_iip = NonlinearLeastSquaresProblem(
192+
ff, θ_init, p)
193+
sol = solve(prob_iip, alg)
194+
return sum(abs2, sol.u)
195+
end
196+
197+
finitediff = FiniteDiff.finite_difference_gradient(obj_4, x)
198+
199+
fdiff4 = ForwardDiff.gradient(obj_4, x)
200+
fdiff5 = ForwardDiff.gradient(obj_5, x)
201+
fdiff6 = ForwardDiff.gradient(obj_6, x)
202+
203+
@test finitedifffdiff4 atol=1e-5
204+
@test finitedifffdiff5 atol=1e-5
205+
@test finitedifffdiff6 atol=1e-5
206+
@test fdiff4 fdiff5 fdiff6
207+
end
208+
end

Diff for: ‎lib/SimpleNonlinearSolve/test/core/least_squares_tests.jl

+16
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
return.- y_target
1313
end
1414

15+
function loss_function!(resid, θ, p)
16+
= true_function(p, θ)
17+
@. resid =- y_target
18+
return
19+
end
20+
1521
θ_init = θ_true .+ 0.1
1622
prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, x)
1723

@@ -21,4 +27,14 @@
2127
sol = solve(prob_oop, solver)
2228
@test norm(sol.resid, Inf) < 1e-12
2329
end
30+
31+
prob_iip = NonlinearLeastSquaresProblem(
32+
NonlinearFunction{true}(loss_function!, resid_prototype = zeros(length(y_target))), θ_init, x)
33+
34+
@testset "Solver: $(nameof(typeof(solver)))" for solver in [
35+
SimpleNewtonRaphson(AutoForwardDiff()), SimpleGaussNewton(AutoForwardDiff()),
36+
SimpleNewtonRaphson(AutoFiniteDiff()), SimpleGaussNewton(AutoFiniteDiff())]
37+
sol = solve(prob_iip, solver)
38+
@test norm(sol.resid, Inf) < 1e-12
39+
end
2440
end

0 commit comments

Comments
 (0)