Skip to content

predict() on a mlp with nnet double names the output with .pred_ #174

Closed
@mouli3c3

Description

@mouli3c3

This problem is similar to an already closed issue(#107) but with mlp using nnet.


library(tidymodels)
#> -- Attaching packages ------------------------------------------------- tidymodels 0.0.2 --
#> v broom     0.5.1       v purrr     0.3.2  
#> v dials     0.0.2       v recipes   0.1.5  
#> v dplyr     0.8.0.1     v rsample   0.0.4  
#> v ggplot2   3.1.0       v tibble    2.1.1  
#> v infer     0.4.0       v yardstick 0.0.3  
#> v parsnip   0.0.2
#> -- Conflicts ---------------------------------------------------- tidymodels_conflicts() --
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter()  masks stats::filter()
#> x dplyr::lag()     masks stats::lag()
#> x recipes::step()  masks stats::step()
data(credit_data)

set.seed(7075)
data_split <- initial_split(credit_data, strata = "Status", p = 0.75)

credit_train <- training(data_split)
credit_test  <- testing(data_split)
credit_rec <- 
  recipe(Status ~ ., data = credit_train) %>%
  step_knnimpute(Home, Job, Marital, Income, Assets, Debt) %>%
  step_dummy(all_nominal(), -Status) %>%
  step_center(all_predictors()) %>%
  step_scale(all_predictors()) %>%
  prep(training = credit_train, retain = TRUE)

test_normalized <- bake(credit_rec, new_data = credit_test, all_predictors())

set.seed(57974)
nnet_fit <-set_engine(mlp("classification",hidden_units =10),"nnet") %>%
  fit(Status ~ ., data = juice(credit_rec))

glm_fit <- set_engine(logistic_reg(),"glm") %>% 
  fit(Status ~ ., data = juice(credit_rec))

#Issue with predict on nnet
glimpse(predict(nnet_fit, new_data = test_normalized, type = "prob"))
#> Observations: 1,113
#> Variables: 2
#> $ .pred_.pred_bad  <dbl> 0.5608545, 0.7023505, 0.3303682, 0.4221877, 0...
#> $ .pred_.pred_good <dbl> 0.4391455, 0.2976495, 0.6696318, 0.5778123, 0...

#Normal with predict on glm (No issue)
glimpse(predict(glm_fit, new_data = test_normalized, type = "prob"))
#> Observations: 1,113
#> Variables: 2
#> $ .pred_bad  <dbl> 0.04675355, 0.94317298, 0.24316454, 0.06970005, 0.0...
#> $ .pred_good <dbl> 0.95324645, 0.05682702, 0.75683546, 0.93029995, 0.9...

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugan unexpected problem or unintended behavior

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions