Skip to content

TrustRegion fix #1

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 2 commits into from
Jan 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 55 additions & 33 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
TrustRegion(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS,
max_trust_radius::Real = 0.0,
initial_trust_radius::Real = 0.0,
step_threshold::Real = 0.1,
shrink_threshold::Real = 0.25,
expand_threshold::Real = 0.75,
shrink_factor::Real = 0.25,
expand_factor::Real = 2.0,
max_trust_radius::Real = 0 // 1,
initial_trust_radius::Real = 0 // 1,
step_threshold::Real = 1 // 10,
shrink_threshold::Real = 1 // 4,
expand_threshold::Real = 3 // 4,
shrink_factor::Real = 1 // 4,
expand_factor::Real = 2 // 1,
max_shrink_times::Int = 32)
```

Expand Down Expand Up @@ -98,13 +98,13 @@ function TrustRegion(; chunk_size = Val{0}(),
autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS,
max_trust_radius::Real = 0.0,
initial_trust_radius::Real = 0.0,
step_threshold::Real = 0.1,
shrink_threshold::Real = 0.25,
expand_threshold::Real = 0.75,
shrink_factor::Real = 0.25,
expand_factor::Real = 2.0,
max_trust_radius::Real = 0 // 1,
initial_trust_radius::Real = 0 // 1,
step_threshold::Real = 1 // 10,
shrink_threshold::Real = 1 // 4,
expand_threshold::Real = 3 // 4,
shrink_factor::Real = 1 // 4,
expand_factor::Real = 2 // 1,
max_shrink_times::Int = 32)
TrustRegion{_unwrap_val(chunk_size), _unwrap_val(autodiff), diff_type,
typeof(linsolve), typeof(precs), _unwrap_val(standardtag),
Expand Down Expand Up @@ -141,6 +141,11 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
prob::probType
trust_r::trustType
max_trust_r::trustType
step_threshold::trustType
shrink_threshold::trustType
expand_threshold::trustType
shrink_factor::trustType
expand_factor::trustType
loss::floatType
loss_new::floatType
H::jType
Expand All @@ -158,10 +163,12 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
force_stop::Bool, maxiters::Int, internalnorm::INType,
retcode::SciMLBase.ReturnCode.T, abstol::tolType,
prob::probType, trust_r::trustType,
max_trust_r::trustType, loss::floatType,
loss_new::floatType, H::jType, g::resType,
shrink_counter::Int, step_size::uType, u_tmp::uType,
fu_new::resType, make_new_J::Bool,
max_trust_r::trustType, step_threshold::trustType,
shrink_threshold::trustType, expand_threshold::trustType,
shrink_factor::trustType, expand_factor::trustType,
loss::floatType, loss_new::floatType, H::jType,
g::resType, shrink_counter::Int, step_size::uType,
u_tmp::uType, fu_new::resType, make_new_J::Bool,
r::floatType) where {iip, fType, algType, uType,
resType, pType, INType,
tolType, probType, ufType, L,
Expand All @@ -171,7 +178,10 @@ mutable struct TrustRegionCache{iip, fType, algType, uType, resType, pType,
}(f, alg, u, fu, p, uf, linsolve, J,
jac_config, iter, force_stop,
maxiters, internalnorm, retcode,
abstol, prob, trust_r, max_trust_r, loss,
abstol, prob, trust_r, max_trust_r,
step_threshold, shrink_threshold,
expand_threshold, shrink_factor,
expand_factor, loss,
loss_new, H, g, shrink_counter,
step_size, u_tmp, fu_new,
make_new_J, r)
Expand Down Expand Up @@ -228,25 +238,37 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
loss = get_loss(fu)
uf, linsolve, J, u_tmp, jac_config = jacobian_caches(alg, f, u, p, Val(iip))

max_trust_radius = convert(eltype(u), alg.max_trust_radius)
initial_trust_radius = convert(eltype(u), alg.initial_trust_radius)
step_threshold = convert(eltype(u), alg.step_threshold)
shrink_threshold = convert(eltype(u), alg.shrink_threshold)
expand_threshold = convert(eltype(u), alg.expand_threshold)
shrink_factor = convert(eltype(u), alg.shrink_factor)
expand_factor = convert(eltype(u), alg.expand_factor)
# Set default trust region radius if not specified
u_elType = uType <: Number ? uType : eltype(u)
max_trust_radius = u_elType(alg.max_trust_radius)
initial_trust_radius = u_elType(alg.initial_trust_radius)
if iszero(max_trust_radius)
max_trust_radius = max(norm(fu), maximum(u) - minimum(u))
max_trust_radius = convert(eltype(u), max(norm(fu), maximum(u) - minimum(u)))
end
if iszero(initial_trust_radius)
initial_trust_radius = max_trust_radius / 11
initial_trust_radius = convert(eltype(u), max_trust_radius / 11)
end

loss_new = loss
H = ArrayInterfaceCore.undefmatrix(u)
g = zero(fu)
shrink_counter = 0
step_size = zero(u)
fu_new = zero(fu)
make_new_J = true
r = loss

return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
1, false, maxiters, internalnorm,
ReturnCode.Default, abstol, prob, initial_trust_radius,
max_trust_radius, loss, loss, H, zero(fu), 0, zero(u),
u_tmp, zero(fu), true,
loss)
max_trust_radius, step_threshold, shrink_threshold,
expand_threshold, shrink_factor, expand_factor, loss,
loss_new, H, g, shrink_counter, step_size, u_tmp, fu_new,
make_new_J, r)
end

function perform_step!(cache::TrustRegionCache{true})
Expand Down Expand Up @@ -295,27 +317,27 @@ function perform_step!(cache::TrustRegionCache{false})
end

function trust_region_step!(cache::TrustRegionCache)
@unpack fu_new, u_tmp, step_size, g, H, loss, alg, max_trust_r = cache
@unpack fu_new, step_size, g, H, loss, max_trust_r = cache
cache.loss_new = get_loss(fu_new)

# Compute the ratio of the actual reduction to the predicted reduction.
cache.r = -(loss - cache.loss_new) / (step_size' * g + step_size' * H * step_size / 2)
@unpack r = cache

# Update the trust region radius.
if r < alg.shrink_threshold
cache.trust_r *= alg.shrink_factor
if r < cache.shrink_threshold
cache.trust_r *= cache.shrink_factor
cache.shrink_counter += 1
else
cache.shrink_counter = 0
end
if r > alg.step_threshold
if r > cache.step_threshold
take_step!(cache)
cache.loss = cache.loss_new

# Update the trust region radius.
if r > alg.expand_threshold
cache.trust_r = min(alg.expand_factor * cache.trust_r, max_trust_r)
if r > cache.expand_threshold
cache.trust_r = min(cache.expand_factor * cache.trust_r, max_trust_r)
end

cache.make_new_J = true
Expand Down
2 changes: 1 addition & 1 deletion test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ sol = benchmark_inplace(ffiip, u0)
u0 = [1.0, 1.0]
probN = NonlinearProblem{true}(ffiip, u0)
solver = init(probN, TrustRegion(), abstol = 1e-9)
@test (@ballocated solve!(solver)) < 120
@test (@ballocated solve!(solver)) < 200

# AD Tests
using ForwardDiff
Expand Down