From 932f7fb040cc6fc567d5d78a002c9a547a9b3d4e Mon Sep 17 00:00:00 2001 From: Catalina Canizares Date: Fri, 19 Sep 2025 14:10:17 -0400 Subject: [PATCH] Fixed # 1212 enabled augment() for quartile regression --- R/augment.R | 51 +++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/R/augment.R b/R/augment.R index bb66f2e4d..e87578266 100644 --- a/R/augment.R +++ b/R/augment.R @@ -31,6 +31,12 @@ #' page in the references below). This enables the user to compute performance #' metrics in the \pkg{yardstick} package. #' +#' ## Quantile Regression +#' +#' For quantile regression models, a `.pred_quantile` column is added that +#' contains the quantile predictions for each row. This column has a special +#' class `"quantile_pred"` and can be unnested using [tidyr::unnest()] +#' #' @param new_data A data frame or matrix. #' @param ... Not currently used. #' @rdname augment @@ -78,14 +84,31 @@ #' augment(cls_xy, cls_tst) #' augment(cls_xy, cls_tst[, -3]) #' +#' # ------------------------------------------------------------------------------ +#' +#' # Quantile regression example +#' qr_form <- +#' linear_reg() |> +#' set_engine("quantreg") |> +#' set_mode("quantile regression", quantile_levels = c(0.25, 0.5, 0.75)) |> +#' fit(mpg ~ ., data = car_trn) +#' +#' augment(qr_form, car_tst) +#' augment(qr_form, car_tst[, -1]) +#' augment.model_fit <- function(x, new_data, eval_time = NULL, ...) { new_data <- tibble::new_tibble(new_data) res <- switch( x$spec$mode, - "regression" = augment_regression(x, new_data), - "classification" = augment_classification(x, new_data), - "censored regression" = augment_censored(x, new_data, eval_time = eval_time), + "regression" = augment_regression(x, new_data), + "classification" = augment_classification(x, new_data), + "censored regression" = augment_censored( + x, + new_data, + eval_time = eval_time + ), + "quantile regression" = augment_quantile_regression(x, new_data), cli::cli_abort( c( "Unknown mode {.val {x$spec$mode}}.", @@ -106,7 +129,11 @@ augment_regression <- function(x, new_data) { ret <- dplyr::mutate(ret, .resid = !!rlang::sym(y_nm) - .pred) } } - dplyr::relocate(ret, dplyr::starts_with(".pred"), dplyr::starts_with(".resid")) + dplyr::relocate( + ret, + dplyr::starts_with(".pred"), + dplyr::starts_with(".resid") + ) } augment_classification <- function(x, new_data) { @@ -117,11 +144,15 @@ augment_classification <- function(x, new_data) { } if (spec_has_pred_type(x, "class")) { - ret <- dplyr::bind_cols(predict(x, new_data = new_data, type = "class"), ret) + ret <- dplyr::bind_cols( + predict(x, new_data = new_data, type = "class"), + ret + ) } ret } + # nocov start # tested in tidymodels/extratests# augment_censored <- function(x, new_data, eval_time = NULL) { @@ -145,7 +176,8 @@ augment_censored <- function(x, new_data, eval_time = NULL) { .filter_eval_time(eval_time) ret <- dplyr::bind_cols( predict(x, new_data = new_data, type = "survival", eval_time = eval_time), - ret) + ret + ) # Add inverse probability weights when the outcome is present in new_data y_col <- .find_surv_col(new_data, fail = FALSE) if (length(y_col) != 0) { @@ -155,3 +187,10 @@ augment_censored <- function(x, new_data, eval_time = NULL) { ret } # nocov end + +augment_quantile_regression <- function(x, new_data) { + ret <- new_data + check_spec_pred_type(x, "quantile") + ret <- dplyr::bind_cols(predict(x, new_data = new_data), ret) + dplyr::relocate(ret, dplyr::starts_with(".pred")) +}