Skip to content
Open
11 changes: 11 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,28 @@ S3method(run_mold,default)
S3method(run_mold,default_formula_blueprint)
S3method(run_mold,default_recipe_blueprint)
S3method(run_mold,default_xy_blueprint)
S3method(snap,numeric)
S3method(snap,quantile_pred)
S3method(standardize,array)
S3method(standardize,data.frame)
S3method(standardize,default)
S3method(standardize,double)
S3method(standardize,factor)
S3method(standardize,integer)
S3method(standardize,matrix)
S3method(vec_arith,quantile_pred)
S3method(vec_arith.numeric,quantile_pred)
S3method(vec_arith.quantile_pred,numeric)
S3method(vec_cast,double.hardhat_frequency_weights)
S3method(vec_cast,double.hardhat_importance_weights)
S3method(vec_cast,hardhat_frequency_weights.hardhat_frequency_weights)
S3method(vec_cast,hardhat_importance_weights.hardhat_importance_weights)
S3method(vec_cast,integer.hardhat_frequency_weights)
S3method(vec_cast,quantile_pred.quantile_pred)
S3method(vec_math,quantile_pred)
S3method(vec_ptype2,hardhat_frequency_weights.hardhat_frequency_weights)
S3method(vec_ptype2,hardhat_importance_weights.hardhat_importance_weights)
S3method(vec_ptype2,quantile_pred.quantile_pred)
S3method(vec_ptype_abbr,hardhat_frequency_weights)
S3method(vec_ptype_abbr,hardhat_importance_weights)
S3method(vec_ptype_abbr,quantile_pred)
Expand Down Expand Up @@ -88,10 +96,12 @@ export(get_data_classes)
export(get_levels)
export(get_outcome_levels)
export(importance_weights)
export(impute_quantiles)
export(is_blueprint)
export(is_case_weights)
export(is_frequency_weights)
export(is_importance_weights)
export(is_quantile_pred)
export(model_frame)
export(model_matrix)
export(model_offset)
Expand All @@ -114,6 +124,7 @@ export(run_forge)
export(run_mold)
export(scream)
export(shrink)
export(snap)
export(spruce_class)
export(spruce_class_multiple)
export(spruce_numeric)
Expand Down
213 changes: 213 additions & 0 deletions R/impute-quantile_pred.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
#' Restrict numeric data to the interval \[lower, upper\]
#'
#' @param x a numeric vector
#' @param lower number, the lower bound
#' @param upper number, the upper bound
#' @param ... unused
#' @export
#'
#' @return An object of the same type as `x`
#'
#' @keywords internal
snap <- function(x, lower, upper, ...) {
UseMethod("snap")
}

#' @export
snap.numeric <- function(x, lower, upper, ...) {
rlang::check_dots_empty()
check_number_decimal(lower)
check_number_decimal(upper)

pmin(pmax(x, lower), upper)
}

#' @export
snap.quantile_pred <- function(x, lower, upper, ...) {
if (vec_size(x) == 0) return(x)
values <- as.matrix(x)
quantile_levels <- extract_quantile_levels(x)
values <- map(vctrs::vec_chop(values), ~ snap(.x, lower, upper))
quantile_pred(do.call(rbind, values), quantile_levels = quantile_levels)
}




