Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ export(ccc)
export(ccc_vec)
export(check_class_metric)
export(check_dynamic_survival_metric)
export(check_linear_pred_survival_metric)
export(check_numeric_metric)
export(check_ordered_prob_metric)
export(check_prob_metric)
Expand Down Expand Up @@ -184,6 +185,7 @@ export(j_index_vec)
export(kap)
export(kap_vec)
export(lift_curve)
export(linear_pred_survival_metric_summarizer)
export(mae)
export(mae_vec)
export(mape)
Expand All @@ -207,6 +209,7 @@ export(new_class_metric)
export(new_dynamic_survival_metric)
export(new_groupwise_metric)
export(new_integrated_survival_metric)
export(new_linear_pred_survival_metric)
export(new_numeric_metric)
export(new_ordered_prob_metric)
export(new_prob_metric)
Expand Down
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

* `poisson_log_loss()` has been enhanced to handle 0 valued estimates, no longer returning `Inf` or `NaN`. (#513)

* Fixed bug where ranked probability metrics didn't work in combination with other classificiation metrics in `metric_set()`. (#539)
* Fixed bug where ranked probability metrics didn't work in combination with other classification metrics in `metric_set()`. (#539)

* Added infrastructure for survival metrics on the linear predictor. (#551)

# yardstick 1.3.2

Expand Down
141 changes: 116 additions & 25 deletions R/aaa-metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ metric_set <- function(...) {
c(
"dynamic_survival_metric",
"static_survival_metric",
"integrated_survival_metric"
"integrated_survival_metric",
"linear_pred_survival_metric"
)
) {
make_survival_metric_function(fns)
Expand Down Expand Up @@ -547,37 +548,96 @@ make_survival_metric_function <- function(fns) {
# Construct common argument set for each metric call
# Doing this dynamically inside the generated function means
# we capture the correct arguments
dynamic_call_args <- quos(
data = data,
truth = !!enquo(truth),
... = ...,
na_rm = na_rm,
case_weights = !!enquo(case_weights),
... = ...
)

static_call_args <- quos(
data = data,
truth = !!enquo(truth),
estimate = !!enquo(estimate),
na_rm = na_rm,
case_weights = !!enquo(case_weights),
... = ...
is_static <- vapply(
fns,
inherits,
logical(1),
"static_survival_metric"
)

call_class_ind <- vapply(
is_linear_pred <- vapply(
fns,
inherits,
"static_survival_metric",
logical(1),
"linear_pred_survival_metric"
)
is_dynamic_or_integrated <- vapply(
fns,
function(fn) {
inherits(fn, "dynamic_survival_metric") ||
inherits(fn, "integrated_survival_metric")
},
FUN.VALUE = logical(1)
)

# Construct calls from the functions + arguments
dynamic_calls <- lapply(fns[!call_class_ind], call2, !!!dynamic_call_args)
static_calls <- lapply(fns[call_class_ind], call2, !!!static_call_args)
# Static and linear pred metrics both use the `estimate` argument
# so we need route the columns to the correct metric functions
is_set_of_static_and_linear_pred <- any(is_static) && any(is_linear_pred)

if (is_set_of_static_and_linear_pred) {
estimate_eval <- tidyselect::eval_select(
expr = enquo(estimate),
data = data,
allow_rename = TRUE,
allow_empty = FALSE,
error_call = current_env()
)

validate_estimate_static_linear_pred(estimate_eval, call = current_env())

static_col_name <- names(data)[estimate_eval["static"]]
linear_pred_col_name <- names(data)[estimate_eval["linear_pred"]]

args_static <- quos(
data = data,
truth = !!enquo(truth),
estimate = !!sym(static_col_name),
na_rm = na_rm,
case_weights = !!enquo(case_weights),
... = ...
)

args_linear_pred <- quos(
data = data,
truth = !!enquo(truth),
estimate = !!sym(linear_pred_col_name),
na_rm = na_rm,
case_weights = !!enquo(case_weights),
... = ...
)

calls_static <- lapply(fns[is_static], call2, !!!args_static)
calls_linear_pred <- lapply(
fns[is_linear_pred],
call2,
!!!args_linear_pred
)

calls_estimate <- c(calls_static, calls_linear_pred)
} else {
args_estimate <- quos(
data = data,
truth = !!enquo(truth),
estimate = !!enquo(estimate),
na_rm = na_rm,
case_weights = !!enquo(case_weights),
... = ...
)

needs_estimate_arg <- is_static | is_linear_pred
calls_estimate <- lapply(fns[needs_estimate_arg], call2, !!!args_estimate)
}

calls <- c(dynamic_calls, static_calls)
args_dots <- quos(
data = data,
truth = !!enquo(truth),
... = ...,
na_rm = na_rm,
case_weights = !!enquo(case_weights)
)
calls_dots <- lapply(fns[is_dynamic_or_integrated], call2, !!!args_dots)

calls <- c(calls_dots, calls_estimate)
calls <- mapply(call_remove_static_arguments, calls, fns)

# Evaluate
Expand Down Expand Up @@ -644,7 +704,8 @@ validate_function_class <- function(fns) {
"numeric_metric",
"dynamic_survival_metric",
"static_survival_metric",
"integrated_survival_metric"
"integrated_survival_metric",
"linear_pred_survival_metric"
)

if (n_unique == 1L) {
Expand All @@ -664,7 +725,8 @@ validate_function_class <- function(fns) {
surv_cls <- c(
"dynamic_survival_metric",
"static_survival_metric",
"integrated_survival_metric"
"integrated_survival_metric",
"linear_pred_survival_metric"
)
if (any(fn_cls_unique %in% surv_cls) && all(fn_cls_unique %in% surv_cls)) {
return(invisible(fns))
Expand Down Expand Up @@ -743,6 +805,35 @@ validate_function_class <- function(fns) {
)
}

validate_estimate_static_linear_pred <- function(
estimate_eval,
call = caller_env()
) {
if (length(estimate_eval) != 2L) {
cli::cli_abort(
"{.arg estimate} must select exactly 2 columns from {.arg data},
not {length(estimate_eval)}.",
call = call
)
}

estimate_names <- names(estimate_eval)
expected_names <- c("static", "linear_pred")

if (!setequal(estimate_names, expected_names)) {
cli::cli_abort(
c(
"When mixing static and linear predictor survival metrics,
{.arg estimate} must use named selection.",
"i" = "Use {.code estimate = c(static = col1, linear_pred = col2)}.",
"i" = "Expected names: {.val {expected_names}}.",
"x" = "Received names: {.val {estimate_names}}."
),
call = call
)
}
}

# Safely evaluate metrics in such a way that we can capture the
# error and inform the user of the metric that failed
eval_safely <- function(expr, expr_nm, data = NULL, env = caller_env()) {
Expand Down
7 changes: 7 additions & 0 deletions R/aaa-new.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ new_static_survival_metric <- function(fn, direction) {
new_metric(fn, direction, class = "static_survival_metric")
}

#' @rdname new-metric
#' @export
new_linear_pred_survival_metric <- function(fn, direction) {
new_metric(fn, direction, class = "linear_pred_survival_metric")
}

#' @include import-standalone-types-check.R
new_metric <- function(fn, direction, class = NULL, call = caller_env()) {
check_function(fn, call = call)
Expand Down Expand Up @@ -121,6 +127,7 @@ format.metric <- function(x, ...) {
"dynamic_survival_metric" = "dynamic survival metric",
"static_survival_metric" = "static survival metric",
"integrated_survival_metric" = "integrated survival metric",
"linear_pred_survival_metric" = "linear predictor survival metric",
"metric"
)

Expand Down
12 changes: 12 additions & 0 deletions R/check-metric.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,15 @@ check_static_survival_metric <- function(
validate_case_weights(case_weights, size = nrow(truth), call = call)
validate_surv_truth_numeric_estimate(truth, estimate, call = call)
}

#' @rdname check_metric
#' @export
check_linear_pred_survival_metric <- function(
truth,
estimate,
case_weights,
call = caller_env()
) {
validate_case_weights(case_weights, size = nrow(truth), call = call)
validate_surv_truth_numeric_estimate(truth, estimate, call = call)
}
95 changes: 95 additions & 0 deletions R/template.R
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,101 @@ curve_survival_metric_summarizer <- function(
out
}


#' @rdname metric-summarizers
#' @export
linear_pred_survival_metric_summarizer <- function(
name,
fn,
data,
truth,
estimate,
...,
na_rm = TRUE,
case_weights = NULL,
fn_options = list(),
error_call = caller_env()
) {
check_dots_empty(call = error_call)

truth <- enquo(truth)
estimate <- enquo(estimate)
case_weights <- enquo(case_weights)

truth <- yardstick_eval_select(
expr = truth,
data = data,
arg = "truth",
error_call = error_call
)
estimate <- yardstick_eval_select(
expr = estimate,
data = data,
arg = "estimate",
error_call = error_call
)

if (quo_is_null(case_weights)) {
group_case_weights <- NULL
} else {
case_weights <- yardstick_eval_select(
expr = case_weights,
data = data,
arg = "case_weights",
error_call = error_call
)
}

group_rows <- dplyr::group_rows(data)
group_keys <- dplyr::group_keys(data)
data <- dplyr::ungroup(data)
groups <- vec_chop(data, indices = group_rows)
out <- vector("list", length = length(groups))

for (i in seq_along(groups)) {
group <- groups[[i]]

group_truth <- group[[truth]]
group_estimate <- group[[estimate]]

if (is_string(case_weights)) {
group_case_weights <- group[[case_weights]]
}

elt_out <- list(
.metric = name,
.estimator = finalize_estimator(
group_truth,
metric_class = name,
call = error_call
),
.estimate = inject(
withCallingHandlers(
fn(
truth = group_truth,
estimate = group_estimate,
case_weights = group_case_weights,
na_rm = na_rm,
!!!fn_options
),
error = function(cnd) {
cnd$call <- error_call
cnd_signal(cnd)
}
)
)
)

out[[i]] <- tibble::new_tibble(elt_out)
}

group_keys <- vec_rep_each(group_keys, times = list_sizes(out))
out <- vec_rbind(!!!out)
out <- vec_cbind(group_keys, out)

out
}

prob_estimate_convert <- function(estimate) {
check_data_frame(estimate, .internal = TRUE)

Expand Down
8 changes: 8 additions & 0 deletions man/check_metric.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading