diff --git a/DESCRIPTION b/DESCRIPTION index 6dbb2399e..00af5f0d1 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -68,7 +68,7 @@ Suggests: VignetteBuilder: knitr ByteCompile: true -Config/Needs/website: C50, dbarts, earth, glmnet, keras, kernlab, kknn, +Config/Needs/website: C50, dbarts, earth, glmnet, grf, keras, kernlab, kknn, LiblineaR, mgcv, nnet, parsnip, randomForest, ranger, rpart, rstanarm, tidymodels/tidymodels, tidyverse/tidytemplate, rstudio/reticulate, xgboost diff --git a/NAMESPACE b/NAMESPACE index d48f33586..87f9ab7ab 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -36,6 +36,7 @@ S3method(extract_spec_parsnip,model_fit) S3method(fit,model_spec) S3method(fit_xy,gen_additive_mod) S3method(fit_xy,model_spec) +S3method(format,vctrs_quantiles) S3method(glance,model_fit) S3method(has_multi_predict,default) S3method(has_multi_predict,model_fit) @@ -54,6 +55,7 @@ S3method(multi_predict_args,default) S3method(multi_predict_args,model_fit) S3method(multi_predict_args,workflow) S3method(nullmodel,default) +S3method(obj_print_footer,vctrs_quantiles) S3method(predict,"_elnet") S3method(predict,"_glmnetfit") S3method(predict,"_lognet") @@ -280,6 +282,7 @@ export(new_model_spec) export(null_model) export(null_value) export(nullmodel) +export(obj_print_footer) export(parsnip_addin) export(pls) export(poisson_reg) @@ -350,6 +353,9 @@ export(update_model_info_file) export(update_spec) export(varying) export(varying_args) +export(vec_ptype_abbr.vctrs_quantiles) +export(vec_ptype_full.vctrs_quantiles) +export(vec_quantiles) export(xgb_predict) export(xgb_train) import(rlang) @@ -396,6 +402,9 @@ importFrom(purrr,map) importFrom(purrr,map_chr) importFrom(purrr,map_dbl) importFrom(purrr,map_lgl) +importFrom(rlang,"!!!") +importFrom(rlang,is_double) +importFrom(rlang,is_list) importFrom(stats,.checkMFClasses) importFrom(stats,.getXlevels) importFrom(stats,as.formula) @@ -426,5 +435,6 @@ importFrom(utils,globalVariables) importFrom(utils,head) importFrom(utils,methods) importFrom(utils,stack) +importFrom(vctrs,obj_print_footer) importFrom(vctrs,vec_size) importFrom(vctrs,vec_unique) diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index 857da4124..8ad86a33c 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -21,21 +21,117 @@ check_quantile_level <- function(x, object, call) { x } + +# ------------------------------------------------------------------------- +# A column vector of quantiles with an attribute + +#' @export +vec_ptype_abbr.vctrs_quantiles <- function(x, ...) "qntls" + +#' @export +vec_ptype_full.vctrs_quantiles <- function(x, ...) "quantiles" + +#' @importFrom rlang is_list is_double !!! +new_vec_quantiles <- function(values = list(), quantile_levels = double()) { + quantile_levels <- vctrs::vec_cast(quantile_levels, double()) + num_values <- vctrs::vec_size_common(!!!values) + if (length(quantile_levels) != num_values) { + cli::cli_abort( + "{.arg quantile_levels} must have the same length as {.arg values}. It has + length {.val {length(quantile_levels)}} not {.val {num_values}}." + ) + } + purrr::walk( + quantile_levels, + ~ check_number_decimal(.x, min = 0, max = 1, arg = "quantile_levels") + ) + vctrs::new_vctr( + values, quantile_levels = quantile_levels, class = "vctrs_quantiles" + ) +} + + +#' A vector containing sets of quantiles +#' +#' @param values A data.frame/matrix/vector of values. If a named data.frame, +#' the column names will be used as the `quantile_levels` if those are missing. +#' @param quantile_levels A vector of probabilities corresponding to `values`. +#' May be `NULL` if `values` is a named data.frame. +#' +#' @export +#' +#' @examples +#' preds <- vec_quantiles(list(1:4, 8:11), c(.2, .4, .6, .8)) +#' +#' vec_quantiles(1:4, 1:4 / 5) +vec_quantiles <- function(values, quantile_levels = NULL) { + check_vec_quantiles_inputs(values, quantile_levels) +# TODO save call reference + quantile_levels <- vctrs::vec_cast(quantile_levels, double()) + + num_lvls <- length(quantile_levels) + + if (is.data.frame(values) || (is.matrix(values) && length(dim(values)) == 2)) { + values <- lapply(vctrs::vec_chop(values), function(v) sort(drop(v))) + } else if (is.list(values)) { + values <- values + } else if (is.null(dim(values))) { + if (length(values) != num_lvls) { + values <- vctrs::vec_chop(values) + } + } else { + cli::cli_abort( + "{.arg values} must be a {.cls list}, {.cls matrix}, or {.cls data.frame}, + not a {.cls {class(values)}}." + ) + } + new_vec_quantiles(values, quantile_levels) +} + +check_vec_quantiles_inputs <- function(values, levels) { + if (is.null(levels)) { + if (!is.data.frame(values)) { + cli::cli_abort("If {.arg quantile_levels} is `NULL`, {.arg values} must + be a data.frame.") + } + levels <- as.numeric(names(values)) + if (any(is.na(levels))) { + cli::cli_abort("If {.arg quantile_levels} is `NULL`, {.arg values} must + be a data.frame with numeric names.") + } + } + invisible(NULL) +} + +#' @export +format.vctrs_quantiles <- function(x, ...) { + quantile_levels <- attr(x, "levels") + if (length(quantile_levels) == 1L) { + x <- unlist(x) + out <- round(x, 3L) + out[is.na(x)] <- NA + } else { + rng <- sapply(x, range) + out <- paste0("[", round(rng[1, ], 3L), ", ", round(rng[2, ], 3L), "]") + out[is.na(rng[1, ]) | is.na(rng[2, ])] <- NA + } + out +} + +#' @importFrom vctrs obj_print_footer +#' @export +vctrs::obj_print_footer + +#' @export +obj_print_footer.vctrs_quantiles <- function(x, ...) { + lvls <- attr(x, "quantile_levels") + cat("# Quantile levels: ", format(lvls, digits = 3), "\n", sep = " ") +} + # Assumes the columns have the same order as quantile_level restructure_rq_pred <- function(x, object) { - n <- nrow(x) - p <- ncol(x) - # TODO check p = length(quantile_level) - # check p = 1 case quantile_level <- object$spec$quantile_level - res <- - tibble::tibble( - .pred_quantile = as.vector(x), - .quantile_level = rep(quantile_level, each = n), - .row = rep(1:n, p)) - res <- vctrs::vec_split(x = res[,1:2], by = res[, ".row"]) - res <- vctrs::vec_cbind(res$key, tibble::new_tibble(list(.pred_quantile = res$val))) - res$.row <- NULL + res <- tibble(.pred_quantile = vec_quantiles(x, quantile_level)) res } diff --git a/R/linear_reg_data.R b/R/linear_reg_data.R index f504607e4..895316bbc 100644 --- a/R/linear_reg_data.R +++ b/R/linear_reg_data.R @@ -612,24 +612,6 @@ set_encoding( ) ) -set_pred( - model = "linear_reg", - eng = "quantreg", - mode = "quantile regression", - type = "numeric", - value = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "response", - rankdeficient = "simple" - ) - ) -) set_pred( model = "linear_reg", diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index 5ea5e6f1d..ac3540635 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -82,6 +82,7 @@ set_new_model("rand_forest") set_model_mode("rand_forest", "classification") set_model_mode("rand_forest", "regression") set_model_mode("rand_forest", "censored regression") +set_model_mode("rand_forest", "quantile regression") # ------------------------------------------------------------------------------ # ranger components @@ -608,3 +609,201 @@ set_pred( dataset = quote(new_data)) ) ) + + +# ------------------------------------------------------------------------- +# wrappers for grf +process_quantile_forest_preds <- function(x, object) { + quantile_levels <- extract_fit_engine(object)$quantiles.orig + out <- lapply(vctrs::vec_chop(x$predictions), function(x) sort(drop(x))) + tibble(.pred_quantile = vec_quantiles(out, quantile_levels)) +} +process_regression_forest_preds <- function(x, object) { + tibble(.pred = x$predictions) +} +process_probability_forest_class <- function(x, object) { + x <- x$predictions + max_class <- factor( + colnames(x)[apply(x, 1, which.max)], + levels = colnames(x) + ) + tibble(.pred_class = max_class) +} +process_probability_forest_prob <- function(x, object) { + as_tibble(x$predictions) +} + +# grf components + +set_model_engine("rand_forest", "quantile regression", "grf") +set_model_engine("rand_forest", "regression", "grf") +set_model_engine("rand_forest", "classification", "grf") +set_dependency( + model = "rand_forest", + eng = "grf", + pkg = "grf" +) +set_model_arg( + model = "rand_forest", + eng = "grf", + parsnip = "mtry", + original = "mtry", + func = list(pkg = "dials", fun = "mtry"), + has_submodel = FALSE +) +set_model_arg( + model = "rand_forest", + eng = "grf", + parsnip = "trees", + original = "num.trees", + func = list(pkg = "dials", fun = "trees"), + has_submodel = FALSE +) +set_model_arg( + model = "rand_forest", + eng = "grf", + parsnip = "min_n", + original = "min.node.size", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE +) + + +set_fit( + model = "rand_forest", + eng = "grf", + mode = "quantile regression", + value = list( + interface = "matrix", + protect = c("x", "y"), + data = c(x = "X", y = "Y"), + func = c(pkg = "grf", fun = "quantile_forest"), + defaults = list( + quantiles = expr(quantile_level), + num.threads = 1L, + seed = expr(runif(1, 0, .Machine$integer.max)) + ) + ) +) +set_encoding( + model = "rand_forest", + eng = "grf", + mode = "quantile regression", + options = list( + predictor_indicators = "one_hot", + compute_intercept = FALSE, + remove_intercept = FALSE, + allow_sparse_x = FALSE + ) +) + +set_fit( + model = "rand_forest", + eng = "grf", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y", "weights"), + data = c(x = "X", y = "Y", weights = "sample.weights"), + func = c(pkg = "grf", fun = "regression_forest"), + defaults = list( + num.threads = 1L, + seed = rlang::expr(stats::runif(1, 0, .Machine$integer.max)) + ) + ) +) +set_encoding( + model = "rand_forest", + eng = "grf", + mode = "regression", + options = list( + predictor_indicators = "one_hot", + compute_intercept = FALSE, + remove_intercept = FALSE, + allow_sparse_x = FALSE + ) +) + +set_fit( + model = "rand_forest", + eng = "grf", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y", "weights"), + data = c(x = "X", y = "Y", weights = "sample.weights"), + func = c(pkg = "grf", fun = "probability_forest"), + defaults = list( + num.threads = 1L, + seed = rlang::expr(stats::runif(1, 0, .Machine$integer.max)) + ) + ) +) +set_encoding( + model = "rand_forest", + eng = "grf", + mode = "classification", + options = list( + predictor_indicators = "one_hot", + compute_intercept = FALSE, + remove_intercept = FALSE, + allow_sparse_x = FALSE + ) +) + +set_pred( + model = "rand_forest", + eng = "grf", + mode = "quantile regression", + type = "quantile", + value = pred_value_template( + pre = NULL, + post = process_quantile_forest_preds, + func = c(fun = "predict"), + object = expr(object$fit), + newdata = expr(new_data), + seed = expr(sample.int(10^5, 1)), + verbose = FALSE + ) +) +set_pred( + model = "rand_forest", + eng = "grf", + mode = "regression", + type = "numeric", + value = pred_value_template( + pre = NULL, + post = process_regression_forest_preds, + func = c(fun = "predict"), + object = quote(object$fit), + newdata = quote(new_data) + ) +) +set_pred( + model = "rand_forest", + eng = "grf", + mode = "classification", + type = "class", + value = pred_value_template( + pre = NULL, + post = process_probability_forest_class, + func = c(fun = "predict"), + object = quote(object$fit), + newdata = quote(new_data) + ) +) +set_pred( + model = "rand_forest", + eng = "grf", + mode = "classification", + type = "prob", + value = pred_value_template( + pre = NULL, + post = process_probability_forest_prob, + func = c(fun = "predict"), + object = quote(object$fit), + newdata = quote(new_data) + ) +) + + diff --git a/man/reexports.Rd b/man/reexports.Rd index f87bde459..f051744e2 100644 --- a/man/reexports.Rd +++ b/man/reexports.Rd @@ -1,8 +1,9 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/reexports.R, R/varying.R +% Please edit documentation in R/aaa_quantiles.R, R/reexports.R, R/varying.R \docType{import} \name{reexports} \alias{reexports} +\alias{obj_print_footer} \alias{autoplot} \alias{\%>\%} \alias{fit} @@ -34,5 +35,7 @@ below to see their documentation. \item{hardhat}{\code{\link[hardhat:hardhat-extract]{extract_fit_engine}}, \code{\link[hardhat:hardhat-extract]{extract_fit_time}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_dials}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_set_dials}}, \code{\link[hardhat:hardhat-extract]{extract_spec_parsnip}}, \code{\link[hardhat]{frequency_weights}}, \code{\link[hardhat]{importance_weights}}, \code{\link[hardhat]{tune}}} \item{magrittr}{\code{\link[magrittr:pipe]{\%>\%}}} + + \item{vctrs}{\code{\link[vctrs:obj_print]{obj_print_footer}}} }} diff --git a/man/vec_quantiles.Rd b/man/vec_quantiles.Rd new file mode 100644 index 000000000..e47ada958 --- /dev/null +++ b/man/vec_quantiles.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/aaa_quantiles.R +\name{vec_quantiles} +\alias{vec_quantiles} +\title{A vector containing sets of quantiles} +\usage{ +vec_quantiles(values, quantile_levels = NULL) +} +\arguments{ +\item{values}{A data.frame/matrix/vector of values. If a named data.frame, +the column names will be used as the \code{quantile_levels} if those are missing.} + +\item{quantile_levels}{A vector of probabilities corresponding to \code{values}. +May be \code{NULL} if \code{values} is a named data.frame.} +} +\description{ +A vector containing sets of quantiles +} +\examples{ +preds <- vec_quantiles(list(1:4, 8:11), c(.2, .4, .6, .8)) + +vec_quantiles(1:4, 1:4 / 5) +}