diff --git a/DESCRIPTION b/DESCRIPTION index a2dbbcb0..7f0ad675 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -23,7 +23,7 @@ Imports: cli, dplyr (>= 1.1.0), generics (>= 0.1.2), - hardhat (>= 1.3.0), + hardhat (>= 1.4.2.9000), lifecycle (>= 1.0.3), rlang (>= 1.1.4), tibble, @@ -49,7 +49,7 @@ Config/usethis/last-upkeep: 2025-04-24 Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.2 +RoxygenNote: 7.3.3 Collate: 'aaa-metrics.R' 'import-standalone-types-check.R' @@ -116,6 +116,7 @@ Collate: 'prob-roc_aunp.R' 'prob-roc_aunu.R' 'prob-roc_curve.R' + 'quant-weighted_interval.R' 'reexports.R' 'surv-brier_survival.R' 'surv-brier_survival_integrated.R' diff --git a/NAMESPACE b/NAMESPACE index acddadbc..ad601a76 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -149,6 +149,7 @@ export(check_dynamic_survival_metric) export(check_numeric_metric) export(check_ordered_prob_metric) export(check_prob_metric) +export(check_quantile_metric) export(check_static_survival_metric) export(class_metric_summarizer) export(classification_cost) @@ -262,6 +263,8 @@ export(specificity_vec) export(static_survival_metric_summarizer) export(tidy) export(validate_estimator) +export(weighted_interval_score) +export(weighted_interval_score_vec) export(yardstick_any_missing) export(yardstick_remove_missing) import(rlang) diff --git a/R/check-metric.R b/R/check-metric.R index eebdfb46..8b2042a4 100644 --- a/R/check-metric.R +++ b/R/check-metric.R @@ -15,6 +15,7 @@ #' - For `check_ordered_prob_metric()`, an ordered factor. #' - For `check_dynamic_survival_metric()`, a Surv object. #' - For `check_static_survival_metric()`, a Surv object. +#' - For `check_quantile_metric()`, a numeric vector. #' #' @param estimate The realized `estimate` result. #' - For `check_numeric_metric()`, a numeric vector. @@ -25,6 +26,7 @@ #' a numeric matrix for multic-class `truth`. #' - For `check_dynamic_survival_metric()`, list-column of data.frames. #' - For `check_static_survival_metric()`, a numeric vector. +#' - For `check_quantile_metric()`, a `hardhat::quantile_pred` vector. #' #' @param case_weights The realized case weights, as a numeric vector. This must #' be the same length as `truth`. @@ -120,3 +122,15 @@ check_static_survival_metric <- function( validate_case_weights(case_weights, size = nrow(truth), call = call) validate_surv_truth_numeric_estimate(truth, estimate, call = call) } + +#' @rdname check_metric +#' @export +check_quantile_metric <- function( + truth, + estimate, + case_weights, + call = caller_env() +) { + validate_numeric_truth_quantile_estimate(truth, estimate, call = call) + validate_case_weights(case_weights, size = nrow(truth), call = call) +} diff --git a/R/quant-weighted_interval.R b/R/quant-weighted_interval.R new file mode 100644 index 00000000..f519a73e --- /dev/null +++ b/R/quant-weighted_interval.R @@ -0,0 +1,174 @@ +#' Compute weighted interval score +#' +#' Weighted interval score (WIS), a well-known quantile-based +#' approximation of the commonly-used continuous ranked probability score +#' (CRPS). WIS is a proper score, and can be thought of as a distributional +#' generalization of absolute error. For example, see [Bracher et +#' al. (2020)](https://arxiv.org/abs/2005.12881) for discussion in the context +#' of COVID-19 forecasting. +#' +#' @param x A vector of class `quantile_pred`. +#' @param actual double. Actual value(s) +#' @param quantile_levels probabilities. If specified, the score will be +#' computed at this set of levels. Otherwise, those present in `x` will be +#' used. If `quantile_levels` do not exactly match those available in `x`, +#' then some quantiles will have implicit missingness. Handling of these +#' is determined by `quantile_estimate_nas`. +#' @param quantile_estimate_nas character. This argument applies only to `x`. +#' It handles imputation of individual `quantile_levels` that are necessary to +#' compute a score. Because each element of `x` is a [hardhat::quantile_pred], +#' it is possible for these to be missing for particular +#' `quantile_levels`. There are a number of different possibilities for such +#' missingness. The options are as follows: +#' * For `"impute"`, both explicit and implicit missing values will be imputed +#' using [hardhat::impute_quantiles()] prior to the calculation of the score. +#' So the score will be `NA` only if imputation fails. +#' * For `"drop"`, any explicit missing values will be removed +#' before calculating the score for a particular prediction. This may be +#' reasonable due to the weighting. For example, if the estimate has +#' `quantile_levels = c(.25, .5, .75)` but the median is `NA` for a particular +#' prediction, it may be reasonable to average the accuracy of `c(.25, .75)` +#' for that prediction with others that don't have missingness. This option +#' is only works if `quantile_levels = NULL` or is a subset of the +#' `quantile_levels` in `x`. +#' * For `"propagate"`, any missing value predictions will result in that +#' element of `x` having a score of `NA`. If `na_rm = TRUE`, then these will +#' be removed before averaging. +#' @param na_rm logical. If `TRUE`, missing values in `actual` or both implicit and +#' explicit (values of `NA` present in `x`), will be ignored (dropped) in the +#' calculation of the summary score. If `FALSE` (the default), any `NA`s will +#' result in the summary being `NA`. +#' @param ... not used +#' +#' @return a vector of nonnegative scores. +#' +#' @export +#' @examples +#' quantile_levels <- c(.2, .4, .6, .8) +#' pred1 <- 1:4 +#' pred2 <- 8:11 +#' preds <- quantile_pred(rbind(pred1, pred2), quantile_levels) +#' truth <- c(3.3, 7.1) +#' weighted_interval_score_vec(truth, preds) +#' weighted_interval_score_vec(truth, preds, quantile_levels = c(.25, .5, .75)) +#' +#' # Missing value behaviours +#' +#' preds_na <- quantile_pred(rbind(pred1, c(1, 2, NA, 4)), 1:4 / 5) +#' truth <- c(2.5, 2.5) +#' weighted_interval_score_vec(truth, preds_na) +#' weighted_interval_score_vec(truth, preds_na, quantile_levels = 1:9 / 10) +#' expect_error(weighted_interval_score_vec( +#' truth, +#' preds_na, +#' quantile_levels = 1:9 / 10, +#' quantile_estimate_nas = "drop" +#' )) +#' weighted_interval_score_vec( +#' truth, +#' preds_na, +#' quantile_levels = c(2, 3) / 5, +#' quantile_estimate_nas = "drop" +#' ) +#' weighted_interval_score_vec( +#' truth, preds_na, na_rm = TRUE, quantile_estimate_nas = "propagate" +#' ) +#' weighted_interval_score_vec( +#' truth, preds_na, quantile_estimate_nas = "propagate" +#' ) +#' +weighted_interval_score <- function(data, ...) { + UseMethod("weighted_interval_score") +} +weighted_interval_score <- new_numeric_metric( + mae, + direction = "minimize" +) + +#' @export +#' @rdname weighted_interval_score +weighted_interval_score_vec <- function( + truth, + estimate, + quantile_levels = NULL, + na_rm = FALSE, + quantile_estimate_nas = c("impute", "drop", "propagate"), + case_weights = NULL, + ... +) { + check_quantile_metric(truth, estimate, case_weights) + estimate_quantile_levels <- hardhat::extract_quantile_levels(estimate) + quantile_estimate_nas <- rlang::arg_match(quantile_estimate_nas) + if (!is.null(quantile_levels)) { + hardhat::check_quantile_levels(quantile_levels) + all_levels_estimated <- all(quantile_levels %in% estimate_quantile_levels) + if (quantile_estimate_nas == "drop" && !all_levels_estimated) { + cli::cli_abort( + "When `quantile_levels` is not a subset of those available in `estimate`, + `quantile_estimate_nas` may not be `'drop'`." + ) + } + if (!all_levels_estimated && (quantile_estimate_nas == "propagate")) { + # We requested particular levels, but the levels aren't all there, + # and NAs propagate, so return NA + return(NA_real_) + } + } + + quantile_levels <- quantile_levels %||% estimate_quantile_levels + if (quantile_estimate_nas %in% c("drop", "propagate")) { + levels_estimated <- estimate_quantile_levels %in% quantile_levels + estimate <- as.matrix(estimate)[, levels_estimated, drop = FALSE] + } else { + estimate <- as.matrix(hardhat::impute_quantiles(estimate, quantile_levels)) + } + + vec_wis <- wis_impl( + truth = truth, + estimate = estimate, + quantile_levels = quantile_levels, + rowwise_na_rm = (quantile_estimate_nas == "drop") + ) + + if (na_rm) { + result <- yardstick_remove_missing(truth, vec_wis, case_weights) + + truth <- result$truth + vec_wis <- result$estimate + case_weights <- result$case_weights + } else if (yardstick_any_missing(truth, vec_wis, case_weights)) { + return(NA_real_) + } + + yardstick_mean(vec_wis, case_weights = case_weights) +} + +wis_impl <- function( + truth, + estimate, + quantile_levels, + rowwise_na_rm = TRUE +) { + as.vector( + mapply( + FUN = function(.x, .y) { + wis_one_quantile(.x, quantile_levels, .y, rowwise_na_rm) + }, + vctrs::vec_chop(estimate), + truth + ), + "double" + ) +} + + +wis_one_quantile <- function(values, quantile_levels, truth, na_rm) { + 2 * + mean( + pmax( + quantile_levels * (truth - values), + (1 - quantile_levels) * (values - truth) + ), + na.rm = na_rm + ) +} diff --git a/R/validation.R b/R/validation.R index 547e072a..d75be0de 100644 --- a/R/validation.R +++ b/R/validation.R @@ -443,3 +443,43 @@ validate_case_weights <- function(case_weights, size, call = caller_env()) { invisible(NULL) } + +validate_numeric_truth_quantile_estimate <- function( + truth, + estimate, + call = caller_env() +) { + if (!is.numeric(truth)) { + cli::cli_abort( + "{.arg truth} should be a numeric vector, + not {.obj_type_friendly {truth}}.", + call = call + ) + } + + if (!inherits(estimate, "quantile_pred")) { + cli::cli_abort( + "{.arg estimate} should be a {.cls quantile_pred} object, + not {.obj_type_friendly {estimate}}.", + call = call + ) + } + + if (is.matrix(truth)) { + cli::cli_abort( + "{.arg truth} should be a numeric vector, not a numeric matrix.", + call = call + ) + } + + n_truth <- length(truth) + n_estimate <- vctrs::vec_size(estimate) + + if (n_truth != n_estimate) { + cli::cli_abort( + "{.arg truth} ({n_truth}) and + {.arg estimate} ({n_estimate}) must be the same length.", + call = call + ) + } +} diff --git a/man/check_metric.Rd b/man/check_metric.Rd index 9281ae66..2c574037 100644 --- a/man/check_metric.Rd +++ b/man/check_metric.Rd @@ -8,6 +8,7 @@ \alias{check_ordered_prob_metric} \alias{check_dynamic_survival_metric} \alias{check_static_survival_metric} +\alias{check_quantile_metric} \title{Developer function for checking inputs in new metrics} \usage{ check_numeric_metric(truth, estimate, case_weights, call = caller_env()) @@ -49,6 +50,8 @@ check_static_survival_metric( case_weights, call = caller_env() ) + +check_quantile_metric(truth, estimate, case_weights, call = caller_env()) } \arguments{ \item{truth}{The realized vector of \code{truth}. @@ -59,6 +62,7 @@ check_static_survival_metric( \item For \code{check_ordered_prob_metric()}, an ordered factor. \item For \code{check_dynamic_survival_metric()}, a Surv object. \item For \code{check_static_survival_metric()}, a Surv object. +\item For \code{check_quantile_metric()}, a numeric vector. }} \item{estimate}{The realized \code{estimate} result. @@ -71,6 +75,7 @@ a numeric matrix for multic-class \code{truth}. a numeric matrix for multic-class \code{truth}. \item For \code{check_dynamic_survival_metric()}, list-column of data.frames. \item For \code{check_static_survival_metric()}, a numeric vector. +\item For \code{check_quantile_metric()}, a \code{hardhat::quantile_pred} vector. }} \item{case_weights}{The realized case weights, as a numeric vector. This must diff --git a/man/weighted_interval_score.Rd b/man/weighted_interval_score.Rd new file mode 100644 index 00000000..5791decc --- /dev/null +++ b/man/weighted_interval_score.Rd @@ -0,0 +1,105 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/quant-weighted_interval.R +\name{weighted_interval_score} +\alias{weighted_interval_score} +\alias{weighted_interval_score_vec} +\title{Compute weighted interval score} +\usage{ +weighted_interval_score(data, ...) + +weighted_interval_score_vec( + truth, + estimate, + quantile_levels = NULL, + na_rm = FALSE, + quantile_estimate_nas = c("impute", "drop", "propagate"), + case_weights = NULL, + ... +) +} +\arguments{ +\item{...}{not used} + +\item{quantile_levels}{probabilities. If specified, the score will be +computed at this set of levels. Otherwise, those present in \code{x} will be +used. If \code{quantile_levels} do not exactly match those available in \code{x}, +then some quantiles will have implicit missingness. Handling of these +is determined by \code{quantile_estimate_nas}.} + +\item{na_rm}{logical. If \code{TRUE}, missing values in \code{actual} or both implicit and +explicit (values of \code{NA} present in \code{x}), will be ignored (dropped) in the +calculation of the summary score. If \code{FALSE} (the default), any \code{NA}s will +result in the summary being \code{NA}.} + +\item{quantile_estimate_nas}{character. This argument applies only to \code{x}. +It handles imputation of individual \code{quantile_levels} that are necessary to +compute a score. Because each element of \code{x} is a \link[hardhat:quantile_pred]{hardhat::quantile_pred}, +it is possible for these to be missing for particular +\code{quantile_levels}. There are a number of different possibilities for such +missingness. The options are as follows: +\itemize{ +\item For \code{"impute"}, both explicit and implicit missing values will be imputed +using \code{\link[hardhat:impute_quantiles]{hardhat::impute_quantiles()}} prior to the calculation of the score. +So the score will be \code{NA} only if imputation fails. +\item For \code{"drop"}, any explicit missing values will be removed +before calculating the score for a particular prediction. This may be +reasonable due to the weighting. For example, if the estimate has +\code{quantile_levels = c(.25, .5, .75)} but the median is \code{NA} for a particular +prediction, it may be reasonable to average the accuracy of \code{c(.25, .75)} +for that prediction with others that don't have missingness. This option +is only works if \code{quantile_levels = NULL} or is a subset of the +\code{quantile_levels} in \code{x}. +\item For \code{"propagate"}, any missing value predictions will result in that +element of \code{x} having a score of \code{NA}. If \code{na_rm = TRUE}, then these will +be removed before averaging. +}} + +\item{x}{A vector of class \code{quantile_pred}.} + +\item{actual}{double. Actual value(s)} +} +\value{ +a vector of nonnegative scores. +} +\description{ +Weighted interval score (WIS), a well-known quantile-based +approximation of the commonly-used continuous ranked probability score +(CRPS). WIS is a proper score, and can be thought of as a distributional +generalization of absolute error. For example, see \href{https://arxiv.org/abs/2005.12881}{Bracher et al. (2020)} for discussion in the context +of COVID-19 forecasting. +} +\examples{ +quantile_levels <- c(.2, .4, .6, .8) +pred1 <- 1:4 +pred2 <- 8:11 +preds <- quantile_pred(rbind(pred1, pred2), quantile_levels) +truth <- c(3.3, 7.1) +weighted_interval_score_vec(truth, preds) +weighted_interval_score_vec(truth, preds, quantile_levels = c(.25, .5, .75)) + +# Missing value behaviours + +preds_na <- quantile_pred(rbind(pred1, c(1, 2, NA, 4)), 1:4 / 5) +truth <- c(2.5, 2.5) +weighted_interval_score_vec(truth, preds_na) +weighted_interval_score_vec(truth, preds_na, quantile_levels = 1:9 / 10) +expect_error(weighted_interval_score_vec( + truth, + preds_na, + quantile_levels = 1:9 / 10, + quantile_estimate_nas = "drop" +)) +weighted_interval_score_vec( + truth, + preds_na, + quantile_levels = c(2, 3) / 5, + quantile_estimate_nas = "drop" +) +weighted_interval_score_vec( + truth, preds_na, na_rm = TRUE, quantile_estimate_nas = "propagate" +) +weighted_interval_score_vec( + truth, preds_na, quantile_estimate_nas = "propagate" +) + +}