Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 45 additions & 6 deletions R/augment.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@
#' page in the references below). This enables the user to compute performance
#' metrics in the \pkg{yardstick} package.
#'
#' ## Quantile Regression
#'
#' For quantile regression models, a `.pred_quantile` column is added that
#' contains the quantile predictions for each row. This column has a special
#' class `"quantile_pred"` and can be unnested using [tidyr::unnest()]
#'
#' @param new_data A data frame or matrix.
#' @param ... Not currently used.
#' @rdname augment
Expand Down Expand Up @@ -78,14 +84,31 @@
#' augment(cls_xy, cls_tst)
#' augment(cls_xy, cls_tst[, -3])
#'
#' # ------------------------------------------------------------------------------
#'
#' # Quantile regression example
#' qr_form <-
#' linear_reg() |>
#' set_engine("quantreg") |>
#' set_mode("quantile regression", quantile_levels = c(0.25, 0.5, 0.75)) |>
#' fit(mpg ~ ., data = car_trn)
#'
#' augment(qr_form, car_tst)
#' augment(qr_form, car_tst[, -1])
#'
augment.model_fit <- function(x, new_data, eval_time = NULL, ...) {
new_data <- tibble::new_tibble(new_data)
res <-
switch(
x$spec$mode,
"regression" = augment_regression(x, new_data),
"classification" = augment_classification(x, new_data),
"censored regression" = augment_censored(x, new_data, eval_time = eval_time),
"regression" = augment_regression(x, new_data),
"classification" = augment_classification(x, new_data),
"censored regression" = augment_censored(
x,
new_data,
eval_time = eval_time
),
"quantile regression" = augment_quantile_regression(x, new_data),
cli::cli_abort(
c(
"Unknown mode {.val {x$spec$mode}}.",
Expand All @@ -106,7 +129,11 @@ augment_regression <- function(x, new_data) {
ret <- dplyr::mutate(ret, .resid = !!rlang::sym(y_nm) - .pred)
}
}
dplyr::relocate(ret, dplyr::starts_with(".pred"), dplyr::starts_with(".resid"))
dplyr::relocate(
ret,
dplyr::starts_with(".pred"),
dplyr::starts_with(".resid")
)
}

augment_classification <- function(x, new_data) {
Expand All @@ -117,11 +144,15 @@ augment_classification <- function(x, new_data) {
}

if (spec_has_pred_type(x, "class")) {
ret <- dplyr::bind_cols(predict(x, new_data = new_data, type = "class"), ret)
ret <- dplyr::bind_cols(
predict(x, new_data = new_data, type = "class"),
ret
)
}
ret
}


# nocov start
# tested in tidymodels/extratests#
augment_censored <- function(x, new_data, eval_time = NULL) {
Expand All @@ -145,7 +176,8 @@ augment_censored <- function(x, new_data, eval_time = NULL) {
.filter_eval_time(eval_time)
ret <- dplyr::bind_cols(
predict(x, new_data = new_data, type = "survival", eval_time = eval_time),
ret)
ret
)
# Add inverse probability weights when the outcome is present in new_data
y_col <- .find_surv_col(new_data, fail = FALSE)
if (length(y_col) != 0) {
Expand All @@ -155,3 +187,10 @@ augment_censored <- function(x, new_data, eval_time = NULL) {
ret
}
# nocov end

augment_quantile_regression <- function(x, new_data) {
ret <- new_data
check_spec_pred_type(x, "quantile")
ret <- dplyr::bind_cols(predict(x, new_data = new_data), ret)
dplyr::relocate(ret, dplyr::starts_with(".pred"))
}
Loading