Skip to content

Survival helpers #893

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 10 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: parsnip
Title: A Common API to Modeling and Analysis Functions
Version: 1.0.4.9000
Version: 1.0.4.9001
Authors@R: c(
person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")),
person("Davis", "Vaughan", , "davis@posit.co", role = "aut"),
Expand Down Expand Up @@ -76,4 +76,4 @@ Config/testthat/edition: 3
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3.9000
RoxygenNote: 7.2.3
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

* Fixed bug with prediction from a boosted tree model fitted with `"xgboost"` using a custom objective function (#875).

* Several internal functions (to help work with `Surv` objects) were added as a standalone file that can be used in other packages via `usethis::use_standalone("tidymodels/parsnip")`.


# parsnip 1.0.4

Expand Down
5 changes: 0 additions & 5 deletions R/fit.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
# General TODOs
# Q: think about case weights in each instance below

# TODO write a better deparser for calls to avoid off-screen text and tabs

#' Fit a Model Specification to a Dataset
#'
#' `fit()` and `fit_xy()` take a model specification, translate the required
Expand Down
34 changes: 34 additions & 0 deletions R/ipcw.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# ------------------------------------------------------------------------------
# Functions for using inverse probability of censoring weights (IPCW) in
# censored regression models

# ------------------------------------------------------------------------------
# Simple helpers for computing the probability of censoring

# For avoiding extremely large, outlier weights
trunc_probs <- function(probs, trunc = 0.01) {
is_complt_prob <- !is.na(probs)
complt_prob <- probs[is_complt_prob]
non_zero_min <- min(complt_prob[complt_prob > 0])
if (non_zero_min < trunc) {
trunc <- non_zero_min / 2
}
probs[is_complt_prob] <-
ifelse(probs[is_complt_prob] <= trunc, trunc, probs[is_complt_prob])
probs
}

.filter_eval_time <- function(eval_time, fail = TRUE) {
# will still propagate nulls:
eval_time <- eval_time[!is.na(eval_time)]
eval_time <- unique(eval_time)
eval_time <- sort(eval_time)
eval_time <- eval_time[eval_time >= 0 & is.finite(eval_time)]
if (fail && identical(eval_time, numeric(0))) {
rlang::abort(
"There were no usable evaluation times (finite, non-missing, and >= 0).",
call = NULL
)
}
eval_time
}
88 changes: 88 additions & 0 deletions R/standalone-survival.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# ---
# repo: tidymodels/parsnip
# file: standalone-survival.R
# last-updated: 2023-02-28
# license: https://unlicense.org
# ---

# This file provides a portable set of helper functions for Surv objects

# ## Changelog

# 2023-02-28:
# * Initial version


# @param surv A [survival::Surv()] object
# @details
# `.is_censored_right()` always returns a logical while
# `.check_censored_right()` will fail if `FALSE`.
#
# `.extract_status()` will return the data as 0/1 even if the original object
# used the legacy encoding of 1/2. See [survival::Surv()].
# @return
# - `.extract_surv_status()` returns a vector.
# - `.extract_surv_time()` returns a vector when the type is `"right"` or `"left"`
# and a tibble otherwise.
# - Functions starting with `.is_` or `.check_` return logicals although the
# latter will fail when `FALSE`.

# nocov start
# These are tested in the extratests repo since it would require a dependency
# on the survival package. https://github.com/tidymodels/extratests/pull/78
.is_censored_right <- function(surv) {
.check_cens_type(surv, fail = FALSE)
}

.check_censored_right <- function(surv) {
.check_cens_type(surv, fail = TRUE)
} # will add more as we need them

.extract_surv_time <- function(surv) {
.is_surv(surv)
keepers <- c("time", "start", "stop", "time1", "time2")
res <- surv[, colnames(surv) %in% keepers]
if (NCOL(res) > 1) {
res <- tibble::tibble(as.data.frame(res))
}
res
}

.extract_surv_status <- function(surv) {
.is_surv(surv)
res <- surv[, "status"]
un_vals <- sort(unique(res))
event_type_to_01 <- !(.extract_surv_type(surv) %in% c("interval", "interval2", "mstate"))
if (
event_type_to_01 &&
(identical(un_vals, 1:2) | identical(un_vals, c(1.0, 2.0))) ) {
res <- res - 1
}
res
}

.is_surv <- function(surv, fail = TRUE) {
is_surv <- inherits(surv, "Surv")
if (!is_surv && fail) {
rlang::abort("The object does not have class `Surv`.", call = NULL)
}
is_surv
}

.extract_surv_type <- function(surv) {
attr(surv, "type")
}

.check_cens_type <- function(surv, type = "right", fail = TRUE) {
.is_surv(surv)
obj_type <- .extract_surv_type(surv)
good_type <- all(obj_type %in% type)
if (!good_type && fail) {
c_list <- paste0("'", type, "'")
msg <- cli::format_inline("For this usage, the allowed censoring type{?s} {?is/are}: {c_list}")
rlang::abort(msg, call = NULL)
}
good_type
}

# nocov end
8 changes: 8 additions & 0 deletions tests/testthat/_snaps/ipcw.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# time filtering

Code
parsnip:::.filter_eval_time(-1)
Condition
Error:
! There were no usable evaluation times (finite, non-missing, and >= 0).

42 changes: 42 additions & 0 deletions tests/testthat/test-ipcw.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
test_that('probability truncation', {
probs <- seq(0, 1, length.out = 5)

expect_equal(
min(parsnip:::trunc_probs(probs, .4)),
min(probs[probs > 0]) / 2
)
expect_equal(
min(parsnip:::trunc_probs(c(NA, probs), .4), na.rm = TRUE),
min(probs[probs > 0]) / 2
)
expect_equal(
min(parsnip:::trunc_probs(probs)),
0.01
)
expect_equal(
min(parsnip:::trunc_probs((1:200)/200)),
1 / 200
)
})


test_that('time filtering', {
times_1 <- 0:10
times_2 <- c(Inf, NA, -3, times_1, times_1)

expect_equal(
parsnip:::.filter_eval_time(times_1),
times_1
)
expect_equal(
parsnip:::.filter_eval_time(times_1),
times_1
)
expect_snapshot(error = TRUE, parsnip:::.filter_eval_time(-1))
expect_null(parsnip:::.filter_eval_time(NULL))
})