Skip to content

glmnet multi_predict(): Check type for all modes #900

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 5 commits into from
Mar 14, 2023

Conversation

hfrick
Copy link
Member

@hfrick hfrick commented Mar 2, 2023

Since multi_predict() calls predict() with type = "raw", the type provided to multi_predict() did not get checked consistently.

This PR improves checks on the type by extending it to all modes, not just "classification". Closes #517

library(parsnip)
data(Chicago, package = "modeldata")

lm_spec <- linear_reg(penalty = 0.1) %>% set_engine("glmnet")
lm_fit <- fit(lm_spec, ridership ~ Clark_Lake + Quincy_Wells, data = Chicago)

multi_predict(lm_fit, Chicago[1:6,], penalty = c(0.05, 0.1), type = "class")
#> Error in `check_pred_type()` at parsnip/R/glmnet-engines.R:176:2:
#> ! For class predictions, the object should be a classification model.

Created on 2023-03-02 with reprex v2.0.2

@hfrick hfrick requested a review from simonpcouch March 6, 2023 17:30
Copy link
Contributor

@simonpcouch simonpcouch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! This is definitely an improvement.

@hfrick hfrick merged commit d8f273c into main Mar 14, 2023
@hfrick hfrick deleted the glmnet-multi_predict-type-check branch March 14, 2023 16:59
@github-actions
Copy link

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Mar 30, 2023
# for free to subscribe to this conversation on GitHub. Already have an account? #.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

multi_predict._elnet() doesn't error on inappropriate type like type = "class"
2 participants