From ec900af0d378d3de124244bbc4ee011a0106748f Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Thu, 2 Mar 2023 16:00:52 +0000 Subject: [PATCH 1/4] check type for all modes --- DESCRIPTION | 2 +- R/glmnet-engines.R | 17 ++++++++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 5be3a50f8..ea2b232e2 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: parsnip Title: A Common API to Modeling and Analysis Functions -Version: 1.0.4.9002 +Version: 1.0.4.9003 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")), person("Davis", "Vaughan", , "davis@posit.co", role = "aut"), diff --git a/R/glmnet-engines.R b/R/glmnet-engines.R index b22dcf802..173996dde 100644 --- a/R/glmnet-engines.R +++ b/R/glmnet-engines.R @@ -173,8 +173,15 @@ multi_predict_glmnet <- function(object, type = NULL, penalty = NULL, ...) { + type <- check_pred_type(object, type) + check_spec_pred_type(object, type) + if (type == "prob") { + check_spec_levels(object) + } + + dots <- list(...) - if (any(names(enquos(...)) == "newdata")) { + if (any(names(dots) == "newdata")) { rlang::abort("Did you mean to use `new_data` instead of `newdata`?") } @@ -184,8 +191,6 @@ multi_predict_glmnet <- function(object, } } - dots <- list(...) - object$spec <- eval_args(object$spec) if (is.null(penalty)) { @@ -200,12 +205,6 @@ multi_predict_glmnet <- function(object, model_type <- class(object$spec)[1] if (object$spec$mode == "classification") { - if (is.null(type)) { - type <- "class" - } - if (!(type %in% c("class", "prob", "link", "raw"))) { - rlang::abort("`type` should be either 'class', 'link', 'raw', or 'prob'.") - } if (type == "prob" | model_type == "logistic_reg") { dots$type <- "response" From 40c24be224774e73cf4e8670bee95eb82bdf79d9 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Mon, 6 Mar 2023 16:04:47 +0000 Subject: [PATCH 2/4] bump version in anticipation of #897 being merged first --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index ea2b232e2..57224a4a5 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: parsnip Title: A Common API to Modeling and Analysis Functions -Version: 1.0.4.9003 +Version: 1.0.4.9004 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")), person("Davis", "Vaughan", , "davis@posit.co", role = "aut"), From dc0935096d7c39159155f84b945d573b1e3257c6 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Mon, 6 Mar 2023 17:00:09 +0000 Subject: [PATCH 3/4] update NEWS --- NEWS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NEWS.md b/NEWS.md index d0b6c0bcb..cae3ba78a 100644 --- a/NEWS.md +++ b/NEWS.md @@ -10,6 +10,8 @@ * Several internal functions (to help work with `Surv` objects) were added as a standalone file that can be used in other packages via `usethis::use_standalone("tidymodels/parsnip")`. +* `multi_predict()` methods for `linear_reg()`, `logistic_reg()`, and `multinomial_reg()` models fitted with the `"glmnet"` engine now check the `type` better and error accordingly (#900). + # parsnip 1.0.4 From f01f55576ba5ca1091d6e9f1ea92f50cbf2cfc91 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Tue, 14 Mar 2023 16:24:46 +0000 Subject: [PATCH 4/4] update to merge PR --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 97acaab61..14121dc61 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: parsnip Title: A Common API to Modeling and Analysis Functions -Version: 1.0.4.9004 +Version: 1.0.4.9003 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")), person("Davis", "Vaughan", , "davis@posit.co", role = "aut"),