From 10c7d30d763e47cab4890cbfb27a6019de1d089f Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Tue, 25 Nov 2025 14:50:24 +0000 Subject: [PATCH 1/3] add `linear_pred_survival_metric` --- NAMESPACE | 3 + R/aaa-metrics.R | 126 ++++++++++++++++++++++++++++++-------- R/aaa-new.R | 7 +++ R/check-metric.R | 12 ++++ R/template.R | 95 ++++++++++++++++++++++++++++ man/check_metric.Rd | 8 +++ man/metric-summarizers.Rd | 14 +++++ man/new-metric.Rd | 3 + 8 files changed, 243 insertions(+), 25 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index acddadbc..e998640c 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -146,6 +146,7 @@ export(ccc) export(ccc_vec) export(check_class_metric) export(check_dynamic_survival_metric) +export(check_linear_pred_survival_metric) export(check_numeric_metric) export(check_ordered_prob_metric) export(check_prob_metric) @@ -184,6 +185,7 @@ export(j_index_vec) export(kap) export(kap_vec) export(lift_curve) +export(linear_pred_survival_metric_summarizer) export(mae) export(mae_vec) export(mape) @@ -207,6 +209,7 @@ export(new_class_metric) export(new_dynamic_survival_metric) export(new_groupwise_metric) export(new_integrated_survival_metric) +export(new_linear_pred_survival_metric) export(new_numeric_metric) export(new_ordered_prob_metric) export(new_prob_metric) diff --git a/R/aaa-metrics.R b/R/aaa-metrics.R index ef0cbae3..e9af130d 100644 --- a/R/aaa-metrics.R +++ b/R/aaa-metrics.R @@ -272,7 +272,8 @@ metric_set <- function(...) { c( "dynamic_survival_metric", "static_survival_metric", - "integrated_survival_metric" + "integrated_survival_metric", + "linear_pred_survival_metric" ) ) { make_survival_metric_function(fns) @@ -547,37 +548,110 @@ make_survival_metric_function <- function(fns) { # Construct common argument set for each metric call # Doing this dynamically inside the generated function means # we capture the correct arguments - dynamic_call_args <- quos( - data = data, - truth = !!enquo(truth), - ... = ..., - na_rm = na_rm, - case_weights = !!enquo(case_weights), - ... = ... - ) - static_call_args <- quos( - data = data, - truth = !!enquo(truth), - estimate = !!enquo(estimate), - na_rm = na_rm, - case_weights = !!enquo(case_weights), - ... = ... + is_static <- vapply( + fns, + inherits, + logical(1), + "static_survival_metric" ) - - call_class_ind <- vapply( + is_linear_pred <- vapply( fns, inherits, - "static_survival_metric", + logical(1), + "linear_pred_survival_metric" + ) + is_dynamic_or_integrated <- vapply( + fns, + function(fn) { + inherits(fn, "dynamic_survival_metric") || + inherits(fn, "integrated_survival_metric") + }, FUN.VALUE = logical(1) ) - # Construct calls from the functions + arguments - dynamic_calls <- lapply(fns[!call_class_ind], call2, !!!dynamic_call_args) - static_calls <- lapply(fns[call_class_ind], call2, !!!static_call_args) + # Static and linear pred metrics both use the `estimate` argument + # so we need route the columns to the correct metric functions + is_set_of_static_and_linear_pred <- any(is_static) && any(is_linear_pred) + + if (is_set_of_static_and_linear_pred) { + estimate_eval <- tidyselect::eval_select( + expr = enquo(estimate), + data = data, + allow_rename = TRUE, + allow_empty = FALSE, + error_call = current_env() + ) + + estimate_names <- names(estimate_eval) + expected_names <- c("static", "linear_pred") + + if (!setequal(estimate_names, expected_names)) { + cli::cli_abort( + c( + "When mixing static and linear predictor survival metrics, + {.arg estimate} must use named selection.", + "i" = "Use {.code estimate = c(static = col1, linear_pred = col2)}.", + "i" = "Expected names: {.val {expected_names}}.", + "x" = "Received names: {.val {estimate_names}}." + ), + call = current_env() + ) + } + + static_col_name <- names(data)[estimate_eval["static"]] + linear_pred_col_name <- names(data)[estimate_eval["linear_pred"]] - calls <- c(dynamic_calls, static_calls) + args_static <- quos( + data = data, + truth = !!enquo(truth), + estimate = !!sym(static_col_name), + na_rm = na_rm, + case_weights = !!enquo(case_weights), + ... = ... + ) + + args_linear_pred <- quos( + data = data, + truth = !!enquo(truth), + estimate = !!sym(linear_pred_col_name), + na_rm = na_rm, + case_weights = !!enquo(case_weights), + ... = ... + ) + + calls_static <- lapply(fns[is_static], call2, !!!args_static) + calls_linear_pred <- lapply( + fns[is_linear_pred], + call2, + !!!args_linear_pred + ) + + calls_estimate <- c(calls_static, calls_linear_pred) + } else { + args_estimate <- quos( + data = data, + truth = !!enquo(truth), + estimate = !!enquo(estimate), + na_rm = na_rm, + case_weights = !!enquo(case_weights), + ... = ... + ) + + needs_estimate_arg <- is_static | is_linear_pred + calls_estimate <- lapply(fns[needs_estimate_arg], call2, !!!args_estimate) + } + + args_dots <- quos( + data = data, + truth = !!enquo(truth), + ... = ..., + na_rm = na_rm, + case_weights = !!enquo(case_weights) + ) + calls_dots <- lapply(fns[is_dynamic_or_integrated], call2, !!!args_dots) + calls <- c(calls_dots, calls_estimate) calls <- mapply(call_remove_static_arguments, calls, fns) # Evaluate @@ -644,7 +718,8 @@ validate_function_class <- function(fns) { "numeric_metric", "dynamic_survival_metric", "static_survival_metric", - "integrated_survival_metric" + "integrated_survival_metric", + "linear_pred_survival_metric" ) if (n_unique == 1L) { @@ -664,7 +739,8 @@ validate_function_class <- function(fns) { surv_cls <- c( "dynamic_survival_metric", "static_survival_metric", - "integrated_survival_metric" + "integrated_survival_metric", + "linear_pred_survival_metric" ) if (any(fn_cls_unique %in% surv_cls) && all(fn_cls_unique %in% surv_cls)) { return(invisible(fns)) diff --git a/R/aaa-new.R b/R/aaa-new.R index 14953925..0e4c4ec7 100644 --- a/R/aaa-new.R +++ b/R/aaa-new.R @@ -70,6 +70,12 @@ new_static_survival_metric <- function(fn, direction) { new_metric(fn, direction, class = "static_survival_metric") } +#' @rdname new-metric +#' @export +new_linear_pred_survival_metric <- function(fn, direction) { + new_metric(fn, direction, class = "linear_pred_survival_metric") +} + #' @include import-standalone-types-check.R new_metric <- function(fn, direction, class = NULL, call = caller_env()) { check_function(fn, call = call) @@ -121,6 +127,7 @@ format.metric <- function(x, ...) { "dynamic_survival_metric" = "dynamic survival metric", "static_survival_metric" = "static survival metric", "integrated_survival_metric" = "integrated survival metric", + "linear_pred_survival_metric" = "linear predictor survival metric", "metric" ) diff --git a/R/check-metric.R b/R/check-metric.R index eebdfb46..9aee511d 100644 --- a/R/check-metric.R +++ b/R/check-metric.R @@ -120,3 +120,15 @@ check_static_survival_metric <- function( validate_case_weights(case_weights, size = nrow(truth), call = call) validate_surv_truth_numeric_estimate(truth, estimate, call = call) } + +#' @rdname check_metric +#' @export +check_linear_pred_survival_metric <- function( + truth, + estimate, + case_weights, + call = caller_env() +) { + validate_case_weights(case_weights, size = nrow(truth), call = call) + validate_surv_truth_numeric_estimate(truth, estimate, call = call) +} diff --git a/R/template.R b/R/template.R index 5a678a24..5b54ea59 100644 --- a/R/template.R +++ b/R/template.R @@ -823,6 +823,101 @@ curve_survival_metric_summarizer <- function( out } + +#' @rdname metric-summarizers +#' @export +linear_pred_survival_metric_summarizer <- function( + name, + fn, + data, + truth, + estimate, + ..., + na_rm = TRUE, + case_weights = NULL, + fn_options = list(), + error_call = caller_env() +) { + check_dots_empty(call = error_call) + + truth <- enquo(truth) + estimate <- enquo(estimate) + case_weights <- enquo(case_weights) + + truth <- yardstick_eval_select( + expr = truth, + data = data, + arg = "truth", + error_call = error_call + ) + estimate <- yardstick_eval_select( + expr = estimate, + data = data, + arg = "estimate", + error_call = error_call + ) + + if (quo_is_null(case_weights)) { + group_case_weights <- NULL + } else { + case_weights <- yardstick_eval_select( + expr = case_weights, + data = data, + arg = "case_weights", + error_call = error_call + ) + } + + group_rows <- dplyr::group_rows(data) + group_keys <- dplyr::group_keys(data) + data <- dplyr::ungroup(data) + groups <- vec_chop(data, indices = group_rows) + out <- vector("list", length = length(groups)) + + for (i in seq_along(groups)) { + group <- groups[[i]] + + group_truth <- group[[truth]] + group_estimate <- group[[estimate]] + + if (is_string(case_weights)) { + group_case_weights <- group[[case_weights]] + } + + elt_out <- list( + .metric = name, + .estimator = finalize_estimator( + group_truth, + metric_class = name, + call = error_call + ), + .estimate = inject( + withCallingHandlers( + fn( + truth = group_truth, + estimate = group_estimate, + case_weights = group_case_weights, + na_rm = na_rm, + !!!fn_options + ), + error = function(cnd) { + cnd$call <- error_call + cnd_signal(cnd) + } + ) + ) + ) + + out[[i]] <- tibble::new_tibble(elt_out) + } + + group_keys <- vec_rep_each(group_keys, times = list_sizes(out)) + out <- vec_rbind(!!!out) + out <- vec_cbind(group_keys, out) + + out +} + prob_estimate_convert <- function(estimate) { check_data_frame(estimate, .internal = TRUE) diff --git a/man/check_metric.Rd b/man/check_metric.Rd index 9281ae66..19c4b51c 100644 --- a/man/check_metric.Rd +++ b/man/check_metric.Rd @@ -8,6 +8,7 @@ \alias{check_ordered_prob_metric} \alias{check_dynamic_survival_metric} \alias{check_static_survival_metric} +\alias{check_linear_pred_survival_metric} \title{Developer function for checking inputs in new metrics} \usage{ check_numeric_metric(truth, estimate, case_weights, call = caller_env()) @@ -49,6 +50,13 @@ check_static_survival_metric( case_weights, call = caller_env() ) + +check_linear_pred_survival_metric( + truth, + estimate, + case_weights, + call = caller_env() +) } \arguments{ \item{truth}{The realized vector of \code{truth}. diff --git a/man/metric-summarizers.Rd b/man/metric-summarizers.Rd index 5c9cae07..eb2f7149 100644 --- a/man/metric-summarizers.Rd +++ b/man/metric-summarizers.Rd @@ -10,6 +10,7 @@ \alias{dynamic_survival_metric_summarizer} \alias{static_survival_metric_summarizer} \alias{curve_survival_metric_summarizer} +\alias{linear_pred_survival_metric_summarizer} \title{Developer function for summarizing new metrics} \usage{ numeric_metric_summarizer( @@ -118,6 +119,19 @@ curve_survival_metric_summarizer( fn_options = list(), error_call = caller_env() ) + +linear_pred_survival_metric_summarizer( + name, + fn, + data, + truth, + estimate, + ..., + na_rm = TRUE, + case_weights = NULL, + fn_options = list(), + error_call = caller_env() +) } \arguments{ \item{name}{A single character representing the name of the metric to diff --git a/man/new-metric.Rd b/man/new-metric.Rd index 284c7a41..ce027523 100644 --- a/man/new-metric.Rd +++ b/man/new-metric.Rd @@ -9,6 +9,7 @@ \alias{new_dynamic_survival_metric} \alias{new_integrated_survival_metric} \alias{new_static_survival_metric} +\alias{new_linear_pred_survival_metric} \title{Construct a new metric function} \usage{ new_class_metric(fn, direction) @@ -24,6 +25,8 @@ new_dynamic_survival_metric(fn, direction) new_integrated_survival_metric(fn, direction) new_static_survival_metric(fn, direction) + +new_linear_pred_survival_metric(fn, direction) } \arguments{ \item{fn}{A function. The metric function to attach a metric-specific class From ca8c18bbcb43b3ba9677f1dd41523ff787563025 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Tue, 2 Dec 2025 14:31:30 +0000 Subject: [PATCH 2/3] pull out validation into a helper --- R/aaa-metrics.R | 45 ++++++++++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/R/aaa-metrics.R b/R/aaa-metrics.R index e9af130d..add35de0 100644 --- a/R/aaa-metrics.R +++ b/R/aaa-metrics.R @@ -583,21 +583,7 @@ make_survival_metric_function <- function(fns) { error_call = current_env() ) - estimate_names <- names(estimate_eval) - expected_names <- c("static", "linear_pred") - - if (!setequal(estimate_names, expected_names)) { - cli::cli_abort( - c( - "When mixing static and linear predictor survival metrics, - {.arg estimate} must use named selection.", - "i" = "Use {.code estimate = c(static = col1, linear_pred = col2)}.", - "i" = "Expected names: {.val {expected_names}}.", - "x" = "Received names: {.val {estimate_names}}." - ), - call = current_env() - ) - } + validate_estimate_static_linear_pred(estimate_eval, call = current_env()) static_col_name <- names(data)[estimate_eval["static"]] linear_pred_col_name <- names(data)[estimate_eval["linear_pred"]] @@ -819,6 +805,35 @@ validate_function_class <- function(fns) { ) } +validate_estimate_static_linear_pred <- function( + estimate_eval, + call = caller_env() +) { + if (length(estimate_eval) != 2L) { + cli::cli_abort( + "{.arg estimate} must select exactly 2 columns from {.arg data}, + not {length(estimate_eval)}.", + call = call + ) + } + + estimate_names <- names(estimate_eval) + expected_names <- c("static", "linear_pred") + + if (!setequal(estimate_names, expected_names)) { + cli::cli_abort( + c( + "When mixing static and linear predictor survival metrics, + {.arg estimate} must use named selection.", + "i" = "Use {.code estimate = c(static = col1, linear_pred = col2)}.", + "i" = "Expected names: {.val {expected_names}}.", + "x" = "Received names: {.val {estimate_names}}." + ), + call = call + ) + } +} + # Safely evaluate metrics in such a way that we can capture the # error and inform the user of the metric that failed eval_safely <- function(expr, expr_nm, data = NULL, env = caller_env()) { From b8d27540d49906de0fcc9dbd52f08398d64bfb60 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Tue, 2 Dec 2025 14:40:44 +0000 Subject: [PATCH 3/3] Update NEWS --- NEWS.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 3932ad66..78690207 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,7 +4,9 @@ * `poisson_log_loss()` has been enhanced to handle 0 valued estimates, no longer returning `Inf` or `NaN`. (#513) -* Fixed bug where ranked probability metrics didn't work in combination with other classificiation metrics in `metric_set()`. (#539) +* Fixed bug where ranked probability metrics didn't work in combination with other classification metrics in `metric_set()`. (#539) + +* Added infrastructure for survival metrics on the linear predictor. (#551) # yardstick 1.3.2