Skip to content

[WIP] feat: reduce reliance on metadata in structural_simplify #3540

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 76 additions & 3 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -668,14 +668,14 @@ function (f::Initial)(x)
iscall(x) && operation(x) isa Initial && return x
result = if symbolic_type(x) == ArraySymbolic()
# create an array for `Initial(array)`
Symbolics.array_term(f, toparam(x))
Symbolics.array_term(f, x)
elseif iscall(x) && operation(x) == getindex
# instead of `Initial(x[1])` create `Initial(x)[1]`
# which allows parameter indexing to handle this case automatically.
arr = arguments(x)[1]
term(getindex, f(toparam(arr)), arguments(x)[2:end]...)
term(getindex, f(arr), arguments(x)[2:end]...)
else
term(f, toparam(x))
term(f, x)
end
# the result should be a parameter
result = toparam(result)
Expand Down Expand Up @@ -1114,9 +1114,25 @@ function _apply_to_variables(f::F, ex) where {F}
metadata(ex))
end

"""
Variable metadata key which contains information about scoping/namespacing of the
variable in a hierarchical system.
"""
abstract type SymScope end

"""
$(TYPEDEF)

The default scope of a variable. It belongs to the system whose equations it is involved
in and is namespaced by every level of the hierarchy.
"""
struct LocalScope <: SymScope end

"""
$(TYPEDSIGNATURES)

Apply `LocalScope` to `sym`.
"""
function LocalScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}})
apply_to_variables(sym) do sym
if iscall(sym) && operation(sym) === getindex
Expand All @@ -1130,9 +1146,25 @@ function LocalScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}})
end
end

"""
$(TYPEDEF)

Denotes that the variable does not belong to the system whose equations it is involved
in. It is not namespaced by this system. In the immediate parent of this system, the
scope of this variable is given by `parent`.

# Fields

$(TYPEDFIELDS)
"""
struct ParentScope <: SymScope
parent::SymScope
end
"""
$(TYPEDSIGNATURES)

Apply `ParentScope` to `sym`, with `parent` being `LocalScope`.
"""
function ParentScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}})
apply_to_variables(sym) do sym
if iscall(sym) && operation(sym) === getindex
Expand All @@ -1148,10 +1180,31 @@ function ParentScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}})
end
end

"""
$(TYPEDEF)

Denotes that a variable belongs to a system that is at least `N + 1` levels up in the
hierarchy from the system whose equations it is involved in. It is namespaced by the
first `N` parents and not namespaced by the `N+1`th parent in the hierarchy. The scope
of the variable after this point is given by `parent`.

In other words, this scope delays applying `ParentScope` by `N` levels, and applies
`LocalScope` in the meantime.

# Fields

$(TYPEDFIELDS)
"""
struct DelayParentScope <: SymScope
parent::SymScope
N::Int
end

"""
$(TYPEDSIGNATURES)

Apply `DelayParentScope` to `sym`, with a delay of `N` and `parent` being `LocalScope`.
"""
function DelayParentScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}, N)
apply_to_variables(sym) do sym
if iscall(sym) && operation(sym) == getindex
Expand All @@ -1166,9 +1219,29 @@ function DelayParentScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}, N)
end
end
end

"""
$(TYPEDSIGNATURES)

Apply `DelayParentScope` to `sym`, with a delay of `1` and `parent` being `LocalScope`.
"""
DelayParentScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}) = DelayParentScope(sym, 1)

"""
$(TYPEDEF)

Denotes that a variable belongs to the root system in the hierarchy, regardless of which
equations of subsystems in the hierarchy it is involved in. Variables with this scope
are never namespaced and only added to the unknowns/parameters of a system when calling
`complete` or `structural_simplify`.
"""
struct GlobalScope <: SymScope end

"""
$(TYPEDSIGNATURES)

Apply `GlobalScope` to `sym`.
"""
function GlobalScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}})
apply_to_variables(sym) do sym
if iscall(sym) && operation(sym) == getindex
Expand Down
183 changes: 104 additions & 79 deletions src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,105 +253,106 @@ function Base.push!(ev::EquationsView, eq)
push!(ev.ts.extra_eqs, eq)
end

