diff --git a/DESCRIPTION b/DESCRIPTION index 5c2223ba..b00fa94f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: hardhat Title: Construct Modeling Packages -Version: 1.4.0.9001 +Version: 1.4.0.9002 Authors@R: c( person("Hannah", "Frick", , "hannah@posit.co", role = c("aut", "cre"), comment = c(ORCID = "0000-0002-6049-5258")), diff --git a/NAMESPACE b/NAMESPACE index d9cf172b..e3d2995a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,16 +1,21 @@ # Generated by roxygen2: do not edit by hand +S3method(as.matrix,quantile_pred) +S3method(as_tibble,quantile_pred) S3method(forge,data.frame) S3method(forge,default) S3method(forge,matrix) S3method(format,formula_blueprint) +S3method(format,quantile_pred) S3method(format,recipe_blueprint) S3method(format,xy_blueprint) +S3method(median,quantile_pred) S3method(mold,data.frame) S3method(mold,default) S3method(mold,formula) S3method(mold,matrix) S3method(mold,recipe) +S3method(obj_print_footer,quantile_pred) S3method(print,formula_blueprint) S3method(print,hardhat_blueprint) S3method(print,hardhat_model) @@ -45,8 +50,10 @@ S3method(vec_ptype2,hardhat_frequency_weights.hardhat_frequency_weights) S3method(vec_ptype2,hardhat_importance_weights.hardhat_importance_weights) S3method(vec_ptype_abbr,hardhat_frequency_weights) S3method(vec_ptype_abbr,hardhat_importance_weights) +S3method(vec_ptype_abbr,quantile_pred) S3method(vec_ptype_full,hardhat_frequency_weights) S3method(vec_ptype_full,hardhat_importance_weights) +S3method(vec_ptype_full,quantile_pred) export(add_intercept_column) export(check_column_names) export(check_no_formula_duplication) @@ -56,6 +63,7 @@ export(check_outcomes_are_numeric) export(check_outcomes_are_univariate) export(check_prediction_size) export(check_predictors_are_numeric) +export(check_quantile_levels) export(create_modeling_package) export(default_formula_blueprint) export(default_recipe_blueprint) @@ -69,6 +77,7 @@ export(extract_parameter_dials) export(extract_parameter_set_dials) export(extract_postprocessor) export(extract_preprocessor) +export(extract_quantile_levels) export(extract_recipe) export(extract_spec_parsnip) export(extract_workflow) @@ -98,6 +107,7 @@ export(new_importance_weights) export(new_model) export(new_recipe_blueprint) export(new_xy_blueprint) +export(quantile_pred) export(recompose) export(refresh_blueprint) export(run_forge) @@ -129,7 +139,9 @@ import(vctrs) importFrom(glue,glue) importFrom(stats,delete.response) importFrom(stats,get_all_vars) +importFrom(stats,median) importFrom(stats,model.frame) importFrom(stats,model.matrix) importFrom(stats,terms) +importFrom(tibble,as_tibble) importFrom(tibble,tibble) diff --git a/NEWS.md b/NEWS.md index 4ce3259d..b76ba201 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # hardhat (development version) +* Added a new vector class called `quantile_pred()` to house predictions made from a quantile regression model (tidymodels/parsnip#1191, @dajmcdon). + # hardhat 1.4.0 * Added `extract_postprocessor()` generic (#247). diff --git a/R/hardhat-package.R b/R/hardhat-package.R index 41e1a01f..6eb515bd 100644 --- a/R/hardhat-package.R +++ b/R/hardhat-package.R @@ -7,11 +7,13 @@ #' @import rlang #' @import vctrs #' @importFrom glue glue +#' @importFrom tibble as_tibble #' @importFrom tibble tibble #' @importFrom stats model.frame #' @importFrom stats model.matrix #' @importFrom stats delete.response #' @importFrom stats get_all_vars #' @importFrom stats terms +#' @importFrom stats median ## usethis namespace: end NULL diff --git a/R/quantile-pred.R b/R/quantile-pred.R new file mode 100644 index 00000000..6630357a --- /dev/null +++ b/R/quantile-pred.R @@ -0,0 +1,207 @@ +#' Create a vector containing sets of quantiles +#' +#' [quantile_pred()] is a special vector class used to efficiently store +#' predictions from a quantile regression model. It requires the same quantile +#' levels for each row being predicted. +#' +#' @param values A matrix of values. Each column should correspond to one of +#' the quantile levels. +#' @param quantile_levels A vector of probabilities corresponding to `values`. +#' @param x An object produced by [quantile_pred()]. +#' @param .rows,.name_repair,rownames Arguments not used but required by the +#' original S3 method. +#' @param ... Not currently used. +#' +#' @export +#' @return +#' * [quantile_pred()] returns a vector of values associated with the +#' quantile levels. +#' * [extract_quantile_levels()] returns a numeric vector of levels. +#' * [as_tibble()] returns a tibble with rows `".pred_quantile"`, +#' `".quantile_levels"`, and `".row"`. +#' * [as.matrix()] returns an unnamed matrix with rows as samples, columns as +#' quantile levels, and entries are predictions. +#' @examples +#' .pred_quantile <- quantile_pred(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) +#' +#' unclass(.pred_quantile) +#' +#' # Access the underlying information +#' extract_quantile_levels(.pred_quantile) +#' +#' # Matrix format +#' as.matrix(.pred_quantile) +#' +#' # Tidy format +#' library(tibble) +#' as_tibble(.pred_quantile) +quantile_pred <- function(values, quantile_levels = double()) { + quantile_levels <- vctrs::vec_cast(quantile_levels, double()) + check_quantile_levels(quantile_levels) + check_quantile_pred_inputs(values, quantile_levels) + + rownames(values) <- NULL + colnames(values) <- NULL + values <- lapply(vctrs::vec_chop(values), drop) + new_quantile_pred(values, quantile_levels) +} + +new_quantile_pred <- function(values = list(), quantile_levels = double()) { + quantile_levels <- vctrs::vec_cast(quantile_levels, double()) + vctrs::new_vctr( + values, quantile_levels = quantile_levels, class = "quantile_pred" + ) +} + +#' @export +#' @rdname quantile_pred +extract_quantile_levels <- function(x) { + if (!inherits(x, "quantile_pred")) { + cli::cli_abort("{.arg x} should have class {.cls quantile_pred}.") + } + attr(x, "quantile_levels") +} + +#' @export +#' @rdname quantile_pred +as_tibble.quantile_pred <- + function (x, ..., .rows = NULL, .name_repair = "minimal", rownames = NULL) { + lvls <- attr(x, "quantile_levels") + n_samp <- length(x) + n_quant <- length(lvls) + tibble::new_tibble(list( + .pred_quantile = unlist(x), + .quantile_levels = rep(lvls, n_samp), + .row = rep(1:n_samp, each = n_quant) + )) + } + +#' @export +#' @rdname quantile_pred +as.matrix.quantile_pred <- function(x, ...) { + num_samp <- length(x) + matrix(unlist(x), nrow = num_samp, byrow = TRUE) +} + +#' @export +format.quantile_pred <- function(x, digits = 3L, ...) { + quantile_levels <- attr(x, "quantile_levels") + if (length(quantile_levels) == 1L) { + x <- unlist(x) + out <- signif(x, digits = digits) + out[is.na(x)] <- NA_real_ + } else { + m <- median(x, na.rm = TRUE) + out <- paste0("[", signif(m, digits = digits), "]") + } + out +} + +#' @export +median.quantile_pred <- function(x, ...) { + lvls <- attr(x, "quantile_levels") + loc_median <- (abs(lvls - 0.5) < sqrt(.Machine$double.eps)) + if (any(loc_median)) { + return(map_dbl(x, ~ .x[min(which(loc_median))])) + } + if (length(lvls) < 2 || min(lvls) > 0.5 || max(lvls) < 0.5) { + return(rep(NA, vctrs::vec_size(x))) + } + map_dbl(x, ~ stats::approx(lvls, .x, xout = 0.5)$y) +} + +#' @export +vec_ptype_abbr.quantile_pred <- function(x, ...) { + n_lvls <- length(attr(x, "quantile_levels")) + cli::format_inline("qtl{?s}({n_lvls})") +} + +#' @export +vec_ptype_full.quantile_pred <- function(x, ...) "quantiles" + +#' @export +obj_print_footer.quantile_pred <- function(x, digits = 3, ...) { + lvls <- attr(x, "quantile_levels") + footer <- cli::format_inline("# Quantile {cli::qty(length(lvls))}level{?s}:") + cat(footer, format(lvls, digits = digits), "\n", sep = " ") +} + + +# ------------------------------------------------------------------------------ +# Checking functions + +check_quantile_pred_inputs <- function(values, levels, call = caller_env()) { + if (!is.matrix(values)) { + cli::cli_abort( + "{.arg values} must be a {.cls matrix}, not {.obj_type_friendly {values}}.", + call = call + ) + } + + num_lvls <- length(levels) + + if (ncol(values) != num_lvls) { + cli::cli_abort( + "The number of columns in {.arg values} must be equal to the length of + {.arg quantile_levels}.", call = call + ) + } + + invisible(TRUE) +} + +#' Check levels of quantiles +#' @param levels The quantile levels. +#' @param arg,call Inputs to use to write error messages +#' @return Invisible `TRUE` +#' @keywords internal +#' @details +#' Checks the levels for their data type, range, uniqueness, order and missingness. +#' @export +check_quantile_levels <- function(levels, call = rlang::caller_env()) { + # data type, range, etc + check_quantile_level_values(levels, arg = "quantile_levels", call = call) + + # uniqueness + is_dup <- duplicated(levels) + if (any(is_dup)) { + redund <- levels[is_dup] + redund <- unique(redund) + redund <- signif(redund, digits = 5) + cli::cli_abort(c( + "Quantile levels should be unique.", + i = "The following {cli::qty(length(redund))}value{?s} {?was/were} repeated: + {redund}."), + call = call + ) + } + + # order + if (is.unsorted(levels)) { + cli::cli_abort( + "{.arg quantile_levels} must be sorted in increasing order.", + call = call + ) + } + + invisible(TRUE) +} + +check_quantile_level_values <- function(levels, arg, call) { + if (is.null(levels)) { + cli::cli_abort("{.arg {arg}} cannot be {.val NULL}.", call = call) + } + for (val in levels) { + check_number_decimal( + val, + min = 0, + max = 1, + arg = arg, + call = call, + allow_na = FALSE, + allow_null = FALSE, + allow_infinite = FALSE + ) + } + invisible(TRUE) +} diff --git a/_pkgdown.yml b/_pkgdown.yml index 3d0ce780..6894b4c5 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -24,7 +24,9 @@ reference: - forge - title: Prediction - contents: contains("spruce") + contents: + - contains("spruce") + - quantile_pred - title: Utility contents: diff --git a/man/check_quantile_levels.Rd b/man/check_quantile_levels.Rd new file mode 100644 index 00000000..3f258287 --- /dev/null +++ b/man/check_quantile_levels.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/quantile-pred.R +\name{check_quantile_levels} +\alias{check_quantile_levels} +\title{Check levels of quantiles} +\usage{ +check_quantile_levels(levels, call = rlang::caller_env()) +} +\arguments{ +\item{levels}{The quantile levels.} + +\item{arg, call}{Inputs to use to write error messages} +} +\value{ +Invisible \code{TRUE} +} +\description{ +Check levels of quantiles +} +\details{ +Checks the levels for their data type, range, uniqueness, order and missingness. +} +\keyword{internal} diff --git a/man/quantile_pred.Rd b/man/quantile_pred.Rd new file mode 100644 index 00000000..b73166a3 --- /dev/null +++ b/man/quantile_pred.Rd @@ -0,0 +1,61 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/quantile-pred.R +\name{quantile_pred} +\alias{quantile_pred} +\alias{extract_quantile_levels} +\alias{as_tibble.quantile_pred} +\alias{as.matrix.quantile_pred} +\title{Create a vector containing sets of quantiles} +\usage{ +quantile_pred(values, quantile_levels = double()) + +extract_quantile_levels(x) + +\method{as_tibble}{quantile_pred}(x, ..., .rows = NULL, .name_repair = "minimal", rownames = NULL) + +\method{as.matrix}{quantile_pred}(x, ...) +} +\arguments{ +\item{values}{A matrix of values. Each column should correspond to one of +the quantile levels.} + +\item{quantile_levels}{A vector of probabilities corresponding to \code{values}.} + +\item{x}{An object produced by \code{\link[=quantile_pred]{quantile_pred()}}.} + +\item{...}{Not currently used.} + +\item{.rows, .name_repair, rownames}{Arguments not used but required by the +original S3 method.} +} +\value{ +\itemize{ +\item \code{\link[=quantile_pred]{quantile_pred()}} returns a vector of values associated with the +quantile levels. +\item \code{\link[=extract_quantile_levels]{extract_quantile_levels()}} returns a numeric vector of levels. +\item \code{\link[=as_tibble]{as_tibble()}} returns a tibble with rows \code{".pred_quantile"}, +\code{".quantile_levels"}, and \code{".row"}. +\item \code{\link[=as.matrix]{as.matrix()}} returns an unnamed matrix with rows as samples, columns as +quantile levels, and entries are predictions. +} +} +\description{ +\code{\link[=quantile_pred]{quantile_pred()}} is a special vector class used to efficiently store +predictions from a quantile regression model. It requires the same quantile +levels for each row being predicted. +} +\examples{ +.pred_quantile <- quantile_pred(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) + +unclass(.pred_quantile) + +# Access the underlying information +extract_quantile_levels(.pred_quantile) + +# Matrix format +as.matrix(.pred_quantile) + +# Tidy format +library(tibble) +as_tibble(.pred_quantile) +} diff --git a/tests/testthat/_snaps/quantile-pred.md b/tests/testthat/_snaps/quantile-pred.md new file mode 100644 index 00000000..2a706946 --- /dev/null +++ b/tests/testthat/_snaps/quantile-pred.md @@ -0,0 +1,170 @@ +# quantile_pred error types + + Code + quantile_pred(1:10, 1:4 / 5) + Condition + Error in `quantile_pred()`: + ! `values` must be a , not an integer vector. + +--- + + Code + quantile_pred(matrix(1:20, 5), -1:4 / 5) + Condition + Error in `quantile_pred()`: + ! `quantile_levels` must be a number between 0 and 1, not the number -0.2. + +--- + + Code + quantile_pred(matrix(1:20, 5), 1:5 / 6) + Condition + Error in `quantile_pred()`: + ! The number of columns in `values` must be equal to the length of `quantile_levels`. + +--- + + Code + quantile_pred(matrix(1:20, 5), 4:1 / 5) + Condition + Error in `quantile_pred()`: + ! `quantile_levels` must be sorted in increasing order. + +# quantile levels are checked + + Code + quantile_pred(matrix(1:20, 5), quantile_levels = NULL) + Condition + Error in `quantile_pred()`: + ! `quantile_levels` cannot be "NULL". + +--- + + Code + quantile_pred(matrix(1:20, 5), quantile_levels = c(0.7, 0.7, 0.7)) + Condition + Error in `quantile_pred()`: + ! Quantile levels should be unique. + i The following value was repeated: 0.7. + +--- + + Code + quantile_pred(matrix(1:20, 5), quantile_levels = c(rep(0.7, 2), rep(0.8, 3))) + Condition + Error in `quantile_pred()`: + ! Quantile levels should be unique. + i The following values were repeated: 0.7 and 0.8. + +--- + + Code + quantile_pred(matrix(1:20, 5), quantile_levels = c(0.8, 0.7)) + Condition + Error in `quantile_pred()`: + ! `quantile_levels` must be sorted in increasing order. + +# extract_quantile_levels + + Code + extract_quantile_levels(1:10) + Condition + Error in `extract_quantile_levels()`: + ! `x` should have class . + +# quantile_pred formatting + + Code + v + Output + + [1] [8.5] [9.5] [10.5] [11.5] [12.5] + # Quantile levels: 0.2 0.4 0.6 0.8 + +--- + + Code + quantile_pred(matrix(1:18, 9), c(1 / 3, 2 / 3)) + Output + + [1] [5.5] [6.5] [7.5] [8.5] [9.5] [10.5] [11.5] [12.5] [13.5] + # Quantile levels: 0.333 0.667 + +--- + + Code + quantile_pred(matrix(seq(0.01, 1 - 0.01, length.out = 6), 3), c(0.2, 0.8)) + Output + + [1] [0.304] [0.5] [0.696] + # Quantile levels: 0.2 0.8 + +--- + + Code + tibble(qntls = v) + Output + # A tibble: 5 x 1 + qntls + + 1 [8.5] + 2 [9.5] + 3 [10.5] + 4 [11.5] + 5 [12.5] + +--- + + Code + quantile_pred(m, 1:4 / 5) + Output + + [1] [8.5] [9.5] [10.5] [11.5] [12.5] + # Quantile levels: 0.2 0.4 0.6 0.8 + +--- + + Code + one_quantile + Output + + [1] 1 2 3 4 5 + # Quantile level: 0.556 + +--- + + Code + tibble(qntls = one_quantile) + Output + # A tibble: 5 x 1 + qntls + + 1 1 + 2 2 + 3 3 + 4 4 + 5 5 + +--- + + Code + quantile_pred(m, 5 / 9) + Output + + [1] 1 NA 3 4 5 + # Quantile level: 0.556 + +--- + + Code + format(v) + Output + [1] "[1.72]" "[0.568]" "[1.24]" "[2.21]" "[0.767]" + +--- + + Code + format(v, digits = 5) + Output + [1] "[1.7154]" "[0.56784]" "[1.2393]" "[2.2062]" "[0.76714]" + diff --git a/tests/testthat/test-quantile-pred.R b/tests/testthat/test-quantile-pred.R new file mode 100644 index 00000000..a18ac570 --- /dev/null +++ b/tests/testthat/test-quantile-pred.R @@ -0,0 +1,97 @@ +test_that("quantile_pred error types", { + expect_snapshot( + error = TRUE, + quantile_pred(1:10, 1:4 / 5) + ) + expect_snapshot( + error = TRUE, + quantile_pred(matrix(1:20, 5), -1:4 / 5) + ) + expect_snapshot( + error = TRUE, + quantile_pred(matrix(1:20, 5), 1:5 / 6) + ) + expect_snapshot( + error = TRUE, + quantile_pred(matrix(1:20, 5), 4:1 / 5) + ) +}) + +test_that("quantile levels are checked", { + expect_snapshot(error = TRUE, { + quantile_pred(matrix(1:20, 5), quantile_levels = NULL) + }) + expect_snapshot(error = TRUE, { + quantile_pred(matrix(1:20, 5), quantile_levels = c(0.7, 0.7, 0.7)) + }) + expect_snapshot(error = TRUE, { + quantile_pred(matrix(1:20, 5), quantile_levels = c(rep(0.7, 2), rep(0.8, 3))) + }) + expect_snapshot(error = TRUE, { + quantile_pred(matrix(1:20, 5), quantile_levels = c(0.8, 0.7)) + }) +}) + +test_that("quantile_pred outputs", { + v <- quantile_pred(matrix(1:20, 5), 1:4 / 5) + expect_s3_class(v, "quantile_pred") + expect_identical(attr(v, "quantile_levels"), 1:4 / 5) + expect_identical( + vctrs::vec_data(v), + lapply(vctrs::vec_chop(matrix(1:20, 5)), drop) + ) +}) + +test_that("extract_quantile_levels", { + v <- quantile_pred(matrix(1:20, 5), 1:4 / 5) + expect_identical(extract_quantile_levels(v), 1:4 / 5) + + expect_snapshot( + error = TRUE, + extract_quantile_levels(1:10) + ) +}) + +test_that("quantile_pred formatting", { + # multiple quantiles + v <- quantile_pred(matrix(1:20, 5), 1:4 / 5) + expect_snapshot(v) + expect_snapshot(quantile_pred(matrix(1:18, 9), c(1/3, 2/3))) + expect_snapshot( + quantile_pred(matrix(seq(0.01, 1 - 0.01, length.out = 6), 3), c(.2, .8)) + ) + expect_snapshot(tibble(qntls = v)) + m <- matrix(1:20, 5) + m[2, 3] <- NA + m[4, 2] <- NA + expect_snapshot(quantile_pred(m, 1:4 / 5)) + + # single quantile + m <- matrix(1:5) + one_quantile <- quantile_pred(m, 5/9) + expect_snapshot(one_quantile) + expect_snapshot(tibble(qntls = one_quantile)) + m[2] <- NA + expect_snapshot(quantile_pred(m, 5/9)) + + set.seed(393) + v <- quantile_pred(matrix(exp(rnorm(20)), ncol = 4), 1:4 / 5) + expect_snapshot(format(v)) + expect_snapshot(format(v, digits = 5)) +}) + +test_that("as_tibble() for quantile_pred", { + v <- quantile_pred(matrix(1:20, 5), 1:4 / 5) + tbl <- as_tibble(v) + expect_s3_class(tbl, c("tbl_df", "tbl", "data.frame")) + expect_named(tbl, c(".pred_quantile", ".quantile_levels", ".row")) + expect_true(nrow(tbl) == 20) +}) + +test_that("as.matrix() for quantile_pred", { + x <- matrix(1:20, 5) + v <- quantile_pred(x, 1:4 / 5) + m <- as.matrix(v) + expect_true(is.matrix(m)) + expect_identical(m, x) +})