Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Imports:
cli,
dplyr (>= 1.1.0),
generics (>= 0.1.2),
hardhat (>= 1.3.0),
hardhat (>= 1.4.2.9000),
lifecycle (>= 1.0.3),
rlang (>= 1.1.4),
tibble,
Expand All @@ -49,7 +49,7 @@ Config/usethis/last-upkeep: 2025-04-24
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.2
RoxygenNote: 7.3.3
Collate:
'aaa-metrics.R'
'import-standalone-types-check.R'
Expand Down Expand Up @@ -116,6 +116,7 @@ Collate:
'prob-roc_aunp.R'
'prob-roc_aunu.R'
'prob-roc_curve.R'
'quant-weighted_interval.R'
'reexports.R'
'surv-brier_survival.R'
'surv-brier_survival_integrated.R'
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ export(check_dynamic_survival_metric)
export(check_numeric_metric)
export(check_ordered_prob_metric)
export(check_prob_metric)
export(check_quantile_metric)
export(check_static_survival_metric)
export(class_metric_summarizer)
export(classification_cost)
Expand Down Expand Up @@ -262,6 +263,8 @@ export(specificity_vec)
export(static_survival_metric_summarizer)
export(tidy)
export(validate_estimator)
export(weighted_interval_score)
export(weighted_interval_score_vec)
export(yardstick_any_missing)
export(yardstick_remove_missing)
import(rlang)
Expand Down
14 changes: 14 additions & 0 deletions R/check-metric.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#' - For `check_ordered_prob_metric()`, an ordered factor.
#' - For `check_dynamic_survival_metric()`, a Surv object.
#' - For `check_static_survival_metric()`, a Surv object.
#' - For `check_quantile_metric()`, a numeric vector.
#'
#' @param estimate The realized `estimate` result.
#' - For `check_numeric_metric()`, a numeric vector.
Expand All @@ -25,6 +26,7 @@
#' a numeric matrix for multic-class `truth`.
#' - For `check_dynamic_survival_metric()`, list-column of data.frames.
#' - For `check_static_survival_metric()`, a numeric vector.
#' - For `check_quantile_metric()`, a `hardhat::quantile_pred` vector.
#'
#' @param case_weights The realized case weights, as a numeric vector. This must
#' be the same length as `truth`.
Expand Down Expand Up @@ -120,3 +122,15 @@ check_static_survival_metric <- function(
validate_case_weights(case_weights, size = nrow(truth), call = call)
validate_surv_truth_numeric_estimate(truth, estimate, call = call)
}

#' @rdname check_metric
#' @export
check_quantile_metric <- function(
truth,
estimate,
case_weights,
call = caller_env()
) {
validate_numeric_truth_quantile_estimate(truth, estimate, call = call)
validate_case_weights(case_weights, size = nrow(truth), call = call)
}
174 changes: 174 additions & 0 deletions R/quant-weighted_interval.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
#' Compute weighted interval score
#'
#' Weighted interval score (WIS), a well-known quantile-based
#' approximation of the commonly-used continuous ranked probability score
#' (CRPS). WIS is a proper score, and can be thought of as a distributional
#' generalization of absolute error. For example, see [Bracher et
#' al. (2020)](https://arxiv.org/abs/2005.12881) for discussion in the context
#' of COVID-19 forecasting.
#'
#' @param x A vector of class `quantile_pred`.
#' @param actual double. Actual value(s)
#' @param quantile_levels probabilities. If specified, the score will be
#' computed at this set of levels. Otherwise, those present in `x` will be
#' used. If `quantile_levels` do not exactly match those available in `x`,
#' then some quantiles will have implicit missingness. Handling of these
#' is determined by `quantile_estimate_nas`.
#' @param quantile_estimate_nas character. This argument applies only to `x`.
#' It handles imputation of individual `quantile_levels` that are necessary to
#' compute a score. Because each element of `x` is a [hardhat::quantile_pred],
#' it is possible for these to be missing for particular
#' `quantile_levels`. There are a number of different possibilities for such
#' missingness. The options are as follows:
#' * For `"impute"`, both explicit and implicit missing values will be imputed
#' using [hardhat::impute_quantiles()] prior to the calculation of the score.
#' So the score will be `NA` only if imputation fails.
#' * For `"drop"`, any explicit missing values will be removed
#' before calculating the score for a particular prediction. This may be
#' reasonable due to the weighting. For example, if the estimate has
#' `quantile_levels = c(.25, .5, .75)` but the median is `NA` for a particular
#' prediction, it may be reasonable to average the accuracy of `c(.25, .75)`
#' for that prediction with others that don't have missingness. This option
#' is only works if `quantile_levels = NULL` or is a subset of the
#' `quantile_levels` in `x`.
#' * For `"propagate"`, any missing value predictions will result in that
#' element of `x` having a score of `NA`. If `na_rm = TRUE`, then these will
#' be removed before averaging.
#' @param na_rm logical. If `TRUE`, missing values in `actual` or both implicit and
#' explicit (values of `NA` present in `x`), will be ignored (dropped) in the
#' calculation of the summary score. If `FALSE` (the default), any `NA`s will
#' result in the summary being `NA`.
#' @param ... not used
#'
#' @return a vector of nonnegative scores.
#'
#' @export
#' @examples
#' quantile_levels <- c(.2, .4, .6, .8)
#' pred1 <- 1:4
#' pred2 <- 8:11
#' preds <- quantile_pred(rbind(pred1, pred2), quantile_levels)
#' truth <- c(3.3, 7.1)
#' weighted_interval_score_vec(truth, preds)
#' weighted_interval_score_vec(truth, preds, quantile_levels = c(.25, .5, .75))
#'
#' # Missing value behaviours
#'
#' preds_na <- quantile_pred(rbind(pred1, c(1, 2, NA, 4)), 1:4 / 5)
#' truth <- c(2.5, 2.5)
#' weighted_interval_score_vec(truth, preds_na)
#' weighted_interval_score_vec(truth, preds_na, quantile_levels = 1:9 / 10)
#' expect_error(weighted_interval_score_vec(
#' truth,
#' preds_na,
#' quantile_levels = 1:9 / 10,
#' quantile_estimate_nas = "drop"
#' ))
#' weighted_interval_score_vec(
#' truth,
#' preds_na,
#' quantile_levels = c(2, 3) / 5,
#' quantile_estimate_nas = "drop"
#' )
#' weighted_interval_score_vec(
#' truth, preds_na, na_rm = TRUE, quantile_estimate_nas = "propagate"
#' )
#' weighted_interval_score_vec(
#' truth, preds_na, quantile_estimate_nas = "propagate"
#' )
#'
weighted_interval_score <- function(data, ...) {
UseMethod("weighted_interval_score")
}
weighted_interval_score <- new_numeric_metric(
mae,
direction = "minimize"
)

