Skip to content

Commit 1e4c3c0

Browse files
Merge pull request #266 from yonatanwesen/yd/pt-caching
Removing Allocation for Inexact Jacobian
2 parents 3565824 + abaf747 commit 1e4c3c0

File tree

3 files changed

+26
-9
lines changed

3 files changed

+26
-9
lines changed

Diff for: Project.toml

+4-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1010
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1111
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
12+
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
1213
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1314
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1415
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
@@ -35,13 +36,14 @@ NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
3536
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
3637

3738
[compat]
38-
BandedMatrices = "1"
3939
ADTypes = "0.2"
4040
ArrayInterface = "6.0.24, 7"
41+
BandedMatrices = "1"
4142
ConcreteStructs = "0.2"
4243
DiffEqBase = "6.130"
4344
EnumX = "1"
4445
Enzyme = "0.11"
46+
FastBroadcast = "0.1.9, 0.2"
4547
FastLevenbergMarquardt = "0.1"
4648
FiniteDiff = "2"
4749
ForwardDiff = "0.10.3"
@@ -63,6 +65,7 @@ julia = "1.9"
6365
[extras]
6466
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
6567
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
68+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
6669
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
6770
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
6871
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -79,7 +82,6 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
7982
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
8083
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8184
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
82-
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
8385

8486
[targets]
8587
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase"]

Diff for: src/NonlinearSolve.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ import PrecompileTools
99

1010
PrecompileTools.@recompile_invalidations begin
1111
using DiffEqBase, LinearAlgebra, LinearSolve, SparseArrays, SparseDiffTools
12+
using FastBroadcast: @..
1213
import ArrayInterface: restructure
1314

1415
import ADTypes: AbstractFiniteDifferencesMode
1516
import ArrayInterface: undefmatrix,
16-
matrix_colors, parameterless_type, ismutable, issingular
17+
matrix_colors, parameterless_type, ismutable, issingular,fast_scalar_indexing
1718
import ConcreteStructs: @concrete
1819
import EnumX: @enumx
1920
import ForwardDiff

Diff for: src/pseudotransient.jl

+20-6
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,25 @@ end
114114
function perform_step!(cache::PseudoTransientCache{true})
115115
@unpack u, u_prev, fu1, f, p, alg, J, linsolve, du, alpha, tc_storage = cache
116116
jacobian!!(J, cache)
117-
J_new = J - (1 / alpha) * I
117+
inv_alpha = inv(alpha)
118+
119+
if J isa SciMLBase.AbstractSciMLOperator
120+
J = J - inv_alpha * I
121+
else
122+
idxs = diagind(J)
123+
if fast_scalar_indexing(J)
124+
@inbounds for i in axes(J, 1)
125+
J[i, i] = J[i, i] - inv_alpha
126+
end
127+
else
128+
@.. broadcast=false @view(J[idxs])=@view(J[idxs]) - inv_alpha
129+
end
130+
end
118131

119132
termination_condition = cache.termination_condition(tc_storage)
120133

121134
# u = u - J \ fu
122-
linres = dolinsolve(alg.precs, linsolve; A = J_new, b = _vec(fu1), linu = _vec(du),
135+
linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(fu1), linu = _vec(du),
123136
p, reltol = cache.abstol)
124137
cache.linsolve = linres.cache
125138
@. u = u - du
@@ -147,13 +160,14 @@ function perform_step!(cache::PseudoTransientCache{false})
147160
termination_condition = cache.termination_condition(tc_storage)
148161

149162
cache.J = jacobian!!(cache.J, cache)
163+
inv_alpha = inv(alpha)
164+
165+
cache.J = cache.J - inv_alpha * I
150166
# u = u - J \ fu
151167
if linsolve === nothing
152-
cache.du = fu1 / (cache.J - (1 / alpha) * I)
168+
cache.du = fu1 / cache.J
153169
else
154-
linres = dolinsolve(alg.precs, linsolve; A = cache.J - (1 / alpha) * I,
155-
b = _vec(fu1),
156-
linu = _vec(cache.du), p, reltol = cache.abstol)
170+
linres = dolinsolve(alg.precs, linsolve; A = cache.J,b = _vec(fu1),linu = _vec(cache.du), p, reltol = cache.abstol)
157171
cache.linsolve = linres.cache
158172
end
159173
cache.u = @. u - cache.du # `u` might not support mutation

0 commit comments

Comments
 (0)