Skip to content

Commit 131217e

Browse files
feat: proagate cse kwarg from problem constructors
1 parent 57a66b1 commit 131217e

10 files changed

+78
-65
lines changed

src/systems/abstractsystem.jl

+9-6
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,8 @@ end
535535
SymbolicIndexingInterface.supports_tuple_observed(::AbstractSystem) = true
536536

537537
function SymbolicIndexingInterface.observed(
538-
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__, checkbounds = true)
538+
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__,
539+
checkbounds = true, cse = true)
539540
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
540541
if sym isa Symbol
541542
_sym = get(ic.symbol_to_variable, sym, nothing)
@@ -559,7 +560,7 @@ function SymbolicIndexingInterface.observed(
559560
end
560561
end
561562
return build_explicit_observed_function(
562-
sys, sym; eval_expression, eval_module, checkbounds)
563+
sys, sym; eval_expression, eval_module, checkbounds, cse)
563564
end
564565

565566
function SymbolicIndexingInterface.default_values(sys::AbstractSystem)
@@ -1774,13 +1775,14 @@ struct ObservedFunctionCache{S}
17741775
eval_expression::Bool
17751776
eval_module::Module
17761777
checkbounds::Bool
1778+
cse::Bool
17771779
end
17781780

17791781
function ObservedFunctionCache(
17801782
sys; steady_state = false, eval_expression = false,
1781-
eval_module = @__MODULE__, checkbounds = true)
1783+
eval_module = @__MODULE__, checkbounds = true, cse = true)
17821784
return ObservedFunctionCache(
1783-
sys, Dict(), steady_state, eval_expression, eval_module, checkbounds)
1785+
sys, Dict(), steady_state, eval_expression, eval_module, checkbounds, cse)
17841786
end
17851787

17861788
# This is hit because ensemble problems do a deepcopy
@@ -1791,8 +1793,9 @@ function Base.deepcopy_internal(ofc::ObservedFunctionCache, stackdict::IdDict)
17911793
eval_expression = ofc.eval_expression
17921794
eval_module = ofc.eval_module
17931795
checkbounds = ofc.checkbounds
1796+
cse = ofc.cse
17941797
newofc = ObservedFunctionCache(
1795-
sys, dict, steady_state, eval_expression, eval_module, checkbounds)
1798+
sys, dict, steady_state, eval_expression, eval_module, checkbounds, cse)
17961799
stackdict[ofc] = newofc
17971800
return newofc
17981801
end
@@ -1801,7 +1804,7 @@ function (ofc::ObservedFunctionCache)(obsvar, args...)
18011804
obs = get!(ofc.dict, value(obsvar)) do
18021805
SymbolicIndexingInterface.observed(
18031806
ofc.sys, obsvar; eval_expression = ofc.eval_expression,
1804-
eval_module = ofc.eval_module, checkbounds = ofc.checkbounds)
1807+
eval_module = ofc.eval_module, checkbounds = ofc.checkbounds, cse = ofc.cse)
18051808
end
18061809
if ofc.steady_state
18071810
obs = let fn = obs

src/systems/diffeqs/abstractodesystem.jl

