Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Suggests:
VignetteBuilder:
knitr
ByteCompile: true
Config/Needs/website: C50, dbarts, earth, glmnet, keras, kernlab, kknn,
Config/Needs/website: C50, dbarts, earth, glmnet, grf, keras, kernlab, kknn,
LiblineaR, mgcv, nnet, parsnip, randomForest, ranger, rpart, rstanarm,
tidymodels/tidymodels, tidyverse/tidytemplate, rstudio/reticulate,
xgboost
Expand Down
10 changes: 10 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ S3method(extract_spec_parsnip,model_fit)
S3method(fit,model_spec)
S3method(fit_xy,gen_additive_mod)
S3method(fit_xy,model_spec)
S3method(format,vctrs_quantiles)
S3method(glance,model_fit)
S3method(has_multi_predict,default)
S3method(has_multi_predict,model_fit)
Expand All @@ -54,6 +55,7 @@ S3method(multi_predict_args,default)
S3method(multi_predict_args,model_fit)
S3method(multi_predict_args,workflow)
S3method(nullmodel,default)
S3method(obj_print_footer,vctrs_quantiles)
S3method(predict,"_elnet")
S3method(predict,"_glmnetfit")
S3method(predict,"_lognet")
Expand Down Expand Up @@ -280,6 +282,7 @@ export(new_model_spec)
export(null_model)
export(null_value)
export(nullmodel)
export(obj_print_footer)
export(parsnip_addin)
export(pls)
export(poisson_reg)
Expand Down Expand Up @@ -350,6 +353,9 @@ export(update_model_info_file)
export(update_spec)
export(varying)
export(varying_args)
export(vec_ptype_abbr.vctrs_quantiles)
export(vec_ptype_full.vctrs_quantiles)
export(vec_quantiles)
export(xgb_predict)
export(xgb_train)
import(rlang)
Expand Down Expand Up @@ -396,6 +402,9 @@ importFrom(purrr,map)
importFrom(purrr,map_chr)
importFrom(purrr,map_dbl)
importFrom(purrr,map_lgl)
importFrom(rlang,"!!!")
importFrom(rlang,is_double)
importFrom(rlang,is_list)
importFrom(stats,.checkMFClasses)
importFrom(stats,.getXlevels)
importFrom(stats,as.formula)
Expand Down Expand Up @@ -426,5 +435,6 @@ importFrom(utils,globalVariables)
importFrom(utils,head)
importFrom(utils,methods)
importFrom(utils,stack)
importFrom(vctrs,obj_print_footer)
importFrom(vctrs,vec_size)
importFrom(vctrs,vec_unique)
120 changes: 108 additions & 12 deletions R/aaa_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,117 @@ check_quantile_level <- function(x, object, call) {
x
}


# -------------------------------------------------------------------------
# A column vector of quantiles with an attribute

#' @export
vec_ptype_abbr.vctrs_quantiles <- function(x, ...) "qntls"

#' @export
vec_ptype_full.vctrs_quantiles <- function(x, ...) "quantiles"

#' @importFrom rlang is_list is_double !!!
new_vec_quantiles <- function(values = list(), quantile_levels = double()) {
quantile_levels <- vctrs::vec_cast(quantile_levels, double())
num_values <- vctrs::vec_size_common(!!!values)
if (length(quantile_levels) != num_values) {
cli::cli_abort(
"{.arg quantile_levels} must have the same length as {.arg values}. It has
length {.val {length(quantile_levels)}} not {.val {num_values}}."
)
}
purrr::walk(
quantile_levels,
~ check_number_decimal(.x, min = 0, max = 1, arg = "quantile_levels")
)
vctrs::new_vctr(
values, quantile_levels = quantile_levels, class = "vctrs_quantiles"
)
}


#' A vector containing sets of quantiles
#'
#' @param values A data.frame/matrix/vector of values. If a named data.frame,
#' the column names will be used as the `quantile_levels` if those are missing.
#' @param quantile_levels A vector of probabilities corresponding to `values`.
#' May be `NULL` if `values` is a named data.frame.
#'
#' @export
#'
#' @examples
#' preds <- vec_quantiles(list(1:4, 8:11), c(.2, .4, .6, .8))
#'
#' vec_quantiles(1:4, 1:4 / 5)
vec_quantiles <- function(values, quantile_levels = NULL) {
check_vec_quantiles_inputs(values, quantile_levels)
# TODO save call reference
quantile_levels <- vctrs::vec_cast(quantile_levels, double())

num_lvls <- length(quantile_levels)

if (is.data.frame(values) || (is.matrix(values) && length(dim(values)) == 2)) {
values <- lapply(vctrs::vec_chop(values), function(v) sort(drop(v)))
} else if (is.list(values)) {
values <- values
} else if (is.null(dim(values))) {
if (length(values) != num_lvls) {
values <- vctrs::vec_chop(values)
}
} else {
cli::cli_abort(
"{.arg values} must be a {.cls list}, {.cls matrix}, or {.cls data.frame},
not a {.cls {class(values)}}."
)
}
new_vec_quantiles(values, quantile_levels)
}

check_vec_quantiles_inputs <- function(values, levels) {
if (is.null(levels)) {
if (!is.data.frame(values)) {
cli::cli_abort("If {.arg quantile_levels} is `NULL`, {.arg values} must
be a data.frame.")
}
levels <- as.numeric(names(values))
if (any(is.na(levels))) {
cli::cli_abort("If {.arg quantile_levels} is `NULL`, {.arg values} must
be a data.frame with numeric names.")
}
}
invisible(NULL)
}

#' @export
format.vctrs_quantiles <- function(x, ...) {
quantile_levels <- attr(x, "levels")
if (length(quantile_levels) == 1L) {
x <- unlist(x)
out <- round(x, 3L)
out[is.na(x)] <- NA
} else {
rng <- sapply(x, range)
out <- paste0("[", round(rng[1, ], 3L), ", ", round(rng[2, ], 3L), "]")
out[is.na(rng[1, ]) | is.na(rng[2, ])] <- NA
}
out
}

#' @importFrom vctrs obj_print_footer
#' @export
vctrs::obj_print_footer

#' @export
obj_print_footer.vctrs_quantiles <- function(x, ...) {
lvls <- attr(x, "quantile_levels")
cat("# Quantile levels: ", format(lvls, digits = 3), "\n", sep = " ")
}

# Assumes the columns have the same order as quantile_level
restructure_rq_pred <- function(x, object) {
n <- nrow(x)
p <- ncol(x)
# TODO check p = length(quantile_level)
# check p = 1 case
quantile_level <- object$spec$quantile_level
res <-
tibble::tibble(
.pred_quantile = as.vector(x),
.quantile_level = rep(quantile_level, each = n),
.row = rep(1:n, p))
res <- vctrs::vec_split(x = res[,1:2], by = res[, ".row"])
res <- vctrs::vec_cbind(res$key, tibble::new_tibble(list(.pred_quantile = res$val)))
res$.row <- NULL
res <- tibble(.pred_quantile = vec_quantiles(x, quantile_level))
res
}

18 changes: 0 additions & 18 deletions R/linear_reg_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -612,24 +612,6 @@ set_encoding(
)
)

set_pred(
model = "linear_reg",
eng = "quantreg",
mode = "quantile regression",
type = "numeric",
value = list(
pre = NULL,
post = NULL,
func = c(fun = "predict"),
args =
list(
object = expr(object$fit),
newdata = expr(new_data),
type = "response",
rankdeficient = "simple"
)
)
)

set_pred(
model = "linear_reg",
Expand Down
Loading