#' Impute additional quantiles from a `quantile_pred`
#'
#' While a [quantile_pred] describes evaluations for the inverse
#' cummulative distribution function (CDF, sometimes called the "quantile
#' function") at particular quantile levels, this is not enough
#' to fully describe the distribution. For example,
#' ```r
#' p <- c(.1, .5, .9)
#' quantile_pred(matrix(qnorm(p), nrow = 1), p)
#' ```
#' encapsulates the 10%, 50%, and 90% quantile levels of the standard normal distribution.
#' But, what if we need, say, the 25% and 75% levels? This function imputes
#' them if possible.
#'
#' @details
#' If `probs` is simply a subset of `quantile_levels` that already exist in `x`,
#' then these will be returned (up to numeric error). Small errors are possible
#' due to difficulties matching double vectors.
#'
#' For `probs` that do not exist in `x`, these will be interpolated or
#' extrapolated as needed. The process has 3 steps.
#'
#' First, by default (`middle = "cubic"`), missing _internal_ quantile levels are
#' interpolated using a cubic spline fit to the observed values + quantile levels with
#' [stats::splinefun]. Second, if cubic interpolation fails (or if
#' `middle = "linear"`), linear interpolation is used via [stats::approx].
#' Finally, missing _external_ quantile levels (those outside the range of
#' `quantile_levels`) are extrapolated. This is done using a linear fit on the
#' logistic scale to the two closest tail values.
#'
#' This procedure results in sorted quantiles that interpolate/extrapolate
#' smoothly, while also enforcing heavy tails beyond the range.
#'
#' Optionally, the resulting quantiles can be constrained to a compact interval
#' using `lower` and/or `upper`. This is done after extrapolation, so it may
#' result in multiple quantile levels having the same value (a CDF with a spike).
#'
#'
#' @param x an object of class `quantile_pred`
#' @param probs vector. probabilities at which to evaluate the inverse CDF
#' @param lower number. lower bound for the resulting values
#' @param upper number. upper bound for the resulting values
#' @param middle character.
#'
#' @returns A matrix with `length(probs)` columns and `length(x)` rows. Each
#' row contains the inverse CDF (quantile function) given by `x`,
#' extrapolated/interpolated to `probs`.
#' @export
#'
#' @examples
#' p <- c(.1, .5, .9)
#' qp <- quantile_pred(matrix(c(qnorm(p), qexp(p)), nrow = 2, byrow = TRUE), p)
#' impute_quantiles(qp, p)
#' as.matrix(qp) # same as the imputation
#'
#' p1 <- c(.05, .25, .75, .95)
#' impute_quantiles(qp, p1)
#' rbind(qnorm(p1), qexp(p1)) # exact values, for comparison
impute_quantiles <- function(
x,
probs = seq(0, 1, 0.25),
lower = -Inf,
upper = Inf,
middle = c("cubic", "linear")
) {
if (!is_quantile_pred(x)) {
cli::cli_abort(
"{.arg x} must be a {.cls quantile_pred} object, not
{.obj_type_friendly {x}}."
)
}
if (length(extract_quantile_levels(x)) < 2) {
cli::cli_abort(
"Quantile interpolation is not possible when fewer than 2 quantiles
are avaliable."
)
}
if (is.unsorted(probs)) probs <- sort(probs)
check_quantile_level_values(probs, "probs", call = caller_env())
check_number_decimal(lower)
check_number_decimal(upper)
if (lower > upper) {
cli::cli_abort("`lower` must be less than `upper`.")
}
middle <- rlang::arg_match(middle)
snap(impute_quantile_internal(x, probs, middle), lower, upper)
}

impute_quantile_internal <- function(x, tau_out, middle) {
tau <- extract_quantile_levels(x)
qvals <- as.matrix(x)
if (all(tau_out %in% tau) && !anyNA(qvals)) {
return(qvals[, match(tau_out, tau), drop = FALSE])
}
qvals_out <- map(
vctrs::vec_chop(qvals),
~ impute_quantiles_single(.x, tau, tau_out, middle)
)
qvals_out <- do.call(rbind, qvals_out)
qvals_out
}

impute_quantiles_single <- function(qvals, tau, tau_out, middle) {
qvals_out <- rep(NA, length(tau_out))
good <- !is.na(qvals)
if (!any(good)) {
return(qvals_out)
}
qvals <- qvals[good]
tau <- tau[good]

# in case we only have one point, and it matches something we wanted
if (length(good) < 2) {
matched_one <- tau_out %in% tau
qvals_out[matched_one] <- qvals[matched_one]
return(qvals_out)
}

indl <- tau_out < min(tau)
indr <- tau_out > max(tau)
indm <- !indl & !indr

if (middle == "cubic") {
method <- "cubic"
result <- tryCatch(
{
Q <- stats::splinefun(tau, qvals, method = "hyman")
quartiles <- Q(c(.25, .5, .75))
},
error = function(e) {
return(NA)
}
)
}
if (middle == "linear" || any(is.na(result))) {
method <- "linear"
quartiles <- stats::approx(tau, qvals, c(.25, .5, .75))$y
}
if (any(indm)) {
qvals_out[indm] <- switch(
method,
linear = stats::approx(tau, qvals, tau_out[indm])$y,
cubic = Q(tau_out[indm])
)
}
if (any(indl) || any(indr)) {
qv <- data.frame(
q = c(tau, tau_out[indm]),
v = c(qvals, qvals_out[indm])
)
qv <- qv[vctrs::vec_unique_loc(qv$q), ]
qv <- qv[vctrs::vec_order(qv$q), ]
}
if (any(indl)) {
qvals_out[indl] <- tail_extrapolate(tau_out[indl], utils::head(qv, 2))
}
if (any(indr)) {
qvals_out[indr] <- tail_extrapolate(tau_out[indr], utils::tail(qv, 2))
}
qvals_out
}

logit <- function(p) {
p <- pmax(pmin(p, 1), 0)
log(p) - log(1 - p)
}

