Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ export(.dat)
export(.extract_surv_status)
export(.extract_surv_time)
export(.facts)
export(.get_prediction_column_names)
export(.lvls)
export(.model_param_name_key)
export(.obs)
Expand Down
67 changes: 67 additions & 0 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -575,3 +575,70 @@
}
# nocov end

# ------------------------------------------------------------------------------

#' Obtain names of prediction columns for a fitted model or workflow
#'
#' [.get_prediction_column_names()] returns a list that has the names of the
#' columns for the primary prediction types for a model.
#' @param x A fitted model (class `"model_fit"`) or a fitted workflow.
#' @param syms Should the column names be converted to symbols?
#' @return A list with elements `"estimate"` and `"probabilities"`.
#' @examplesIf !parsnip:::is_cran_check()
#' library(dplyr)
#' library(modeldata)
#' data("two_class_dat")
#'
#' levels(two_class_dat$Class)
#' lr_fit <- logistic_reg() %>% fit(Class ~ ., data = two_class_dat)
#'
#' .get_prediction_column_names(lr_fit)
#' .get_prediction_column_names(lr_fit, syms = TRUE)
#' @export
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make this function @keyword internal?

.get_prediction_column_names <- function(x, syms = FALSE) {
if (!inherits(x, c("model_fit", "workflow"))) {
cli::cli_abort("{.arg x} should be an object with class {.cls model_fit} or
{.cls workflow}, not {.obj_type_friendly {x}}.")
}
model_spec <- extract_spec_parsnip(x)
model_engine <- model_spec$engine
model_mode <- model_spec$mode
model_type <- class(model_spec)[1]

# appropriate populate the model db
inst_res <- purrr::map(required_pkgs(x), rlang::check_installed)
predict_types <-
get_from_env(paste0(model_type, "_predict")) %>%
dplyr::filter(engine == model_engine & mode == model_mode) %>%
purrr::pluck("type")

if (length(predict_types) == 0) {
cli::cli_abort("Prediction information could not be found for this
{.fn {model_type}} with engine {.val {model_engine}} and mode
{.val {model_mode}}. Does a parsnip extension package need to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That last question would be very nice in an i bullet rather than the error message itself.

be loaded?")
}

res <- list(estimate = character(0), probabilities = character(0))

if (model_mode == "regression") {
res$estimate <- ".pred"
} else if (model_mode == "classification") {
res$estimate <- ".pred_class"
if (any(predict_types == "prob")) {
res$probabilities <- paste0(".pred_", x$lvl)
}
} else if (model_mode == "censored regression") {
res$estimate <- ".pred_time"
if (any(predict_types %in% c("survival"))) {
res$probabilities <- ".pred"

Check warning on line 634 in R/misc.R

View check run for this annotation

Codecov / codecov/patch

R/misc.R#L631-L634

Added lines #L631 - L634 were not covered by tests
}
} else {
cli::cli_abort("Unsupported model mode {model_mode}.")

Check warning on line 637 in R/misc.R

View check run for this annotation

Codecov / codecov/patch

R/misc.R#L637

Added line #L637 was not covered by tests
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know that this branch of the if/then will ever be encountered, given the error check above. However, we might hit is when we have a mode for quantile regression data so I'd err on the side of leaving it in, untested.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think that is reasonable

}

if (syms) {
res <- purrr::map(res, rlang::syms)
}
res
}
1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,4 @@ reference:
- .extract_surv_status
- .extract_surv_time
- .model_param_name_key
- .get_prediction_column_names
33 changes: 33 additions & 0 deletions man/dot-get_prediction_column_names.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions tests/testthat/_snaps/misc.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,19 @@
Error in `check_outcome()`:
! For a censored regression model, the outcome should be a <Surv> object, not an integer vector.

# obtaining prediction columns

Code
.get_prediction_column_names(1)
Condition
Error in `.get_prediction_column_names()`:
! `x` should be an object with class <model_fit> or <workflow>, not a number.

---

Code
.get_prediction_column_names(unk_fit)
Condition
Error in `.get_prediction_column_names()`:
! Prediction information could not be found for this `linear_reg()` with engine "lm" and mode "Depeche". Does a parsnip extension package need to be loaded?

50 changes: 50 additions & 0 deletions tests/testthat/test-misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,53 @@ test_that('check_outcome works as expected', {
check_outcome(1:2, cens_spec)
)
})

# ------------------------------------------------------------------------------

test_that('obtaining prediction columns', {
skip_if_not_installed("modeldata")
data(two_class_dat, package = "modeldata")

### classification
lr_fit <- logistic_reg() %>% fit(Class ~ ., data = two_class_dat)
expect_equal(
.get_prediction_column_names(lr_fit),
list(estimate = ".pred_class",
probabilities = c(".pred_Class1", ".pred_Class2"))
)
expect_equal(
.get_prediction_column_names(lr_fit, syms = TRUE),
list(estimate = list(quote(.pred_class)),
probabilities = list(quote(.pred_Class1), quote(.pred_Class2)))
)

### regression
ols_fit <- linear_reg() %>% fit(mpg ~ ., data = mtcars)
expect_equal(
.get_prediction_column_names(ols_fit),
list(estimate = ".pred",
probabilities = character(0))
)
expect_equal(
.get_prediction_column_names(ols_fit, syms = TRUE),
list(estimate = list(quote(.pred)),
probabilities = list())
)

### censored regression
# in extratests

### bad input
expect_snapshot(
.get_prediction_column_names(1),
error = TRUE
)

unk_fit <- ols_fit
unk_fit$spec$mode <- "Depeche"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😆

expect_snapshot(
.get_prediction_column_names(unk_fit),
error = TRUE
)

})
Loading