From 8bce0949be92b430a124e238740e33a454585f92 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 13 Feb 2024 17:38:12 -0500 Subject: [PATCH 1/5] Run Formatter --- Project.toml | 2 +- docs/src/devdocs/internal_interfaces.md | 1 + src/NonlinearSolve.jl | 2 + src/abstract_types.jl | 13 +--- src/core/approximate_jacobian.jl | 96 +++++++++++++------------ src/core/generalized_first_order.jl | 90 ++++++++++++----------- src/descent/common.jl | 26 +++++++ src/descent/damped_newton.jl | 9 ++- src/descent/dogleg.jl | 23 +++--- src/descent/geodesic_acceleration.jl | 11 +-- src/descent/multistep.jl | 0 src/descent/newton.jl | 8 +-- src/descent/steepest.jl | 2 +- 13 files changed, 163 insertions(+), 120 deletions(-) create mode 100644 src/descent/common.jl create mode 100644 src/descent/multistep.jl diff --git a/Project.toml b/Project.toml index 9afd3d558..5c1501bb0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "3.5.4" +version = "3.6.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/devdocs/internal_interfaces.md b/docs/src/devdocs/internal_interfaces.md index 843054cc8..97762e18b 100644 --- a/docs/src/devdocs/internal_interfaces.md +++ b/docs/src/devdocs/internal_interfaces.md @@ -13,6 +13,7 @@ NonlinearSolve.AbstractNonlinearSolveCache ```@docs NonlinearSolve.AbstractDescentAlgorithm NonlinearSolve.AbstractDescentCache +NonlinearSolve.DescentResult ``` ## Approximate Jacobian diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index de46fe153..abd4d5e40 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -45,11 +45,13 @@ include("adtypes.jl") include("timer_outputs.jl") include("internal/helpers.jl") +include("descent/common.jl") include("descent/newton.jl") include("descent/steepest.jl") include("descent/dogleg.jl") include("descent/damped_newton.jl") include("descent/geodesic_acceleration.jl") +include("descent/multistep.jl") include("internal/operators.jl") include("internal/jacobian.jl") diff --git a/src/abstract_types.jl b/src/abstract_types.jl index 1b30f2b9f..63c33b931 100644 --- a/src/abstract_types.jl +++ b/src/abstract_types.jl @@ -66,8 +66,8 @@ Abstract Type for all Descent Caches. ### `__internal_solve!` specification ```julia -δu, success, intermediates = __internal_solve!(cache::AbstractDescentCache, J, fu, u, - idx::Val; skip_solve::Bool = false, kwargs...) +descent_result = __internal_solve!(cache::AbstractDescentCache, J, fu, u, idx::Val; + skip_solve::Bool = false, kwargs...) ``` - `J`: Jacobian or Inverse Jacobian (if `pre_inverted = Val(true)`). @@ -79,14 +79,7 @@ Abstract Type for all Descent Caches. direction was rejected and we want to try with a modified trust region. - `kwargs`: keyword arguments to pass to the linear solver if there is one. -#### Returned values - - - `δu`: the descent direction. - - `success`: Certain Descent Algorithms can reject a descent direction for example - `GeodesicAcceleration`. - - `intermediates`: A named tuple containing intermediates computed during the solve. - For example, `GeodesicAcceleration` returns `NamedTuple{(:v, :a)}` containing the - "velocity" and "acceleration" terms. +Returns a result of type [`DescentResult`](@ref). ### Interface Functions diff --git a/src/core/approximate_jacobian.jl b/src/core/approximate_jacobian.jl index ffc77cea8..22a192e41 100644 --- a/src/core/approximate_jacobian.jl +++ b/src/core/approximate_jacobian.jl @@ -163,8 +163,7 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem{uType, iip}, linsolve = get_linear_solver(alg.descent) initialization_cache = __internal_init(prob, alg.initialization, alg, f, fu, u, p; - linsolve, - maxiters, internalnorm) + linsolve, maxiters, internalnorm) abstol, reltol, termination_cache = init_termination_cache(abstol, reltol, fu, u, termination_condition) @@ -222,9 +221,7 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip}; new_jacobian = true @static_timeit cache.timer "jacobian init/reinit" begin if get_nsteps(cache) == 0 # First Step is special ignore kwargs - J_init = __internal_solve!(cache.initialization_cache, - cache.fu, - cache.u, + J_init = __internal_solve!(cache.initialization_cache, cache.fu, cache.u, Val(false)) if INV if jacobian_initialized_preinverted(cache.initialization_cache.alg) @@ -283,54 +280,65 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip}; @static_timeit cache.timer "descent" begin if cache.trustregion_cache !== nothing && hasfield(typeof(cache.trustregion_cache), :trust_region) - δu, descent_success, descent_intermediates = __internal_solve!( - cache.descent_cache, - J, cache.fu, cache.u; new_jacobian, - trust_region = cache.trustregion_cache.trust_region) + descent_result = __internal_solve!(cache.descent_cache, J, cache.fu, cache.u; + new_jacobian, trust_region = cache.trustregion_cache.trust_region) else - δu, descent_success, descent_intermediates = __internal_solve!( - cache.descent_cache, - J, cache.fu, cache.u; new_jacobian) + descent_result = __internal_solve!(cache.descent_cache, J, cache.fu, cache.u; + new_jacobian) end end - if descent_success - if GB === :LineSearch - @static_timeit cache.timer "linesearch" begin - needs_reset, α = __internal_solve!(cache.linesearch_cache, cache.u, δu) - end - if needs_reset && cache.steps_since_last_reset > 5 # Reset after a burn-in period - cache.force_reinit = true - else - @static_timeit cache.timer "step" begin - @bb axpy!(α, δu, cache.u) - evaluate_f!(cache, cache.u, cache.p) - end - end - elseif GB === :TrustRegion - @static_timeit cache.timer "trustregion" begin - tr_accepted, u_new, fu_new = __internal_solve!(cache.trustregion_cache, J, - cache.fu, cache.u, δu, descent_intermediates) - if tr_accepted - @bb copyto!(cache.u, u_new) - @bb copyto!(cache.fu, fu_new) - end - if hasfield(typeof(cache.trustregion_cache), :shrink_counter) && - cache.trustregion_cache.shrink_counter > cache.max_shrink_times - cache.retcode = ReturnCode.ShrinkThresholdExceeded - cache.force_stop = true - end - end - α = true - elseif GB === :None + if descent_result.success + if GB === :None @static_timeit cache.timer "step" begin - @bb axpy!(1, δu, cache.u) + if descent_result.u !== missing + @bb copyto!(cache.u, descent_result.u) + elseif descent_result.δu !== missing + @bb axpy!(1, descent_result.δu, cache.u) + else + error("This shouldn't occur. `$(cache.alg.descent)` is incorrectly \ + specified.") + end evaluate_f!(cache, cache.u, cache.p) end α = true else - error("Unknown Globalization Strategy: $(GB). Allowed values are (:LineSearch, \ - :TrustRegion, :None)") + δu = descent_result.δu + @assert δu!==missing "Descent Supporting LineSearch or TrustRegion must return a `δu`." + + if GB === :LineSearch + @static_timeit cache.timer "linesearch" begin + needs_reset, α = __internal_solve!(cache.linesearch_cache, cache.u, δu) + end + if needs_reset && cache.steps_since_last_reset > 5 # Reset after a burn-in period + cache.force_reinit = true + else + @static_timeit cache.timer "step" begin + @bb axpy!(α, δu, cache.u) + evaluate_f!(cache, cache.u, cache.p) + end + end + elseif GB === :TrustRegion + @static_timeit cache.timer "trustregion" begin + tr_accepted, u_new, fu_new = __internal_solve!(cache.trustregion_cache, + J, cache.fu, cache.u, δu, descent_result.extras) + if tr_accepted + @bb copyto!(cache.u, u_new) + @bb copyto!(cache.fu, fu_new) + α = true + else + α = false + end + if hasfield(typeof(cache.trustregion_cache), :shrink_counter) && + cache.trustregion_cache.shrink_counter > cache.max_shrink_times + cache.retcode = ReturnCode.ShrinkThresholdExceeded + cache.force_stop = true + end + end + else + error("Unknown Globalization Strategy: $(GB). Allowed values are \ + (:LineSearch, :TrustRegion, :None)") + end end check_and_update!(cache, cache.fu, cache.u, cache.u_cache) else diff --git a/src/core/generalized_first_order.jl b/src/core/generalized_first_order.jl index 0812e7f05..fb7580319 100644 --- a/src/core/generalized_first_order.jl +++ b/src/core/generalized_first_order.jl @@ -215,59 +215,67 @@ function __step!(cache::GeneralizedFirstOrderAlgorithmCache{iip, GB}; @static_timeit cache.timer "descent" begin if cache.trustregion_cache !== nothing && hasfield(typeof(cache.trustregion_cache), :trust_region) - δu, descent_success, descent_intermediates = __internal_solve!( - cache.descent_cache, - J, cache.fu, cache.u; new_jacobian, - trust_region = cache.trustregion_cache.trust_region) + descent_result = __internal_solve!(cache.descent_cache, J, cache.fu, cache.u; + new_jacobian, trust_region = cache.trustregion_cache.trust_region) else - δu, descent_success, descent_intermediates = __internal_solve!( - cache.descent_cache, - J, cache.fu, cache.u; new_jacobian) + descent_result = __internal_solve!(cache.descent_cache, J, cache.fu, cache.u; + new_jacobian) end end - if descent_success + if descent_result.success cache.make_new_jacobian = true - if GB === :LineSearch - @static_timeit cache.timer "linesearch" begin - linesearch_failed, α = __internal_solve!(cache.linesearch_cache, - cache.u, δu) - end - if linesearch_failed - cache.retcode = ReturnCode.InternalLineSearchFailed - cache.force_stop = true - end + if GB === :None @static_timeit cache.timer "step" begin - @bb axpy!(α, δu, cache.u) - evaluate_f!(cache, cache.u, cache.p) - end - elseif GB === :TrustRegion - @static_timeit cache.timer "trustregion" begin - tr_accepted, u_new, fu_new = __internal_solve!(cache.trustregion_cache, J, - cache.fu, cache.u, δu, descent_intermediates) - if tr_accepted - @bb copyto!(cache.u, u_new) - @bb copyto!(cache.fu, fu_new) - α = true + if descent_result.u !== missing + @bb copyto!(cache.u, descent_result.u) + elseif descent_result.δu !== missing + @bb axpy!(1, descent_result.δu, cache.u) else - α = false - cache.make_new_jacobian = false + error("This shouldn't occur. `$(cache.alg.descent)` is incorrectly \ + specified.") end - if hasfield(typeof(cache.trustregion_cache), :shrink_counter) && - cache.trustregion_cache.shrink_counter > cache.max_shrink_times - cache.retcode = ReturnCode.ShrinkThresholdExceeded - cache.force_stop = true - end - end - elseif GB === :None - @static_timeit cache.timer "step" begin - @bb axpy!(1, δu, cache.u) evaluate_f!(cache, cache.u, cache.p) end α = true else - error("Unknown Globalization Strategy: $(GB). Allowed values are (:LineSearch, \ - :TrustRegion, :None)") + δu = descent_result.δu + @assert δu!==missing "Descent Supporting LineSearch or TrustRegion must return a `δu`." + + if GB === :LineSearch + @static_timeit cache.timer "linesearch" begin + failed, α = __internal_solve!(cache.linesearch_cache, cache.u, δu) + end + if failed + cache.retcode = ReturnCode.InternalLineSearchFailed + cache.force_stop = true + else + @static_timeit cache.timer "step" begin + @bb axpy!(α, δu, cache.u) + evaluate_f!(cache, cache.u, cache.p) + end + end + elseif GB === :TrustRegion + @static_timeit cache.timer "trustregion" begin + tr_accepted, u_new, fu_new = __internal_solve!(cache.trustregion_cache, + J, cache.fu, cache.u, δu, descent_result.extras) + if tr_accepted + @bb copyto!(cache.u, u_new) + @bb copyto!(cache.fu, fu_new) + α = true + else + α = false + end + if hasfield(typeof(cache.trustregion_cache), :shrink_counter) && + cache.trustregion_cache.shrink_counter > cache.max_shrink_times + cache.retcode = ReturnCode.ShrinkThresholdExceeded + cache.force_stop = true + end + end + else + error("Unknown Globalization Strategy: $(GB). Allowed values are \ + (:LineSearch, :TrustRegion, :None)") + end end check_and_update!(cache, cache.fu, cache.u, cache.u_cache) else diff --git a/src/descent/common.jl b/src/descent/common.jl new file mode 100644 index 000000000..10b14ad14 --- /dev/null +++ b/src/descent/common.jl @@ -0,0 +1,26 @@ +""" + DescentResult(; δu = missing, u = missing, success::Bool = true, extras = (;)) + +Construct a `DescentResult` object. + +### Keyword Arguments + + * `δu`: The descent direction. + * `u`: The new iterate. This is provided only for multi-step methods currently. + * `success`: Certain Descent Algorithms can reject a descent direction for example + [`GeodesicAcceleration`](@ref). + * `extras`: A named tuple containing intermediates computed during the solve. + For example, [`GeodesicAcceleration`](@ref) returns `NamedTuple{(:v, :a)}` containing + the "velocity" and "acceleration" terms. +""" +@concrete struct DescentResult + δu + u + success::Bool + extras +end + +function DescentResult(; δu = missing, u = missing, success::Bool = true, extras = (;)) + @assert δu !== missing || u !== missing + return DescentResult(δu, u, success, extras) +end diff --git a/src/descent/damped_newton.jl b/src/descent/damped_newton.jl index 77ad95b54..a00b480f8 100644 --- a/src/descent/damped_newton.jl +++ b/src/descent/damped_newton.jl @@ -138,7 +138,7 @@ function __internal_solve!(cache::DampedNewtonDescentCache{INV, mode}, J, fu, u, idx::Val{N} = Val(1); skip_solve::Bool = false, new_jacobian::Bool = true, kwargs...) where {INV, N, mode} δu = get_du(cache, idx) - skip_solve && return δu, true, (;) + skip_solve && return DescentResult(; δu) recompute_A = idx === Val(1) @@ -203,15 +203,14 @@ function __internal_solve!(cache::DampedNewtonDescentCache{INV, mode}, J, fu, u, end @static_timeit cache.timer "linear solve" begin - δu = cache.lincache(; A, b, - reuse_A_if_factorization = !new_jacobian && !recompute_A, - kwargs..., linu = _vec(δu)) + δu = cache.lincache(; A, b, linu = _vec(δu), + reuse_A_if_factorization = !new_jacobian && !recompute_A, kwargs...) δu = _restructure(get_du(cache, idx), δu) end @bb @. δu *= -1 set_du!(cache, δu, idx) - return δu, true, (;) + return DescentResult(; δu) end # Define special concatenation for certain Array combinations diff --git a/src/descent/dogleg.jl b/src/descent/dogleg.jl index e1a50832f..772f06295 100644 --- a/src/descent/dogleg.jl +++ b/src/descent/dogleg.jl @@ -40,7 +40,7 @@ end newton_cache cauchy_cache internalnorm - JᵀJ_cache + Jᵀδu_cache δu_cache_1 δu_cache_2 δu_cache_mul @@ -68,10 +68,10 @@ function __internal_init(prob::AbstractNonlinearProblem, alg::Dogleg, J, fu, u; normal_form = prob isa NonlinearLeastSquaresProblem && __needs_square_A(alg.newton_descent.linsolve, u) - JᵀJ_cache = !normal_form ? J * _vec(δu) : nothing # TODO: Rename + Jᵀδu_cache = !normal_form ? J * _vec(δu) : nothing return DoglegCache{INV, normal_form}(δu, δus, newton_cache, cauchy_cache, internalnorm, - JᵀJ_cache, δu_cache_1, δu_cache_2, δu_cache_mul) + Jᵀδu_cache, δu_cache_1, δu_cache_2, δu_cache_mul) end # If TrustRegion is not specified, then use a Gauss-Newton step @@ -82,14 +82,16 @@ function __internal_solve!(cache::DoglegCache{INV, NF}, J, fu, u, idx::Val{N} = want to use a Trust Region." δu = get_du(cache, idx) T = promote_type(eltype(u), eltype(fu)) - δu_newton, _, _ = __internal_solve!(cache.newton_cache, J, fu, u, idx; skip_solve, + + res_newton = __internal_solve!(cache.newton_cache, J, fu, u, idx; skip_solve, kwargs...) + δu_newton = res_newton.δu # Newton's Step within the trust region if cache.internalnorm(δu_newton) ≤ trust_region @bb copyto!(δu, δu_newton) set_du!(cache, δu, idx) - return δu, true, (; δuJᵀJδu = T(NaN)) + return DescentResult(; δu, extras = (; δuJᵀJδu = T(NaN))) end # Take intersection of steepest descent direction and trust region if Cauchy point lies @@ -103,12 +105,13 @@ function __internal_solve!(cache::DoglegCache{INV, NF}, J, fu, u, idx::Val{N} = @bb cache.δu_cache_mul = JᵀJ × vec(δu_cauchy) δuJᵀJδu = __dot(δu_cauchy, cache.δu_cache_mul) else - δu_cauchy, _, _ = __internal_solve!(cache.cauchy_cache, J, fu, u, idx; skip_solve, + res_cauchy = __internal_solve!(cache.cauchy_cache, J, fu, u, idx; skip_solve, kwargs...) + δu_cauchy = res_cauchy.δu J_ = INV ? inv(J) : J l_grad = cache.internalnorm(δu_cauchy) - @bb cache.JᵀJ_cache = J × vec(δu_cauchy) # TODO: Rename - δuJᵀJδu = __dot(cache.JᵀJ_cache, cache.JᵀJ_cache) + @bb cache.Jᵀδu_cache = J × vec(δu_cauchy) + δuJᵀJδu = __dot(cache.Jᵀδu_cache, cache.Jᵀδu_cache) end d_cauchy = (l_grad^3) / δuJᵀJδu @@ -116,7 +119,7 @@ function __internal_solve!(cache::DoglegCache{INV, NF}, J, fu, u, idx::Val{N} = λ = trust_region / l_grad @bb @. δu = λ * δu_cauchy set_du!(cache, δu, idx) - return δu, true, (; δuJᵀJδu = λ^2 * δuJᵀJδu) + return DescentResult(; δu, extras = (; δuJᵀJδu = λ^2 * δuJᵀJδu)) end # FIXME: For anything other than 2-norm a quadratic root will give incorrect results @@ -134,5 +137,5 @@ function __internal_solve!(cache::DoglegCache{INV, NF}, J, fu, u, idx::Val{N} = @bb @. δu = cache.δu_cache_1 + τ * cache.δu_cache_2 set_du!(cache, δu, idx) - return δu, true, (; δuJᵀJδu = T(NaN)) + return DescentResult(; δu, extras = (; δuJᵀJδu = τ^2 * δuJᵀJδu)) end diff --git a/src/descent/geodesic_acceleration.jl b/src/descent/geodesic_acceleration.jl index 35764783c..76033da0f 100644 --- a/src/descent/geodesic_acceleration.jl +++ b/src/descent/geodesic_acceleration.jl @@ -106,9 +106,11 @@ function __internal_solve!( cache::GeodesicAccelerationCache, J, fu, u, idx::Val{N} = Val(1); skip_solve::Bool = false, kwargs...) where {N} a, v, δu = get_acceleration(cache, idx), get_velocity(cache, idx), get_du(cache, idx) - skip_solve && return δu, true, (; a, v) - v, _, _ = __internal_solve!(cache.descent_cache, J, fu, u, Val(2N - 1); skip_solve, + skip_solve && return DescentResult(; δu, extras = (; a, v)) + + res_v = __internal_solve!(cache.descent_cache, J, fu, u, Val(2N - 1); skip_solve, kwargs...) + v = res_v.δu @bb @. cache.u_cache = u + cache.h * v cache.fu_cache = evaluate_f!!(cache.f, cache.fu_cache, cache.u_cache, cache.p) @@ -117,8 +119,9 @@ function __internal_solve!( Jv = _restructure(cache.fu_cache, cache.Jv) @bb @. cache.fu_cache = (2 / cache.h) * ((cache.fu_cache - fu) / cache.h - Jv) - a, _, _ = __internal_solve!(cache.descent_cache, J, cache.fu_cache, u, Val(2N); + res_a = __internal_solve!(cache.descent_cache, J, cache.fu_cache, u, Val(2N); skip_solve, kwargs..., reuse_A_if_factorization = true) + a = res_a.δu norm_v = cache.internalnorm(v) norm_a = cache.internalnorm(a) @@ -131,5 +134,5 @@ function __internal_solve!( cache.last_step_accepted = false end - return δu, cache.last_step_accepted, (; a, v) + return DescentResult(; δu, success = cache.last_step_accepted, extras = (; a, v)) end diff --git a/src/descent/multistep.jl b/src/descent/multistep.jl new file mode 100644 index 000000000..e69de29bb diff --git a/src/descent/newton.jl b/src/descent/newton.jl index c8ba35ed9..26bea6350 100644 --- a/src/descent/newton.jl +++ b/src/descent/newton.jl @@ -75,7 +75,7 @@ function __internal_solve!(cache::NewtonDescentCache{INV, false}, J, fu, u, idx::Val = Val(1); skip_solve::Bool = false, new_jacobian::Bool = true, kwargs...) where {INV} δu = get_du(cache, idx) - skip_solve && return δu, true, (;) + skip_solve && return DescentResult(; δu) if INV @assert J!==nothing "`J` must be provided when `pre_inverted = Val(true)`." @bb δu = J × vec(fu) @@ -88,13 +88,13 @@ function __internal_solve!(cache::NewtonDescentCache{INV, false}, J, fu, u, end @bb @. δu *= -1 set_du!(cache, δu, idx) - return δu, true, (;) + return DescentResult(; δu) end function __internal_solve!(cache::NewtonDescentCache{false, true}, J, fu, u, idx::Val = Val(1); skip_solve::Bool = false, new_jacobian::Bool = true, kwargs...) δu = get_du(cache, idx) - skip_solve && return δu, true, (;) + skip_solve && return DescentResult(; δu) if idx === Val(1) @bb cache.JᵀJ_cache = transpose(J) × J end @@ -107,5 +107,5 @@ function __internal_solve!(cache::NewtonDescentCache{false, true}, J, fu, u, end @bb @. δu *= -1 set_du!(cache, δu, idx) - return δu, true, (;) + return DescentResult(; δu) end diff --git a/src/descent/steepest.jl b/src/descent/steepest.jl index d19505a86..da7812fa0 100644 --- a/src/descent/steepest.jl +++ b/src/descent/steepest.jl @@ -63,5 +63,5 @@ function __internal_solve!(cache::SteepestDescentCache{INV}, J, fu, u, idx::Val end @bb @. δu *= -1 set_du!(cache, δu, idx) - return δu, true, (;) + return DescentResult(; δu) end From ff82fd9832588e11cbf17adfd4fadc7670e717f1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 18 Jan 2024 03:46:46 -0500 Subject: [PATCH 2/5] Add PotraPtak3 --- src/NonlinearSolve.jl | 8 ++- src/algorithms/multistep.jl | 7 ++ src/descent/common.jl | 8 +-- src/descent/multistep.jl | 134 ++++++++++++++++++++++++++++++++++++ src/internal/tracing.jl | 1 + 5 files changed, 152 insertions(+), 6 deletions(-) create mode 100644 src/algorithms/multistep.jl diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index abd4d5e40..bd39e63d0 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -71,6 +71,7 @@ include("core/spectral_methods.jl") include("algorithms/raphson.jl") include("algorithms/pseudo_transient.jl") +include("algorithms/multistep.jl") include("algorithms/broyden.jl") include("algorithms/klement.jl") include("algorithms/lbroyden.jl") @@ -140,7 +141,8 @@ include("default.jl") end # Core Algorithms -export NewtonRaphson, PseudoTransient, Klement, Broyden, LimitedMemoryBroyden, DFSane +export NewtonRaphson, PseudoTransient, Klement, Broyden, LimitedMemoryBroyden, DFSane, + MultiStepNonlinearSolver export GaussNewton, LevenbergMarquardt, TrustRegion export NonlinearSolvePolyAlgorithm, RobustMultiNewton, FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg @@ -154,7 +156,9 @@ export GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm, Genera # Descent Algorithms export NewtonDescent, SteepestDescent, Dogleg, DampedNewtonDescent, - GeodesicAcceleration + GeodesicAcceleration, GenericMultiStepDescent +## Multistep Algorithms +export MultiStepSchemes # Globalization ## Line Search Algorithms diff --git a/src/algorithms/multistep.jl b/src/algorithms/multistep.jl new file mode 100644 index 000000000..35b204094 --- /dev/null +++ b/src/algorithms/multistep.jl @@ -0,0 +1,7 @@ +function MultiStepNonlinearSolver(; concrete_jac = nothing, linsolve = nothing, + scheme = MSS.PotraPtak3, precs = DEFAULT_PRECS, autodiff = nothing) + descent = GenericMultiStepDescent(; scheme, linsolve, precs) + # TODO: Use the scheme as the name + return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = :MultiStepNonlinearSolver, + descent, jacobian_ad = autodiff) +end diff --git a/src/descent/common.jl b/src/descent/common.jl index 10b14ad14..2a614d84a 100644 --- a/src/descent/common.jl +++ b/src/descent/common.jl @@ -5,11 +5,11 @@ Construct a `DescentResult` object. ### Keyword Arguments - * `δu`: The descent direction. - * `u`: The new iterate. This is provided only for multi-step methods currently. - * `success`: Certain Descent Algorithms can reject a descent direction for example + - `δu`: The descent direction. + - `u`: The new iterate. This is provided only for multi-step methods currently. + - `success`: Certain Descent Algorithms can reject a descent direction for example [`GeodesicAcceleration`](@ref). - * `extras`: A named tuple containing intermediates computed during the solve. + - `extras`: A named tuple containing intermediates computed during the solve. For example, [`GeodesicAcceleration`](@ref) returns `NamedTuple{(:v, :a)}` containing the "velocity" and "acceleration" terms. """ diff --git a/src/descent/multistep.jl b/src/descent/multistep.jl index e69de29bb..2879a9bef 100644 --- a/src/descent/multistep.jl +++ b/src/descent/multistep.jl @@ -0,0 +1,134 @@ +""" + MultiStepSchemes + +This module defines the multistep schemes used in the multistep descent algorithms. The +naming convention follows . The name of method is +typically the last names of the authors of the paper that introduced the method. +""" +module MultiStepSchemes + +abstract type AbstractMultiStepScheme end + +function Base.show(io::IO, mss::AbstractMultiStepScheme) + print(io, "MultiStepSchemes.$(string(nameof(typeof(mss)))[3:end])") +end + +struct __PotraPtak3 <: AbstractMultiStepScheme end +const PotraPtak3 = __PotraPtak3() + +alg_steps(::__PotraPtak3) = 1 + +struct __SinghSharma4 <: AbstractMultiStepScheme end +const SinghSharma4 = __SinghSharma4() + +alg_steps(::__SinghSharma4) = 3 + +struct __SinghSharma5 <: AbstractMultiStepScheme end +const SinghSharma5 = __SinghSharma5() + +alg_steps(::__SinghSharma5) = 3 + +struct __SinghSharma7 <: AbstractMultiStepScheme end +const SinghSharma7 = __SinghSharma7() + +alg_steps(::__SinghSharma7) = 4 + +end + +const MSS = MultiStepSchemes + +@kwdef @concrete struct GenericMultiStepDescent <: AbstractDescentAlgorithm + scheme + linsolve = nothing + precs = DEFAULT_PRECS +end + +supports_line_search(::GenericMultiStepDescent) = false +supports_trust_region(::GenericMultiStepDescent) = false + +@concrete mutable struct GenericMultiStepDescentCache{S, INV} <: AbstractDescentCache + f + p + δu + δus + scheme::S + lincache + timer + nf::Int +end + +@internal_caches GenericMultiStepDescentCache :lincache + +function __reinit_internal!(cache::GenericMultiStepDescentCache, args...; p = cache.p, + kwargs...) + cache.nf = 0 + cache.p = p +end + +function __δu_caches(scheme::MSS.__PotraPtak3, fu, u, ::Val{N}) where {N} + caches = ntuple(N) do i + @bb δu = similar(u) + @bb y = similar(u) + @bb fy = similar(fu) + @bb δy = similar(u) + @bb u_new = similar(u) + (δu, δy, fy, y, u_new) + end + return first(caches), (N ≤ 1 ? nothing : caches[2:end]) +end + +function __internal_init(prob::NonlinearProblem, alg::GenericMultiStepDescent, J, fu, u; + shared::Val{N} = Val(1), pre_inverted::Val{INV} = False, linsolve_kwargs = (;), + abstol = nothing, reltol = nothing, timer = get_timer_output(), + kwargs...) where {INV, N} + δu, δus = __δu_caches(alg.scheme, fu, u, shared) + INV && return GenericMultiStepDescentCache{true}(prob.f, prob.p, δu, δus, + alg.scheme, nothing, timer, 0) + lincache = LinearSolverCache(alg, alg.linsolve, J, _vec(fu), _vec(u); abstol, reltol, + linsolve_kwargs...) + return GenericMultiStepDescentCache{false}(prob.f, prob.p, δu, δus, alg.scheme, + lincache, timer, 0) +end + +function __internal_init(prob::NonlinearLeastSquaresProblem, alg::GenericMultiStepDescent, + J, fu, u; kwargs...) + error("Multi-Step Descent Algorithms for NLLS are not implemented yet.") +end + +function __internal_solve!(cache::GenericMultiStepDescentCache{MSS.__PotraPtak3, INV}, J, + fu, u, idx::Val = Val(1); skip_solve::Bool = false, new_jacobian::Bool = true, + kwargs...) where {INV} + (u_new, δy, fy, y, δu) = get_du(cache, idx) + skip_solve && return DescentResult(; u = u_new) + + @static_timeit cache.timer "linear solve" begin + @static_timeit cache.timer "solve and step 1" begin + if INV + J !== nothing && @bb(δu=J × _vec(fu)) + else + δu = cache.lincache(; A = J, b = _vec(fu), kwargs..., linu = _vec(δu), + du = _vec(δu), + reuse_A_if_factorization = !new_jacobian || (idx !== Val(1))) + δu = _restructure(u, δu) + end + @bb @. y = u - δu + end + + fy = evaluate_f!!(cache.f, fy, y, cache.p) + cache.nf += 1 + + @static_timeit cache.timer "solve and step 2" begin + if INV + J !== nothing && @bb(δy=J × _vec(fy)) + else + δy = cache.lincache(; A = J, b = _vec(fy), kwargs..., linu = _vec(δy), + du = _vec(δy), reuse_A_if_factorization = true) + δy = _restructure(u, δy) + end + @bb @. u_new = y - δy + end + end + + set_du!(cache, (u_new, δy, fy, y, δu), idx) + return DescentResult(; u = u_new) +end diff --git a/src/internal/tracing.jl b/src/internal/tracing.jl index 667c6ce07..bfb93c6d7 100644 --- a/src/internal/tracing.jl +++ b/src/internal/tracing.jl @@ -187,6 +187,7 @@ function update_trace!(cache::AbstractNonlinearSolveCache, α = true) trace === nothing && return nothing J = __getproperty(cache, Val(:J)) + # TODO: fix tracing for multi-step methods where du is not aliased properly if J === nothing update_trace!(trace, get_nsteps(cache) + 1, get_u(cache), get_fu(cache), nothing, cache.du, α) From ceeadcbe5871e9d0fe394c04d9f78b0e6b69b635 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 13 Feb 2024 17:37:42 -0500 Subject: [PATCH 3/5] materialize the multi-step scheme --- Project.toml | 4 +++- docs/src/basics/faq.md | 4 ++-- docs/src/basics/sparsity_detection.md | 4 ++-- docs/src/tutorials/large_systems.md | 18 +++++++++--------- src/NonlinearSolve.jl | 10 +++++----- src/algorithms/multistep.jl | 9 +++++---- src/descent/multistep.jl | 26 ++++++++++++++++++++++---- src/utils.jl | 19 +++++++++++++++++++ 8 files changed, 67 insertions(+), 27 deletions(-) diff --git a/Project.toml b/Project.toml index 5c1501bb0..403831c1d 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "3.6.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" @@ -55,12 +56,13 @@ NonlinearSolveZygoteExt = "Zygote" [compat] ADTypes = "0.2.6" +Accessors = "0.1" Aqua = "0.8" ArrayInterface = "7.7" BandedMatrices = "1.4" BenchmarkTools = "1.4" -ConcreteStructs = "0.2.3" CUDA = "5.1" +ConcreteStructs = "0.2.3" DiffEqBase = "6.146.0" Enzyme = "0.11.11" FastBroadcast = "0.2.8" diff --git a/docs/src/basics/faq.md b/docs/src/basics/faq.md index e40b57b33..2144fcba4 100644 --- a/docs/src/basics/faq.md +++ b/docs/src/basics/faq.md @@ -72,7 +72,7 @@ differentiate the function based on the input types. However, this function has `xx = [1.0, 2.0, 3.0, 4.0]` followed by a `xx[1] = var[1] - v_true[1]` where `var` might be a Dual number. This causes the error. To fix it: - 1. Specify the `autodiff` to be `AutoFiniteDiff` +1. Specify the `autodiff` to be `AutoFiniteDiff` ```@example dual_error_faq sol = solve(prob_oop, LevenbergMarquardt(; autodiff = AutoFiniteDiff()); maxiters = 10000, @@ -81,7 +81,7 @@ sol = solve(prob_oop, LevenbergMarquardt(; autodiff = AutoFiniteDiff()); maxiter This worked but, Finite Differencing is not the recommended approach in any scenario. - 2. Rewrite the function to use +2. Rewrite the function to use [PreallocationTools.jl](https://github.com/SciML/PreallocationTools.jl) or write it as ```@example dual_error_faq diff --git a/docs/src/basics/sparsity_detection.md b/docs/src/basics/sparsity_detection.md index 222aebe19..e23d42dab 100644 --- a/docs/src/basics/sparsity_detection.md +++ b/docs/src/basics/sparsity_detection.md @@ -34,7 +34,7 @@ prob = NonlinearProblem( If the `colorvec` is not provided, then it is computed on demand. !!! note - + One thing to be careful about in this case is that `colorvec` is dependent on the autodiff backend used. Forward Mode and Finite Differencing will assume that the colorvec is the column colorvec, while Reverse Mode will assume that the colorvec is the @@ -76,7 +76,7 @@ loaded, we default to using `SymbolicsSparsityDetection()`, else we default to u options if those are provided. !!! warning - + If you provide a non-sparse AD, and provide a `sparsity` or `jac_prototype` then we will use dense AD. This is because, if you provide a specific AD type, we assume that you know what you are doing and want to override the default choice of `nothing`. diff --git a/docs/src/tutorials/large_systems.md b/docs/src/tutorials/large_systems.md index aedd58445..d6f3c96fb 100644 --- a/docs/src/tutorials/large_systems.md +++ b/docs/src/tutorials/large_systems.md @@ -2,15 +2,15 @@ This tutorial is for getting into the extra features of using NonlinearSolve.jl. Solving ill-conditioned nonlinear systems requires specializing the linear solver on properties of -the Jacobian in order to cut down on the ``\mathcal{O}(n^3)`` linear solve and the -``\mathcal{O}(n^2)`` back-solves. This tutorial is designed to explain the advanced usage of +the Jacobian in order to cut down on the `\mathcal{O}(n^3)` linear solve and the +`\mathcal{O}(n^2)` back-solves. This tutorial is designed to explain the advanced usage of NonlinearSolve.jl by solving the steady state stiff Brusselator partial differential equation (BRUSS) using NonlinearSolve.jl. ## Definition of the Brusselator Equation !!! note - + Feel free to skip this section: it simply defines the example problem. The Brusselator PDE is defined as follows: @@ -118,11 +118,11 @@ However, if you know the sparsity of your problem, then you can pass a different type. For example, a `SparseMatrixCSC` will give a sparse matrix. Other sparse matrix types include: - - Bidiagonal - - Tridiagonal - - SymTridiagonal - - BandedMatrix ([BandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BandedMatrices.jl)) - - BlockBandedMatrix ([BlockBandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BlockBandedMatrices.jl)) +- Bidiagonal +- Tridiagonal +- SymTridiagonal +- BandedMatrix ([BandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BandedMatrices.jl)) +- BlockBandedMatrix ([BlockBandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BlockBandedMatrices.jl)) ## Approximate Sparsity Detection & Sparse Jacobians @@ -213,7 +213,7 @@ choices, see the `linsolve` choices are any valid [LinearSolve.jl](https://linearsolve.sciml.ai/dev/) solver. !!! note - + Switching to a Krylov linear solver will automatically change the nonlinear problem solver into Jacobian-free mode, dramatically reducing the memory required. This can be overridden by adding `concrete_jac=true` to the algorithm. diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index bd39e63d0..784b00b76 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -8,9 +8,9 @@ import Reexport: @reexport import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload @recompile_invalidations begin - using ADTypes, ConcreteStructs, DiffEqBase, FastBroadcast, FastClosures, LazyArrays, - LineSearches, LinearAlgebra, LinearSolve, MaybeInplace, Preferences, Printf, - SciMLBase, SimpleNonlinearSolve, SparseArrays, SparseDiffTools + using Accessors, ADTypes, ConcreteStructs, DiffEqBase, FastBroadcast, FastClosures, + LazyArrays, LineSearches, LinearAlgebra, LinearSolve, MaybeInplace, Preferences, + Printf, SciMLBase, SimpleNonlinearSolve, SparseArrays, SparseDiffTools import ArrayInterface: undefmatrix, can_setindex, restructure, fast_scalar_indexing import DiffEqBase: AbstractNonlinearTerminationMode, @@ -142,7 +142,7 @@ end # Core Algorithms export NewtonRaphson, PseudoTransient, Klement, Broyden, LimitedMemoryBroyden, DFSane, - MultiStepNonlinearSolver + MultiStepNonlinearSolver export GaussNewton, LevenbergMarquardt, TrustRegion export NonlinearSolvePolyAlgorithm, RobustMultiNewton, FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg @@ -156,7 +156,7 @@ export GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm, Genera # Descent Algorithms export NewtonDescent, SteepestDescent, Dogleg, DampedNewtonDescent, - GeodesicAcceleration, GenericMultiStepDescent + GeodesicAcceleration, GenericMultiStepDescent ## Multistep Algorithms export MultiStepSchemes diff --git a/src/algorithms/multistep.jl b/src/algorithms/multistep.jl index 35b204094..d1f087fe3 100644 --- a/src/algorithms/multistep.jl +++ b/src/algorithms/multistep.jl @@ -1,7 +1,8 @@ function MultiStepNonlinearSolver(; concrete_jac = nothing, linsolve = nothing, - scheme = MSS.PotraPtak3, precs = DEFAULT_PRECS, autodiff = nothing) - descent = GenericMultiStepDescent(; scheme, linsolve, precs) - # TODO: Use the scheme as the name - return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = :MultiStepNonlinearSolver, + scheme = MSS.PotraPtak3, precs = DEFAULT_PRECS, autodiff = nothing, + vjp_autodiff = nothing) + scheme_concrete = apply_patch(scheme, (; autodiff, vjp_autodiff)) + descent = GenericMultiStepDescent(; scheme = scheme_concrete, linsolve, precs) + return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = MSS.display_name(scheme), descent, jacobian_ad = autodiff) end diff --git a/src/descent/multistep.jl b/src/descent/multistep.jl index 2879a9bef..e92653eb8 100644 --- a/src/descent/multistep.jl +++ b/src/descent/multistep.jl @@ -7,32 +7,47 @@ typically the last names of the authors of the paper that introduced the method. """ module MultiStepSchemes +using ConcreteStructs + abstract type AbstractMultiStepScheme end function Base.show(io::IO, mss::AbstractMultiStepScheme) print(io, "MultiStepSchemes.$(string(nameof(typeof(mss)))[3:end])") end +alg_steps(::Type{T}) where {T <: AbstractMultiStepScheme} = alg_steps(T()) + struct __PotraPtak3 <: AbstractMultiStepScheme end const PotraPtak3 = __PotraPtak3() -alg_steps(::__PotraPtak3) = 1 +alg_steps(::__PotraPtak3) = 2 -struct __SinghSharma4 <: AbstractMultiStepScheme end +@kwdef @concrete struct __SinghSharma4 <: AbstractMultiStepScheme + vjp_autodiff = nothing +end const SinghSharma4 = __SinghSharma4() alg_steps(::__SinghSharma4) = 3 -struct __SinghSharma5 <: AbstractMultiStepScheme end +@kwdef @concrete struct __SinghSharma5 <: AbstractMultiStepScheme + vjp_autodiff = nothing +end const SinghSharma5 = __SinghSharma5() alg_steps(::__SinghSharma5) = 3 -struct __SinghSharma7 <: AbstractMultiStepScheme end +@kwdef @concrete struct __SinghSharma7 <: AbstractMultiStepScheme + vjp_autodiff = nothing +end const SinghSharma7 = __SinghSharma7() alg_steps(::__SinghSharma7) = 4 +@generated function display_name(alg::T) where {T <: AbstractMultiStepScheme} + res = Symbol(first(split(last(split(string(T), ".")), "{"; limit = 2))[3:end]) + return :($(Meta.quot(res))) +end + end const MSS = MultiStepSchemes @@ -43,6 +58,8 @@ const MSS = MultiStepSchemes precs = DEFAULT_PRECS end +Base.show(io::IO, alg::GenericMultiStepDescent) = print(io, "$(alg.scheme)()") + supports_line_search(::GenericMultiStepDescent) = false supports_trust_region(::GenericMultiStepDescent) = false @@ -51,6 +68,7 @@ supports_trust_region(::GenericMultiStepDescent) = false p δu δus + extras scheme::S lincache timer diff --git a/src/utils.jl b/src/utils.jl index 7f4c2c439..e5595ea0d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -158,3 +158,22 @@ Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the i """ @inline pickchunksize(x) = pickchunksize(length(x)) @inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x) + +""" + apply_patch(scheme, patch::NamedTuple{names}) + +Applies the patch to the scheme, returning the new scheme. If some of the `names` are not, +present in the scheme, they are ignored. +""" +@generated function apply_patch(scheme, patch::NamedTuple{names}) where {names} + exprs = [] + for name in names + hasfield(scheme, name) || continue + push!(exprs, quote + lens = PropertyLens{$(Meta.quot(name))}() + return set(scheme, lens, getfield(patch, $(Meta.quot(name)))) + end) + end + push!(exprs, :(return scheme)) + return Expr(:block, exprs...) +end From dccc1ddb1a66a4d970df3278314620dd6def92bb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 13 Feb 2024 18:57:07 -0500 Subject: [PATCH 4/5] Reuse NewtonDescent for MultiStepSchemes --- Project.toml | 2 +- docs/src/basics/faq.md | 4 +- docs/src/basics/sparsity_detection.md | 4 +- docs/src/tutorials/large_systems.md | 14 +-- src/abstract_types.jl | 28 ++++++ src/algorithms/multistep.jl | 4 +- src/descent/multistep.jl | 134 +++++++++++++++----------- 7 files changed, 118 insertions(+), 72 deletions(-) diff --git a/Project.toml b/Project.toml index 403831c1d..ca81165b0 100644 --- a/Project.toml +++ b/Project.toml @@ -56,7 +56,7 @@ NonlinearSolveZygoteExt = "Zygote" [compat] ADTypes = "0.2.6" -Accessors = "0.1" +Accessors = "0.1.32" Aqua = "0.8" ArrayInterface = "7.7" BandedMatrices = "1.4" diff --git a/docs/src/basics/faq.md b/docs/src/basics/faq.md index 2144fcba4..e40b57b33 100644 --- a/docs/src/basics/faq.md +++ b/docs/src/basics/faq.md @@ -72,7 +72,7 @@ differentiate the function based on the input types. However, this function has `xx = [1.0, 2.0, 3.0, 4.0]` followed by a `xx[1] = var[1] - v_true[1]` where `var` might be a Dual number. This causes the error. To fix it: -1. Specify the `autodiff` to be `AutoFiniteDiff` + 1. Specify the `autodiff` to be `AutoFiniteDiff` ```@example dual_error_faq sol = solve(prob_oop, LevenbergMarquardt(; autodiff = AutoFiniteDiff()); maxiters = 10000, @@ -81,7 +81,7 @@ sol = solve(prob_oop, LevenbergMarquardt(; autodiff = AutoFiniteDiff()); maxiter This worked but, Finite Differencing is not the recommended approach in any scenario. -2. Rewrite the function to use + 2. Rewrite the function to use [PreallocationTools.jl](https://github.com/SciML/PreallocationTools.jl) or write it as ```@example dual_error_faq diff --git a/docs/src/basics/sparsity_detection.md b/docs/src/basics/sparsity_detection.md index e23d42dab..222aebe19 100644 --- a/docs/src/basics/sparsity_detection.md +++ b/docs/src/basics/sparsity_detection.md @@ -34,7 +34,7 @@ prob = NonlinearProblem( If the `colorvec` is not provided, then it is computed on demand. !!! note - + One thing to be careful about in this case is that `colorvec` is dependent on the autodiff backend used. Forward Mode and Finite Differencing will assume that the colorvec is the column colorvec, while Reverse Mode will assume that the colorvec is the @@ -76,7 +76,7 @@ loaded, we default to using `SymbolicsSparsityDetection()`, else we default to u options if those are provided. !!! warning - + If you provide a non-sparse AD, and provide a `sparsity` or `jac_prototype` then we will use dense AD. This is because, if you provide a specific AD type, we assume that you know what you are doing and want to override the default choice of `nothing`. diff --git a/docs/src/tutorials/large_systems.md b/docs/src/tutorials/large_systems.md index d6f3c96fb..4788a99af 100644 --- a/docs/src/tutorials/large_systems.md +++ b/docs/src/tutorials/large_systems.md @@ -10,7 +10,7 @@ equation (BRUSS) using NonlinearSolve.jl. ## Definition of the Brusselator Equation !!! note - + Feel free to skip this section: it simply defines the example problem. The Brusselator PDE is defined as follows: @@ -118,11 +118,11 @@ However, if you know the sparsity of your problem, then you can pass a different type. For example, a `SparseMatrixCSC` will give a sparse matrix. Other sparse matrix types include: -- Bidiagonal -- Tridiagonal -- SymTridiagonal -- BandedMatrix ([BandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BandedMatrices.jl)) -- BlockBandedMatrix ([BlockBandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BlockBandedMatrices.jl)) + - Bidiagonal + - Tridiagonal + - SymTridiagonal + - BandedMatrix ([BandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BandedMatrices.jl)) + - BlockBandedMatrix ([BlockBandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BlockBandedMatrices.jl)) ## Approximate Sparsity Detection & Sparse Jacobians @@ -213,7 +213,7 @@ choices, see the `linsolve` choices are any valid [LinearSolve.jl](https://linearsolve.sciml.ai/dev/) solver. !!! note - + Switching to a Krylov linear solver will automatically change the nonlinear problem solver into Jacobian-free mode, dramatically reducing the memory required. This can be overridden by adding `concrete_jac=true` to the algorithm. diff --git a/src/abstract_types.jl b/src/abstract_types.jl index 63c33b931..6209359f8 100644 --- a/src/abstract_types.jl +++ b/src/abstract_types.jl @@ -87,6 +87,11 @@ Returns a result of type [`DescentResult`](@ref). - `get_du(cache, ::Val{N})`: get the `N`th descent direction. - `set_du!(cache, δu)`: set the descent direction. - `set_du!(cache, δu, ::Val{N})`: set the `N`th descent direction. + - `get_internal_cache(cache, ::Val{field})`: get the internal cache field. + - `get_internal_cache(cache, field::Val, ::Val{N})`: get the `N`th internal cache field. + - `set_internal_cache!(cache, value, ::Val{field})`: set the internal cache field. + - `set_internal_cache!(cache, value, field::Val, ::Val{N})`: set the `N`th internal cache + field. - `last_step_accepted(cache)`: whether or not the last step was accepted. Checks if the cache has a `last_step_accepted` field and returns it if it does, else returns `true`. """ @@ -98,6 +103,29 @@ SciMLBase.get_du(cache::AbstractDescentCache, ::Val{N}) where {N} = cache.δus[N set_du!(cache::AbstractDescentCache, δu) = (cache.δu = δu) set_du!(cache::AbstractDescentCache, δu, ::Val{1}) = set_du!(cache, δu) set_du!(cache::AbstractDescentCache, δu, ::Val{N}) where {N} = (cache.δus[N - 1] = δu) +function get_internal_cache(cache::AbstractDescentCache, ::Val{field}) where {field} + return getproperty(cache, field) +end +function get_internal_cache(cache::AbstractDescentCache, field::Val, ::Val{1}) + return get_internal_cache(cache, field) +end +function get_internal_cache( + cache::AbstractDescentCache, ::Val{field}, ::Val{N}) where {field, N} + true_field = Symbol(string(field), "s") # Julia 1.10 compiles this away + return getproperty(cache, true_field)[N] +end +function set_internal_cache!(cache::AbstractDescentCache, value, ::Val{field}) where {field} + return setproperty!(cache, field, value) +end +function set_internal_cache!( + cache::AbstractDescentCache, value, field::Val, ::Val{1}) + return set_internal_cache!(cache, value, field) +end +function set_internal_cache!( + cache::AbstractDescentCache, value, ::Val{field}, ::Val{N}) where {field, N} + true_field = Symbol(string(field), "s") # Julia 1.10 compiles this away + return setproperty!(cache, true_field, value, N) +end function last_step_accepted(cache::AbstractDescentCache) hasfield(typeof(cache), :last_step_accepted) && return cache.last_step_accepted diff --git a/src/algorithms/multistep.jl b/src/algorithms/multistep.jl index d1f087fe3..abb056402 100644 --- a/src/algorithms/multistep.jl +++ b/src/algorithms/multistep.jl @@ -1,8 +1,8 @@ function MultiStepNonlinearSolver(; concrete_jac = nothing, linsolve = nothing, scheme = MSS.PotraPtak3, precs = DEFAULT_PRECS, autodiff = nothing, - vjp_autodiff = nothing) + vjp_autodiff = nothing, linesearch = NoLineSearch()) scheme_concrete = apply_patch(scheme, (; autodiff, vjp_autodiff)) descent = GenericMultiStepDescent(; scheme = scheme_concrete, linsolve, precs) return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = MSS.display_name(scheme), - descent, jacobian_ad = autodiff) + descent, jacobian_ad = autodiff, linesearch, reverse_ad = vjp_autodiff) end diff --git a/src/descent/multistep.jl b/src/descent/multistep.jl index e92653eb8..67c756a2c 100644 --- a/src/descent/multistep.jl +++ b/src/descent/multistep.jl @@ -21,23 +21,24 @@ struct __PotraPtak3 <: AbstractMultiStepScheme end const PotraPtak3 = __PotraPtak3() alg_steps(::__PotraPtak3) = 2 +nintermediates(::__PotraPtak3) = 1 @kwdef @concrete struct __SinghSharma4 <: AbstractMultiStepScheme - vjp_autodiff = nothing + jvp_autodiff = nothing end const SinghSharma4 = __SinghSharma4() alg_steps(::__SinghSharma4) = 3 @kwdef @concrete struct __SinghSharma5 <: AbstractMultiStepScheme - vjp_autodiff = nothing + jvp_autodiff = nothing end const SinghSharma5 = __SinghSharma5() alg_steps(::__SinghSharma5) = 3 @kwdef @concrete struct __SinghSharma7 <: AbstractMultiStepScheme - vjp_autodiff = nothing + jvp_autodiff = nothing end const SinghSharma7 = __SinghSharma7() @@ -60,93 +61,110 @@ end Base.show(io::IO, alg::GenericMultiStepDescent) = print(io, "$(alg.scheme)()") -supports_line_search(::GenericMultiStepDescent) = false +supports_line_search(::GenericMultiStepDescent) = true supports_trust_region(::GenericMultiStepDescent) = false -@concrete mutable struct GenericMultiStepDescentCache{S, INV} <: AbstractDescentCache +@concrete mutable struct GenericMultiStepDescentCache{S} <: AbstractDescentCache f p δu δus - extras + u + us + fu + fus + internal_cache + internal_caches scheme::S - lincache timer nf::Int end -@internal_caches GenericMultiStepDescentCache :lincache +# FIXME: @internal_caches needs to be updated to support tuples and namedtuples +# @internal_caches GenericMultiStepDescentCache :internal_caches function __reinit_internal!(cache::GenericMultiStepDescentCache, args...; p = cache.p, kwargs...) cache.nf = 0 cache.p = p + reset_timer!(cache.timer) end -function __δu_caches(scheme::MSS.__PotraPtak3, fu, u, ::Val{N}) where {N} - caches = ntuple(N) do i - @bb δu = similar(u) - @bb y = similar(u) - @bb fy = similar(fu) - @bb δy = similar(u) - @bb u_new = similar(u) - (δu, δy, fy, y, u_new) +function __internal_multistep_caches( + scheme::MSS.__PotraPtak3, alg::GenericMultiStepDescent, + prob, args...; shared::Val{N} = Val(1), kwargs...) where {N} + internal_descent = NewtonDescent(; alg.linsolve, alg.precs) + internal_cache = __internal_init( + prob, internal_descent, args...; kwargs..., shared = Val(2)) + internal_caches = N ≤ 1 ? nothing : + map(2:N) do i + __internal_init(prob, internal_descent, args...; kwargs..., shared = Val(2)) end - return first(caches), (N ≤ 1 ? nothing : caches[2:end]) + return internal_cache, internal_caches end -function __internal_init(prob::NonlinearProblem, alg::GenericMultiStepDescent, J, fu, u; - shared::Val{N} = Val(1), pre_inverted::Val{INV} = False, linsolve_kwargs = (;), +function __internal_init(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, + alg::GenericMultiStepDescent, J, fu, u; shared::Val{N} = Val(1), + pre_inverted::Val{INV} = False, linsolve_kwargs = (;), abstol = nothing, reltol = nothing, timer = get_timer_output(), kwargs...) where {INV, N} - δu, δus = __δu_caches(alg.scheme, fu, u, shared) - INV && return GenericMultiStepDescentCache{true}(prob.f, prob.p, δu, δus, - alg.scheme, nothing, timer, 0) - lincache = LinearSolverCache(alg, alg.linsolve, J, _vec(fu), _vec(u); abstol, reltol, - linsolve_kwargs...) - return GenericMultiStepDescentCache{false}(prob.f, prob.p, δu, δus, alg.scheme, - lincache, timer, 0) -end - -function __internal_init(prob::NonlinearLeastSquaresProblem, alg::GenericMultiStepDescent, - J, fu, u; kwargs...) - error("Multi-Step Descent Algorithms for NLLS are not implemented yet.") + @bb δu = similar(u) + δus = N ≤ 1 ? nothing : map(2:N) do i + @bb δu_ = similar(u) + end + fu_cache = ntuple(MSS.nintermediates(alg.scheme)) do i + @bb xx = similar(fu) + end + fus_cache = N ≤ 1 ? nothing : map(2:N) do i + ntuple(MSS.nintermediates(alg.scheme)) do j + @bb xx = similar(fu) + end + end + u_cache = ntuple(MSS.nintermediates(alg.scheme)) do i + @bb xx = similar(u) + end + us_cache = N ≤ 1 ? nothing : map(2:N) do i + ntuple(MSS.nintermediates(alg.scheme)) do j + @bb xx = similar(u) + end + end + internal_cache, internal_caches = __internal_multistep_caches( + alg.scheme, alg, prob, J, fu, u; shared, pre_inverted, linsolve_kwargs, + abstol, reltol, timer, kwargs...) + return GenericMultiStepDescentCache( + prob.f, prob.p, δu, δus, u_cache, us_cache, fu_cache, fus_cache, + internal_cache, internal_caches, alg.scheme, timer, 0) end function __internal_solve!(cache::GenericMultiStepDescentCache{MSS.__PotraPtak3, INV}, J, fu, u, idx::Val = Val(1); skip_solve::Bool = false, new_jacobian::Bool = true, kwargs...) where {INV} - (u_new, δy, fy, y, δu) = get_du(cache, idx) - skip_solve && return DescentResult(; u = u_new) - - @static_timeit cache.timer "linear solve" begin - @static_timeit cache.timer "solve and step 1" begin - if INV - J !== nothing && @bb(δu=J × _vec(fu)) - else - δu = cache.lincache(; A = J, b = _vec(fu), kwargs..., linu = _vec(δu), - du = _vec(δu), - reuse_A_if_factorization = !new_jacobian || (idx !== Val(1))) - δu = _restructure(u, δu) - end - @bb @. y = u - δu - end + δu = get_du(cache, idx) + skip_solve && return DescentResult(; δu) + + (y,) = get_internal_cache(cache, Val(:u), idx) + (fy,) = get_internal_cache(cache, Val(:fu), idx) + internal_cache = get_internal_cache(cache, Val(:internal_cache), idx) + @static_timeit cache.timer "descent step" begin + result_1 = __internal_solve!( + internal_cache, J, fu, u, Val(1); new_jacobian, kwargs...) + δx = result_1.δu + + @bb @. y = u + δx fy = evaluate_f!!(cache.f, fy, y, cache.p) cache.nf += 1 - @static_timeit cache.timer "solve and step 2" begin - if INV - J !== nothing && @bb(δy=J × _vec(fy)) - else - δy = cache.lincache(; A = J, b = _vec(fy), kwargs..., linu = _vec(δy), - du = _vec(δy), reuse_A_if_factorization = true) - δy = _restructure(u, δy) - end - @bb @. u_new = y - δy - end + result_2 = __internal_solve!( + internal_cache, J, fy, y, Val(2); kwargs...) + δy = result_2.δu + + @bb @. δu = δx + δy end - set_du!(cache, (u_new, δy, fy, y, δu), idx) - return DescentResult(; u = u_new) + set_du!(cache, δu, idx) + set_internal_cache!(cache, (y,), Val(:u), idx) + set_internal_cache!(cache, (fy,), Val(:fu), idx) + set_internal_cache!(cache, internal_cache, Val(:internal_cache), idx) + return DescentResult(; δu) end From 75f18743d4f2604f6b87b0dc37079b40b52d137b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 13 Feb 2024 19:22:53 -0500 Subject: [PATCH 5/5] Use macro for shared caches --- src/algorithms/multistep.jl | 6 ++- src/descent/damped_newton.jl | 6 +-- src/descent/dogleg.jl | 5 +-- src/descent/geodesic_acceleration.jl | 5 +-- src/descent/multistep.jl | 56 ++++++++++++---------------- src/descent/newton.jl | 10 +---- src/descent/steepest.jl | 5 +-- src/internal/helpers.jl | 38 +++++++++++++++++++ src/utils.jl | 19 ---------- 9 files changed, 72 insertions(+), 78 deletions(-) diff --git a/src/algorithms/multistep.jl b/src/algorithms/multistep.jl index abb056402..cd5a31890 100644 --- a/src/algorithms/multistep.jl +++ b/src/algorithms/multistep.jl @@ -1,8 +1,10 @@ function MultiStepNonlinearSolver(; concrete_jac = nothing, linsolve = nothing, scheme = MSS.PotraPtak3, precs = DEFAULT_PRECS, autodiff = nothing, vjp_autodiff = nothing, linesearch = NoLineSearch()) - scheme_concrete = apply_patch(scheme, (; autodiff, vjp_autodiff)) + forward_ad = ifelse(autodiff isa ADTypes.AbstractForwardMode, autodiff, nothing) + scheme_concrete = apply_patch( + scheme, (; autodiff, vjp_autodiff, jvp_autodiff = forward_ad)) descent = GenericMultiStepDescent(; scheme = scheme_concrete, linsolve, precs) return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = MSS.display_name(scheme), - descent, jacobian_ad = autodiff, linesearch, reverse_ad = vjp_autodiff) + descent, jacobian_ad = autodiff, linesearch, reverse_ad = vjp_autodiff, forward_ad) end diff --git a/src/descent/damped_newton.jl b/src/descent/damped_newton.jl index a00b480f8..cee437d7e 100644 --- a/src/descent/damped_newton.jl +++ b/src/descent/damped_newton.jl @@ -58,11 +58,7 @@ function __internal_init( shared::Val{N} = Val(1), kwargs...) where {INV, N} length(fu) != length(u) && @assert !INV "Precomputed Inverse for Non-Square Jacobian doesn't make sense." - @bb δu = similar(u) - δus = N ≤ 1 ? nothing : map(2:N) do i - @bb δu_ = similar(u) - end - + δu, δus = @shared_caches N (@bb δu = similar(u)) normal_form_damping = returns_norm_form_damping(alg.damping_fn) normal_form_linsolve = __needs_square_A(alg.linsolve, u) if u isa Number diff --git a/src/descent/dogleg.jl b/src/descent/dogleg.jl index 772f06295..ca7314760 100644 --- a/src/descent/dogleg.jl +++ b/src/descent/dogleg.jl @@ -56,10 +56,7 @@ function __internal_init(prob::AbstractNonlinearProblem, alg::Dogleg, J, fu, u; linsolve_kwargs, abstol, reltol, shared, kwargs...) cauchy_cache = __internal_init(prob, alg.steepest_descent, J, fu, u; pre_inverted, linsolve_kwargs, abstol, reltol, shared, kwargs...) - @bb δu = similar(u) - δus = N ≤ 1 ? nothing : map(2:N) do i - @bb δu_ = similar(u) - end + δu, δus = @shared_caches N (@bb δu = similar(u)) @bb δu_cache_1 = similar(u) @bb δu_cache_2 = similar(u) @bb δu_cache_mul = similar(u) diff --git a/src/descent/geodesic_acceleration.jl b/src/descent/geodesic_acceleration.jl index 76033da0f..a989c0376 100644 --- a/src/descent/geodesic_acceleration.jl +++ b/src/descent/geodesic_acceleration.jl @@ -89,10 +89,7 @@ function __internal_init(prob::AbstractNonlinearProblem, alg::GeodesicAccelerati abstol = nothing, reltol = nothing, internalnorm::F = DEFAULT_NORM, kwargs...) where {INV, N, F} T = promote_type(eltype(u), eltype(fu)) - @bb δu = similar(u) - δus = N ≤ 1 ? nothing : map(2:N) do i - @bb δu_ = similar(u) - end + δu, δus = @shared_caches N (@bb δu = similar(u)) descent_cache = __internal_init(prob, alg.descent, J, fu, u; shared = Val(N * 2), pre_inverted, linsolve_kwargs, abstol, reltol, kwargs...) @bb Jv = similar(fu) diff --git a/src/descent/multistep.jl b/src/descent/multistep.jl index 67c756a2c..eae086493 100644 --- a/src/descent/multistep.jl +++ b/src/descent/multistep.jl @@ -15,12 +15,12 @@ function Base.show(io::IO, mss::AbstractMultiStepScheme) print(io, "MultiStepSchemes.$(string(nameof(typeof(mss)))[3:end])") end -alg_steps(::Type{T}) where {T <: AbstractMultiStepScheme} = alg_steps(T()) +newton_steps(::Type{T}) where {T <: AbstractMultiStepScheme} = newton_steps(T()) struct __PotraPtak3 <: AbstractMultiStepScheme end const PotraPtak3 = __PotraPtak3() -alg_steps(::__PotraPtak3) = 2 +newton_steps(::__PotraPtak3) = 2 nintermediates(::__PotraPtak3) = 1 @kwdef @concrete struct __SinghSharma4 <: AbstractMultiStepScheme @@ -28,21 +28,23 @@ nintermediates(::__PotraPtak3) = 1 end const SinghSharma4 = __SinghSharma4() -alg_steps(::__SinghSharma4) = 3 +newton_steps(::__SinghSharma4) = 4 +nintermediates(::__SinghSharma4) = 2 @kwdef @concrete struct __SinghSharma5 <: AbstractMultiStepScheme jvp_autodiff = nothing end const SinghSharma5 = __SinghSharma5() -alg_steps(::__SinghSharma5) = 3 +newton_steps(::__SinghSharma5) = 4 +nintermediates(::__SinghSharma5) = 2 @kwdef @concrete struct __SinghSharma7 <: AbstractMultiStepScheme jvp_autodiff = nothing end const SinghSharma7 = __SinghSharma7() -alg_steps(::__SinghSharma7) = 4 +newton_steps(::__SinghSharma7) = 6 @generated function display_name(alg::T) where {T <: AbstractMultiStepScheme} res = Symbol(first(split(last(split(string(T), ".")), "{"; limit = 2))[3:end]) @@ -75,6 +77,8 @@ supports_trust_region(::GenericMultiStepDescent) = false fus internal_cache internal_caches + extra + extras scheme::S timer nf::Int @@ -91,49 +95,37 @@ function __reinit_internal!(cache::GenericMultiStepDescentCache, args...; p = ca end function __internal_multistep_caches( - scheme::MSS.__PotraPtak3, alg::GenericMultiStepDescent, - prob, args...; shared::Val{N} = Val(1), kwargs...) where {N} + scheme::Union{MSS.__PotraPtak3, MSS.__SinghSharma4, MSS.__SinghSharma5}, + alg::GenericMultiStepDescent, prob, args...; + shared::Val{N} = Val(1), kwargs...) where {N} internal_descent = NewtonDescent(; alg.linsolve, alg.precs) - internal_cache = __internal_init( + return @shared_caches N __internal_init( prob, internal_descent, args...; kwargs..., shared = Val(2)) - internal_caches = N ≤ 1 ? nothing : - map(2:N) do i - __internal_init(prob, internal_descent, args...; kwargs..., shared = Val(2)) - end - return internal_cache, internal_caches end +__extras_cache(::MSS.AbstractMultiStepScheme, args...; kwargs...) = nothing, nothing + function __internal_init(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, alg::GenericMultiStepDescent, J, fu, u; shared::Val{N} = Val(1), pre_inverted::Val{INV} = False, linsolve_kwargs = (;), abstol = nothing, reltol = nothing, timer = get_timer_output(), kwargs...) where {INV, N} - @bb δu = similar(u) - δus = N ≤ 1 ? nothing : map(2:N) do i - @bb δu_ = similar(u) - end - fu_cache = ntuple(MSS.nintermediates(alg.scheme)) do i + δu, δus = @shared_caches N (@bb δu = similar(u)) + fu_cache, fus_cache = @shared_caches N (ntuple(MSS.nintermediates(alg.scheme)) do i @bb xx = similar(fu) - end - fus_cache = N ≤ 1 ? nothing : map(2:N) do i - ntuple(MSS.nintermediates(alg.scheme)) do j - @bb xx = similar(fu) - end - end - u_cache = ntuple(MSS.nintermediates(alg.scheme)) do i + end) + u_cache, us_cache = @shared_caches N (ntuple(MSS.nintermediates(alg.scheme)) do i @bb xx = similar(u) - end - us_cache = N ≤ 1 ? nothing : map(2:N) do i - ntuple(MSS.nintermediates(alg.scheme)) do j - @bb xx = similar(u) - end - end + end) internal_cache, internal_caches = __internal_multistep_caches( alg.scheme, alg, prob, J, fu, u; shared, pre_inverted, linsolve_kwargs, abstol, reltol, timer, kwargs...) + extra, extras = __extras_cache( + alg.scheme, alg, prob, J, fu, u; shared, pre_inverted, linsolve_kwargs, + abstol, reltol, timer, kwargs...) return GenericMultiStepDescentCache( prob.f, prob.p, δu, δus, u_cache, us_cache, fu_cache, fus_cache, - internal_cache, internal_caches, alg.scheme, timer, 0) + internal_cache, internal_caches, extra, extras, alg.scheme, timer, 0) end function __internal_solve!(cache::GenericMultiStepDescentCache{MSS.__PotraPtak3, INV}, J, diff --git a/src/descent/newton.jl b/src/descent/newton.jl index 26bea6350..52f8e9743 100644 --- a/src/descent/newton.jl +++ b/src/descent/newton.jl @@ -36,10 +36,7 @@ function __internal_init(prob::NonlinearProblem, alg::NewtonDescent, J, fu, u; shared::Val{N} = Val(1), pre_inverted::Val{INV} = False, linsolve_kwargs = (;), abstol = nothing, reltol = nothing, timer = get_timer_output(), kwargs...) where {INV, N} - @bb δu = similar(u) - δus = N ≤ 1 ? nothing : map(2:N) do i - @bb δu_ = similar(u) - end + δu, δus = @shared_caches N (@bb δu = similar(u)) INV && return NewtonDescentCache{true, false}(δu, δus, nothing, nothing, nothing, timer) lincache = LinearSolverCache(alg, alg.linsolve, J, _vec(fu), _vec(u); abstol, reltol, linsolve_kwargs...) @@ -64,10 +61,7 @@ function __internal_init(prob::NonlinearLeastSquaresProblem, alg::NewtonDescent, end lincache = LinearSolverCache(alg, alg.linsolve, A, b, _vec(u); abstol, reltol, linsolve_kwargs...) - @bb δu = similar(u) - δus = N ≤ 1 ? nothing : map(2:N) do i - @bb δu_ = similar(u) - end + δu, δus = @shared_caches N (@bb δu = similar(u)) return NewtonDescentCache{false, normal_form}(δu, δus, lincache, JᵀJ, Jᵀfu, timer) end diff --git a/src/descent/steepest.jl b/src/descent/steepest.jl index da7812fa0..9fd7cc9a9 100644 --- a/src/descent/steepest.jl +++ b/src/descent/steepest.jl @@ -34,10 +34,7 @@ end linsolve_kwargs = (;), abstol = nothing, reltol = nothing, timer = get_timer_output(), kwargs...) where {INV, N} INV && @assert length(fu)==length(u) "Non-Square Jacobian Inverse doesn't make sense." - @bb δu = similar(u) - δus = N ≤ 1 ? nothing : map(2:N) do i - @bb δu_ = similar(u) - end + δu, δus = @shared_caches N (@bb δu = similar(u)) if INV lincache = LinearSolverCache(alg, alg.linsolve, transpose(J), _vec(fu), _vec(u); abstol, reltol, linsolve_kwargs...) diff --git a/src/internal/helpers.jl b/src/internal/helpers.jl index 4f475214b..8226f6cf7 100644 --- a/src/internal/helpers.jl +++ b/src/internal/helpers.jl @@ -268,3 +268,41 @@ function __internal_caches(__source__, __module__, cType, internal_cache_names:: end end) end + +""" + apply_patch(scheme, patch::NamedTuple{names}) + +Applies the patch to the scheme, returning the new scheme. If some of the `names` are not, +present in the scheme, they are ignored. +""" +@generated function apply_patch(scheme, patch::NamedTuple{names}) where {names} + exprs = [] + for name in names + hasfield(scheme, name) || continue + push!(exprs, quote + lens = PropertyLens{$(Meta.quot(name))}() + return set(scheme, lens, getfield(patch, $(Meta.quot(name)))) + end) + end + push!(exprs, :(return scheme)) + return Expr(:block, exprs...) +end + +""" + @shared_caches N expr + +Create a shared cache and a vector of caches. If `N` is 1, then the vector of caches is +`nothing`. +""" +macro shared_caches(N, expr) + @gensym cache caches + return esc(quote + begin + $(cache) = $(expr) + $(caches) = $(N) ≤ 1 ? nothing : map(2:($(N))) do i + $(expr) + end + ($cache, $caches) + end + end) +end diff --git a/src/utils.jl b/src/utils.jl index e5595ea0d..7f4c2c439 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -158,22 +158,3 @@ Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the i """ @inline pickchunksize(x) = pickchunksize(length(x)) @inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x) - -""" - apply_patch(scheme, patch::NamedTuple{names}) - -Applies the patch to the scheme, returning the new scheme. If some of the `names` are not, -present in the scheme, they are ignored. -""" -@generated function apply_patch(scheme, patch::NamedTuple{names}) where {names} - exprs = [] - for name in names - hasfield(scheme, name) || continue - push!(exprs, quote - lens = PropertyLens{$(Meta.quot(name))}() - return set(scheme, lens, getfield(patch, $(Meta.quot(name)))) - end) - end - push!(exprs, :(return scheme)) - return Expr(:block, exprs...) -end