+23-16
Original file line numberDiff line numberDiff line change
@@ -312,12 +312,13 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
312312
analytic = nothing,
313313
split_idxs = nothing,
314314
initialization_data = nothing,
315+
cse = true,
315316
kwargs...) where {iip, specialize}
316317
if !iscomplete(sys)
317318
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`")
318319
end
319320
f_gen = generate_function(sys, dvs, ps; expression = Val{true},
320-
expression_module = eval_module, checkbounds = checkbounds,
321+
expression_module = eval_module, checkbounds = checkbounds, cse,
321322
kwargs...)
322323
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
323324
f = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(f_oop, f_iip)
@@ -333,7 +334,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
333334
tgrad_gen = generate_tgrad(sys, dvs, ps;
334335
simplify = simplify,
335336
expression = Val{true},
336-
expression_module = eval_module,
337+
expression_module = eval_module, cse,
337338
checkbounds = checkbounds, kwargs...)
338339
tgrad_oop, tgrad_iip = eval_or_rgf.(tgrad_gen; eval_expression, eval_module)
339340
_tgrad = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(tgrad_oop, tgrad_iip)
@@ -345,7 +346,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
345346
jac_gen = generate_jacobian(sys, dvs, ps;
346347
simplify = simplify, sparse = sparse,
347348
expression = Val{true},
348-
expression_module = eval_module,
349+
expression_module = eval_module, cse,
349350
checkbounds = checkbounds, kwargs...)
350351
jac_oop, jac_iip = eval_or_rgf.(jac_gen; eval_expression, eval_module)
351352

@@ -365,7 +366,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
365366
end
366367

367368
observedfun = ObservedFunctionCache(
368-
sys; steady_state, eval_expression, eval_module, checkbounds)
369+
sys; steady_state, eval_expression, eval_module, checkbounds, cse)
369370

370371
jac_prototype = if sparse
371372
uElType = u0 === nothing ? Float64 : eltype(u0)
@@ -420,12 +421,13 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
420421
eval_module = @__MODULE__,
421422
checkbounds = false,
422423
initialization_data = nothing,
424+
cse = true,
423425
kwargs...) where {iip}
424426
if !iscomplete(sys)
425427
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`")
426428
end
427429
f_gen = generate_function(sys, dvs, ps; implicit_dae = true,
428-
expression = Val{true},
430+
expression = Val{true}, cse,
429431
expression_module = eval_module, checkbounds = checkbounds,
430432
kwargs...)
431433
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
@@ -435,7 +437,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
435437
jac_gen = generate_dae_jacobian(sys, dvs, ps;
436438
simplify = simplify, sparse = sparse,
437439
expression = Val{true},
438-
expression_module = eval_module,
440+
expression_module = eval_module, cse,
439441
checkbounds = checkbounds, kwargs...)
440442
jac_oop, jac_iip = eval_or_rgf.(jac_gen; eval_expression, eval_module)
441443

@@ -445,7 +447,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
445447
end
446448

447449
observedfun = ObservedFunctionCache(
448-
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
450+
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false), cse)
449451

450452
jac_prototype = if sparse
451453
uElType = u0 === nothing ? Float64 : eltype(u0)
@@ -479,14 +481,15 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
479481
eval_module = @__MODULE__,
480482
checkbounds = false,
481483
initialization_data = nothing,
484+
cse = true,
482485
kwargs...) where {iip}
483486
if !iscomplete(sys)
484487
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `DDEFunction`")
485488
end
486489
f_gen = generate_function(sys, dvs, ps; isdde = true,
487490
expression = Val{true},
488491
expression_module = eval_module, checkbounds = checkbounds,
489-
kwargs...)
492+
cse, kwargs...)
490493
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
491494
f = GeneratedFunctionWrapper{(3, 4, is_split(sys))}(f_oop, f_iip)
492495

@@ -503,19 +506,20 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys
503506
eval_module = @__MODULE__,
504507
checkbounds = false,
505508
initialization_data = nothing,
509+
cse = true,
506510
kwargs...) where {iip}
507511
if !iscomplete(sys)
508512
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `SDDEFunction`")
509513
end
510514
f_gen = generate_function(sys, dvs, ps; isdde = true,
511515
expression = Val{true},
512516
expression_module = eval_module, checkbounds = checkbounds,
513-
kwargs...)
517+
cse, kwargs...)
514518
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
515519
f = GeneratedFunctionWrapper{(3, 4, is_split(sys))}(f_oop, f_iip)
516520

517521
g_gen = generate_diffusion_function(sys, dvs, ps; expression = Val{true},
518-
isdde = true, kwargs...)
522+
isdde = true, cse, kwargs...)
519523
g_oop, g_iip = eval_or_rgf.(g_gen; eval_expression, eval_module)
520524
g = GeneratedFunctionWrapper{(3, 4, is_split(sys))}(g_oop, g_iip)
521525

