Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

## Bug Fixes

* Make sure that parsnip does not convert ordered factor predictions to be unordered.

* Ensure that `knit_engine_docs()` has the required packages installed (#1156).

* Fixed bug where some models fit using `fit_xy()` couldn't predict (#1166).
Expand Down
1 change: 1 addition & 0 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
#' \itemize{
#' \item \code{lvl}: If the outcome is a factor, this contains
#' the factor levels at the time of model fitting.
#' \item \code{ordered}: If the outcome is a factor, was it an ordered factor?
#' \item \code{spec}: The model specification object
#' (\code{object} in the call to \code{fit})
#' \item \code{fit}: when the model is executed without error,
Expand Down
5 changes: 3 additions & 2 deletions R/fit_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ form_form <-
fit_call <- make_form_call(object, env = env)

res <- list(
lvl = y_levels,
lvl = y_levels$lvl,
ordered = y_levels$ordered,
spec = object
)

Expand Down Expand Up @@ -98,7 +99,7 @@ xy_xy <- function(object,

fit_call <- make_xy_call(object, target, env, call)

res <- list(lvl = levels(env$y), spec = object)
res <- list(lvl = levels(env$y), ordered = is.ordered(env$y), spec = object)

time <- proc.time()
res$fit <- eval_mod(
Expand Down
7 changes: 5 additions & 2 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,12 @@ convert_arg <- function(x) {

levels_from_formula <- function(f, dat) {
if (inherits(dat, "tbl_spark")) {
res <- NULL
res <- list(lvls = NULL, ordered = FALSE)
} else {
res <- levels(eval_tidy(rlang::f_lhs(f), dat))
res <- list()
y_data <- eval_tidy(rlang::f_lhs(f), dat)
res$lvls <- levels(y_data)
res$ordered <- is.ordered(y_data)
}
res
}
Expand Down
6 changes: 4 additions & 2 deletions R/predict_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@ predict_class.model_fit <- function(object, new_data, ...) {

# coerce levels to those in `object`
if (is.vector(res) || is.factor(res)) {
res <- factor(as.character(res), levels = object$lvl)
res <- factor(as.character(res), levels = object$lvl, ordered = object$ordered)
} else {
if (!inherits(res, "tbl_spark")) {
# Now case where a parsnip model generated `res`
if (is.data.frame(res) && ncol(res) == 1 && is.factor(res[[1]])) {
res <- res[[1]]
} else {
res$values <- factor(as.character(res$values), levels = object$lvl)
res$values <- factor(as.character(res$values),
levels = object$lvl,
ordered = object$ordered)
}
}
}
Expand Down
1 change: 1 addition & 0 deletions man/fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

44 changes: 44 additions & 0 deletions tests/testthat/test-predict_formats.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,50 @@ test_that('classification predictions', {
c(".pred_high", ".pred_low"))
})


test_that('ordinal classification predictions', {
skip_if_not_installed("modeldata")
skip_if_not_installed("rpart")

set.seed(382)
dat_tr <-
modeldata::sim_multinomial(
200,
~ -0.5 + 0.6 * abs(A),
~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, - 2),
~ -0.6 * A + 0.50 * B - A * B) %>%
dplyr::mutate(class = as.ordered(class))
dat_te <-
modeldata::sim_multinomial(
5,
~ -0.5 + 0.6 * abs(A),
~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, - 2),
~ -0.6 * A + 0.50 * B - A * B) %>%
dplyr::mutate(class = as.ordered(class))

###

mod_f_fit <-
decision_tree() %>%
set_mode("classification") %>%
fit(class ~ ., data = dat_tr)
expect_true("ordered" %in% names(mod_f_fit))
mod_f_pred <- predict(mod_f_fit, dat_te)
expect_true(is.ordered(mod_f_pred$.pred_class))

###

mod_xy_fit <-
decision_tree() %>%
set_mode("classification") %>%
fit_xy(x = dat_tr %>% dplyr::select(-class), dat_tr$class)

expect_true("ordered" %in% names(mod_xy_fit))
mod_xy_pred <- predict(mod_xy_fit, dat_te)
expect_true(is.ordered(mod_f_pred$.pred_class))
})


test_that('non-standard levels', {
expect_true(is_tibble(predict(lr_fit, new_data = class_dat[1:5,-1])))
expect_true(is.factor(parsnip:::predict_class.model_fit(lr_fit, new_data = class_dat[1:5,-1])))
Expand Down
Loading