Skip to content

Commit b849190

Browse files
Merge pull request #3566 from AayushSabharwal/as/type-stable-remake
fix: make `remake` type-stable
2 parents 7f90698 + 0e7cad9 commit b849190

File tree

10 files changed

+281
-132
lines changed

10 files changed

+281
-132
lines changed

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ RecursiveArrayTools = "3.26"
139139
Reexport = "0.2, 1"
140140
RuntimeGeneratedFunctions = "0.5.9"
141141
SCCNonlinearSolve = "1.0.0"
142-
SciMLBase = "2.75"
142+
SciMLBase = "2.84"
143143
SciMLStructures = "1.7"
144144
Serialization = "1"
145145
Setfield = "0.7, 0.8, 1"
@@ -150,7 +150,7 @@ StaticArrays = "0.10, 0.11, 0.12, 1.0"
150150
StochasticDelayDiffEq = "1.8.1"
151151
StochasticDiffEq = "6.72.1"
152152
SymbolicIndexingInterface = "0.3.39"
153-
SymbolicUtils = "3.25.1"
153+
SymbolicUtils = "3.26.1"
154154
Symbolics = "6.37"
155155
URIs = "1"
156156
UnPack = "0.1, 1.0"

docs/src/basics/Events.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ prob = ODEProblem(ball, Pair[], tspan)
126126
127127
sol = solve(prob, Tsit5())
128128
@assert 0 <= minimum(sol[x]) <= 1e-10 # the ball never went through the floor but got very close
129-
@assert minimum(sol[y]) > -1.5 # check wall conditions
130-
@assert maximum(sol[y]) < 1.5 # check wall conditions
129+
@assert minimum(sol[y]) >= -1.5 # check wall conditions
130+
@assert maximum(sol[y]) <= 1.5 # check wall conditions
131131
132132
tv = sort([LinRange(0, 10, 200); sol.t])
133133
plot(sol(tv)[y], sol(tv)[x], line_z = tv)

