Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

refactor: use Accessors instead of Setfield #3279

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@ version = "9.59.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
@@ -73,11 +74,12 @@ MTKBifurcationKitExt = "BifurcationKit"
MTKChainRulesCoreExt = "ChainRulesCore"
MTKDeepDiffsExt = "DeepDiffs"
MTKHomotopyContinuationExt = "HomotopyContinuation"
MTKLabelledArraysExt = "LabelledArrays"
MTKInfiniteOptExt = "InfiniteOpt"
MTKLabelledArraysExt = "LabelledArrays"

[compat]
AbstractTrees = "0.3, 0.4"
Accessors = "0.1.36"
ArrayInterface = "6, 7"
BifurcationKit = "0.4"
BlockArrays = "1.1"
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
@@ -23,6 +24,7 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[compat]
Accessors = "0.1.36"
BenchmarkTools = "1.3"
BifurcationKit = "0.4"
DataInterpolations = "6.5"
34 changes: 17 additions & 17 deletions docs/src/basics/Events.md
Original file line number Diff line number Diff line change
@@ -412,16 +412,16 @@ is below `furnace_on_threshold` and off when above `furnace_off_threshold`, whil
in between. To do this, we create two continuous callbacks:

```@example events
using Setfield
using Accessors
furnace_disable = ModelingToolkit.SymbolicContinuousCallback(
[temp ~ furnace_off_threshold],
ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, c, i
@set! x.furnace_on = false
@reset x.furnace_on = false
end)
furnace_enable = ModelingToolkit.SymbolicContinuousCallback(
[temp ~ furnace_on_threshold],
ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, c, i
@set! x.furnace_on = true
@reset x.furnace_on = true
end)
```

@@ -432,7 +432,7 @@ You can also write
```julia
[temp ~ furnace_off_threshold] => ModelingToolkit.ImperativeAffect(modified = (;
furnace_on)) do x, o, i, c
@set! x.furnace_on = false
@reset x.furnace_on = false
end
```

@@ -462,7 +462,7 @@ f(modified::NamedTuple, observed::NamedTuple, ctx, integrator)::NamedTuple
The function `f` will be called with `observed` and `modified` `NamedTuple`s that are derived from their respective `NamedTuple` definitions.
In our example, if `furnace_on` is `false`, then the value of the `x` that's passed in as `modified` will be `(furnace_on = false)`.
The modified values should be passed out in the same format: to set `furnace_on` to `true` we need to return a tuple `(furnace_on = true)`.
The examples does this with Setfield, recreating the result tuple before returning it; the returned tuple may optionally be missing values as
The examples does this with Accessors, recreating the result tuple before returning it; the returned tuple may optionally be missing values as
well, in which case those values will not be written back to the problem.

Accordingly, we can now interpret the `ImperativeAffect` definitions to mean that when `temp = furnace_off_threshold` we
@@ -542,18 +542,18 @@ In our encoder, we interpret this as occlusion or nonocclusion of the sensor, up
```@example events
qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta) ~ 0],
ModelingToolkit.ImperativeAffect((; qA, hA, hB, cnt), (; qB)) do x, o, c, i
@set! x.hA = x.qA
@set! x.hB = o.qB
@set! x.qA = 1
@set! x.cnt += decoder(x.hA, x.hB, x.qA, o.qB)
@reset x.hA = x.qA
@reset x.hB = o.qB
@reset x.qA = 1
@reset x.cnt += decoder(x.hA, x.hB, x.qA, o.qB)
x
end,
affect_neg = ModelingToolkit.ImperativeAffect(
(; qA, hA, hB, cnt), (; qB)) do x, o, c, i
@set! x.hA = x.qA
@set! x.hB = o.qB
@set! x.qA = 0
@set! x.cnt += decoder(x.hA, x.hB, x.qA, o.qB)
@reset x.hA = x.qA
@reset x.hB = o.qB
@reset x.qA = 0
@reset x.cnt += decoder(x.hA, x.hB, x.qA, o.qB)
x
end)
```
@@ -566,10 +566,10 @@ Instead, we can use right root finding:
```@example events
qBevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta - π / 2) ~ 0],
ModelingToolkit.ImperativeAffect((; qB, hA, hB, cnt), (; qA, theta)) do x, o, c, i
@set! x.hA = o.qA
@set! x.hB = x.qB
@set! x.qB = clamp(sign(cos(100 * o.theta - π / 2)), 0.0, 1.0)
@set! x.cnt += decoder(x.hA, x.hB, o.qA, x.qB)
@reset x.hA = o.qA
@reset x.hB = x.qB
@reset x.qB = clamp(sign(cos(100 * o.theta - π / 2)), 0.0, 1.0)
@reset x.cnt += decoder(x.hA, x.hB, o.qA, x.qB)
x
end; rootfind = SciMLBase.RightRootFind)
```
4 changes: 2 additions & 2 deletions ext/MTKBifurcationKitExt.jl
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@ module MTKBifurcationKitExt
### Preparations ###

# Imports
using ModelingToolkit, Setfield
using ModelingToolkit, Accessors
import BifurcationKit

