diff --git a/src/callbacks.jl b/src/callbacks.jl index 9f0999c83..cc8a0a002 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -122,18 +122,17 @@ function ContinuousCallback(condition,affect!; rootfind,interp_points, collect(save_positions), dtrelax,abstol,reltol) - end """ ```julia VectorContinuousCallback(condition,affect!,affect_neg!,len; - initialize = INITIALIZE_DEFAULT, - idxs = nothing, - rootfind=true, - save_positions=(true,true), - interp_points=10, - abstol=10eps(),reltol=0) + initialize = INITIALIZE_DEFAULT, + idxs = nothing, + rootfind=true, + save_positions=(true,true), + interp_points=10, + abstol=10eps(),reltol=0,pooltol=nothing,pool_events=false) ``` ```julia @@ -144,7 +143,7 @@ VectorContinuousCallback(condition,affect!,len; save_positions=(true,true), affect_neg! = affect!, interp_points=10, - abstol=10eps(),reltol=0) + abstol=10eps(),reltol=0,pooltol=nothing,pool_events=false) ``` This is also a subtype of `AbstractContinuousCallback`. `CallbackSet` is not feasible when you have a large number of callbacks, @@ -159,10 +158,13 @@ multiple events. - `affect!`: This is a function `affect!(integrator, event_index)` which lets you modify `integrator` and it tells you about which event occured using `event_idx` i.e. gives you index `i` for which `out[i]` came out to be zero. - `len`: Number of callbacks chained. This is compulsory to be specified. +- `pool_events`: Whether multiple concurrent events should be passed as one array of indexs instead of the indexes on a time. +- `pooltol`: Custom limit which values get grouped. Callback accepted if it's absolute value is smaller than pooltol at callback time. + The default value is `eps(integrator.t) + eps(callback_return_type)`. Rest of the arguments have the same meaning as in [`ContinuousCallback`](@ref). """ -struct VectorContinuousCallback{F1,F2,F3,F4,T,T2,I,R} <: AbstractContinuousCallback +struct VectorContinuousCallback{F1,F2,F3,F4,T,T2,T3,I,R} <: AbstractContinuousCallback condition::F1 affect!::F2 affect_neg!::F3 @@ -175,15 +177,17 @@ struct VectorContinuousCallback{F1,F2,F3,F4,T,T2,I,R} <: AbstractContinuousCallb dtrelax::R abstol::T reltol::T2 + pooltol::T3 + pool_events::Bool VectorContinuousCallback(condition::F1,affect!::F2,affect_neg!::F3,len::Int, initialize::F4,idxs::I,rootfind, interp_points,save_positions,dtrelax::R, - abstol::T,reltol::T2) where {F1,F2,F3,F4,T,T2,I,R} = - new{F1,F2,F3,F4,T,T2,I,R}(condition, + abstol::T,reltol::T2, pooltol::T3, pool_events) where {F1,F2,F3,F4,T,T2,T3,I,R} = + new{F1,F2,F3,F4,T,T2,T3,I,R}(condition, affect!,affect_neg!,len, initialize,idxs,rootfind,interp_points, BitArray(collect(save_positions)), - dtrelax,abstol,reltol) + dtrelax,abstol,reltol,pooltol,pool_events) end VectorContinuousCallback(condition,affect!,affect_neg!,len; @@ -193,13 +197,13 @@ VectorContinuousCallback(condition,affect!,affect_neg!,len; save_positions=(true,true), interp_points=10, dtrelax=1, - abstol=10eps(),reltol=0) = VectorContinuousCallback( + abstol=10eps(),reltol=0, pooltol=missing, pool_events=false) = VectorContinuousCallback( condition,affect!,affect_neg!,len, initialize, idxs, rootfind,interp_points, save_positions,dtrelax, - abstol,reltol) + abstol,reltol, pooltol, pool_events) function VectorContinuousCallback(condition,affect!,len; initialize = INITIALIZE_DEFAULT, @@ -209,14 +213,13 @@ function VectorContinuousCallback(condition,affect!,len; affect_neg! = affect!, interp_points=10, dtrelax=1, - abstol=10eps(),reltol=0) + abstol=10eps(),reltol=0, pooltol=missing, pool_events=false) VectorContinuousCallback( condition,affect!,affect_neg!,len,initialize,idxs, rootfind,interp_points, collect(save_positions), - dtrelax,abstol,reltol) - + dtrelax,abstol,reltol,pooltol,pool_events) end """ @@ -754,6 +757,16 @@ function find_callback_time(integrator,callback::VectorContinuousCallback,counte new_t = integrator.dt min_event_idx = event_idx[1] end + if callback.pool_events + tmp = get_condition(integrator, callback, integrator.dt + new_t) + if callback.pooltol isa Missing + # This is still dubious + pool_tol = eps(integrator.t + new_t) + eps(typeof(tmp[end])) + else + pool_tol = callback.pooltol + end + min_event_idx = findall(x-> abs(x) < pool_tol, tmp) + end end else new_t = zero(typeof(integrator.t))