Skip to content

Commit

Permalink
Consistent naming of prior weights for groups
Browse files Browse the repository at this point in the history
  • Loading branch information
bschneidr committed Feb 23, 2025
1 parent cdc69ec commit 469f3eb
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 56 deletions.
4 changes: 2 additions & 2 deletions R/data-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,8 @@ validate_newdata <- function(
mf[[i]] <- as.factor(mf[[i]])
}
}
gr_weights_vars <- ufrom_list(get_re(bterms)$gcall, "weights")
for (v in setdiff(gr_weights_vars, names(newdata))) {
pw_vars <- ufrom_list(get_re(bterms)$gcall, "pw")
for (v in setdiff(pw_vars, names(newdata))) {
newdata[[v]] <- 1
}
# fixes issue #279
Expand Down
57 changes: 47 additions & 10 deletions R/data-predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ data_gr_local <- function(bframe, data) {
# all members get equal weights by default
weights <- matrix(1 / ngs, nrow = nrow(data), ncol = ngs)
}
group_prior_weights <- id_reframe$gcall[[1]]$pw
for (i in seq_along(gs)) {
gdata <- get(gs[i], data)
J <- match(gdata, levels)
Expand All @@ -283,6 +284,42 @@ data_gr_local <- function(bframe, data) {
out[[paste0("J_", idresp, "_", i)]] <- as.array(J)
out[[paste0("W_", idresp, "_", i)]] <- as.array(weights[, i])
}
if (is.formula(group_prior_weights)) {
group_prior_weights <- as.matrix(eval_rhs(group_prior_weights, data))
if (!identical(dim(group_prior_weights), c(nrow(data), ngs))) {
stop2(
"Grouping structure 'mm' expects 'pw' to be ",
"a matrix with as many columns as grouping factors."
)
}
if (!is.numeric(group_prior_weights)) {
stop2("Prior weights supplied to `pw` argument in `mm()` must be numeric.")
}
if (any(group_prior_weights < 0)) {
warning2("Negative weights supplied to `gr()`.")
}
} else {
# all groups get equal prior weights by default
group_prior_weights <- matrix(1, nrow = nrow(data), ncol = ngs)
}
gdata <- do.call(`c`, lapply(seq_along(gs), \(i) get(gs[i], data)))
J <- match(gdata, levels)
group_prior_weights <- as.vector(group_prior_weights)
# check that group-level weights do not vary within a group
group_weights_consistent <- tapply(
X = group_prior_weights, INDEX = J,
FUN = function(x) length(unique(x)) == 1
)
if (!all(group_weights_consistent)) {
stop2("Weights supplied in `gr()` cannot vary within a group.")
}

# deduplicate weights vector (so length matches number of groups)
# and order the weights vector to match groups' assigned indices
distinct_J_indices <- !duplicated(J)
group_prior_weights <- group_prior_weights[distinct_J_indices]
group_prior_weights <- group_prior_weights[order(J[distinct_J_indices])]
out[[paste0("PW_", id)]] <- as.array(group_prior_weights)
} else {
# ordinary grouping term
g <- id_reframe$gcall[[1]]$groups
Expand All @@ -296,19 +333,19 @@ data_gr_local <- function(bframe, data) {
}
out[[paste0("J_", idresp)]] <- as.array(J)

group_model_weights <- id_reframe$gcall[[1]]$weights
if (nzchar(group_model_weights)) {
group_prior_weights <- id_reframe$gcall[[1]]$pw
if (nzchar(group_prior_weights)) {
# extract weights from data as a vector (length equals number of observations)
group_model_weights <- str2formula(id_reframe$gcall[[1]]$weights)
group_model_weights <- as.vector(eval_rhs(group_model_weights, data))
group_prior_weights <- str2formula(id_reframe$gcall[[1]]$weights)
group_prior_weights <- as.vector(eval_rhs(group_prior_weights, data))

if (!is.numeric(group_model_weights)) {
if (!is.numeric(group_prior_weights)) {
stop2("Weights supplied in `gr()` must be numeric.")
}

# check that group-level weights do not vary within a group
group_weights_consistent <- tapply(
X = group_model_weights, INDEX = J,
X = group_prior_weights, INDEX = J,
FUN = function(x) length(unique(x)) == 1
)
if (!all(group_weights_consistent)) {
Expand All @@ -318,14 +355,14 @@ data_gr_local <- function(bframe, data) {
# deduplicate weights vector (so length matches number of groups)
# and order the weights vector to match groups' assigned indices
distinct_J_indices <- !duplicated(J)
group_model_weights <- group_model_weights[distinct_J_indices]
group_model_weights <- group_model_weights[order(J[distinct_J_indices])]
group_prior_weights <- group_prior_weights[distinct_J_indices]
group_prior_weights <- group_prior_weights[order(J[distinct_J_indices])]

if (any(group_model_weights < 0)) {
if (any(group_prior_weights < 0)) {
warning2("Negative weights supplied to `gr()`.")
}

out[[paste0("GMW_", id)]] <- as.array(group_model_weights)
out[[paste0("PW_", id)]] <- as.array(group_prior_weights)
}
}
}
Expand Down
27 changes: 16 additions & 11 deletions R/formula-re.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#' @param id Optional character string. All group-level terms across the model
#' with the same \code{id} will be modeled as correlated (if \code{cor} is
#' \code{TRUE}). See \code{\link{brmsformula}} for more details.
#' @param weights Optional numeric variable. Weights the contribution
#' @param pw Optional numeric variable. "Prior weights": weights the contribution
#' of each group to the log-likelihood for the distribution of the group-level effects.
#' The \code{weights} variable in the data should have one distinct value for
#' each level of the grouping variable.
Expand Down Expand Up @@ -50,13 +50,13 @@
#'
#' # include a group-level weight variable
#' epilepsy[['patient_samp_wgt']] <- c(1, rep(c(0.9, 1.1), each = 29))
#' fit4 <- brm(count ~ Trt + (1|gr(patient, weights = patient_samp_wgt)),
#' fit4 <- brm(count ~ Trt + (1|gr(patient, pw = patient_samp_wgt)),
#' data = epilepsy)
#' summary(fit4)
#' }
#'
#' @export
gr <- function(..., by = NULL, cor = TRUE, id = NA, weights = NULL,
gr <- function(..., by = NULL, cor = TRUE, id = NA, pw = NULL,
cov = NULL, dist = "gaussian") {
label <- deparse0(match.call())
groups <- as.character(as.list(substitute(list(...)))[-1])
Expand All @@ -67,16 +67,16 @@ gr <- function(..., by = NULL, cor = TRUE, id = NA, weights = NULL,
cor <- as_one_logical(cor)
id <- as_one_character(id, allow_na = TRUE)
by <- substitute(by)
weights <- substitute(weights)
pw <- substitute(pw)
if (!is.null(by)) {
by <- deparse0(by)
} else {
by <- ""
}
if (!is.null(weights)) {
weights <- deparse0(weights)
if (!is.null(pw)) {
pw <- deparse0(pw)
} else {
weights <- ""
pw <- ""
}
cov <- substitute(cov)
if (!is.null(cov)) {
Expand All @@ -89,9 +89,9 @@ gr <- function(..., by = NULL, cor = TRUE, id = NA, weights = NULL,
}
dist <- match.arg(dist, c("gaussian", "student"))
byvars <- all_vars(by)
weights_vars <- all_vars(weights)
allvars <- str2formula(c(groups, byvars, weights_vars))
nlist(groups, allvars, label, by, cor, id, weights, cov, dist, type = "")
pw_vars <- all_vars(pw)
allvars <- str2formula(c(groups, byvars, pw_vars))
nlist(groups, allvars, label, by, cor, id, pw, cov, dist, type = "")
}

#' Set up multi-membership grouping terms in \pkg{brms}
Expand All @@ -113,7 +113,12 @@ gr <- function(..., by = NULL, cor = TRUE, id = NA, weights = NULL,
#' weights are standardized in order to sum to one per row.
#' If negative weights are specified, \code{scale} needs
#' to be set to \code{FALSE}.
#'
#' @param pw An optional numeric matrix.
#' It should have as many columns as grouping terms specified in \code{...}.
#' These are "prior weights": they weight
#' the contribution of each group to the log-likelihood
#' for the distribution of the group-level effects.
#' There should be only one distinct value for each group.
#' @seealso \code{\link{brmsformula}}, \code{\link{mmc}}
#'
#' @examples
Expand Down
18 changes: 9 additions & 9 deletions R/stan-predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -515,10 +515,10 @@ stan_re <- function(bframe, prior, normalize, ...) {
stopifnot(is.reframe(r))
has_cov <- nzchar(r$cov[1])
has_by <- nzchar(r$by[[1]])
has_weights <- ifelse(
test = is.null(r$gcall[[1]]$weights[[1]]),
has_pw <- ifelse(
test = is.null(r$gcall[[1]]$pw[[1]]),
yes = FALSE,
no = nzchar(r$gcall[[1]]$weights[[1]])
no = nzchar(r$gcall[[1]]$pw[[1]])
)
Nby <- seq_along(r$bylevels[[1]])
ng <- seq_along(r$gcall[[1]]$groups)
Expand Down Expand Up @@ -565,9 +565,9 @@ stan_re <- function(bframe, prior, normalize, ...) {
" // cholesky factor of known covariance matrix\n"
)
}
if (has_weights) {
if (has_pw) {
str_add(out$data) <- glue(
" vector[N_{id}] GMW_{id};",
" vector[N_{id}] PW_{id};",
" // weights for group contribution to the prior\n"
)
}
Expand Down Expand Up @@ -640,9 +640,9 @@ stan_re <- function(bframe, prior, normalize, ...) {
" matrix[M_{id}, N_{id}] z_{id};",
" // standardized group-level effects\n"
)
if (has_weights) {
if (has_pw) {
str_add(out$model_prior) <- glue(
" target += GMW_{id} * std_normal_{lpdf}(to_vector(z_{id}));\n"
" target += PW_{id} * std_normal_{lpdf}(to_vector(z_{id}));\n"
)
} else {
str_add(out$model_prior) <- glue(
Expand Down Expand Up @@ -738,9 +738,9 @@ stan_re <- function(bframe, prior, normalize, ...) {
" array[M_{id}] vector[N_{id}] z_{id};",
" // standardized group-level effects\n"
)
if (has_weights) {
if (has_pw) {
str_add(out$model_prior) <- cglue(
" target += GMW_{id} * std_normal_{lpdf}(z_{id}[{seq_rows(r)}]);\n"
" target += PW_{id} * std_normal_{lpdf}(z_{id}[{seq_rows(r)}]);\n"
)
} else {
str_add(out$model_prior) <- cglue(
Expand Down
6 changes: 3 additions & 3 deletions man/gr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions man/mm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 11 additions & 11 deletions tests/testthat/tests.stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -2742,24 +2742,24 @@ test_that("Weights from `gr()` incorporated into prior", {
wtd_epilepsy[['patient_samp_wgt']] <- patient_weights[match(epilepsy$patient, levels(epilepsy$patient))]

scode <- stancode(
count ~ Trt + (1 + Trt | gr(patient, weights = patient_samp_wgt)),
count ~ Trt + (1 + Trt | gr(patient, pw = patient_samp_wgt)),
data = wtd_epilepsy, family = gaussian()
)
expect_match2(scode, "vector[N_1] GMW_1; // weights for group contribution to the prior")
expect_match2(scode, "target += GMW_1 * std_normal_lpdf(to_vector(z_1));")
expect_match2(scode, "vector[N_1] PW_1; // weights for group contribution to the prior")
expect_match2(scode, "target += PW_1 * std_normal_lpdf(to_vector(z_1));")

# Check for multiple grouping variables, varying intercept and slope
wtd_epilepsy[['random_group']] <- rep(4:1, times = 59)
wtd_epilepsy[['random_group_wgt']] <- rep(c(0.8, 1.2, 0.7, 1.3), times = 59)

scode <- stancode(
count ~ Trt + (1 + Trt | gr(patient, weights = patient_samp_wgt))
+ (1 | gr(random_group, weights = random_group_wgt)),
count ~ Trt + (1 + Trt | gr(patient, pw = patient_samp_wgt))
+ (1 | gr(random_group, pw = random_group_wgt)),
data = wtd_epilepsy, family = gaussian()
)
expect_match2(scode, "vector[N_2] GMW_2; // weights for group contribution to the prior")
expect_match2(scode, "target += GMW_1 * std_normal_lpdf(to_vector(z_1));")
expect_match2(scode, "target += GMW_2 * std_normal_lpdf(z_2[1]);")
expect_match2(scode, "vector[N_2] PW_2; // weights for group contribution to the prior")
expect_match2(scode, "target += PW_1 * std_normal_lpdf(to_vector(z_1));")
expect_match2(scode, "target += PW_2 * std_normal_lpdf(z_2[1]);")

# Check for multivariate model
dat <- data.frame(
Expand All @@ -2772,10 +2772,10 @@ test_that("Weights from `gr()` incorporated into prior", {
censi = sample(0:1, 10, TRUE)
)
# models with residual correlations
form <- bf(mvbind(y1, y2) ~ x + (1 | gr(g1, weights = g1wgt)) + (1 | gr(g2, weights = g2wgt))) + set_rescor(TRUE)
form <- bf(mvbind(y1, y2) ~ x + (1 | gr(g1, pw = g1wgt)) + (1 | gr(g2, pw = g2wgt))) + set_rescor(TRUE)
prior <- prior(horseshoe(2), resp = "y1") +
prior(horseshoe(2), resp = "y2")
scode <- stancode(form, dat, prior = prior)
expect_match2(scode, "vector[N_4] GMW_4; // weights for group contribution to the prior")
expect_match2(scode, "target += GMW_4 * std_normal_lpdf(z_4[1]);")
expect_match2(scode, "vector[N_4] PW_4; // weights for group contribution to the prior")
expect_match2(scode, "target += PW_4 * std_normal_lpdf(z_4[1]);")
})
20 changes: 10 additions & 10 deletions tests/testthat/tests.standata.R
Original file line number Diff line number Diff line change
Expand Up @@ -1131,23 +1131,23 @@ test_that("Group weights correctly created from `gr()`", {
wtd_epilepsy[['patient_samp_wgt']] <- patient_weights[match(epilepsy$patient, levels(epilepsy$patient))]

sdata <- standata(
count ~ Trt + (1 + Trt | gr(patient, weights = patient_samp_wgt)),
count ~ Trt + (1 + Trt | gr(patient, pw = patient_samp_wgt)),
data = wtd_epilepsy, family = gaussian()
)
expect_equal(object = as.vector(sdata[['GMW_1']]), expected = patient_weights)
expect_equal(object = as.vector(sdata[['PW_1']]), expected = patient_weights)

# Multiple grouping variables
# with one variable whose factor level order differs from order of appearance
wtd_epilepsy[['random_group']] <- rep(c('d', 'b', 'a', 'c'), times = 59)
wtd_epilepsy[['random_group_wgt']] <- rep(c(0.8, 1.2, 0.7, 1.3), times = 59)

sdata <- standata(
count ~ Trt + (1 + Trt | gr(patient, weights = patient_samp_wgt))
+ (1 | gr(random_group, weights = random_group_wgt)),
count ~ Trt + (1 + Trt | gr(patient, pw = patient_samp_wgt))
+ (1 | gr(random_group, pw = random_group_wgt)),
data = wtd_epilepsy, family = gaussian()
)

expect_equal(object = as.vector(sdata[['GMW_2']]),
expect_equal(object = as.vector(sdata[['PW_2']]),
expected = c(0.7, 1.2, 1.3, 0.8))

# Model with multiple outcomes
Expand All @@ -1161,17 +1161,17 @@ test_that("Group weights correctly created from `gr()`", {
censi = sample(0:1, 10, TRUE)
)

form <- bf(mvbind(y1, y2) ~ x + (1 | gr(g1, weights = g1wgt)) + (1 | gr(g2, weights = g2wgt))) + set_rescor(TRUE)
form <- bf(mvbind(y1, y2) ~ x + (1 | gr(g1, pw = g1wgt)) + (1 | gr(g2, pw = g2wgt))) + set_rescor(TRUE)
prior <- prior(horseshoe(2), resp = "y1") +
prior(horseshoe(2), resp = "y2")
sdata <- standata(form, dat, prior = prior)
expect_in(c("GMW_1", "GMW_4"), names(sdata))
expect_in(c("PW_1", "PW_4"), names(sdata))

# Informative error message if group weight variable varies among observations in a group
expect_error(
object = {
sdata <- standata(
count ~ Trt + (1 | gr(patient, weights = bad_group_wgt)),
count ~ Trt + (1 | gr(patient, pw = bad_group_wgt)),
data = wtd_epilepsy |> transform(bad_group_wgt = runif(n = nrow(wtd_epilepsy))),
family = gaussian()
)
Expand All @@ -1196,7 +1196,7 @@ test_that("Group weights correctly created from `gr()`", {
expect_warning(
object = {
sdata <- standata(
count ~ Trt + (1 | gr(random_group, weights = bad_random_group_wgt)),
count ~ Trt + (1 | gr(random_group, pw = bad_random_group_wgt)),
data = wtd_epilepsy,
family = gaussian()
)
Expand All @@ -1208,7 +1208,7 @@ test_that("Group weights correctly created from `gr()`", {
expect_warning(
object = {
sdata <- standata(
count ~ Trt + (1 | gr(random_group, weights = bad_random_group_wgt)),
count ~ Trt + (1 | gr(random_group, pw = bad_random_group_wgt)),
data = wtd_epilepsy,
family = gaussian()
)
Expand Down

0 comments on commit 469f3eb

Please # to comment.