### Observable Plotting Handling ###
@@ -94,7 +94,7 @@ function BifurcationKit.BifurcationProblem(nsys::NonlinearSystem,
if !ModelingToolkit.iscomplete(nsys)
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `BifurcationProblem`")
end
@set! nsys.index_cache = nothing # force usage of a parameter vector instead of `MTKParameters`
@reset nsys.index_cache = nothing # force usage of a parameter vector instead of `MTKParameters`
# Creates F and J functions.
ofun = NonlinearFunction(nsys; jac = jac)
F = ofun.f
3 changes: 2 additions & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
@@ -31,7 +31,8 @@ using JumpProcesses
using DataStructures
using Base.Threads
using Latexify, Unitful, ArrayInterface
using Setfield, ConstructionBase
import Setfield
using Accessors, ConstructionBase
import Libdl
using DocStringExtensions
using Base: RefValue
2 changes: 1 addition & 1 deletion src/bipartite_graph.jl
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@ using DocStringExtensions
using UnPack
using SparseArrays
using Graphs
using Setfield
using Accessors

### Matching
struct Unassigned
16 changes: 8 additions & 8 deletions src/inputoutput.jl
Original file line number Diff line number Diff line change
@@ -315,12 +315,12 @@ function inputs_to_parameters!(state::TransformationState, io)
@assert new_v > 0
new_var_to_diff[new_i] = new_v
end
@set! structure.var_to_diff = complete(new_var_to_diff)
@set! structure.graph = complete(new_graph)
@reset structure.var_to_diff = complete(new_var_to_diff)
@reset structure.graph = complete(new_graph)

@set! sys.eqs = isempty(input_to_parameters) ? equations(sys) :
@reset sys.eqs = isempty(input_to_parameters) ? equations(sys) :
fast_substitute(equations(sys), input_to_parameters)
@set! sys.unknowns = setdiff(unknowns(sys), keys(input_to_parameters))
@reset sys.unknowns = setdiff(unknowns(sys), keys(input_to_parameters))
ps = parameters(sys)

if io !== nothing
@@ -334,11 +334,11 @@ function inputs_to_parameters!(state::TransformationState, io)
new_parameters = new_parameters[permutation]
end

@set! sys.ps = [ps; new_parameters]
@reset sys.ps = [ps; new_parameters]

@set! state.sys = sys
@set! state.fullvars = new_fullvars
@set! state.structure = structure
@reset state.sys = sys
@reset state.fullvars = new_fullvars
@reset state.structure = structure
base_params = length(ps)
return state, (base_params + 1):(base_params + length(new_parameters)) # (1:length(new_parameters)) .+ base_params
end
2 changes: 1 addition & 1 deletion src/structural_transformation/StructuralTransformations.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module StructuralTransformations

using Setfield: @set!, @set
using Accessors: @set, @reset
using UnPack: @unpack

using Symbolics: unwrap, linear_expansion, fast_substitute
2 changes: 1 addition & 1 deletion src/structural_transformation/codegen.jl
Original file line number Diff line number Diff line change
@@ -300,7 +300,7 @@ function build_torn_function(sys;
rhss)

unknown_vars = Any[fullvars[i] for i in unknowns_idxs]
@set! sys.solved_unknowns = unknown_vars
@reset sys.solved_unknowns = unknown_vars

pre = get_postprocess_fbody(sys)
cpre = get_preprocess_constants(rhss)
4 changes: 2 additions & 2 deletions src/structural_transformation/pantelides.jl
Original file line number Diff line number Diff line change
@@ -65,8 +65,8 @@ function pantelides_reassemble(state::TearingState, var_eq_matching)
filter(x -> value(x.lhs) !== nothing,
out_eqs[sort(filter(x -> x !== unassigned, var_eq_matching))]))

@set! sys.eqs = final_eqs
@set! sys.unknowns = final_vars
@reset sys.eqs = final_eqs
@reset sys.unknowns = final_vars
return sys
end

28 changes: 14 additions & 14 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
@@ -133,9 +133,9 @@ end

function tearing_substitution(sys::AbstractSystem; kwargs...)
neweqs = full_equations(sys::AbstractSystem; kwargs...)
@set! sys.eqs = neweqs
@set! sys.substitutions = nothing
@set! sys.schedule = nothing
@reset sys.eqs = neweqs
@reset sys.substitutions = nothing
@reset sys.schedule = nothing
end

function tearing_assignments(sys::AbstractSystem)
@@ -563,10 +563,10 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
diff_to_var = invview(var_to_diff)

old_fullvars = fullvars
@set! state.structure.graph = complete(graph)
@set! state.structure.var_to_diff = var_to_diff
@set! state.structure.eq_to_diff = eq_to_diff
@set! state.fullvars = fullvars = fullvars[invvarsperm]
@reset state.structure.graph = complete(graph)
@reset state.structure.var_to_diff = var_to_diff
@reset state.structure.eq_to_diff = eq_to_diff
@reset state.fullvars = fullvars = fullvars[invvarsperm]
ispresent = let var_to_diff = var_to_diff, graph = graph
i -> (!isempty(𝑑neighbors(graph, i)) ||
(var_to_diff[i] !== nothing && !isempty(𝑑neighbors(graph, var_to_diff[i]))))
@@ -590,24 +590,24 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
push!(unknowns, old_fullvars[v])
end
end
@set! sys.unknowns = unknowns
@reset sys.unknowns = unknowns

obs, subeqs, deps = cse_and_array_hacks(
obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack)

@set! sys.eqs = neweqs
@set! sys.observed = obs
@reset sys.eqs = neweqs
@reset sys.observed = obs

@set! sys.substitutions = Substitutions(subeqs, deps)
@reset sys.substitutions = Substitutions(subeqs, deps)

# Only makes sense for time-dependent
# TODO: generalize to SDE
if sys isa ODESystem
@set! sys.schedule = Schedule(var_eq_matching, dummy_sub)
@reset sys.schedule = Schedule(var_eq_matching, dummy_sub)
end
sys = schedule(sys)
@set! state.sys = sys
@set! sys.tearing_state = state
@reset state.sys = sys
@reset sys.tearing_state = state
return invalidate_cache!(sys)
end

Loading
Loading