Skip to content

helper for bridging causal fits #652

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
simonpcouch opened this issue Mar 24, 2023 · 0 comments · May be fixed by #679
Open

helper for bridging causal fits #652

simonpcouch opened this issue Mar 24, 2023 · 0 comments · May be fixed by #679
Labels
feature a feature request or enhancement

Comments

@simonpcouch
Copy link
Contributor

simonpcouch commented Mar 24, 2023

fit()ting a single model vs resampling a model fit via fit_resamples() has a nice parallelism to it with the usual single-model approach:

# single model:
          fit(workflow, data)

# multiple models:
fit_resamples(workflow, resamples(data))

Ideally, the two-stage approach in causal inference could have the same ring to it.

library(tidymodels)
library(causalworkshop)
library(propensity)

net_data <- net_data %>% mutate(net = factor(net, levels = c("TRUE", "FALSE")))

Defining workflows:

outcome_wf <- 
  workflow(
    malaria_risk ~ net,
    linear_reg()
  ) %>%
  add_case_weights(wts)

propensity_wf <- 
  workflow(
    net ~ income + health + temperature,
    logistic_reg()
  )

Note that there’s an add_case_weights() step for outcome_wf() that can’t happen until propensity_wf() can generate predictions.

With no changes made to tidymodels, the code for a fit to one dataset looks like something like:

net_data_wts <-
  fit(propensity_wf, net_data) %>%
  augment(net_data) %>%
  mutate(
    wts = wt_ate(.pred_TRUE, net, .treated = "TRUE")
  )

results <- 
  outcome_wf %>%
  fit(
    data = net_data_wts %>% mutate(wts = importance_weights(wts))
  ) %>%
  tidy()

For the resampled fit, we definitely need a helper that bridges the two calls to fit_resamples() by mutating propensity weights onto the assessment set underlying the rsplit. In this example, I call it weight_propensity():

results <-
  fit_resamples(
    propensity_wf,
    resamples = bootstraps(net_data),
    control = control_resamples(extract = identity)
  ) %>%
  weight_propensity(wt_ate) %>%
  fit_resamples(
    outcome_wf,
    resamples = .,
    control = control_resamples(extract = tidy)
  )

The helper (sans error checking) could look something like:

EDIT: Out of date, see linked PR

# a function that takes in a resample fit object and outputs a modified
# version of that object where the training data underlying each rsplit
# is augmented with propensity weights for each element of the analysis set. 
# this serves as the "bridge" between two calls to
# `fit_resamples()` (or `tune_*()`) in a causal workflow, the
# first being for the propensity model and the second for the outcome model.
# `tune_results` must have been executed with option `extract = identity`.
weight_propensity <- function(tune_results, wt_fn) {
  for (resample in seq_along(tune_results$splits)) {
    tune_results$splits[[resample]] <- 
      augment_split(
        tune_results$splits[[resample]],
        tune_results$.extracts[[resample]]$.extracts[[1]], 
        wt_fn = wt_fn,
        outcome_name = outcome_names(tune_results)
      )
  }
  
  tibble::new_tibble(
    tune_results[, c("splits", "id")], 
    !!!attr(tune_results, "rset_info")$att,
    class = c(attr(tune_results, "rset_info")$att$class, "rset")
  )
}

augment_split <- function(split, workflow, wt_fn, outcome_name) {
  d <- analysis(split)
  d <- vctrs::vec_cbind(d, predict(workflow, d, type = "prob"))
  d <- vctrs::vec_slice(d, !duplicated(d$id))

  model_fit <- extract_fit_parsnip(workflow)
  lvls <- model_fit$lvl
  event_lvl <- lvls[1]
  preds <- d[[paste0(".pred_", event_lvl)]]

  split[["data"]][d$id, "wts"] <- 
    importance_weights(wt_fn(preds, d[[outcome_name]], .treated = event_lvl))
  
  split
}

Questions:

  • What’s a good name for weight_propensity()? Is there something more compatible with the analogous(?) procedure in survival analysis?

  • Do we want that function to be able to be used in both the single-fit and resampled-fit setting to aid with that parallelism? We could make a method that takes in data, a weighting function, and a model fit to make the single-fit setting feel more like the resampled-fit setting, a la:

result <- 
  fit(
    propensity_wf, 
    net_data
  ) %>%
  weight_propensity(net_data, wt_ate, .) %>%
  fit(
    outcome_wf %>% add_case_weights(wts),
    data = .
  ) %>%
  tidy()

I'd propose we do include that parsnip/workflows counterpart. That helper (and probably the generic?) ought to live in parsnip/workflows, if so.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
feature a feature request or enhancement
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant