Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0b4cf5d
code from parsnip
Sep 13, 2024
1739032
version bump
Sep 13, 2024
82a89c4
news update
Sep 13, 2024
8add7cb
add pkgdown entry
Sep 13, 2024
7192170
Apply suggestions from code review
topepo Sep 13, 2024
fa85a3b
add snapshot file
hfrick Sep 16, 2024
cae0633
#259 already claimed `.9001`
hfrick Sep 16, 2024
b18861e
tidy style
hfrick Sep 16, 2024
a8c9355
change styling to class and test
hfrick Sep 16, 2024
754e385
add test for `as.matrix()` to illustrate current behavior
hfrick Sep 16, 2024
100ddcb
move constructor
hfrick Sep 16, 2024
89d98d1
group vctrs methods
hfrick Sep 16, 2024
9ab5809
group checking functions together
hfrick Sep 16, 2024
26bc327
move up functions which are placed on the main help page
hfrick Sep 16, 2024
0979985
group remaining methods
hfrick Sep 16, 2024
c8c4230
move unused check function with others for visibility
hfrick Sep 16, 2024
9f14926
Merge branch 'main' into quantile-pred
hfrick Sep 17, 2024
db0e1a5
fix typo
Sep 17, 2024
9a659ef
remove function for parsnip's set_mode()
Sep 17, 2024
30de46f
remove test helpers
Sep 17, 2024
9dc4b12
remove functions for restructuring
Sep 17, 2024
dfe8288
remove range code; add digits argument
Sep 17, 2024
7d6c4e2
fix as.matrix method
Sep 17, 2024
d0444bd
remove @importFrom
Sep 18, 2024
c597214
don't re-export vctrs generics
hfrick Sep 18, 2024
9006d02
add pluralization to `obj_print_footer()` method
hfrick Sep 18, 2024
fbed1c1
refactor inout checks
Sep 23, 2024
f9808c6
fix typo
hfrick Sep 23, 2024
04525e8
only show unique values in error
hfrick Sep 23, 2024
480c2ff
add pluralization
hfrick Sep 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: hardhat
Title: Construct Modeling Packages
Version: 1.4.0.9000
Version: 1.4.0.9002
Authors@R: c(
person("Hannah", "Frick", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0002-6049-5258")),
Expand Down
17 changes: 17 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
# Generated by roxygen2: do not edit by hand

S3method(as.matrix,quantile_pred)
S3method(as_tibble,quantile_pred)
S3method(forge,data.frame)
S3method(forge,default)
S3method(forge,matrix)
S3method(format,formula_blueprint)
S3method(format,quantile_pred)
S3method(format,recipe_blueprint)
S3method(format,xy_blueprint)
S3method(median,quantile_pred)
S3method(mold,data.frame)
S3method(mold,default)
S3method(mold,formula)
S3method(mold,matrix)
S3method(mold,recipe)
S3method(obj_print_footer,quantile_pred)
S3method(print,formula_blueprint)
S3method(print,hardhat_blueprint)
S3method(print,hardhat_model)
Expand Down Expand Up @@ -45,8 +50,10 @@ S3method(vec_ptype2,hardhat_frequency_weights.hardhat_frequency_weights)
S3method(vec_ptype2,hardhat_importance_weights.hardhat_importance_weights)
S3method(vec_ptype_abbr,hardhat_frequency_weights)
S3method(vec_ptype_abbr,hardhat_importance_weights)
S3method(vec_ptype_abbr,quantile_pred)
S3method(vec_ptype_full,hardhat_frequency_weights)
S3method(vec_ptype_full,hardhat_importance_weights)
S3method(vec_ptype_full,quantile_pred)
export(add_intercept_column)
export(check_column_names)
export(check_no_formula_duplication)
Expand All @@ -69,6 +76,7 @@ export(extract_parameter_dials)
export(extract_parameter_set_dials)
export(extract_postprocessor)
export(extract_preprocessor)
export(extract_quantile_levels)
export(extract_recipe)
export(extract_spec_parsnip)
export(extract_workflow)
Expand Down Expand Up @@ -98,6 +106,8 @@ export(new_importance_weights)
export(new_model)
export(new_recipe_blueprint)
export(new_xy_blueprint)
export(obj_print_footer)
export(quantile_pred)
export(recompose)
export(refresh_blueprint)
export(run_forge)
Expand All @@ -123,13 +133,20 @@ export(validate_outcomes_are_numeric)
export(validate_outcomes_are_univariate)
export(validate_prediction_size)
export(validate_predictors_are_numeric)
export(vec_ptype_abbr)
export(vec_ptype_full)
export(weighted_table)
import(rlang)
import(vctrs)
importFrom(glue,glue)
importFrom(stats,delete.response)
importFrom(stats,get_all_vars)
importFrom(stats,median)
importFrom(stats,model.frame)
importFrom(stats,model.matrix)
importFrom(stats,terms)
importFrom(tibble,as_tibble)
importFrom(tibble,tibble)
importFrom(vctrs,obj_print_footer)
importFrom(vctrs,vec_ptype_abbr)
importFrom(vctrs,vec_ptype_full)
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# hardhat (development version)

* Added a new vector class called `quantile_pred()` to house predictions made from a quantile regression model (tidymodels/parsnip#1191, @dajmcdon).

# hardhat 1.4.0

* Added `extract_postprocessor()` generic (#247).
Expand Down
2 changes: 2 additions & 0 deletions R/hardhat-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
#' @import rlang
#' @import vctrs
#' @importFrom glue glue
#' @importFrom tibble as_tibble
#' @importFrom tibble tibble
#' @importFrom stats model.frame
#' @importFrom stats model.matrix
#' @importFrom stats delete.response
#' @importFrom stats get_all_vars
#' @importFrom stats terms
#' @importFrom stats median
## usethis namespace: end
NULL
219 changes: 219 additions & 0 deletions R/quantile-pred.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
#' Create a vector containing sets of quantiles
#'
#' [quantile_pred()] is a special vector class used to efficiently store
#' predictions from a quantile regression model. It requires the same quantile
#' levels for each row being predicted.
#'
#' @param values A matrix of values. Each column should correspond to one of
#' the quantile levels.
#' @param quantile_levels A vector of probabilities corresponding to `values`.
#' @param x An object produced by [quantile_pred()].
#' @param .rows,.name_repair,rownames Arguments not used but required by the
#' original S3 method.
#' @param ... Not currently used.
#'
#' @export
#' @return
#' * [quantile_pred()] returns a vector of values associated with the
#' quantile levels.
#' * [extract_quantile_levels()] returns a numeric vector of levels.
#' * [as_tibble()] returns a tibble with rows `".pred_quantile"`,
#' `".quantile_levels"`, and `".row"`.
#' * [as.matrix()] returns an unnamed matrix with rows as sames, columns as
#' quantile levels, and entries are predictions.
#' @examples
#' .pred_quantile <- quantile_pred(matrix(rnorm(20), 5), c(.2, .4, .6, .8))
#'
#' unclass(.pred_quantile)
#'
#' # Access the underlying information
#' extract_quantile_levels(.pred_quantile)
#'
#' # Matrix format
#' as.matrix(.pred_quantile)
#'
#' # Tidy format
#' library(tibble)
#' as_tibble(.pred_quantile)
quantile_pred <- function(values, quantile_levels = double()) {
check_quantile_pred_inputs(values, quantile_levels)

quantile_levels <- vctrs::vec_cast(quantile_levels, double())
num_lvls <- length(quantile_levels)

if (ncol(values) != num_lvls) {
cli::cli_abort(
"The number of columns in {.arg values} must be equal to the length of
{.arg quantile_levels}."
)
}
rownames(values) <- NULL
colnames(values) <- NULL
values <- lapply(vctrs::vec_chop(values), drop)
new_quantile_pred(values, quantile_levels)
}

new_quantile_pred <- function(values = list(), quantile_levels = double()) {
quantile_levels <- vctrs::vec_cast(quantile_levels, double())
vctrs::new_vctr(
values, quantile_levels = quantile_levels, class = "quantile_pred"
)
}

check_quantile_pred_inputs <- function(values, levels, call = caller_env()) {
if (any(is.na(levels))) {
cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.",
call = call)
}

if (!is.matrix(values)) {
cli::cli_abort(
"{.arg values} must be a {.cls matrix}, not {.obj_type_friendly {values}}.",
call = call
)
}
check_vector_probability(levels, arg = "quantile_levels", call = call)

if (is.unsorted(levels)) {
cli::cli_abort(
"{.arg quantile_levels} must be sorted in increasing order.",
call = call
)
}
invisible(NULL)
}

check_vector_probability <- function(x, ...,
allow_na = FALSE,
allow_null = FALSE,
arg = caller_arg(x),
call = caller_env()) {
for (d in x) {
check_number_decimal(
d,
min = 0,
max = 1,
arg = arg,
call = call,
allow_na = allow_na,
allow_null = allow_null,
allow_infinite = FALSE
)
}
}

check_quantile_level <- function(x, object, call) {
if (object$mode != "quantile regression") {
return(invisible(TRUE))
} else {
if (is.null(x)) {
cli::cli_abort("In {.fn check_mode}, at least one value of
{.arg quantile_level} must be specified for quantile regression models.")
}
}
if (any(is.na(x))) {
cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.",
call = call)
}
x <- sort(unique(x))
check_vector_probability(x, arg = "quantile_level", call = call)
x
}

#' @export
#' @rdname quantile_pred
extract_quantile_levels <- function(x) {
if (!inherits(x, "quantile_pred")) {
cli::cli_abort("{.arg x} should have class {.cls quantile_pred}.")
}
attr(x, "quantile_levels")
}

#' @export
#' @rdname quantile_pred
as_tibble.quantile_pred <-
function (x, ..., .rows = NULL, .name_repair = "minimal", rownames = NULL) {
lvls <- attr(x, "quantile_levels")
n_samp <- length(x)
n_quant <- length(lvls)
tibble::new_tibble(list(
.pred_quantile = unlist(x),
.quantile_levels = rep(lvls, n_samp),
.row = rep(1:n_samp, each = n_quant)
))
}

#' @export
#' @rdname quantile_pred
as.matrix.quantile_pred <- function(x, ...) {
num_samp <- length(x)
matrix(unlist(x), nrow = num_samp)
}

#' @export
format.quantile_pred <- function(x, ...) {
quantile_levels <- attr(x, "quantile_levels")
if (length(quantile_levels) == 1L) {
x <- unlist(x)
out <- round(x, 3L)
out[is.na(x)] <- NA_real_
} else {
rng <- sapply(x, range, na.rm = TRUE)
out <- paste0("[", round(rng[1, ], 3L), ", ", round(rng[2, ], 3L), "]")
out[is.na(rng[1, ]) & is.na(rng[2, ])] <- NA_character_
m <- median(x)
out <- paste0("[", round(m, 3L), "]")
}
out
}

#' @export
median.quantile_pred <- function(x, ...) {
lvls <- attr(x, "quantile_levels")
loc_median <- (abs(lvls - 0.5) < sqrt(.Machine$double.eps))
if (any(loc_median)) {
return(map_dbl(x, ~ .x[min(which(loc_median))]))
}
if (length(lvls) < 2 || min(lvls) > 0.5 || max(lvls) < 0.5) {
return(rep(NA, vctrs::vec_size(x)))
}
map_dbl(x, ~ stats::approx(lvls, .x, xout = 0.5)$y)
}

#' @importFrom vctrs vec_ptype_abbr
#' @export
vctrs::vec_ptype_abbr

#' @importFrom vctrs vec_ptype_full
#' @export
vctrs::vec_ptype_full

#' @export
vec_ptype_abbr.quantile_pred <- function(x, ...) {
n_lvls <- length(attr(x, "quantile_levels"))
cli::format_inline("qtl{?s}({n_lvls})")
}

#' @export
vec_ptype_full.quantile_pred <- function(x, ...) "quantiles"

#' @importFrom vctrs obj_print_footer
#' @export
vctrs::obj_print_footer

#' @export
obj_print_footer.quantile_pred <- function(x, digits = 3, ...) {
lvls <- attr(x, "quantile_levels")
cat("# Quantile levels: ", format(lvls, digits = digits), "\n", sep = " ")
}

restructure_rq_pred <- function(x, object) {
if (!is.matrix(x)) {
x <- as.matrix(x)
}
rownames(x) <- NULL
n_pred_quantiles <- ncol(x)
quantile_level <- object$spec$quantile_level

tibble::new_tibble(x = list(.pred_quantile = quantile_pred(x, quantile_level)))
}
4 changes: 3 additions & 1 deletion _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ reference:
- forge

- title: Prediction
contents: contains("spruce")
contents:
- contains("spruce")
- quantile_pred

- title: Utility
contents:
Expand Down
61 changes: 61 additions & 0 deletions man/quantile_pred.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading