Skip to content

Add support for an external synchronous compiler to discrete and hybrid systems #3399

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
3 changes: 2 additions & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ using Symbolics: _parse_vars, value, @derivatives, get_variables,
NAMESPACE_SEPARATOR, set_scalar_metadata, setdefaultval,
initial_state, transition, activeState, entry, hasnode,
ticksInState, timeInState, fixpoint_sub, fast_substitute,
CallWithMetadata, CallWithParent
CallWithMetadata, CallWithParent, Transition, InitialState,
StateMachineOperator
const NAMESPACE_SEPARATOR_SYMBOL = Symbol(NAMESPACE_SEPARATOR)
import Symbolics: rename, get_variables!, _solve, hessian_sparsity,
jacobian_sparsity, isaffine, islinear, _iszero, _isone,
Expand Down
20 changes: 20 additions & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1390,6 +1390,26 @@ function namespace_expr(
O
end
end

function namespace_expr(
O::Transition, sys, n = nameof(sys); ivs = independent_variables(sys))
return Transition(
O.from === nothing ? O.from : renamespace(sys, O.from),
O.to === nothing ? O.to : renamespace(sys, O.to),
O.cond === nothing ? O.cond : namespace_expr(O.cond, sys),
O.immediate, O.reset, O.synchronize, O.priority
)
end

function namespace_expr(
O::InitialState, sys, n = nameof(sys); ivs = independent_variables(sys))
return InitialState(O.s === nothing ? O.s : renamespace(sys, O.s))
end

function namespace_expr(O::StateMachineOperator, sys, n = nameof(sys); kwargs...)
error("Unhandled state machine operator")
end

_nonum(@nospecialize x) = x isa Num ? x.val : x

"""
Expand Down
11 changes: 10 additions & 1 deletion src/systems/clock_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ function infer_clocks!(ci::ClockInference)
c = BitSet(c′)
idxs = intersect(c, inferred)
isempty(idxs) && continue
if !allequal(var_domain[i] for i in idxs)
if !allequal(iscontinuous(var_domain[i]) for i in idxs)
display(fullvars[c′])
throw(ClockInferenceException("Clocks are not consistent in connected component $(fullvars[c′])"))
end
Expand Down Expand Up @@ -144,6 +144,9 @@ function split_system(ci::ClockInference{S}) where {S}
var_to_cid = Vector{Int}(undef, ndsts(graph))
cid_to_var = Vector{Int}[]
cid_counter = Ref(0)

# populates clock_to_id and id_to_clock
# checks if there is a continuous_id (for some reason? clock to id does this too)
for (i, d) in enumerate(eq_domain)
cid = let cid_counter = cid_counter, id_to_clock = id_to_clock,
continuous_id = continuous_id
Expand All @@ -161,9 +164,13 @@ function split_system(ci::ClockInference{S}) where {S}
resize_or_push!(cid_to_eq, i, cid)
end
continuous_id = continuous_id[]
# for each clock partition what are the input (indexes/vars)
input_idxs = map(_ -> Int[], 1:cid_counter[])
inputs = map(_ -> Any[], 1:cid_counter[])
# var_domain corresponds to fullvars/all variables in the system
nvv = length(var_domain)
# put variables into the right clock partition
# keep track of inputs to each partition
for i in 1:nvv
d = var_domain[i]
cid = get(clock_to_id, d, 0)
Expand All @@ -177,6 +184,7 @@ function split_system(ci::ClockInference{S}) where {S}
resize_or_push!(cid_to_var, i, cid)
end

# breaks the system up into a continous and 0 or more discrete systems
tss = similar(cid_to_eq, S)
for (id, ieqs) in enumerate(cid_to_eq)
ts_i = system_subset(ts, ieqs)
Expand All @@ -186,6 +194,7 @@ function split_system(ci::ClockInference{S}) where {S}
end
tss[id] = ts_i
end
# put the continous system at the back
if continuous_id != 0
tss[continuous_id], tss[end] = tss[end], tss[continuous_id]
inputs[continuous_id], inputs[end] = inputs[end], inputs[continuous_id]
Expand Down
4 changes: 3 additions & 1 deletion src/systems/imperative_affect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs.
upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ)

# write the new values back to the integrator
_generated_writeback(integ, upd_funs, upd_vals)
if !isnothing(upd_vals)
_generated_writeback(integ, upd_funs, upd_vals)
end

for idx in save_idxs
SciMLBase.save_discretes!(integ, idx)
Expand Down
13 changes: 10 additions & 3 deletions src/systems/systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function structural_simplify(
kwargs...)
isscheduled(sys) && throw(RepeatedStructuralSimplificationError())
newsys′ = __structural_simplify(sys, io; simplify,
allow_symbolic, allow_parameter, conservative, fully_determined,
allow_symbolic, allow_parameter, conservative, fully_determined, additional_passes,
kwargs...)
if newsys′ isa Tuple
@assert length(newsys′) == 2
Expand Down Expand Up @@ -82,12 +82,13 @@ end

function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = false,
kwargs...)
sys, statemachines = extract_top_level_statemachines(sys)
sys = expand_connections(sys)
state = TearingState(sys)
append!(state.statemachines, statemachines)

@unpack structure, fullvars = state
@unpack graph, var_to_diff, var_types = structure
eqs = equations(state)
brown_vars = Int[]
new_idxs = zeros(Int, length(var_types))
idx = 0
Expand All @@ -104,7 +105,8 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal
Is = Int[]
Js = Int[]
vals = Num[]
new_eqs = copy(eqs)
make_eqs_zero_equals!(state)
new_eqs = copy(equations(state))
dvar2eq = Dict{Any, Int}()
for (v, dv) in enumerate(var_to_diff)
dv === nothing && continue
Expand Down Expand Up @@ -169,3 +171,8 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal
guesses = guesses(sys), initialization_eqs = initialization_equations(sys))
end
end

"""
Mark whether an extra pass `p` can support compiling discrete systems.
"""
discrete_compile_pass(p) = false
149 changes: 121 additions & 28 deletions src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using DataStructures
using Symbolics: linear_expansion, unwrap, Connection
using Symbolics: linear_expansion, unwrap, Connection, Transition, InitialState
using SymbolicUtils: iscall, operation, arguments, Symbolic
using SymbolicUtils: quick_cancel, maketerm
using ..ModelingToolkit
Expand Down Expand Up @@ -198,16 +198,35 @@ end

mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
sys::T
original_eqs::Vector{Equation}
fullvars::Vector
structure::SystemStructure
extra_eqs::Vector
statemachines::Vector{T}
end

TransformationState(sys::AbstractSystem) = TearingState(sys)
function system_subset(ts::TearingState, ieqs::Vector{Int})
eqs = equations(ts)
@set! ts.original_eqs = ts.original_eqs[ieqs]
@set! ts.sys.eqs = eqs[ieqs]
@set! ts.structure = system_subset(ts.structure, ieqs)
if all(eq -> eq.rhs isa StateMachineOperator, get_eqs(ts.sys))
names = Symbol[]
for eq in get_eqs(ts.sys)
if eq.lhs isa Transition
push!(names, first(namespace_hierarchy(nameof(eq.rhs.from))))
push!(names, first(namespace_hierarchy(nameof(eq.rhs.to))))
elseif eq.lhs isa InitialState
push!(names, first(namespace_hierarchy(nameof(eq.rhs.s))))
else
error("Unhandled state machine operator")
end
end
@set! ts.statemachines = filter(x -> nameof(x) in names, ts.statemachines)
else
@set! ts.statemachines = eltype(ts.statemachines)[]
end
ts
end

Expand Down Expand Up @@ -247,12 +266,56 @@ function Base.push!(ev::EquationsView, eq)
push!(ev.ts.extra_eqs, eq)
end

"""
$(TYPEDSIGNATURES)

Descend through the system hierarchy and look for statemachines. Remove equations from
the inner statemachine systems. Return the new `sys` and an array of top-level
statemachines.
"""
function extract_top_level_statemachines(sys::AbstractSystem)
eqs = get_eqs(sys)

if !isempty(eqs) && all(eq -> eq.lhs isa StateMachineOperator, eqs)
# top-level statemachine
with_removed = @set sys.systems = map(remove_child_equations, get_systems(sys))
return with_removed, [sys]
elseif !isempty(eqs) && any(eq -> eq.lhs isa StateMachineOperator, eqs)
# error: can't mix
error("Mixing statemachine equations and standard equations in a top-level statemachine is not allowed.")
else
# descend
subsystems = get_systems(sys)
newsubsystems = eltype(subsystems)[]
statemachines = eltype(subsystems)[]
for subsys in subsystems
newsubsys, sub_statemachines = extract_top_level_statemachines(subsys)
push!(newsubsystems, newsubsys)
append!(statemachines, sub_statemachines)
end
@set! sys.systems = newsubsystems
return sys, statemachines
end
end

"""
$(TYPEDSIGNATURES)

Return `sys` with all equations (including those in subsystems) removed.
"""
function remove_child_equations(sys::AbstractSystem)
@set! sys.eqs = eltype(get_eqs(sys))[]
@set! sys.systems = map(remove_child_equations, get_systems(sys))
return sys
end

function TearingState(sys; quick_cancel = false, check = true)
sys = flatten(sys)
ivs = independent_variables(sys)
iv = length(ivs) == 1 ? ivs[1] : nothing
# scalarize array equations, without scalarizing arguments to registered functions
eqs = flatten_equations(copy(equations(sys)))
original_eqs = flatten_equations(copy(equations(sys)))
eqs = copy(original_eqs)
neqs = length(eqs)
dervaridxs = OrderedSet{Int}()
var2idx = Dict{Any, Int}()
Expand All @@ -275,7 +338,12 @@ function TearingState(sys; quick_cancel = false, check = true)
check ? error("$(nameof(sys)) has unexpanded `connect` statements") :
return nothing
end
if _iszero(eq′.lhs)
is_statemachine_equation = false
if eq′.lhs isa StateMachineOperator
is_statemachine_equation = true
eq = eq′
rhs = eq.rhs
elseif _iszero(eq′.lhs)
rhs = quick_cancel ? quick_cancel_expr(eq′.rhs) : eq′.rhs
eq = eq′
else
Expand Down Expand Up @@ -340,7 +408,7 @@ function TearingState(sys; quick_cancel = false, check = true)
empty!(unknownvars)
empty!(vars)
empty!(varsvec)
if isalgeq
if isalgeq || is_statemachine_equation
eqs[i] = eq
else
eqs[i] = eqs[i].lhs ~ rhs
Expand Down Expand Up @@ -428,10 +496,10 @@ function TearingState(sys; quick_cancel = false, check = true)

eq_to_diff = DiffGraph(nsrcs(graph))

ts = TearingState(sys, fullvars,
ts = TearingState(sys, original_eqs, fullvars,
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
complete(graph), nothing, var_types, sys isa DiscreteSystem),
Any[])
Any[], typeof(sys)[])
if sys isa DiscreteSystem
ts = shift_discrete_system(ts)
end
Expand Down Expand Up @@ -622,44 +690,69 @@ function merge_io(io, inputs)
return io
end

function make_eqs_zero_equals!(ts::TearingState)
neweqs = map(enumerate(get_eqs(ts.sys))) do kvp
i, eq = kvp
isalgeq = true
for j in 𝑠neighbors(ts.structure.graph, i)
isalgeq &= invview(ts.structure.var_to_diff)[j] === nothing
end
if isalgeq
return 0 ~ eq.rhs - eq.lhs
else
return eq
end
end
copyto!(get_eqs(ts.sys), neweqs)
end

function structural_simplify!(state::TearingState, io = nothing; simplify = false,
check_consistency = true, fully_determined = true, warn_initialize_determined = true,
kwargs...)
if state.sys isa ODESystem
# split_system returns one or two systems and the inputs for each
# mod clock inference to be binary
# if it's continous keep going, if not then error unless given trait impl in additional passes
ci = ModelingToolkit.ClockInference(state)
ci = ModelingToolkit.infer_clocks!(ci)
time_domains = merge(Dict(state.fullvars .=> ci.var_domain),
Dict(default_toterm.(state.fullvars) .=> ci.var_domain))
tss, inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci)
if continuous_id == 0
# do a trait check here - handle fully discrete system
additional_passes = get(kwargs, :additional_passes, nothing)
if !isnothing(additional_passes) &&
any(discrete_compile_pass, additional_passes)
# take the first discrete compilation pass given for now
discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes)
discrete_compile = additional_passes[discrete_pass_idx]
deleteat!(additional_passes, discrete_pass_idx)
return discrete_compile(tss, inputs, ci)
else
# error goes here! this is a purely discrete system
throw(HybridSystemNotSupportedException("Discrete systems without JuliaSimCompiler are currently not supported in ODESystem."))
end
end
make_eqs_zero_equals!(tss[continuous_id])
# puts the ios passed in to the call into the continous system
cont_io = merge_io(io, inputs[continuous_id])
# simplify as normal
sys, input_idxs = _structural_simplify!(tss[continuous_id], cont_io; simplify,
check_consistency, fully_determined,
kwargs...)
if length(tss) > 1
if continuous_id > 0
additional_passes = get(kwargs, :additional_passes, nothing)
if !isnothing(additional_passes) &&
any(discrete_compile_pass, additional_passes)
discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes)
discrete_compile = additional_passes[discrete_pass_idx]
deleteat!(additional_passes, discrete_pass_idx)
# in the case of a hybrid system, the discrete_compile pass should take the currents of sys.discrete_subsystems
# and modifies discrete_subsystems to bea tuple of the io and anything else, while adding or manipulating the rest of sys as needed
sys = discrete_compile(sys, tss[[i for i in eachindex(tss) if i != continuous_id]], inputs, ci)
else
throw(HybridSystemNotSupportedException("Hybrid continuous-discrete systems are currently not supported with the standard MTK compiler. This system requires JuliaSimCompiler.jl, see https://help.juliahub.com/juliasimcompiler/stable/"))
end
# TODO: rename it to something else
discrete_subsystems = Vector{ODESystem}(undef, length(tss))
# Note that the appended_parameters must agree with
# `generate_discrete_affect`!
appended_parameters = parameters(sys)
for (i, state) in enumerate(tss)
if i == continuous_id
discrete_subsystems[i] = sys
continue
end
dist_io = merge_io(io, inputs[i])
ss, = _structural_simplify!(state, dist_io; simplify, check_consistency,
fully_determined, kwargs...)
append!(appended_parameters, inputs[i], unknowns(ss))
discrete_subsystems[i] = ss
end
@set! sys.discrete_subsystems = discrete_subsystems, inputs, continuous_id,
id_to_clock
@set! sys.ps = appended_parameters
@set! sys.defaults = merge(ModelingToolkit.defaults(sys),
Dict(v => 0.0 for v in Iterators.flatten(inputs)))
end
ps = [sym isa CallWithMetadata ? sym :
setmetadata(sym, VariableTimeDomain, get(time_domains, sym, Continuous()))
Expand Down
19 changes: 19 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,25 @@ vars(eq::Equation; op = Differential) = vars!(Set(), eq; op = op)
function vars!(vars, eq::Equation; op = Differential)
(vars!(vars, eq.lhs; op = op); vars!(vars, eq.rhs; op = op); vars)
end
function vars!(vars, O::AbstractSystem; op = Differential)
for eq in equations(O)
vars!(vars, eq; op)
end
return vars
end
function vars!(vars, O::Transition; op = Differential)
vars!(vars, O.from)
vars!(vars, O.to)
vars!(vars, O.cond; op)
return vars
end
function vars!(vars, O::InitialState; op = Differential)
vars!(vars, O.s; op)
return vars
end
function vars!(vars, O::StateMachineOperator; op = Differential)
error("Unhandled state machine operator")
end
function vars!(vars, O; op = Differential)
if isvariable(O)
if iscall(O) && operation(O) === getindex && iscalledparameter(first(arguments(O)))
Expand Down