From 23bbbbc1383b0f5f7b8e9c59c485124509bd5e4a Mon Sep 17 00:00:00 2001 From: Bart de Koning <74617371+SouthEndMusic@users.noreply.github.com> Date: Thu, 29 Aug 2024 15:53:24 +0200 Subject: [PATCH] Re-enable `BackTracking` (#1761) `BackTracking` as relaxation is now enabled again, with a thin wrapper to reject it when the residual gets worse. Upstream issue: https://github.com/SciML/OrdinaryDiffEq.jl/issues/2442 --- Manifest.toml | 12 ++---- core/Project.toml | 2 + core/ext/RibasimMakieExt.jl | 34 ++++++--------- core/src/Ribasim.jl | 13 +++++- core/src/config.jl | 6 ++- core/src/model.jl | 4 +- core/src/read.jl | 10 +++-- core/src/solve.jl | 11 +++++ core/src/util.jl | 82 +++++++++++++++++++++++++++++++++++-- core/test/main_test.jl | 1 - 10 files changed, 133 insertions(+), 42 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 2512ea2a4..1413e6984 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.10.4" manifest_format = "2.0" -project_hash = "c2cb085c326f61a96abd1a295e6fa775c585beba" +project_hash = "a410a350a7b0c63bc6696029509aa68c14023275" [[deps.ADTypes]] git-tree-sha1 = "6778bcc27496dae5723ff37ee30af451db8b35fe" @@ -1070,7 +1070,7 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" version = "1.2.0" [[deps.NonlinearSolve]] -deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "FastBroadcast", "FastClosures", "FiniteDiff", "ForwardDiff", "LazyArrays", "LineSearches", "LinearAlgebra", "LinearSolve", "MaybeInplace", "PrecompileTools", "Preferences", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SimpleNonlinearSolve", "SparseArrays", "SparseDiffTools", "StaticArraysCore", "SymbolicIndexingInterface", "TimerOutputs"] +deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "FastBroadcast", "FastClosures", "FiniteDiff", "ForwardDiff", "LazyArrays", "LineSearches", "LinearAlgebra", "LinearSolve", "MaybeInplace", "PrecompileTools", "Preferences", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SimpleNonlinearSolve", "SparseArrays", "SparseDiffTools", "StaticArraysCore", "SymbolicIndexingInterface"] git-tree-sha1 = "3adb1e5945b5a6b1eaee754077f25ccc402edd7f" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" version = "3.13.1" @@ -1302,7 +1302,7 @@ uuid = "295af30f-e4ad-537b-8983-00126c2a3abe" version = "3.5.18" [[deps.Ribasim]] -deps = ["Accessors", "Arrow", "BasicModelInterface", "CodecZstd", "ComponentArrays", "Configurations", "DBInterface", "DataInterpolations", "DataStructures", "Dates", "DiffEqCallbacks", "EnumX", "FiniteDiff", "ForwardDiff", "Graphs", "HiGHS", "IterTools", "JuMP", "Legolas", "LinearSolve", "Logging", "LoggingExtras", "MetaGraphsNext", "OrdinaryDiffEq", "PreallocationTools", "SQLite", "SciMLBase", "SparseArrays", "SparseConnectivityTracer", "StructArrays", "Tables", "TerminalLoggers", "TranscodingStreams"] +deps = ["Accessors", "Arrow", "BasicModelInterface", "CodecZstd", "ComponentArrays", "Configurations", "DBInterface", "DataInterpolations", "DataStructures", "Dates", "DiffEqCallbacks", "EnumX", "FiniteDiff", "ForwardDiff", "Graphs", "HiGHS", "IterTools", "JuMP", "Legolas", "LineSearches", "LinearSolve", "Logging", "LoggingExtras", "MetaGraphsNext", "OrdinaryDiffEq", "PreallocationTools", "SQLite", "SciMLBase", "SparseArrays", "SparseConnectivityTracer", "StructArrays", "Tables", "TerminalLoggers", "TranscodingStreams"] path = "core" uuid = "aac5e3d9-0b8f-4d4f-8241-b1a7a9632635" version = "2024.10.0" @@ -1669,12 +1669,6 @@ weakdeps = ["RecipesBase"] [deps.TimeZones.extensions] TimeZonesRecipesBaseExt = "RecipesBase" -[[deps.TimerOutputs]] -deps = ["ExprTools", "Printf"] -git-tree-sha1 = "5a13ae8a41237cff5ecf34f73eb1b8f42fff6531" -uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.24" - [[deps.TranscodingStreams]] git-tree-sha1 = "d73336d81cafdc277ff45558bb7eaa2b04a8e472" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" diff --git a/core/Project.toml b/core/Project.toml index ce3e4882c..cf25cc035 100644 --- a/core/Project.toml +++ b/core/Project.toml @@ -24,6 +24,7 @@ HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" JuMP = "4076af6c-e467-56ae-b986-b466b2749572" Legolas = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd" +LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" @@ -70,6 +71,7 @@ IOCapture = "0.2" IterTools = "1.4" JuMP = "1.15" Legolas = "0.5" +LineSearches = "7" LinearSolve = "2.24" Logging = "<0.0.1, 1" LoggingExtras = "1" diff --git a/core/ext/RibasimMakieExt.jl b/core/ext/RibasimMakieExt.jl index 97077bab8..7c21330ae 100644 --- a/core/ext/RibasimMakieExt.jl +++ b/core/ext/RibasimMakieExt.jl @@ -1,13 +1,13 @@ module RibasimMakieExt using DataFrames: DataFrame -using Makie: Figure, Axis, lines!, axislegend +using Makie: Figure, Axis, scatterlines!, axislegend using Ribasim: Ribasim, Model function Ribasim.plot_basin_data!(model::Model, ax::Axis, column::Symbol) basin_data = DataFrame(Ribasim.basin_table(model)) for node_id in unique(basin_data.node_id) group = filter(:node_id => ==(node_id), basin_data) - lines!(ax, group.time, getproperty(group, column); label = "Basin #$node_id") + scatterlines!(ax, group.time, getproperty(group, column); label = "Basin #$node_id") end axislegend(ax) @@ -23,31 +23,23 @@ function Ribasim.plot_basin_data(model::Model) f end -function Ribasim.plot_flow!( - model::Model, - ax::Axis, - edge_id::Int32; - skip_conservative_out = false, -) +function Ribasim.plot_flow!(model::Model, ax::Axis, edge_metadata::Ribasim.EdgeMetadata) flow_data = DataFrame(Ribasim.flow_table(model)) - flow_data = filter(:edge_id => ==(edge_id), flow_data) - first_row = first(flow_data) - # Skip outflows of conservative nodes because these are the same as the inflows - if skip_conservative_out && - Ribasim.NodeType.T(first_row.from_node_type) in Ribasim.conservative_nodetypes - return nothing - end - label = "$(first_row.from_node_type) #$(first_row.from_node_id) → $(first_row.to_node_type) #$(first_row.to_node_id)" - lines!(ax, flow_data.time, flow_data.flow_rate; label) + flow_data = filter(:edge_id => ==(edge_metadata.id), flow_data) + label = "$(edge_metadata.edge[1]) → $(edge_metadata.edge[2])" + scatterlines!(ax, flow_data.time, flow_data.flow_rate; label) return nothing end -function Ribasim.plot_flow(model::Model) +function Ribasim.plot_flow(model::Model; skip_conservative_out = true) f = Figure() ax = Axis(f[1, 1]; xlabel = "time", ylabel = "flow rate [m³s⁻¹]") - edge_ids = unique(Ribasim.flow_table(model).edge_id) - for edge_id in edge_ids - Ribasim.plot_flow!(model, ax, edge_id; skip_conservative_out = true) + for edge_metadata in values(model.integrator.p.graph.edge_data) + if skip_conservative_out && + edge_metadata.edge[1].type in Ribasim.conservative_nodetypes + continue + end + Ribasim.plot_flow!(model, ax, edge_metadata) end axislegend(ax) f diff --git a/core/src/Ribasim.jl b/core/src/Ribasim.jl index 059a62d4e..01d735dd2 100644 --- a/core/src/Ribasim.jl +++ b/core/src/Ribasim.jl @@ -15,7 +15,15 @@ For more granular access, see: module Ribasim # Algorithms for solving ODEs. -using OrdinaryDiffEq: OrdinaryDiffEq, OrdinaryDiffEqRosenbrockAdaptiveAlgorithm, get_du +using OrdinaryDiffEq: + OrdinaryDiffEq, + OrdinaryDiffEqRosenbrockAdaptiveAlgorithm, + get_du, + AbstractNLSolver, + relax!, + _compute_rhs!, + calculate_residuals! +using LineSearches: BackTracking # Interface for defining and solving the ODE problem of the physical layer. using SciMLBase: @@ -31,7 +39,8 @@ using SciMLBase: ODEProblem, ODESolution, VectorContinuousCallback, - get_proposed_dt + get_proposed_dt, + DEIntegrator # Automatically detecting the sparsity pattern of the Jacobian of water_balance! # through operator overloading diff --git a/core/src/config.jl b/core/src/config.jl index 794a17477..bb024cd0f 100644 --- a/core/src/config.jl +++ b/core/src/config.jl @@ -230,7 +230,7 @@ const algorithms = Dict{String, Type}( ) "Create an OrdinaryDiffEqAlgorithm from solver config" -function algorithm(solver::Solver)::OrdinaryDiffEqAlgorithm +function algorithm(solver::Solver; u0 = [])::OrdinaryDiffEqAlgorithm algotype = get(algorithms, solver.algorithm, nothing) if algotype === nothing options = join(keys(algorithms), ", ") @@ -239,7 +239,9 @@ function algorithm(solver::Solver)::OrdinaryDiffEqAlgorithm end kwargs = Dict{Symbol, Any}() if algotype <: OrdinaryDiffEqNewtonAdaptiveAlgorithm - kwargs[:nlsolve] = NLNewton(; relax = 0.1) + kwargs[:nlsolve] = NLNewton(; + relax = Ribasim.MonitoredBackTracking(; z_tmp = copy(u0), dz_tmp = copy(u0)), + ) end # not all algorithms support this keyword kwargs[:autodiff] = solver.autodiff diff --git a/core/src/model.jl b/core/src/model.jl index 13d898336..9c6a447be 100644 --- a/core/src/model.jl +++ b/core/src/model.jl @@ -37,7 +37,6 @@ function Model(config_path::AbstractString)::Model end function Model(config::Config)::Model - alg = algorithm(config.solver) db_path = input_path(config, config.database) if !isfile(db_path) @error "Database file not found" db_path @@ -109,6 +108,9 @@ function Model(config::Config)::Model u0 = ComponentVector{Float64}(; storage, integral) du0 = zero(u0) + # The Solver algorithm + alg = algorithm(config.solver; u0) + # Synchronize level with storage set_current_basin_properties!(parameters.basin, u0, du0) diff --git a/core/src/read.jl b/core/src/read.jl index e76793e28..05102d64e 100644 --- a/core/src/read.jl +++ b/core/src/read.jl @@ -572,7 +572,8 @@ function Basin(db::DB, config::Config, graph::MetaGraph)::Basin error("Invalid Basin / profile table.") end - level_to_area = LinearInterpolation.(area, level; extrapolate = true) + level_to_area = + LinearInterpolation.(area, level; extrapolate = true, cache_parameters = true) storage_to_level = invert_integral.(level_to_area) t_end = seconds_since(config.endtime, config.starttime) @@ -921,6 +922,7 @@ function user_demand_static!( fill(first_row.return_factor, 2), return_factor_old.t; extrapolate = true, + cache_parameters = true, ) min_level[user_demand_idx] = first_row.min_level @@ -1026,8 +1028,10 @@ function UserDemand(db::DB, config::Config, graph::MetaGraph)::UserDemand ] demand_from_timeseries = fill(false, n_user) allocated = fill(Inf, n_user, n_priority) - return_factor = - [LinearInterpolation(zeros(2), trivial_timespan) for i in eachindex(node_ids)] + return_factor = [ + LinearInterpolation(zeros(2), trivial_timespan; cache_parameters = true) for + i in eachindex(node_ids) + ] min_level = zeros(n_user) # Process static table diff --git a/core/src/solve.jl b/core/src/solve.jl index cabfcd6e0..04aea58a7 100644 --- a/core/src/solve.jl +++ b/core/src/solve.jl @@ -51,9 +51,20 @@ function water_balance!( # Formulate du (controlled by PidControl) formulate_du_pid_controlled!(du, graph, pid_control) + # https://github.com/Deltares/Ribasim/issues/1705#issuecomment-2283293974 + stop_declining_negative_storage!(du, u) + return nothing end +function stop_declining_negative_storage!(du, u) + for (i, s) in enumerate(u.storage) + if s < 0 + du.storage[i] = max(du.storage[i], 0.0) + end + end +end + function formulate_continuous_control!(du, p, t)::Nothing (; compound_variable, target_ref, func) = p.continuous_control diff --git a/core/src/util.jl b/core/src/util.jl index 695f62435..90d197cf3 100644 --- a/core/src/util.jl +++ b/core/src/util.jl @@ -53,9 +53,18 @@ end Compute the area and level of a basin given its storage. """ function get_area_and_level(basin::Basin, state_idx::Int, storage::T)::Tuple{T, T} where {T} - level = basin.storage_to_level[state_idx](max(storage, 0.0)) - area = basin.level_to_area[state_idx](level) - + storage_to_level = basin.storage_to_level[state_idx] + level_to_area = basin.level_to_area[state_idx] + if storage >= 0 + level = storage_to_level(storage) + else + # Negative storage is not feasible and this yields a level + # below the basin bottom, but this does yield usable gradients + # for the non-linear solver + bottom = first(level_to_area.t) + level = bottom + derivative(storage_to_level, 0.0) * storage + end + area = level_to_area(level) return area, level end @@ -887,3 +896,70 @@ end (A::AbstractInterpolation)(t::GradientTracer) = t reduction_factor(x::GradientTracer, threshold::Real) = x relaxed_root(x::GradientTracer, threshold::Real) = x +get_area_and_level(basin::Basin, state_idx::Int, storage::GradientTracer) = storage, storage +stop_declining_negative_storage!(du, u::ComponentVector{<:GradientTracer}) = nothing + +@kwdef struct MonitoredBackTracking{B, V} + linesearch::B = BackTracking() + dz_tmp::V = [] + z_tmp::V = [] +end + +""" +Compute the residual of the non-linear solver, i.e. a measure of the +error in the solution to the implicit equation defined by the solver algorithm +""" +function residual(z, integrator, nlsolver, f) + (; uprev, t, p, dt, opts, isdae) = integrator + (; tmp, ztmp, γ, α, cache, method) = nlsolver + (; ustep, atmp, tstep, k, invγdt, tstep, k, invγdt) = cache + if isdae + _uprev = get_dae_uprev(integrator, uprev) + b, ustep2 = + _compute_rhs!(tmp, ztmp, ustep, α, tstep, k, invγdt, p, _uprev, f::TF, z) + else + b, ustep2 = + _compute_rhs!(tmp, ztmp, ustep, γ, α, tstep, k, invγdt, method, p, dt, f, z) + end + calculate_residuals!( + atmp, + b, + uprev, + ustep2, + opts.abstol, + opts.reltol, + opts.internalnorm, + t, + ) + ndz = opts.internalnorm(atmp, t) + return ndz +end + +""" +MonitoredBackTracing is a thin wrapper of BackTracking, making sure that +the BackTracking relaxation is rejected if it results in a residual increase +""" +function OrdinaryDiffEq.relax!( + dz, + nlsolver::AbstractNLSolver, + integrator::DEIntegrator, + f, + linesearch::MonitoredBackTracking, +) + (; linesearch, dz_tmp, z_tmp) = linesearch + + # Store step before relaxation + @. dz_tmp = dz + + # Apply relaxation and measure the residual change + @. z_tmp = nlsolver.z + dz + resid_before = residual(z_tmp, integrator, nlsolver, f) + relax!(dz, nlsolver, integrator, f, linesearch) + @. z_tmp = nlsolver.z + dz + resid_after = residual(z_tmp, integrator, nlsolver, f) + + # If the residual increased due to the relaxation, reject it + if resid_after > resid_before + @. dz = dz_tmp + end +end diff --git a/core/test/main_test.jl b/core/test/main_test.jl index 788e9c7c3..95dfac40e 100644 --- a/core/test/main_test.jl +++ b/core/test/main_test.jl @@ -24,7 +24,6 @@ @show backtrace end @test occursin("version in the TOML config file does not match", output) - @test occursin("Info: Convergence bottlenecks in descending order of severity:", output) end @testitem "main error logging" begin