From 1154c6c532b71519622cdae39b413b5c894e19bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= Date: Fri, 25 Oct 2024 12:02:50 -0400 Subject: [PATCH 1/2] fixes for ordinal regression --- NEWS.md | 2 ++ R/fit.R | 1 + R/fit_helpers.R | 5 ++-- R/misc.R | 7 +++-- R/predict_class.R | 6 ++-- man/fit.Rd | 1 + tests/testthat/test-predict_formats.R | 43 +++++++++++++++++++++++++++ 7 files changed, 59 insertions(+), 6 deletions(-) diff --git a/NEWS.md b/NEWS.md index ab97048ca..e35af6bcd 100644 --- a/NEWS.md +++ b/NEWS.md @@ -25,6 +25,8 @@ ## Bug Fixes +* Make sure that parsnip does not convert ordered factor predictions to be unordered. + * Ensure that `knit_engine_docs()` has the required packages installed (#1156). * Fixed bug where some models fit using `fit_xy()` couldn't predict (#1166). diff --git a/R/fit.R b/R/fit.R index 5f7416aa9..8ccd8fa59 100644 --- a/R/fit.R +++ b/R/fit.R @@ -87,6 +87,7 @@ #' \itemize{ #' \item \code{lvl}: If the outcome is a factor, this contains #' the factor levels at the time of model fitting. +#' \item \code{ordered}: If the outcome is a factor, was it an ordered factor? #' \item \code{spec}: The model specification object #' (\code{object} in the call to \code{fit}) #' \item \code{fit}: when the model is executed without error, diff --git a/R/fit_helpers.R b/R/fit_helpers.R index 168ea8e44..94c3c754d 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -40,7 +40,8 @@ form_form <- fit_call <- make_form_call(object, env = env) res <- list( - lvl = y_levels, + lvl = y_levels$lvl, + ordered = y_levels$ordered, spec = object ) @@ -98,7 +99,7 @@ xy_xy <- function(object, fit_call <- make_xy_call(object, target, env, call) - res <- list(lvl = levels(env$y), spec = object) + res <- list(lvl = levels(env$y), ordered = is.ordered(env$y), spec = object) time <- proc.time() res$fit <- eval_mod( diff --git a/R/misc.R b/R/misc.R index 7eb22b4f9..3582d8ca2 100644 --- a/R/misc.R +++ b/R/misc.R @@ -260,9 +260,12 @@ convert_arg <- function(x) { levels_from_formula <- function(f, dat) { if (inherits(dat, "tbl_spark")) { - res <- NULL + res <- list(lvls = NULL, ordered = FALSE) } else { - res <- levels(eval_tidy(rlang::f_lhs(f), dat)) + res <- list() + y_data <- eval_tidy(rlang::f_lhs(f), dat) + res$lvls <- levels(y_data) + res$ordered <- is.ordered(y_data) } res } diff --git a/R/predict_class.R b/R/predict_class.R index d11f79f36..98d1adff0 100644 --- a/R/predict_class.R +++ b/R/predict_class.R @@ -41,14 +41,16 @@ predict_class.model_fit <- function(object, new_data, ...) { # coerce levels to those in `object` if (is.vector(res) || is.factor(res)) { - res <- factor(as.character(res), levels = object$lvl) + res <- factor(as.character(res), levels = object$lvl, ordered = object$ordered) } else { if (!inherits(res, "tbl_spark")) { # Now case where a parsnip model generated `res` if (is.data.frame(res) && ncol(res) == 1 && is.factor(res[[1]])) { res <- res[[1]] } else { - res$values <- factor(as.character(res$values), levels = object$lvl) + res$values <- factor(as.character(res$values), + levels = object$lvl, + ordered = object$ordered) } } } diff --git a/man/fit.Rd b/man/fit.Rd index 332d993bc..f0f9ccbcf 100644 --- a/man/fit.Rd +++ b/man/fit.Rd @@ -52,6 +52,7 @@ A \code{model_fit} object that contains several elements: \itemize{ \item \code{lvl}: If the outcome is a factor, this contains the factor levels at the time of model fitting. +\item \code{ordered}: If the outcome is a factor, was it an ordered factor? \item \code{spec}: The model specification object (\code{object} in the call to \code{fit}) \item \code{fit}: when the model is executed without error, diff --git a/tests/testthat/test-predict_formats.R b/tests/testthat/test-predict_formats.R index 0e7125164..39ed6903a 100644 --- a/tests/testthat/test-predict_formats.R +++ b/tests/testthat/test-predict_formats.R @@ -43,6 +43,49 @@ test_that('classification predictions', { c(".pred_high", ".pred_low")) }) + +test_that('ordinal classification predictions', { + skip_if_not_installed("modeldata") + + set.seed(382) + dat_tr <- + modeldata::sim_multinomial( + 200, + ~ -0.5 + 0.6 * abs(A), + ~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, - 2), + ~ -0.6 * A + 0.50 * B - A * B) %>% + dplyr::mutate(class = as.ordered(class)) + dat_te <- + modeldata::sim_multinomial( + 5, + ~ -0.5 + 0.6 * abs(A), + ~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, - 2), + ~ -0.6 * A + 0.50 * B - A * B) %>% + dplyr::mutate(class = as.ordered(class)) + + ### + + mod_f_fit <- + decision_tree() %>% + set_mode("classification") %>% + fit(class ~ ., data = dat_tr) + expect_true("ordered" %in% names(mod_f_fit)) + mod_f_pred <- predict(mod_f_fit, dat_te) + expect_true(is.ordered(mod_f_pred$.pred_class)) + + ### + + mod_xy_fit <- + decision_tree() %>% + set_mode("classification") %>% + fit_xy(x = dat_tr %>% dplyr::select(-class), dat_tr$class) + + expect_true("ordered" %in% names(mod_xy_fit)) + mod_xy_pred <- predict(mod_xy_fit, dat_te) + expect_true(is.ordered(mod_f_pred$.pred_class)) +}) + + test_that('non-standard levels', { expect_true(is_tibble(predict(lr_fit, new_data = class_dat[1:5,-1]))) expect_true(is.factor(parsnip:::predict_class.model_fit(lr_fit, new_data = class_dat[1:5,-1]))) From 8b8de6ac76b54b0e0ff9480d9a0a349b0603c918 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= Date: Fri, 25 Oct 2024 12:13:18 -0400 Subject: [PATCH 2/2] skip for extra suggested package --- tests/testthat/test-predict_formats.R | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/testthat/test-predict_formats.R b/tests/testthat/test-predict_formats.R index 39ed6903a..aee3369a4 100644 --- a/tests/testthat/test-predict_formats.R +++ b/tests/testthat/test-predict_formats.R @@ -46,6 +46,7 @@ test_that('classification predictions', { test_that('ordinal classification predictions', { skip_if_not_installed("modeldata") + skip_if_not_installed("rpart") set.seed(382) dat_tr <-