-
Notifications
You must be signed in to change notification settings - Fork 45
Add Fold Weights for Variable Resample Weighting #1007
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 10 commits
59176dc
848a30a
cc8bad1
24caeea
5d97e2a
22754b8
03c038f
6da3d80
27a5ab5
7fa521d
b53856f
4c45658
87c9e3b
b6d7836
c397d67
5a2f855
ba3b420
44551b4
0a87de7
a38c55e
a053787
79f8c65
5cd334d
ac04e3f
612325e
8349ded
8aea0ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -247,6 +247,248 @@ pretty.tune_results <- function(x, ...) { | |
attr(x, "rset_info")$label | ||
} | ||
|
||
#' Fold weights utility functions | ||
#' | ||
#' These are internal functions for handling variable fold weights in | ||
#' hyperparameter tuning. | ||
#' | ||
#' @param x A tune_results object. | ||
#' @param weights Numeric vector of weights. | ||
#' @param id_names Character vector of ID column names. | ||
#' @param metrics_data The metrics data frame. | ||
#' @param w Numeric vector of weights. | ||
#' @param n_folds Integer number of folds. | ||
#' | ||
#' @return Various return values depending on the function. | ||
#' @keywords internal | ||
#' @name fold_weights_utils | ||
#' @aliases .create_weight_mapping .weighted_mean .weighted_sd .effective_sample_size .validate_fold_weights | ||
#' @export | ||
#' @rdname fold_weights_utils | ||
.get_fold_weights <- function(x) { | ||
rset_info <- attr(x, "rset_info") | ||
if (is.null(rset_info)) { | ||
return(NULL) | ||
} | ||
|
||
# Access weights from rset_info attributes using correct path | ||
weights <- rset_info$att[[".fold_weights"]] | ||
|
||
weights | ||
} | ||
|
||
#' @export | ||
#' @rdname fold_weights_utils | ||
.create_weight_mapping <- function(weights, id_names, metrics_data) { | ||
# Get unique combinations of ID columns from the metrics data | ||
unique_ids <- dplyr::distinct(metrics_data, !!!rlang::syms(id_names)) | ||
|
||
if (nrow(unique_ids) != length(weights)) { | ||
cli::cli_warn( | ||
"Number of weights ({length(weights)}) does not match number of resamples ({nrow(unique_ids)}). Weights will be ignored." | ||
) | ||
return(NULL) | ||
} | ||
|
||
# Add weights to the unique ID combinations | ||
unique_ids$.fold_weight <- weights | ||
unique_ids | ||
} | ||
|
||
#' @export | ||
#' @rdname fold_weights_utils | ||
.weighted_mean <- function(x, w) { | ||
|
||
if (all(is.na(x))) { | ||
return(NA_real_) | ||
} | ||
|
||
# Remove NA values and corresponding weights | ||
valid <- !is.na(x) | ||
x_valid <- x[valid] | ||
w_valid <- w[valid] | ||
|
||
if (length(x_valid) == 0) { | ||
return(NA_real_) | ||
} | ||
|
||
# Normalize weights | ||
w_valid <- w_valid / sum(w_valid) | ||
|
||
sum(x_valid * w_valid) | ||
} | ||
|
||
#' @export | ||
#' @rdname fold_weights_utils | ||
.weighted_sd <- function(x, w) { | ||
if (all(is.na(x))) { | ||
return(NA_real_) | ||
} | ||
|
||
# Remove NA values and corresponding weights | ||
valid <- !is.na(x) | ||
x_valid <- x[valid] | ||
w_valid <- w[valid] | ||
|
||
if (length(x_valid) <= 1) { | ||
return(NA_real_) | ||
} | ||
|
||
# Normalize weights | ||
w_valid <- w_valid / sum(w_valid) | ||
|
||
# Calculate weighted mean | ||
weighted_mean <- sum(x_valid * w_valid) | ||
|
||
# Calculate weighted variance | ||
weighted_var <- sum(w_valid * (x_valid - weighted_mean)^2) | ||
|
||
sqrt(weighted_var) | ||
} | ||
|
||
#' @export | ||
#' @rdname fold_weights_utils | ||
.effective_sample_size <- function(w) { | ||
# Remove NA weights | ||
w <- w[!is.na(w)] | ||
|
||
if (length(w) == 0) { | ||
return(0) | ||
} | ||
|
||
# Calculate effective sample size: (sum of weights)^2 / sum of squared weights | ||
sum_w <- sum(w) | ||
sum_w_sq <- sum(w^2) | ||
|
||
if (sum_w_sq == 0) { | ||
return(0) | ||
} | ||
|
||
sum_w^2 / sum_w_sq | ||
} | ||
|
||
#' @export | ||
#' @rdname fold_weights_utils | ||
.validate_fold_weights <- function(weights, n_folds) { | ||
if (is.null(weights)) { | ||
return(NULL) | ||
} | ||
|
||
if (!is.numeric(weights)) { | ||
cli::cli_abort("{.arg weights} must be numeric.") | ||
} | ||
|
||
if (length(weights) != n_folds) { | ||
cli::cli_abort( | ||
"Length of {.arg weights} ({length(weights)}) must equal number of folds ({n_folds})." | ||
) | ||
} | ||
|
||
if (any(weights < 0)) { | ||
cli::cli_abort("{.arg weights} must be non-negative.") | ||
} | ||
|
||
if (all(weights == 0)) { | ||
cli::cli_abort("At least one weight must be positive.") | ||
} | ||
|
||
# Return normalized weights | ||
weights / sum(weights) | ||
} | ||
|
||
#' Add fold weights to an rset object | ||
#' | ||
#' @param rset An rset object. | ||
#' @param weights A numeric vector of weights. | ||
#' @return The rset object with weights added as an attribute. | ||
#' @export | ||
add_fold_weights <- function(rset, weights) { | ||
if (!inherits(rset, "rset")) { | ||
cli::cli_abort("{.arg rset} must be an rset object.") | ||
} | ||
|
||
# Validate weights | ||
weights <- .validate_fold_weights(weights, nrow(rset)) | ||
|
||
# Add weights as an attribute | ||
attr(rset, ".fold_weights") <- weights | ||
|
||
rset | ||
} | ||
|
||
#' Calculate fold weights from fold sizes | ||
#' | ||
#' @param rset An rset object. | ||
#' @return A numeric vector of weights proportional to fold sizes. | ||
#' @export | ||
calculate_fold_weights <- function(rset) { | ||
if (!inherits(rset, "rset")) { | ||
cli::cli_abort("{.arg rset} must be an rset object.") | ||
} | ||
|
||
# Calculate the size of each analysis set | ||
fold_sizes <- purrr::map_int(rset$splits, ~ nrow(rsample::analysis(.x))) | ||
|
||
# Return weights proportional to fold sizes | ||
fold_sizes / sum(fold_sizes) | ||
} | ||
|
||
#' Extract fold weights from rset or tune_results objects | ||
#' | ||
#' This function provides a consistent interface to access fold weights | ||
#' regardless of whether they were added to an rset object or are stored | ||
#' in tune_results after tuning. | ||
#' | ||
#' @param x An rset object with fold weights, or a tune_results object. | ||
#' @return A numeric vector of fold weights, or NULL if no weights are present. | ||
#' @export | ||
#' @examples | ||
#' \dontrun{ | ||
#' library(rsample) | ||
#' folds <- vfold_cv(mtcars, v = 3) | ||
#' weighted_folds <- add_fold_weights(folds, c(0.2, 0.3, 0.5)) | ||
#' get_fold_weights(weighted_folds) | ||
#' } | ||
get_fold_weights <- function(x) { | ||
if (inherits(x, "rset")) { | ||
# For rset objects, weights are stored as an attribute | ||
return(attr(x, ".fold_weights")) | ||
} else if (inherits(x, c("tune_results", "resample_results"))) { | ||
# For tune results, use the internal function | ||
return(.get_fold_weights(x)) | ||
} else { | ||
cli::cli_abort("{.arg x} must be an rset or tune_results object.") | ||
} | ||
} | ||
|
||
#' @export | ||
print.rset <- function(x, ...) { | ||
fold_weights <- attr(x, ".fold_weights") | ||
|
||
if (!is.null(fold_weights)) { | ||
# Create a tibble with fold weights as a column | ||
x_tbl <- tibble::as_tibble(x) | ||
x_tbl$fold_weight <- fold_weights | ||
print(x_tbl, ...) | ||
} else { | ||
# Use default behavior | ||
NextMethod("print") | ||
} | ||
} | ||
|
||
#' @export | ||
print.manual_rset <- function(x, ...) { | ||
fold_weights <- attr(x, ".fold_weights") | ||
|
||
if (!is.null(fold_weights)) { | ||
# Create a tibble with fold weights as a column | ||
x_tbl <- tibble::as_tibble(x) | ||
x_tbl$fold_weight <- fold_weights | ||
print(x_tbl, ...) | ||
} else { | ||
# Use default behavior for manual_rset | ||
NextMethod("print") | ||
} | ||
} | ||
|
||
# ------------------------------------------------------------------------------ | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.