function symbolic_contains(var, set)
var in set || symbolic_type(var) == ArraySymbolic() && Symbolics.shape(var) != Symbolics.Unknown() && all(i -> var[i] in set, eachindex(var))
end

function TearingState(sys; quick_cancel = false, check = true)
# flatten system
sys = flatten(sys)
ivs = independent_variables(sys)
iv = length(ivs) == 1 ? ivs[1] : nothing
# scalarize array equations, without scalarizing arguments to registered functions
eqs = flatten_equations(copy(equations(sys)))
# flatten array equations
eqs = flatten_equations(equations(sys))
neqs = length(eqs)
dervaridxs = OrderedSet{Int}()
var2idx = Dict{Any, Int}()
symbolic_incidence = []
fullvars = []
var_counter = Ref(0)
var_types = VariableType[]
addvar! = let fullvars = fullvars, var_counter = var_counter, var_types = var_types
# * Scalarize unknowns
dvs = Set{BasicSymbolic}()
fullvars = BasicSymbolic[]
for x in unknowns(sys)
push!(dvs, x)
xx = Symbolics.scalarize(x)
if xx isa AbstractArray
union!(dvs, xx)
append!(fullvars, xx)
else
push!(fullvars, xx)
end
end
var2idx = Dict{BasicSymbolic, Int}(v => k for (k, v) in enumerate(fullvars))
addvar! = let fullvars = fullvars, dvs = dvs, var2idx = var2idx
var -> get!(var2idx, var) do
push!(dvs, var)
push!(fullvars, var)
push!(var_types, getvariabletype(var))
var_counter[] += 1
return length(fullvars)
end
end

vars = OrderedSet()
varsvec = []
for (i, eq′) in enumerate(eqs)
if eq′.lhs isa Connection
check ? error("$(nameof(sys)) has unexpanded `connect` statements") :
return nothing
end
if _iszero(eq′.lhs)
rhs = quick_cancel ? quick_cancel_expr(eq′.rhs) : eq′.rhs
eq = eq′
else
lhs = quick_cancel ? quick_cancel_expr(eq′.lhs) : eq′.lhs
rhs = quick_cancel ? quick_cancel_expr(eq′.rhs) : eq′.rhs
eq = 0 ~ rhs - lhs
# build symbolic incidence
symbolic_incidence = Vector{BasicSymbolic}[]
varsbuf = Set()
for (i, eq) in enumerate(eqs)
rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs
if !_iszero(eq.lhs)
lhs = quick_cancel ? quick_cancel_expr(eq.lhs) : eq.lhs
eq = eqs[i] = 0 ~ rhs - lhs
end
vars!(vars, eq.rhs, op = Symbolics.Operator)
for v in vars
_var, _ = var_from_nested_derivative(v)
any(isequal(_var), ivs) && continue
if isparameter(_var) ||
(iscall(_var) && isparameter(operation(_var)) || isconstant(_var))
continue
empty!(varsbuf)
vars!(varsbuf, eq; op = Symbolics.Operator)
incidence = Set{BasicSymbolic}()
for v in varsbuf
# FIXME: This check still needs to rely on metadata
isconstant(v) && continue
vtype = getvariabletype(v)
# additionally track brownians in fullvars
# TODO: When uniting system types, track brownians in their own field
if vtype == BROWNIAN
i = addvar!(v)
push!(incidence, v)
end
v = scalarize(v)
if v isa AbstractArray
append!(varsvec, v)
else
push!(varsvec, v)
end
end
isalgeq = true
unknownvars = []
for var in varsvec
ModelingToolkit.isdelay(var, iv) && continue
set_incidence = true
@label ANOTHER_VAR
_var, _ = var_from_nested_derivative(var)
any(isequal(_var), ivs) && continue
if isparameter(_var) ||
(iscall(_var) && isparameter(operation(_var)) || isconstant(_var))
continue
end
varidx = addvar!(var)
set_incidence && push!(unknownvars, var)

dvar = var
idx = varidx
while isdifferential(dvar)
if !(idx in dervaridxs)
push!(dervaridxs, idx)

vtype == VARIABLE || continue

if !symbolic_contains(v, dvs)
isvalid = iscall(v) && operation(v) isa Union{Shift, Sample, Hold}
v′ = v
while !isvalid && iscall(v′) && operation(v′) isa Union{Differential, Shift}
v′ = arguments(v)[1]
if v′ in dvs || getmetadata(v′, SymScope, LocalScope()) isa GlobalScope
isvalid = true
break
end
end
if !isvalid
throw(ArgumentError("$v is present in the system but $v′ is not an unknown."))
end
isalgeq = false
dvar = arguments(dvar)[1]
idx = addvar!(dvar)
end

dvar = var
idx = varidx
addvar!(v)
if iscall(v) && operation(v) isa Symbolics.Operator && !isdifferential(v) && (it = input_timedomain(v)) !== nothing
v′ = only(arguments(v))
addvar!(setmetadata(v′, VariableTimeDomain, it))
end
end

if iscall(var) && operation(var) isa Symbolics.Operator &&
!isdifferential(var) && (it = input_timedomain(var)) !== nothing
set_incidence = false
var = only(arguments(var))
var = setmetadata(var, VariableTimeDomain, it)
@goto ANOTHER_VAR
if symbolic_type(v) == ArraySymbolic()
union!(incidence, collect(v))
else
push!(incidence, v)
end
end
push!(symbolic_incidence, copy(unknownvars))
empty!(unknownvars)
empty!(vars)
empty!(varsvec)
if isalgeq
eqs[i] = eq
else
eqs[i] = eqs[i].lhs ~ rhs

push!(symbolic_incidence, collect(incidence))
end

dervaridxs = Int[]
for (i, v) in enumerate(fullvars)
while isdifferential(v)
push!(dervaridxs, i)
v = arguments(v)[1]
i = addvar!(v)
end
end

# Handle shifts - find lowest shift and add intermediates with derivative edges
### Handle discrete variables
lowest_shift = Dict()
for var in fullvars
Expand Down Expand Up @@ -391,6 +392,9 @@ function TearingState(sys; quick_cancel = false, check = true)
end
end
end

var_types = Vector{VariableType}(getvariabletype.(fullvars))

# sort `fullvars` such that the mass matrix is as diagonal as possible.
dervaridxs = collect(dervaridxs)
sorted_fullvars = OrderedSet(fullvars[dervaridxs])
Expand All @@ -414,6 +418,7 @@ function TearingState(sys; quick_cancel = false, check = true)
var2idx = Dict(fullvars .=> eachindex(fullvars))
dervaridxs = 1:length(dervaridxs)

# build `var_to_diff`
nvars = length(fullvars)
diffvars = []
var_to_diff = DiffGraph(nvars, true)
Expand All @@ -425,20 +430,24 @@ function TearingState(sys; quick_cancel = false, check = true)
var_to_diff[diffvaridx] = dervaridx
end

# build incidence graph
graph = BipartiteGraph(neqs, nvars, Val(false))
for (ie, vars) in enumerate(symbolic_incidence), v in vars
jv = var2idx[v]
add_edge!(graph, ie, jv)
end

@set! sys.eqs = eqs
@set! sys.unknowns = [v for (i, v) in enumerate(fullvars) if var_types[i] != BROWNIAN]

eq_to_diff = DiffGraph(nsrcs(graph))

ts = TearingState(sys, fullvars,
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
complete(graph), nothing, var_types, sys isa AbstractDiscreteSystem),
Any[])

# `shift_discrete_system`
if sys isa DiscreteSystem
ts = shift_discrete_system(ts)
end
Expand Down Expand Up @@ -726,3 +735,19 @@ function _structural_simplify!(state::TearingState, io; simplify = false,

ModelingToolkit.invalidate_cache!(sys), input_idxs
end

struct DifferentiatedVariableNotUnknownError <: Exception
differentiated
undifferentiated
end

function Base.showerror(io::IO, err::DifferentiatedVariableNotUnknownError)
undiff = err.undifferentiated
diff = err.differentiated
print(io, "Variable $undiff occurs differentiated as $diff but is not an unknown of the system.")
scope = getmetadata(undiff, SymScope, LocalScope())
depth = expected_scope_depth(scope)
if depth > 0
print(io, "\nVariable $undiff expects $depth more levels in the hierarchy to be an unknown.")
end
end
Loading
Loading