Skip to content

Add outcome_names() method for workflows using add_variables() #994

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: tune
Title: Tidy Tuning Tools
Version: 1.3.0.9000
Version: 1.3.0.9001
Authors@R: c(
person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre"),
comment = c(ORCID = "0000-0003-2402-136X")),
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ S3method(outcome_names,recipe)
S3method(outcome_names,terms)
S3method(outcome_names,tune_results)
S3method(outcome_names,workflow)
S3method(outcome_names,workflow_variables)
S3method(parameters,model_spec)
S3method(parameters,recipe)
S3method(parameters,workflow)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

* Post-processing: new `schedule_grid()` for scheduling a grid including post-processing (#988).

* Added a new method for `outcome_names()` for workflows that use `add_variables()` (#993).

# tune 1.3.0

* The package will now warn when parallel processing has been enabled with foreach but not with future. See [`?parallelism`](https://tune.tidymodels.org/dev/reference/parallelism.html) to learn more about transitioning your code to future (#878, #866). The next version of tune will move to a pure future implementation.
Expand Down
31 changes: 26 additions & 5 deletions R/outcome-names.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#' Determine names of the outcome data in a workflow
#'
#' @param x An object.
#' @param ... Not used.
#' @param ... Further arguments passed to or from other methods (such as `data`).
#' @param data The training set data (if needed).
#' @param call The call to be displayed in warnings or errors.
#' @return A character string of variable names
#' @keywords internal
#' @examples
Expand Down Expand Up @@ -39,20 +41,39 @@ outcome_names.recipe <- function(x, ...) {

#' @export
#' @rdname outcome_names
outcome_names.workflow <- function(x, ...) {
if (!is.null(x$fit$fit)) {
outcome_names.workflow <- function(x, ..., call = caller_env()) {
if (!is.null(x$pre$mold)) {
y_vals <- extract_mold(x)$outcomes
res <- colnames(y_vals)
} else {
preprocessor <- extract_preprocessor(x)
res <- outcome_names(preprocessor)
res <- outcome_names(preprocessor, ..., call = call)
}
res
}

#' @export
#' @rdname outcome_names
outcome_names.tune_results <- function(x, ...) {
outcome_names.workflow_variables <- function(
x,
data = NULL,
...,
call = caller_env()
) {
if (is.null(data)) {
cli::cli_abort(
"To determine the outcome names when {.fn add_variables} is used, please
pass the training set data to the {.arg data} argument.",
call = call
)
}
Comment on lines +64 to +69
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this an error message for the user or for us?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The user

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, digging around a bit more here and I think there is a little messaging mismatch to resolve in that error message:

Do you mean that the user of outcome_names() should use the data argument of outcome_names()? The user of outcome_names() is presumably just us since outcome_names() is labeled with @keywords internal. In that case, I wouldn't pass in the call to cli_abort() and possibly remove the reference to add_variables(). The context of how we got into this probably lives better elsewhere, likely wherever we call it from and need to add data, so in tune_grid() or int_pctl().

If you mean the general tidymodels user the current error message is probably confusing: they only call tune_grid() or int_pctl() and couldn't use that data argument to outcome_names(). tune_grid() does not have a data argument and int_pctl() has .data but that is not the same as the data for outcome_name(). (And imagining a very desperate user, looking into the mentioned add_variables() also does not reveal a data argument.)

I'm guessing path 1 is the right one here?

res <- rlang::eval_tidy(x$outcomes, data, env = call)
res
}

#' @export
#' @rdname outcome_names
outcome_names.tune_results <- function(x, ..., call = caller_env()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this change in function signature for tune_results objects? 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're welcome to suggest a change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
outcome_names.tune_results <- function(x, ..., call = caller_env()) {
outcome_names.tune_results <- function(x, ...) {

I don't think the addition is necessary (and tests pass without it) so I'd just leave things as is.

att <- attributes(x)
if (any(names(att) == "outcomes")) {
res <- att$outcomes
Expand Down
11 changes: 8 additions & 3 deletions man/outcome_names.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions tests/testthat/_snaps/outcome-names.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# workflows + variables

Code
outcome_names(wflow_1)
Condition
Error:
! To determine the outcome names when `add_variables()` is used, please pass the training set data to the `data` argument.

13 changes: 13 additions & 0 deletions tests/testthat/test-outcome-names.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,19 @@ test_that("workflows + formulas", {
expect_equal(outcome_names(parsnip::fit(wflow_2, mtcars)), c("mpg", "wt"))
})

## -----------------------------------------------------------------------------

test_that("workflows + variables", {
lm_mod <- parsnip::linear_reg() %>% parsnip::set_engine("lm")
wflow <- workflow() %>% add_model(lm_mod)

wflow_1 <- wflow %>% add_variables(outcomes = "mpg", predictors = c(wt))
fit_1 <- fit(wflow_1, mtcars)

expect_snapshot(outcome_names(wflow_1), error = TRUE)
expect_equal(outcome_names(wflow_1, mtcars), "mpg")
expect_equal(outcome_names(fit_1, mtcars), "mpg")
})

## -----------------------------------------------------------------------------

Expand Down
Loading