Skip to content

[v9] fix: fix major compile time regression due to concrete_getu #3692

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -617,30 +617,40 @@ struct ReconstructInitializeprob{GP, GU}
ugetter::GU
end

"""
$(TYPEDEF)

A wrapper over an observed function which allows calling it on a problem-like object.
`TD` determines whether the getter function is `(u, p, t)` (if `true`) or `(u, p)` (if
`false`).
"""
struct ObservedWrapper{TD, F}
f::F
end

ObservedWrapper{TD}(f::F) where {TD, F} = ObservedWrapper{TD, F}(f)

function (ow::ObservedWrapper{true})(prob)
ow.f(state_values(prob), parameter_values(prob), current_time(prob))
end

function (ow::ObservedWrapper{false})(prob)
ow.f(state_values(prob), parameter_values(prob))
end

"""
$(TYPEDSIGNATURES)

Given an index provider `indp` and a vector of symbols `syms` return a type-stable getter
function by splitting `syms` into contiguous buffers where the getter of each buffer
is type-stable and constructing a function that calls and concatenates the results.
"""
function concrete_getu(indp, syms::AbstractVector)
# a list of contiguous buffer
split_syms = [Any[syms[1]]]
# the type of the getter of the last buffer
current = typeof(getu(indp, syms[1]))
for sym in syms[2:end]
getter = getu(indp, sym)
if typeof(getter) != current
# if types don't match, build a new buffer
push!(split_syms, [])
current = typeof(getter)
end
push!(split_syms[end], sym)
end
split_syms = Tuple(split_syms)
# the getter is now type-stable, and we can vcat it to get the full buffer
return Base.Fix1(reduce, vcat) ∘ getu(indp, split_syms)
function.

Note that the getter ONLY works for problem-like objects, since it generates an observed
function. It does NOT work for solutions.
"""
Base.@nospecializeinfer function concrete_getu(indp, syms::AbstractVector)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the docstring, this could be typed to avoid mistakenly calling it with other types.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The returned callable has this restriction. I'll have to think about how to handle that case, since it should just work on anything that isn't a solution or a lone parameter object/state vector.

@nospecialize
obsfn = SymbolicIndexingInterface.observed(indp, syms)
return ObservedWrapper{is_time_dependent(indp)}(obsfn)
end

"""
Expand Down
Loading