Skip to content

predict ordinal factors from ordinal regression models #1217

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 2 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

## Bug Fixes

* Make sure that parsnip does not convert ordered factor predictions to be unordered.

* Ensure that `knit_engine_docs()` has the required packages installed (#1156).

* Fixed bug where some models fit using `fit_xy()` couldn't predict (#1166).
Expand Down
1 change: 1 addition & 0 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
#' \itemize{
#' \item \code{lvl}: If the outcome is a factor, this contains
#' the factor levels at the time of model fitting.
#' \item \code{ordered}: If the outcome is a factor, was it an ordered factor?
#' \item \code{spec}: The model specification object
#' (\code{object} in the call to \code{fit})
#' \item \code{fit}: when the model is executed without error,
Expand Down
5 changes: 3 additions & 2 deletions R/fit_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ form_form <-
fit_call <- make_form_call(object, env = env)

res <- list(
lvl = y_levels,
lvl = y_levels$lvl,
ordered = y_levels$ordered,
spec = object
)

Expand Down Expand Up @@ -98,7 +99,7 @@ xy_xy <- function(object,

fit_call <- make_xy_call(object, target, env, call)

res <- list(lvl = levels(env$y), spec = object)
res <- list(lvl = levels(env$y), ordered = is.ordered(env$y), spec = object)

time <- proc.time()
res$fit <- eval_mod(
Expand Down
7 changes: 5 additions & 2 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,12 @@ convert_arg <- function(x) {

levels_from_formula <- function(f, dat) {
if (inherits(dat, "tbl_spark")) {
res <- NULL
res <- list(lvls = NULL, ordered = FALSE)
} else {
res <- levels(eval_tidy(rlang::f_lhs(f), dat))
res <- list()
y_data <- eval_tidy(rlang::f_lhs(f), dat)
res$lvls <- levels(y_data)
res$ordered <- is.ordered(y_data)
}
res
}
Expand Down
6 changes: 4 additions & 2 deletions R/predict_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@ predict_class.model_fit <- function(object, new_data, ...) {

# coerce levels to those in `object`
if (is.vector(res) || is.factor(res)) {
res <- factor(as.character(res), levels = object$lvl)
res <- factor(as.character(res), levels = object$lvl, ordered = object$ordered)
} else {
if (!inherits(res, "tbl_spark")) {
# Now case where a parsnip model generated `res`
if (is.data.frame(res) && ncol(res) == 1 && is.factor(res[[1]])) {
res <- res[[1]]
} else {
res$values <- factor(as.character(res$values), levels = object$lvl)
res$values <- factor(as.character(res$values),
levels = object$lvl,
ordered = object$ordered)
}
}
}
Expand Down
1 change: 1 addition & 0 deletions man/fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

44 changes: 44 additions & 0 deletions tests/testthat/test-predict_formats.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,50 @@ test_that('classification predictions', {
c(".pred_high", ".pred_low"))
})


test_that('ordinal classification predictions', {
skip_if_not_installed("modeldata")
skip_if_not_installed("rpart")

set.seed(382)
dat_tr <-
modeldata::sim_multinomial(
200,
~ -0.5 + 0.6 * abs(A),
~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, - 2),
~ -0.6 * A + 0.50 * B - A * B) %>%
dplyr::mutate(class = as.ordered(class))
dat_te <-
modeldata::sim_multinomial(
5,
~ -0.5 + 0.6 * abs(A),
~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, - 2),
~ -0.6 * A + 0.50 * B - A * B) %>%
dplyr::mutate(class = as.ordered(class))

###

mod_f_fit <-
decision_tree() %>%
set_mode("classification") %>%
fit(class ~ ., data = dat_tr)
expect_true("ordered" %in% names(mod_f_fit))
mod_f_pred <- predict(mod_f_fit, dat_te)
expect_true(is.ordered(mod_f_pred$.pred_class))

###

mod_xy_fit <-
decision_tree() %>%
set_mode("classification") %>%
fit_xy(x = dat_tr %>% dplyr::select(-class), dat_tr$class)

expect_true("ordered" %in% names(mod_xy_fit))
mod_xy_pred <- predict(mod_xy_fit, dat_te)
expect_true(is.ordered(mod_f_pred$.pred_class))
})


test_that('non-standard levels', {
expect_true(is_tibble(predict(lr_fit, new_data = class_dat[1:5,-1])))
expect_true(is.factor(parsnip:::predict_class.model_fit(lr_fit, new_data = class_dat[1:5,-1])))
Expand Down
Loading