docs/src/tutorials/disturbance_modeling.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ To see full examples that perform state estimation with ModelingToolkit models,
224224
Pages = ["disturbance_modeling.md"]
225225
```
226226

227-
```@autodocs
227+
```@autodocs; canonical = false
228228
Modules = [ModelingToolkit]
229229
Pages = ["systems/analysis_points.jl"]
230230
Order = [:function, :type]

src/systems/abstractsystem.jl

+1
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,7 @@ for prop in [:eqs
915915
:substitutions
916916
:metadata
917917
:gui_metadata
918+
:is_initializesystem
918919
:discrete_subsystems
919920
:parameter_dependencies
920921
:assertions

src/systems/diffeqs/abstractodesystem.jl

+2-8
Original file line numberDiff line numberDiff line change
@@ -1457,12 +1457,12 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
14571457
elseif isempty(u0map) && get_initializesystem(sys) === nothing
14581458
isys = generate_initializesystem(
14591459
sys; initialization_eqs, check_units, pmap = parammap,
1460-
guesses, extra_metadata = (; use_scc), algebraic_only)
1460+
guesses, algebraic_only)
14611461
simplify_system = true
14621462
else
14631463
isys = generate_initializesystem(
14641464
sys; u0map, initialization_eqs, check_units,
1465-
pmap = parammap, guesses, extra_metadata = (; use_scc), algebraic_only)
1465+
pmap = parammap, guesses, algebraic_only)
14661466
simplify_system = true
14671467
end
14681468

@@ -1477,12 +1477,6 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
14771477
isys = structural_simplify(isys; fully_determined)
14781478
end
14791479

1480-
meta = get_metadata(isys)
1481-
if meta isa InitializationSystemMetadata
1482-
@set! isys.metadata.oop_reconstruct_u0_p = ReconstructInitializeprob(
1483-
sys, isys)
1484-
end
1485-
14861480
ts = get_tearing_state(isys)
14871481
unassigned_vars = StructuralTransformations.singular_check(ts)
14881482
if warn_initialize_determined && !isempty(unassigned_vars)

src/systems/index_cache.jl

+29-12
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ function reorder_parameters(
502502
end
503503
end
504504

505-
function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
505+
function reorder_parameters(ic::IndexCache, ps; drop_missing = false, flatten = true)
506506
isempty(ps) && return ()
507507
param_buf = if ic.tunable_buffer_size.length == 0
508508
()
@@ -555,20 +555,37 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
555555
end
556556
end
557557

558-
result = broadcast.(
559-
unwrap, (
560-
param_buf..., initials_buf..., disc_buf..., const_buf..., nonnumeric_buf...))
558+
param_buf = broadcast.(unwrap, param_buf)
559+
initials_buf = broadcast.(unwrap, initials_buf)
560+
disc_buf = broadcast.(unwrap, disc_buf)
561+
const_buf = broadcast.(unwrap, const_buf)
562+
nonnumeric_buf = broadcast.(unwrap, nonnumeric_buf)
563+
561564
if drop_missing
562-
result = map(result) do buf
563-
filter(buf) do sym
564-
return !isequal(sym, unwrap(variable(:DEF)))
565-
end
565+
filterer = !isequal(unwrap(variable(:DEF)))
566+
param_buf = filter.(filterer, param_buf)
567+
initials_buf = filter.(filterer, initials_buf)
568+
disc_buf = filter.(filterer, disc_buf)
569+
const_buf = filter.(filterer, const_buf)
570+
nonnumeric_buf = filter.(filterer, nonnumeric_buf)
571+
end
572+
573+
if flatten
574+
result = (
575+
param_buf..., initials_buf..., disc_buf..., const_buf..., nonnumeric_buf...)
576+
if all(isempty, result)
577+
return ()
566578
end
579+
return result
580+
else
581+
if isempty(param_buf)
582+
param_buf = ((),)
583+
end
584+
if isempty(initials_buf)
585+
initials_buf = ((),)
586+
end
587+
return (param_buf..., initials_buf..., disc_buf, const_buf, nonnumeric_buf)
567588
end
568-
if all(isempty, result)
569-
return ()
570-
end
571-
return result
572589
end
573590

574591
# Given a parameter index, find the index of the buffer it is in when

src/systems/nonlinear/initializesystem.jl

+33-99
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ function generate_initializesystem(sys::AbstractTimeDependentSystem;
1111
default_dd_guess = Bool(0),
1212
algebraic_only = false,
1313
check_units = true, check_defguess = false,
14-
name = nameof(sys), extra_metadata = (;), kwargs...)
14+
name = nameof(sys), kwargs...)
1515
eqs = equations(sys)
1616
if !(eqs isa Vector{Equation})
1717
eqs = Equation[x for x in eqs if x isa Equation]
@@ -143,17 +143,15 @@ function generate_initializesystem(sys::AbstractTimeDependentSystem;
143143
for k in keys(defs)
144144
defs[k] = substitute(defs[k], paramsubs)
145145
end
146-
meta = InitializationSystemMetadata(
147-
anydict(u0map), anydict(pmap), additional_guesses,
148-
additional_initialization_eqs, extra_metadata, nothing)
146+
149147
return NonlinearSystem(eqs_ics,
150148
vars,
151149
pars;
152150
defaults = defs,
153151
checks = check_units,
154152
parameter_dependencies = new_parameter_deps,
155153
name,
156-
metadata = meta,
154+
is_initializesystem = true,
157155
kwargs...)
158156
end
159157

@@ -169,7 +167,7 @@ function generate_initializesystem(sys::AbstractTimeIndependentSystem;
169167
guesses = Dict(),
170168
algebraic_only = false,
171169
check_units = true, check_defguess = false,
172-
name = nameof(sys), extra_metadata = (;), kwargs...)
170+
name = nameof(sys), kwargs...)
173171
eqs = equations(sys)
174172
trueobs, eqs = unhack_observed(observed(sys), eqs)
175173
vars = unique([unknowns(sys); getfield.(trueobs, :lhs)])
@@ -244,17 +242,15 @@ function generate_initializesystem(sys::AbstractTimeIndependentSystem;
244242
for k in keys(defs)
245243
defs[k] = substitute(defs[k], paramsubs)
246244
end
247-
meta = InitializationSystemMetadata(
248-
anydict(u0map), anydict(pmap), additional_guesses,
249-
additional_initialization_eqs, extra_metadata, nothing)
245+
250246
return NonlinearSystem(eqs_ics,
251247
vars,
252248
pars;
253249
defaults = defs,
254250
checks = check_units,
255251
parameter_dependencies = new_parameter_deps,
256252
name,
257-
metadata = meta,
253+
is_initializesystem = true,
258254
kwargs...)
259255
end
260256

@@ -436,64 +432,6 @@ function _has_delays(sys::AbstractSystem, ex, banned)
436432
return any(x -> _has_delays(sys, x, banned), args)
437433
end
438434

439-
struct ReconstructInitializeprob
440-
getter::Any
441-
setter::Any
442-
end
443-
444-
function ReconstructInitializeprob(
445-
srcsys::AbstractSystem, dstsys::AbstractSystem)
446-
syms = reduce(
447-
vcat, reorder_parameters(dstsys, parameters(dstsys));
448-
init = [])
449-
getter = getu(srcsys, syms)
450-
setter = setp_oop(dstsys, syms)
451-
return ReconstructInitializeprob(getter, setter)
452-
end
453-
454-
function (rip::ReconstructInitializeprob)(srcvalp, dstvalp)
455-
newp = rip.setter(dstvalp, rip.getter(srcvalp))
456-
if state_values(dstvalp) === nothing
457-
return nothing, newp
458-
end
459-
srcu0 = state_values(srcvalp)
460-
T = srcu0 === nothing || isempty(srcu0) ? Union{} : eltype(srcu0)
461-
if parameter_values(dstvalp) isa MTKParameters
462-
if !isempty(newp.tunable)
463-
T = promote_type(eltype(newp.tunable), T)
464-
end
465-
elseif !isempty(newp)
466-
T = promote_type(eltype(newp), T)
467-
end
468-
if T == eltype(state_values(dstvalp))
469-
u0 = state_values(dstvalp)
470-
elseif T != Union{}
471-
u0 = T.(state_values(dstvalp))
472-
end
473-
buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp)
474-
if eltype(buf) != T
475-
newbuf = similar(buf, T)
476-
copyto!(newbuf, buf)
477-
newp = repack(newbuf)
478-
end
479-
buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp)
480-
if eltype(buf) != T
481-
newbuf = similar(buf, T)
482-
copyto!(newbuf, buf)
483-
newp = repack(newbuf)
484-
end
485-
return u0, newp
486-
end
487-
488-
struct InitializationSystemMetadata
489-
u0map::Dict{Any, Any}
490-
pmap::Dict{Any, Any}
491-
additional_guesses::Dict{Any, Any}
492-
additional_initialization_eqs::Vector{Equation}
493-
extra_metadata::NamedTuple
494-
oop_reconstruct_u0_p::Union{Nothing, ReconstructInitializeprob}
495-
end
496-
497435
function get_possibly_array_fallback_singletons(varmap, p)
498436
if haskey(varmap, p)
499437
return varmap[p]
@@ -543,22 +481,19 @@ function SciMLBase.remake_initialization_data(
543481
if u0 === missing && p === missing
544482
return odefn.initialization_data
545483
end
484+
485+
oldinitdata = odefn.initialization_data
486+
546487
if !(eltype(u0) <: Pair) && !(eltype(p) <: Pair)
547-
oldinitdata = odefn.initialization_data
548488
oldinitdata === nothing && return nothing
549489

550490
oldinitprob = oldinitdata.initializeprob
551491
oldinitprob === nothing && return nothing
552-
if !SciMLBase.has_sys(oldinitprob.f) || !(oldinitprob.f.sys isa NonlinearSystem)
553-
return oldinitdata
554-
end
555-
oldinitsys = oldinitprob.f.sys
556-
meta = get_metadata(oldinitsys)
557-
if meta isa InitializationSystemMetadata && meta.oop_reconstruct_u0_p !== nothing
558-
reconstruct_fn = meta.oop_reconstruct_u0_p
559-
else
560-
reconstruct_fn = ReconstructInitializeprob(sys, oldinitsys)
561-
end
492+
493+
meta = oldinitdata.metadata
494+
meta isa InitializationMetadata || return oldinitdata
495+
496+
reconstruct_fn = meta.oop_reconstruct_u0_p
562497
# the history function doesn't matter because `reconstruct_fn` is only going to
563498
# update the values of parameters, which aren't time dependent. The reason it
564499
# is called is because `Initial` parameters are calculated from the corresponding
@@ -569,16 +504,15 @@ function SciMLBase.remake_initialization_data(
569504
if oldinitprob.f.resid_prototype === nothing
570505
newf = oldinitprob.f
571506
else
572-
newf = NonlinearFunction{
573-
SciMLBase.isinplace(oldinitprob.f), SciMLBase.specialization(oldinitprob.f)}(
574-
oldinitprob.f;
507+
newf = remake(oldinitprob.f;
575508
resid_prototype = calculate_resid_prototype(
576509
length(oldinitprob.f.resid_prototype), new_initu0, new_initp))
577510
end
578511
initprob = remake(oldinitprob; f = newf, u0 = new_initu0, p = new_initp)
579512
return SciMLBase.OverrideInitData(initprob, oldinitdata.update_initializeprob!,
580-
oldinitdata.initializeprobmap, oldinitdata.initializeprobpmap)
513+
oldinitdata.initializeprobmap, oldinitdata.initializeprobpmap; metadata = oldinitdata.metadata)
581514
end
515+
582516
dvs = unknowns(sys)
583517
ps = parameters(sys)
584518
u0map = to_varmap(u0, dvs)
@@ -592,16 +526,13 @@ function SciMLBase.remake_initialization_data(
592526
use_scc = true
593527
initialization_eqs = Equation[]
594528

595-
if SciMLBase.has_initializeprob(odefn)
596-
oldsys = odefn.initialization_data.initializeprob.f.sys
597-
meta = get_metadata(oldsys)
598-
if meta isa InitializationSystemMetadata
599-
u0map = merge(meta.u0map, u0map)
600-
pmap = merge(meta.pmap, pmap)
601-
merge!(guesses, meta.additional_guesses)
602-
use_scc = get(meta.extra_metadata, :use_scc, true)
603-
initialization_eqs = meta.additional_initialization_eqs
604-
end
529+
if oldinitdata !== nothing && oldinitdata.metadata isa InitializationMetadata
530+
meta = oldinitdata.metadata
531+
u0map = merge(meta.u0map, u0map)
532+
pmap = merge(meta.pmap, pmap)
533+
merge!(guesses, meta.guesses)
534+
use_scc = meta.use_scc
535+
initialization_eqs = meta.additional_initialization_eqs
605536
else
606537
# there is no initializeprob, so the original problem construction
607538
# had no solvable parameters and had the differential variables
@@ -662,19 +593,22 @@ function SciMLBase.late_binding_update_u0_p(
662593
if !(eltype(u0) <: Pair)
663594
# if `p` is not provided or is symbolic
664595
p === missing || eltype(p) <: Pair || return newu0, newp
665-
newu0 === nothing && return newu0, newp
666-
all(is_parameter(sys, Initial(x)) for x in unknowns(sys)) || return newu0, newp
596+
(newu0 === nothing || isempty(newu0)) && return newu0, newp
597+
initdata = prob.f.initialization_data
598+
initdata === nothing && return newu0, newp
599+
meta = initdata.metadata
600+
meta isa InitializationMetadata || return newu0, newp
667601
newp = p === missing ? copy(newp) : newp
668602
initials, repack, alias = SciMLStructures.canonicalize(
669603
SciMLStructures.Initials(), newp)
670604
if eltype(initials) != eltype(newu0)
671605
initials = DiffEqBase.promote_u0(initials, newu0, t0)
672606
newp = repack(initials)
673607
end
674-
if length(newu0) != length(unknowns(sys))
675-
throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(unknowns(sys)))). Got $(typeof(newu0)) of length $(length(newu0))"))
608+
if length(newu0) != length(prob.u0)
609+
throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))"))
676610
end
677-
setp(sys, Initial.(unknowns(sys)))(newp, newu0)
611+
meta.set_initial_unknowns!(newp, newu0)
678612
return newu0, newp
679613
end
680614

@@ -714,7 +648,7 @@ end
714648
Check if the given system is an initialization system.
715649
"""
716650
function is_initializesystem(sys::AbstractSystem)
717-
sys isa NonlinearSystem && get_metadata(sys) isa InitializationSystemMetadata
651+
has_is_initializesystem(sys) && get_is_initializesystem(sys)
718652
end
719653

720654
"""

0 commit comments

Comments
 (0)