diff --git a/R/sysdata.rda b/R/sysdata.rda index beb919b8..60e01c46 100644 Binary files a/R/sysdata.rda and b/R/sysdata.rda differ diff --git a/R/wwinference.R b/R/wwinference.R index 1e10e6ea..385d5e39 100644 --- a/R/wwinference.R +++ b/R/wwinference.R @@ -31,13 +31,15 @@ #' `get_model_spec()`. The default here pertains to the `forecast_date` in the #' example data provided by the package, but this should be specified by the #' user based on the date they are producing a forecast -#' @param fit_opts The fit options, which in this case default to the -#' MCMC parameters as defined using `get_mcmc_options()`. This includes -#' the following arguments, which are passed to -#' [`$sample()`][cmdstanr::model-method-sample]: -#' the number of chains, the number of warmup -#' and sampling iterations, the maximum tree depth, the average acceptance -#' probability, and the stan PRNG seed +#' @param fit_opts MCMC fitting options, as a list of keys and values. +#' These are passed as keyword arguments to +#' [`compiled_model$sample()`][cmdstanr::model-method-sample]. +#' Where no option is specified, [wwinference()] will fall back first on a +#' package-specific default value given by [get_mcmc_options()], if one exists. +#' If no package-specific default exists, [wwinference()] will fall back on +#' the default value defined in [`$sample()`][cmdstanr::model-method-sample]. +#' See the documentation for [`$sample()`][cmdstanr::model-method-sample] for +#' details on available options. #' @param generate_initial_values Boolean indicating whether or not to specify #' the initialization of the sampler, default is `TRUE`, meaning that #' initialization lists will be generated and passed as the `init` argument @@ -124,24 +126,27 @@ #' calibration_time <- 90 #' forecast_horizon <- 28 #' include_ww <- 1 -#' ww_fit <- wwinference(input_ww_data, -#' input_count_data, +#' +#' ww_fit <- wwinference( +#' ww_data = input_ww_data, +#' count_data = input_count_data, +#' forecast_date = forecast_date, +#' calibration_time = calibration_time, +#' forecast_horizon = forecast_horizon, #' model_spec = get_model_spec( -#' forecast_date = forecast_date, -#' calibration_time = calibration_time, -#' forecast_horizon = forecast_horizon, #' generation_interval = generation_interval, -#' inf_to_count_delay = inf_to_coutn_delay, +#' inf_to_count_delay = inf_to_count_delay, #' infection_feedback_pmf = infection_feedback_pmf, #' params = params #' ), -#' fit_opts = get_mcmc_options( +#' fit_opts = list( #' iter_warmup = 250, #' iter_sampling = 250, -#' n_chains = 2 +#' chains = 2 #' ) #' ) #' } +#' #' @rdname wwinference #' @aliases wwinference_fit wwinference <- function(ww_data, @@ -150,7 +155,7 @@ wwinference <- function(ww_data, calibration_time = 90, forecast_horizon = 28, model_spec = get_model_spec(), - fit_opts = get_mcmc_options(), + fit_opts = list(), generate_initial_values = TRUE, initial_values_seed = NULL, compiled_model = compile_model()) { @@ -160,6 +165,18 @@ wwinference <- function(ww_data, ) } + fit_opts_use <- get_mcmc_options() # get defaults + # this overwrites defaults with all and only the values the user sets in + # `fit_opts` + fit_opts_use[names(fit_opts)] <- fit_opts + + # Check that the fit options passed to wwinference are valid cmdstanr::sample + # arguments + checkmate::assert_names(names(fit_opts), + subset.of = formalArgs(compiled_model$sample) + ) + + # Check that data is compatible with specifications assert_no_dates_after_max(ww_data$date, forecast_date) assert_no_dates_after_max(count_data$date, forecast_date) @@ -204,7 +221,7 @@ wwinference <- function(ww_data, if (generate_initial_values) { withr::with_seed(initial_values_seed, { init_lists <- lapply( - 1:fit_opts$n_chains, + 1:fit_opts_use$chains, \(x) { get_inits_for_one_chain(stan_data_list) } @@ -220,7 +237,7 @@ wwinference <- function(ww_data, fit <- safe_fit_model( compiled_model = compiled_model, stan_data_list = stan_data_list, - fit_opts = fit_opts, + fit_opts = fit_opts_use, init_lists = init_lists ) @@ -329,15 +346,18 @@ fit_model <- function(compiled_model, stan_data_list, fit_opts, init_lists) { - fit <- compiled_model$sample( - data = stan_data_list, - init = init_lists, - seed = fit_opts$seed, - iter_sampling = fit_opts$iter_sampling, - iter_warmup = fit_opts$iter_warmup, - max_treedepth = fit_opts$max_treedepth, - chains = fit_opts$n_chains, - parallel_chains = fit_opts$n_chains + args_for_stan_sampling <- + c( + list( + data = stan_data_list, + init = init_lists + ), + fit_opts + ) + + fit <- do.call( + compiled_model$sample, + args_for_stan_sampling ) return(fit) @@ -348,42 +368,45 @@ fit_model <- function(compiled_model, #' #' @description #' This function returns a list of MCMC settings to pass to the -#' `cmdstanr::sample()` function to fit the model. The default settings are -#' specified for production-level runs, consider adjusting to optimize -#' for speed while iterating. +#' [`$sample()`][cmdstanr::model-method-sample] function to fit the model. +#' The default settings are specified for production-level runs. +#' All input arguments to [`$sample()`][cmdstanr::model-method-sample] +#' are configurable by the user. See +#' [`$sample()`][cmdstanr::model-method-sample] documentation +#' for details of the available arguments. #' #' #' @param iter_warmup integer indicating the number of warm-up iterations, -#' default is `750` +#' default is `750`. #' @param iter_sampling integer indicating the number of sampling iterations, -#' default is `500` -#' @param n_chains integer indicating the number of MCMC chains to run, default -#' is `4` -#' @param seed set of integers indicating the random seed of the stan sampler, -#' default is NULL +#' default is `500`. +#' @param seed integer, A seed for the (P)RNG to pass to CmdStan. In the case +#' of multi-chain sampling the single seed will automatically be augmented by +#' the the run (chain) ID so that each chain uses a different seed. +#' Default is `NULL`. +#' @param chains integer indicating the number of MCMC chains to run, default +#' is `4`. #' @param adapt_delta float between 0 and 1 indicating the average acceptance -#' probability, default is `0.95` +#' probability, default is `0.95`. #' @param max_treedepth integer indicating the maximum tree depth of the -#' sampler, default is 12 +#' sampler, default is 12. #' -#' @return a list of mcmc settings with the values given by the function +#' @return A list of MCMC settings with the values given by the function. #' arguments -#' @export #' -#' @examples -#' mcmc_settings <- get_mcmc_options() +#' @export get_mcmc_options <- function( iter_warmup = 750, iter_sampling = 500, - n_chains = 4, seed = NULL, + chains = 4, adapt_delta = 0.95, max_treedepth = 12) { mcmc_settings <- list( iter_warmup = iter_warmup, iter_sampling = iter_sampling, - n_chains = n_chains, seed = seed, + chains = chains, adapt_delta = adapt_delta, max_treedepth = max_treedepth ) diff --git a/data-raw/test_data.R b/data-raw/test_data.R index fbbeb29a..15abb0e6 100644 --- a/data-raw/test_data.R +++ b/data-raw/test_data.R @@ -46,11 +46,12 @@ model_spec <- wwinference::get_model_spec( params = params ) -mcmc_options <- wwinference::get_mcmc_options( - seed = 55, +mcmc_options <- list( + seed = 5, iter_warmup = 25, iter_sampling = 25, - n_chains = 1 + chains = 1, + show_messages = FALSE ) generate_initial_values <- TRUE @@ -66,7 +67,7 @@ model_test_data <- list( generate_initial_values = generate_initial_values ) -withr::with_seed(5, { +withr::with_seed(55, { fit <- do.call( wwinference::wwinference, model_test_data diff --git a/man/get_mcmc_options.Rd b/man/get_mcmc_options.Rd index 454b2c9a..193bb6f1 100644 --- a/man/get_mcmc_options.Rd +++ b/man/get_mcmc_options.Rd @@ -7,41 +7,43 @@ get_mcmc_options( iter_warmup = 750, iter_sampling = 500, - n_chains = 4, seed = NULL, + chains = 4, adapt_delta = 0.95, max_treedepth = 12 ) } \arguments{ \item{iter_warmup}{integer indicating the number of warm-up iterations, -default is \code{750}} +default is \code{750}.} \item{iter_sampling}{integer indicating the number of sampling iterations, -default is \code{500}} +default is \code{500}.} -\item{n_chains}{integer indicating the number of MCMC chains to run, default -is \code{4}} +\item{seed}{integer, A seed for the (P)RNG to pass to CmdStan. In the case +of multi-chain sampling the single seed will automatically be augmented by +the the run (chain) ID so that each chain uses a different seed. +Default is \code{NULL}.} -\item{seed}{set of integers indicating the random seed of the stan sampler, -default is NULL} +\item{chains}{integer indicating the number of MCMC chains to run, default +is \code{4}.} \item{adapt_delta}{float between 0 and 1 indicating the average acceptance -probability, default is \code{0.95}} +probability, default is \code{0.95}.} \item{max_treedepth}{integer indicating the maximum tree depth of the -sampler, default is 12} +sampler, default is 12.} } \value{ -a list of mcmc settings with the values given by the function +A list of MCMC settings with the values given by the function. arguments } \description{ This function returns a list of MCMC settings to pass to the -\code{cmdstanr::sample()} function to fit the model. The default settings are -specified for production-level runs, consider adjusting to optimize -for speed while iterating. -} -\examples{ -mcmc_settings <- get_mcmc_options() +\code{\link[cmdstanr:model-method-sample]{$sample()}} function to fit the model. +The default settings are specified for production-level runs. +All input arguments to \code{\link[cmdstanr:model-method-sample]{$sample()}} +are configurable by the user. See +\code{\link[cmdstanr:model-method-sample]{$sample()}} documentation +for details of the available arguments. } diff --git a/man/wwinference.Rd b/man/wwinference.Rd index 41306ea9..bfc62d04 100644 --- a/man/wwinference.Rd +++ b/man/wwinference.Rd @@ -15,7 +15,7 @@ wwinference( calibration_time = 90, forecast_horizon = 28, model_spec = get_model_spec(), - fit_opts = get_mcmc_options(), + fit_opts = list(), generate_initial_values = TRUE, initial_values_seed = NULL, compiled_model = compile_model() @@ -50,13 +50,15 @@ forecast date, to produce forecasts for, default is \code{28}} example data provided by the package, but this should be specified by the user based on the date they are producing a forecast} -\item{fit_opts}{The fit options, which in this case default to the -MCMC parameters as defined using \code{get_mcmc_options()}. This includes -the following arguments, which are passed to -\code{\link[cmdstanr:model-method-sample]{$sample()}}: -the number of chains, the number of warmup -and sampling iterations, the maximum tree depth, the average acceptance -probability, and the stan PRNG seed} +\item{fit_opts}{MCMC fitting options, as a list of keys and values. +These are passed as keyword arguments to +\code{\link[cmdstanr:model-method-sample]{compiled_model$sample()}}. +Where no option is specified, \code{\link[=wwinference]{wwinference()}} will fall back first on a +package-specific default value given by \code{\link[=get_mcmc_options]{get_mcmc_options()}}, if one exists. +If no package-specific default exists, \code{\link[=wwinference]{wwinference()}} will fall back on +the default value defined in \code{\link[cmdstanr:model-method-sample]{$sample()}}. +See the documentation for \code{\link[cmdstanr:model-method-sample]{$sample()}} for +details on available options.} \item{generate_initial_values}{Boolean indicating whether or not to specify the initialization of the sampler, default is \code{TRUE}, meaning that @@ -170,24 +172,27 @@ forecast_date <- "2023-11-06" calibration_time <- 90 forecast_horizon <- 28 include_ww <- 1 -ww_fit <- wwinference(input_ww_data, - input_count_data, + +ww_fit <- wwinference( + ww_data = input_ww_data, + count_data = input_count_data, + forecast_date = forecast_date, + calibration_time = calibration_time, + forecast_horizon = forecast_horizon, model_spec = get_model_spec( - forecast_date = forecast_date, - calibration_time = calibration_time, - forecast_horizon = forecast_horizon, generation_interval = generation_interval, - inf_to_count_delay = inf_to_coutn_delay, + inf_to_count_delay = inf_to_count_delay, infection_feedback_pmf = infection_feedback_pmf, params = params ), - fit_opts = get_mcmc_options( + fit_opts = list( iter_warmup = 250, iter_sampling = 250, - n_chains = 2 + chains = 2 ) ) } + } \seealso{ Other diagnostics: diff --git a/tests/testthat/helper.R b/tests/testthat/helper.R index 59e37f77..8264c677 100644 --- a/tests/testthat/helper.R +++ b/tests/testthat/helper.R @@ -131,3 +131,10 @@ diff_ar1_from_z_scores_alt <- function(x0, ar, sd, z, stationary = FALSE) { return(x) } + +silent_wwinference <- function(...) { + utils::capture.output( + fit <- suppressMessages(wwinference(...)) + ) + return(fit) +} diff --git a/tests/testthat/test_ww_model.R b/tests/testthat/test_ww_model.R index 0e9b90b6..697d201e 100644 --- a/tests/testthat/test_ww_model.R +++ b/tests/testthat/test_ww_model.R @@ -2,13 +2,18 @@ test_that("Test the wastewater inference model on simulated data.", { ####### # run model briefly on the simulated data ####### - withr::with_seed(5, { + + # This seed sets the initial values seed. Must be the same as the one used + # in generating the test data. + # model_test_data contains the seed that gets passed to stan + withr::with_seed(55, { fit <- do.call( - wwinference::wwinference, + silent_wwinference, model_test_data ) }) + params <- model_test_data$model_spec$params obs_last_draw <- posterior::subset_draws(fit$fit$result$draws(), draw = 25 diff --git a/tests/testthat/test_wwinference.R b/tests/testthat/test_wwinference.R index 6abf6ab5..aa7f904d 100644 --- a/tests/testthat/test_wwinference.R +++ b/tests/testthat/test_wwinference.R @@ -59,12 +59,9 @@ test_that("wwinference model can compile", { test_that("Function to get mcmc options produces the expected outputs", { mcmc_options <- get_mcmc_options() expected_names <- c( - "iter_warmup", "iter_sampling", - "n_chains", "seed", "adapt_delta", "max_treedepth", - "compute_likelihood" + "iter_warmup", "iter_sampling", "seed", "adapt_delta", "max_treedepth" ) - # Checkmade doesn't work here for a list, says it must be a character vector - expect_true(all(names(mcmc_options) %in% expected_names)) + checkmate::expect_names(names(mcmc_options), must.include = expected_names) }) test_that("Function to get model specs produces expected outputs", { @@ -77,3 +74,16 @@ test_that("Function to get model specs produces expected outputs", { # Checkmade doesn't work here for a list, says it must be a character vector expect_true(all(names(model_spec) %in% expected_names)) }) + +test_that("Passing invalid args to fit_opts throws an error ", { + expect_error( + wwinference( + ww_data = input_ww_data, + count_data = input_count_data, + forecast_date = forecast_date, + model_spec = get_model_spec, + fit_opts = list(not_an_arg = 4) + ), + regexp = c("Names must be a subset of ") + ) +}) diff --git a/vignettes/wwinference.Rmd b/vignettes/wwinference.Rmd index e19c7c3f..3e8bb364 100644 --- a/vignettes/wwinference.Rmd +++ b/vignettes/wwinference.Rmd @@ -17,6 +17,7 @@ vignette: > ```{r setup, echo=FALSE} knitr::opts_chunk$set(dev = "svg") +options(mc.cores = 4) # This tells cmdstan to run the 4 chains in parallel ``` # Quick start @@ -357,7 +358,7 @@ to achieve improved model convergence and/or faster model fitting times. See the We also pass our preprocessed datasets (`ww_data_to_fit` and `hosp_data_preprocessed`), specify our model using `get_model_spec()`, -set the MCMC settings using `get_mcmc_options()`, and pass in our +set the MCMC settings by passing a list of arguments to `fit_opts` that will be passed to the `cmdstanr::sample()` function, and pass in our pre-compiled model(`model`) to `wwinference()` where they are combined and used to fit the model. @@ -374,7 +375,7 @@ ww_fit <- wwinference::wwinference( infection_feedback_pmf = infection_feedback_pmf, params = params ), - fit_opts = get_mcmc_options(seed = 123), + fit_opts = list(seed = 123), compiled_model = model ) ``` @@ -561,7 +562,7 @@ fit_hosp_only <- wwinference::wwinference( include_ww = FALSE, params = params ), - fit_opts = get_mcmc_options(seed = 123), + fit_opts = list(seed = 123), compiled_model = model ) ```