@@ -841,6 +845,7 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
841845
warn_initialize_determined = true,
842846
eval_expression = false,
843847
eval_module = @__MODULE__,
848+
cse = true,
844849
kwargs...) where {iip, specialize}
845850
if !iscomplete(sys)
846851
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`")
@@ -864,12 +869,12 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
864869
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
865870
f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, _u0map, parammap;
866871
t = tspan !== nothing ? tspan[1] : tspan, guesses,
867-
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
872+
check_length, warn_initialize_determined, eval_expression, eval_module, cse, kwargs...)
868873

869874
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
870875
u0_idxs = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k, v) in u0map]
871876

872-
fns = generate_function_bc(sys, u0, u0_idxs, tspan)
877+
fns = generate_function_bc(sys, u0, u0_idxs, tspan; cse)
873878
bc_oop, bc_iip = eval_or_rgf.(fns; eval_expression, eval_module)
874879
bc(sol, p, t) = bc_oop(sol, p, t)
875880
bc(resid, u, p, t) = bc_iip(resid, u, p, t)
@@ -988,15 +993,16 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
988993
eval_expression = false,
989994
eval_module = @__MODULE__,
990995
u0_constructor = identity,
996+
cse = true,
991997
kwargs...) where {iip}
992998
if !iscomplete(sys)
993999
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DDEProblem`")
9941000
end
9951001
f, u0, p = process_SciMLProblem(DDEFunction{iip}, sys, u0map, parammap;
9961002
t = tspan !== nothing ? tspan[1] : tspan,
997-
symbolic_u0 = true, u0_constructor,
1003+
symbolic_u0 = true, u0_constructor, cse,
9981004
check_length, eval_expression, eval_module, kwargs...)
999-
h_gen = generate_history(sys, u0; expression = Val{true})
1005+
h_gen = generate_history(sys, u0; expression = Val{true}, cse)
10001006
h_oop, h_iip = eval_or_rgf.(h_gen; eval_expression, eval_module)
10011007
h = h_oop
10021008
u0 = float.(h(p, tspan[1]))
@@ -1027,15 +1033,16 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
10271033
eval_expression = false,
10281034
eval_module = @__MODULE__,
10291035
u0_constructor = identity,
1036+
cse = true,
10301037
kwargs...) where {iip}
10311038
if !iscomplete(sys)
10321039
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `SDDEProblem`")
10331040
end
10341041
f, u0, p = process_SciMLProblem(SDDEFunction{iip}, sys, u0map, parammap;
10351042
t = tspan !== nothing ? tspan[1] : tspan,
10361043
symbolic_u0 = true, eval_expression, eval_module, u0_constructor,
1037-
check_length, kwargs...)
1038-
h_gen = generate_history(sys, u0; expression = Val{true})
1044+
check_length, cse, kwargs...)
1045+
h_gen = generate_history(sys, u0; expression = Val{true}, cse)
10391046
h_oop, h_iip = eval_or_rgf.(h_gen; eval_expression, eval_module)
10401047
h = h_oop
10411048
u0 = h(p, tspan[1])

src/systems/diffeqs/odesystem.jl

+4-3
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,8 @@ Generates a function that computes the observed value(s) `ts` in the system `sys
454454
- `checkbounds = true` checks bounds if true when destructuring parameters
455455
- `op = Operator` sets the recursion terminator for the walk done by `vars` to identify the variables that appear in `ts`. See the documentation for `vars` for more detail.
456456
- `throw = true` if true, throw an error when generating a function for `ts` that reference variables that do not exist.
457-
- `mkarray`; only used if the output is an array (that is, `!isscalar(ts)` and `ts` is not a tuple, in which case the result will always be a tuple). Called as `mkarray(ts, output_type)` where `ts` are the expressions to put in
458-
the array and `output_type` is the argument of the same name passed to build_explicit_observed_function.
457+
- `mkarray`: only used if the output is an array (that is, `!isscalar(ts)` and `ts` is not a tuple, in which case the result will always be a tuple). Called as `mkarray(ts, output_type)` where `ts` are the expressions to put in the array and `output_type` is the argument of the same name passed to build_explicit_observed_function.
458+
- `cse = true`: Whether to use Common Subexpression Elimination (CSE) to generate a more efficient function.
459459
460460
## Returns
461461
@@ -493,6 +493,7 @@ function build_explicit_observed_function(sys, ts;
493493
param_only = false,
494494
op = Operator,
495495
throw = true,
496+
cse = true,
496497
mkarray = nothing)
497498
is_tuple = ts isa Tuple
498499
if is_tuple
@@ -579,7 +580,7 @@ function build_explicit_observed_function(sys, ts;
579580
p_end = length(dvs) + length(inputs) + length(ps)
580581
fns = build_function_wrapper(
581582
sys, ts, args...; p_start, p_end, filter_observed = obsfilter,
582-
output_type, mkarray, try_namespaced = true, expression = Val{true})
583+
output_type, mkarray, try_namespaced = true, expression = Val{true}, cse)
583584
if fns isa Tuple
584585
if expression
585586
return return_inplace ? fns : fns[1]

src/systems/diffeqs/sdesystem.jl

+7-7
Original file line numberDiff line numberDiff line change
@@ -595,23 +595,23 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
595595
jac = false, Wfact = false, eval_expression = false,
596596
eval_module = @__MODULE__,
597597
checkbounds = false, initialization_data = nothing,
598-
kwargs...) where {iip, specialize}
598+
cse = true, kwargs...) where {iip, specialize}
599599
if !iscomplete(sys)
600600
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunction`")
601601
end
602602
dvs = scalarize.(dvs)
603603

604-
f_gen = generate_function(sys, dvs, ps; expression = Val{true}, kwargs...)
604+
f_gen = generate_function(sys, dvs, ps; expression = Val{true}, cse, kwargs...)
605605
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
606606
g_gen = generate_diffusion_function(sys, dvs, ps; expression = Val{true},
607-
kwargs...)
607+
cse, kwargs...)
608608
g_oop, g_iip = eval_or_rgf.(g_gen; eval_expression, eval_module)
609609

610610
f = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(f_oop, f_iip)
611611
g = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(g_oop, g_iip)
612612

613613
if tgrad
614-
tgrad_gen = generate_tgrad(sys, dvs, ps; expression = Val{true},
614+
tgrad_gen = generate_tgrad(sys, dvs, ps; expression = Val{true}, cse,
615615
kwargs...)
616616
tgrad_oop, tgrad_iip = eval_or_rgf.(tgrad_gen; eval_expression, eval_module)
617617
_tgrad = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(tgrad_oop, tgrad_iip)
@@ -621,7 +621,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
621621

622622
if jac
623623
jac_gen = generate_jacobian(sys, dvs, ps; expression = Val{true},
624-
sparse = sparse, kwargs...)
624+
sparse = sparse, cse, kwargs...)
625625
jac_oop, jac_iip = eval_or_rgf.(jac_gen; eval_expression, eval_module)
626626

627627
_jac = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(jac_oop, jac_iip)
@@ -631,7 +631,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
631631

632632
if Wfact
633633
tmp_Wfact, tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, true;
634-
expression = Val{true}, kwargs...)
634+
expression = Val{true}, cse, kwargs...)
635635
Wfact_oop, Wfact_iip = eval_or_rgf.(tmp_Wfact; eval_expression, eval_module)
636636
Wfact_oop_t, Wfact_iip_t = eval_or_rgf.(tmp_Wfact_t; eval_expression, eval_module)
637637

@@ -645,7 +645,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
645645
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M)
646646

647647
observedfun = ObservedFunctionCache(
648-
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
648+
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false), cse)
649649

650650
SDEFunction{iip, specialize}(f, g;
651651
sys = sys,

src/systems/discrete_system/discrete_system.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -360,13 +360,13 @@ function SciMLBase.DiscreteFunction{iip, specialize}(
360360
t = nothing,
361361
eval_expression = false,
362362
eval_module = @__MODULE__,
363-
analytic = nothing,
363+
analytic = nothing, cse = true,
364364
kwargs...) where {iip, specialize}
365365
if !iscomplete(sys)
366366
error("A completed `DiscreteSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`")
367367
end
368368
f_gen = generate_function(sys, dvs, ps; expression = Val{true},
369-
expression_module = eval_module, kwargs...)
369+
expression_module = eval_module, cse, kwargs...)
370370
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
371371
f = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(f_oop, f_iip)
372372

@@ -378,7 +378,7 @@ function SciMLBase.DiscreteFunction{iip, specialize}(
378378
end
379379

380380
observedfun = ObservedFunctionCache(
381-
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
381+
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false), cse)
382382

383383
DiscreteFunction{iip, specialize}(f;
384384
sys = sys,

src/systems/discrete_system/implicit_discrete_system.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -369,13 +369,13 @@ function SciMLBase.ImplicitDiscreteFunction{iip, specialize}(
369369
t = nothing,
370370
eval_expression = false,
371371
eval_module = @__MODULE__,
372-
analytic = nothing,
372+
analytic = nothing, cse = true,
373373
kwargs...) where {iip, specialize}
374374
if !iscomplete(sys)
375375
error("A completed `ImplicitDiscreteSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `ImplicitDiscreteProblem`")
376376
end
377377
f_gen = generate_function(sys, dvs, ps; expression = Val{true},
378-
expression_module = eval_module, kwargs...)
378+
expression_module = eval_module, cse, kwargs...)
379379
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
380380
f(u_next, u, p, t) = f_oop(u_next, u, p, t)
381381
f(resid, u_next, u, p, t) = f_iip(resid, u_next, u, p, t)
@@ -388,7 +388,7 @@ function SciMLBase.ImplicitDiscreteFunction{iip, specialize}(
388388
end
389389

390390
observedfun = ObservedFunctionCache(
391-
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
391+
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false), cse)
392392

393393
ImplicitDiscreteFunction{iip, specialize}(f;
394394
sys = sys,

0 commit comments

Comments
 (0)