@@ -21,23 +21,24 @@ struct __PotraPtak3 <: AbstractMultiStepScheme end
21
21
const PotraPtak3 = __PotraPtak3 ()
22
22
23
23
alg_steps (:: __PotraPtak3 ) = 2
24
+ nintermediates (:: __PotraPtak3 ) = 1
24
25
25
26
@kwdef @concrete struct __SinghSharma4 <: AbstractMultiStepScheme
26
- vjp_autodiff = nothing
27
+ jvp_autodiff = nothing
27
28
end
28
29
const SinghSharma4 = __SinghSharma4 ()
29
30
30
31
alg_steps (:: __SinghSharma4 ) = 3
31
32
32
33
@kwdef @concrete struct __SinghSharma5 <: AbstractMultiStepScheme
33
- vjp_autodiff = nothing
34
+ jvp_autodiff = nothing
34
35
end
35
36
const SinghSharma5 = __SinghSharma5 ()
36
37
37
38
alg_steps (:: __SinghSharma5 ) = 3
38
39
39
40
@kwdef @concrete struct __SinghSharma7 <: AbstractMultiStepScheme
40
- vjp_autodiff = nothing
41
+ jvp_autodiff = nothing
41
42
end
42
43
const SinghSharma7 = __SinghSharma7 ()
43
44
60
61
61
62
Base. show (io:: IO , alg:: GenericMultiStepDescent ) = print (io, " $(alg. scheme) ()" )
62
63
63
- supports_line_search (:: GenericMultiStepDescent ) = false
64
+ supports_line_search (:: GenericMultiStepDescent ) = true
64
65
supports_trust_region (:: GenericMultiStepDescent ) = false
65
66
66
- @concrete mutable struct GenericMultiStepDescentCache{S, INV } <: AbstractDescentCache
67
+ @concrete mutable struct GenericMultiStepDescentCache{S} <: AbstractDescentCache
67
68
f
68
69
p
69
70
δu
70
71
δus
71
- extras
72
+ u
73
+ us
74
+ fu
75
+ fus
76
+ internal_cache
77
+ internal_caches
72
78
scheme:: S
73
- lincache
74
79
timer
75
80
nf:: Int
76
81
end
77
82
78
- @internal_caches GenericMultiStepDescentCache :lincache
83
+ # FIXME : @internal_caches needs to be updated to support tuples and namedtuples
84
+ # @internal_caches GenericMultiStepDescentCache :internal_caches
79
85
80
86
function __reinit_internal! (cache:: GenericMultiStepDescentCache , args... ; p = cache. p,
81
87
kwargs... )
82
88
cache. nf = 0
83
89
cache. p = p
90
+ reset_timer! (cache. timer)
84
91
end
85
92
86
- function __δu_caches (scheme:: MSS.__PotraPtak3 , fu, u, :: Val{N} ) where {N}
87
- caches = ntuple (N) do i
88
- @bb δu = similar (u)
89
- @bb y = similar (u)
90
- @bb fy = similar (fu)
91
- @bb δy = similar (u)
92
- @bb u_new = similar (u)
93
- (δu, δy, fy, y, u_new)
93
+ function __internal_multistep_caches (
94
+ scheme:: MSS.__PotraPtak3 , alg:: GenericMultiStepDescent ,
95
+ prob, args... ; shared:: Val{N} = Val (1 ), kwargs... ) where {N}
96
+ internal_descent = NewtonDescent (; alg. linsolve, alg. precs)
97
+ internal_cache = __internal_init (
98
+ prob, internal_descent, args... ; kwargs... , shared = Val (2 ))
99
+ internal_caches = N ≤ 1 ? nothing :
100
+ map (2 : N) do i
101
+ __internal_init (prob, internal_descent, args... ; kwargs... , shared = Val (2 ))
94
102
end
95
- return first (caches), (N ≤ 1 ? nothing : caches[ 2 : end ])
103
+ return internal_cache, internal_caches
96
104
end
97
105
98
- function __internal_init (prob:: NonlinearProblem , alg:: GenericMultiStepDescent , J, fu, u;
99
- shared:: Val{N} = Val (1 ), pre_inverted:: Val{INV} = False, linsolve_kwargs = (;),
106
+ function __internal_init (prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} ,
107
+ alg:: GenericMultiStepDescent , J, fu, u; shared:: Val{N} = Val (1 ),
108
+ pre_inverted:: Val{INV} = False, linsolve_kwargs = (;),
100
109
abstol = nothing , reltol = nothing , timer = get_timer_output (),
101
110
kwargs... ) where {INV, N}
102
- δu, δus = __δu_caches (alg. scheme, fu, u, shared)
103
- INV && return GenericMultiStepDescentCache {true} (prob. f, prob. p, δu, δus,
104
- alg. scheme, nothing , timer, 0 )
105
- lincache = LinearSolverCache (alg, alg. linsolve, J, _vec (fu), _vec (u); abstol, reltol,
106
- linsolve_kwargs... )
107
- return GenericMultiStepDescentCache {false} (prob. f, prob. p, δu, δus, alg. scheme,
108
- lincache, timer, 0 )
109
- end
110
-
111
- function __internal_init (prob:: NonlinearLeastSquaresProblem , alg:: GenericMultiStepDescent ,
112
- J, fu, u; kwargs... )
113
- error (" Multi-Step Descent Algorithms for NLLS are not implemented yet." )
111
+ @bb δu = similar (u)
112
+ δus = N ≤ 1 ? nothing : map (2 : N) do i
113
+ @bb δu_ = similar (u)
114
+ end
115
+ fu_cache = ntuple (MSS. nintermediates (alg. scheme)) do i
116
+ @bb xx = similar (fu)
117
+ end
118
+ fus_cache = N ≤ 1 ? nothing : map (2 : N) do i
119
+ ntuple (MSS. nintermediates (alg. scheme)) do j
120
+ @bb xx = similar (fu)
121
+ end
122
+ end
123
+ u_cache = ntuple (MSS. nintermediates (alg. scheme)) do i
124
+ @bb xx = similar (u)
125
+ end
126
+ us_cache = N ≤ 1 ? nothing : map (2 : N) do i
127
+ ntuple (MSS. nintermediates (alg. scheme)) do j
128
+ @bb xx = similar (u)
129
+ end
130
+ end
131
+ internal_cache, internal_caches = __internal_multistep_caches (
132
+ alg. scheme, alg, prob, J, fu, u; shared, pre_inverted, linsolve_kwargs,
133
+ abstol, reltol, timer, kwargs... )
134
+ return GenericMultiStepDescentCache (
135
+ prob. f, prob. p, δu, δus, u_cache, us_cache, fu_cache, fus_cache,
136
+ internal_cache, internal_caches, alg. scheme, timer, 0 )
114
137
end
115
138
116
139
function __internal_solve! (cache:: GenericMultiStepDescentCache{MSS.__PotraPtak3, INV} , J,
117
140
fu, u, idx:: Val = Val (1 ); skip_solve:: Bool = false , new_jacobian:: Bool = true ,
118
141
kwargs... ) where {INV}
119
- (u_new, δy, fy, y, δu) = get_du (cache, idx)
120
- skip_solve && return DescentResult (; u = u_new)
121
-
122
- @static_timeit cache. timer " linear solve" begin
123
- @static_timeit cache. timer " solve and step 1" begin
124
- if INV
125
- J != = nothing && @bb (δu= J × _vec (fu))
126
- else
127
- δu = cache. lincache (; A = J, b = _vec (fu), kwargs... , linu = _vec (δu),
128
- du = _vec (δu),
129
- reuse_A_if_factorization = ! new_jacobian || (idx != = Val (1 )))
130
- δu = _restructure (u, δu)
131
- end
132
- @bb @. y = u - δu
133
- end
142
+ δu = get_du (cache, idx)
143
+ skip_solve && return DescentResult (; δu)
144
+
145
+ (y,) = get_internal_cache (cache, Val (:u ), idx)
146
+ (fy,) = get_internal_cache (cache, Val (:fu ), idx)
147
+ internal_cache = get_internal_cache (cache, Val (:internal_cache ), idx)
134
148
149
+ @static_timeit cache. timer " descent step" begin
150
+ result_1 = __internal_solve! (
151
+ internal_cache, J, fu, u, Val (1 ); new_jacobian, kwargs... )
152
+ δx = result_1. δu
153
+
154
+ @bb @. y = u + δx
135
155
fy = evaluate_f!! (cache. f, fy, y, cache. p)
136
156
cache. nf += 1
137
157
138
- @static_timeit cache. timer " solve and step 2" begin
139
- if INV
140
- J != = nothing && @bb (δy= J × _vec (fy))
141
- else
142
- δy = cache. lincache (; A = J, b = _vec (fy), kwargs... , linu = _vec (δy),
143
- du = _vec (δy), reuse_A_if_factorization = true )
144
- δy = _restructure (u, δy)
145
- end
146
- @bb @. u_new = y - δy
147
- end
158
+ result_2 = __internal_solve! (
159
+ internal_cache, J, fy, y, Val (2 ); kwargs... )
160
+ δy = result_2. δu
161
+
162
+ @bb @. δu = δx + δy
148
163
end
149
164
150
- set_du! (cache, (u_new, δy, fy, y, δu), idx)
151
- return DescentResult (; u = u_new)
165
+ set_du! (cache, δu, idx)
166
+ set_internal_cache! (cache, (y,), Val (:u ), idx)
167
+ set_internal_cache! (cache, (fy,), Val (:fu ), idx)
168
+ set_internal_cache! (cache, internal_cache, Val (:internal_cache ), idx)
169
+ return DescentResult (; δu)
152
170
end
0 commit comments