# extrapolates linearly on the logistic scale using
# the two points nearest the tail
tail_extrapolate <- function(tau_out, qv) {
if (nrow(qv) == 1L) return(rep(qv$v[1], length(tau_out)))
x <- logit(qv$q)
x0 <- logit(tau_out)
y <- qv$v
m <- diff(y) / diff(x)
m * (x0 - x[1]) + y[1]
}
84 changes: 82 additions & 2 deletions R/quantile-pred.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#' `".quantile_levels"`, and `".row"`.
#' * `as.matrix()` returns an unnamed matrix with rows as samples, columns as
#' quantile levels, and entries are predictions.
#' * `is_quantile_pred()` tests for the "quantile_pred" class
#' @examples
#' .pred_quantile <- quantile_pred(matrix(rnorm(20), 5), c(.2, .4, .6, .8))
#'
Expand All @@ -42,7 +43,7 @@ quantile_pred <- function(values, quantile_levels = double()) {

rownames(values) <- NULL
colnames(values) <- NULL
values <- lapply(vctrs::vec_chop(values), drop)
values <- map(vctrs::vec_chop(values), drop)
new_quantile_pred(values, quantile_levels)
}

Expand All @@ -55,10 +56,18 @@ new_quantile_pred <- function(values = list(), quantile_levels = double()) {
)
}



#' @export
#' @rdname quantile_pred
is_quantile_pred <- function(x) {
inherits(x, "quantile_pred")
}

#' @export
#' @rdname quantile_pred
extract_quantile_levels <- function(x) {
if (!inherits(x, "quantile_pred")) {
if (!is_quantile_pred(x)) {
cli::cli_abort(
"{.arg x} must be a {.cls quantile_pred} object, not
{.obj_type_friendly {x}}."
Expand Down Expand Up @@ -208,3 +217,74 @@ check_quantile_level_values <- function(levels, arg, call) {
}
invisible(TRUE)
}


# vctrs behaviours --------------------------------------------------------

#' @export
#' @keywords internal
vec_ptype2.quantile_pred.quantile_pred <- function(
x, y, ..., x_arg = "", y_arg = "", call = caller_env()
) {
if (all(extract_quantile_levels(y) %in% extract_quantile_levels(x))) {
return(x)
}
if (all(extract_quantile_levels(x) %in% extract_quantile_levels(y))) {
return(y)
}
stop_incompatible_type(
x, y, x_arg = x_arg, y_arg = y_arg,
details = "`quantile_levels` must be compatible (a superset/subset relation)."
)
}

#' @export
vec_cast.quantile_pred.quantile_pred <- function(x, to, ..., x_arg = "", to_arg = "") {
x_lvls <- extract_quantile_levels(x)
to_lvls <- extract_quantile_levels(to)
x_in_to <- x_lvls %in% to_lvls
to_in_x <- to_lvls %in% x_lvls

old_qdata <- as.matrix(x)[, x_in_to]
new_qdata <- matrix(NA, nrow = vec_size(x), ncol = length(to_lvls))
new_qdata[, to_in_x] <- old_qdata
quantile_pred(new_qdata, quantile_levels = to_lvls)
}


#' @export
#' @method vec_math quantile_pred
vec_math.quantile_pred <- function(.fn, .x, ...) {
fn <- .fn
.fn <- getExportedValue("base", .fn)
if (fn %in% c("any", "all", "prod", "sum", "cumsum", "cummax", "cummin", "cumprod")) {
cli::cli_abort("{.fn {fn}} is not a supported operation for {.cls quantile_pred}.")
}
quantile_levels <- .x %@% "quantile_levels"
.x <- as.matrix(.x)
quantile_pred(.fn(.x), quantile_levels)
}

#' @export
#' @method vec_arith quantile_pred
vec_arith.quantile_pred <- function(op, x, y, ...) {
UseMethod("vec_arith.quantile_pred", y)
}

#' @export
#' @method vec_arith.quantile_pred numeric
vec_arith.quantile_pred.numeric <- function(op, x, y, ...) {
op_fn <- getExportedValue("base", op)
l <- vctrs::vec_recycle_common(x = x, y = y)
out <- op_fn(as.matrix(l$x), l$y)
quantile_pred(out, x %@% "quantile_levels")
}

#' @export
#' @method vec_arith.numeric quantile_pred
vec_arith.numeric.quantile_pred <- function(op, x, y, ...) {
op_fn <- getExportedValue("base", op)
l <- vctrs::vec_recycle_common(x = x, y = y)
out <- op_fn(l$x, as.matrix(l$y))
quantile_pred(out, y %@% "quantile_levels")
}
1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ reference:
- add_intercept_column
- weighted_table
- fct_encode_one_hot
- impute_quantiles

- title: Validation
contents:
Expand Down
Loading
Loading