Skip to content

Commit

Permalink
generalize trigger_callback
Browse files Browse the repository at this point in the history
  • Loading branch information
juliasloan25 committed Sep 10, 2024
1 parent 8b8ffcc commit 450641d
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 35 deletions.
2 changes: 1 addition & 1 deletion docs/src/callbackmanager.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ ClimaUtilities.CallbackManager.EveryTimestep
ClimaUtilities.CallbackManager.to_datetime
ClimaUtilities.CallbackManager.strdate_to_datetime
ClimaUtilities.CallbackManager.datetime_to_strdate
ClimaUtilities.CallbackManager.trigger_callback
ClimaUtilities.CallbackManager.trigger_callback!
```
50 changes: 20 additions & 30 deletions src/CallbackManager.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export HourlyCallback,
MonthlyCallback,
Monthly,
EveryTimestep,
trigger_callback,
trigger_callback!,
to_datetime,
strdate_to_datetime,
datetime_to_strdate
Expand All @@ -32,7 +32,7 @@ This is a callback type that triggers at intervals of 1h or multiple hours.
dt::FT = FT(1) # hours
""" Function to be called at each trigger. """
func::Function = do_nothing
""" Reference date for the callback. """
""" Reference date when the callback should be called. """
ref_date::Array = [Dates.DateTime(0)]
""" Whether the callback is active. """
active::Bool = false
Expand Down Expand Up @@ -78,37 +78,27 @@ struct Monthly <: AbstractFrequency end
struct EveryTimestep <: AbstractFrequency end

"""
trigger_callback(date_nextcall::Dates.DateTime,
date_current::Dates.DateTime,
::Monthly,
func::Function,)
trigger_callback!(callback, date_current)
If the current date is equal to or later than the "next call" date at time
00:00:00, call the callback function and increment the next call date by one
month. Otherwise, do nothing and leave the next call date unchanged.
If the callback is active and the current date is equal to or later than the
"next call" reference date/time, call the callback function and increment the
next call date based on the callback frequency. Otherwise, do nothing and leave
the next call date unchanged.
The tuple of arguments `func_args` must match the types, number, and order
of arguments expected by `func`.
Note that the collection of data in `callback.data` must match the types, number,
and orderof arguments expected by `callback.func`.
"""
function trigger_callback!(callback::HourlyCallback, date_current)
if callback.active && date_current >= callback.ref_date[1]
callback.func(callback.data...)
callback.ref_date[1] += Dates.Hour(1)
end
end

# Arguments
- `date_nextcall::DateTime` the next date to call the callback function at or after
- `date_current::DateTime` the current date of the simulation
- `save_freq::AbstractFrequency` frequency with which to trigger callback
- `func::Function` function to be triggered if date is at or past the next call date
- `func_args::Tuple` a tuple of arguments to be passed into the callback function
"""
function trigger_callback(
date_nextcall::Dates.DateTime,
date_current::Dates.DateTime,
::Monthly,
func::Function,
func_args::Tuple,
)
if date_current >= date_nextcall
func(func_args...)
return date_nextcall + Dates.Month(1)
else
return date_nextcall
function trigger_callback!(callback::MonthlyCallback, date_current)
if callback.active && date_current >= callback.ref_date[1]
callback.func(callback.data...)
callback.ref_date[1] += Dates.Month(1)
end
end

Expand Down
8 changes: 4 additions & 4 deletions test/callbackmanager.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ for FT in (Float32, Float64)
"00000101"
end

@testset "test trigger_callback for FT=$FT" begin
@testset "test trigger_callback! for FT=$FT" begin
# Define callback function
func! = (val) -> val[1] += 1
# Case 1: date_current == date_nextcall
Expand All @@ -47,7 +47,7 @@ for FT in (Float32, Float64)
arg_copy = copy(arg)
date_current =
date_nextcall = date_nextcall_copy = Dates.DateTime(1979, 3, 21)
date_nextcall = CallbackManager.trigger_callback(
date_nextcall = CallbackManager.trigger_callback!(
date_nextcall,
date_current,
CallbackManager.Monthly(),
Expand All @@ -61,7 +61,7 @@ for FT in (Float32, Float64)
# Case 2: date_current > date_nextcall
date_nextcall = date_nextcall_copy = Dates.DateTime(1979, 3, 21)
date_current = date_nextcall + Dates.Day(1)
date_nextcall = CallbackManager.trigger_callback(
date_nextcall = CallbackManager.trigger_callback!(
date_nextcall,
date_current,
CallbackManager.Monthly(),
Expand All @@ -75,7 +75,7 @@ for FT in (Float32, Float64)
# Case 3: date_current < date_nextcall
date_nextcall = date_nextcall_copy = Dates.DateTime(1979, 3, 21)
date_current = date_nextcall - Dates.Day(1)
date_nextcall = CallbackManager.trigger_callback(
date_nextcall = CallbackManager.trigger_callback!(
date_nextcall,
date_current,
CallbackManager.Monthly(),
Expand Down

0 comments on commit 450641d

Please # to comment.