#' @export
#' @rdname weighted_interval_score
weighted_interval_score_vec <- function(
truth,
estimate,
quantile_levels = NULL,
na_rm = FALSE,
quantile_estimate_nas = c("impute", "drop", "propagate"),
case_weights = NULL,
...
) {
check_quantile_metric(truth, estimate, case_weights)
estimate_quantile_levels <- hardhat::extract_quantile_levels(estimate)
quantile_estimate_nas <- rlang::arg_match(quantile_estimate_nas)
if (!is.null(quantile_levels)) {
hardhat::check_quantile_levels(quantile_levels)
all_levels_estimated <- all(quantile_levels %in% estimate_quantile_levels)
if (quantile_estimate_nas == "drop" && !all_levels_estimated) {
cli::cli_abort(
"When `quantile_levels` is not a subset of those available in `estimate`,
`quantile_estimate_nas` may not be `'drop'`."
)
}
if (!all_levels_estimated && (quantile_estimate_nas == "propagate")) {
# We requested particular levels, but the levels aren't all there,
# and NAs propagate, so return NA
return(NA_real_)
}
}

quantile_levels <- quantile_levels %||% estimate_quantile_levels
if (quantile_estimate_nas %in% c("drop", "propagate")) {
levels_estimated <- estimate_quantile_levels %in% quantile_levels
estimate <- as.matrix(estimate)[, levels_estimated, drop = FALSE]
} else {
estimate <- as.matrix(hardhat::impute_quantiles(estimate, quantile_levels))
}

vec_wis <- wis_impl(
truth = truth,
estimate = estimate,
quantile_levels = quantile_levels,
rowwise_na_rm = (quantile_estimate_nas == "drop")
)

if (na_rm) {
result <- yardstick_remove_missing(truth, vec_wis, case_weights)

truth <- result$truth
vec_wis <- result$estimate
case_weights <- result$case_weights
} else if (yardstick_any_missing(truth, vec_wis, case_weights)) {
return(NA_real_)
}

yardstick_mean(vec_wis, case_weights = case_weights)
}

wis_impl <- function(
truth,
estimate,
quantile_levels,
rowwise_na_rm = TRUE
) {
as.vector(
mapply(
FUN = function(.x, .y) {
wis_one_quantile(.x, quantile_levels, .y, rowwise_na_rm)
},
vctrs::vec_chop(estimate),
truth
),
"double"
)
}


wis_one_quantile <- function(values, quantile_levels, truth, na_rm) {
2 *
mean(
pmax(
quantile_levels * (truth - values),
(1 - quantile_levels) * (values - truth)
),
na.rm = na_rm
)
}
40 changes: 40 additions & 0 deletions R/validation.R
Original file line number Diff line number Diff line change
Expand Up @@ -443,3 +443,43 @@ validate_case_weights <- function(case_weights, size, call = caller_env()) {

invisible(NULL)
}

validate_numeric_truth_quantile_estimate <- function(
truth,
estimate,
call = caller_env()
) {
if (!is.numeric(truth)) {
cli::cli_abort(
"{.arg truth} should be a numeric vector,
not {.obj_type_friendly {truth}}.",
call = call
)
}

if (!inherits(estimate, "quantile_pred")) {
cli::cli_abort(
"{.arg estimate} should be a {.cls quantile_pred} object,
not {.obj_type_friendly {estimate}}.",
call = call
)
}

if (is.matrix(truth)) {
cli::cli_abort(
"{.arg truth} should be a numeric vector, not a numeric matrix.",
call = call
)
}

n_truth <- length(truth)
n_estimate <- vctrs::vec_size(estimate)

if (n_truth != n_estimate) {
cli::cli_abort(
"{.arg truth} ({n_truth}) and
{.arg estimate} ({n_estimate}) must be the same length.",
call = call
)
}
}
5 changes: 5 additions & 0 deletions man/check_metric.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading