Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
59176dc
Add fold weights functionality
tjburch May 23, 2025
848a30a
Remove manually added .Rd files - these are generated by roxygen2
tjburch May 23, 2025
cc8bad1
Add print options for fold weights
Jul 22, 2025
24caeea
Merge branch 'tidymodels:main' into variable-fold-weights
tjburch Jul 22, 2025
5d97e2a
Export exposed functions
Jul 22, 2025
22754b8
printing manual_rset too?
Jul 22, 2025
03c038f
Merge branch 'tidymodels:main' into variable-fold-weights
tjburch Jul 30, 2025
6da3d80
Merge branch 'tidymodels:main' into variable-fold-weights
tjburch Jul 30, 2025
27a5ab5
Merge branch 'main' into variable-fold-weights
Sep 11, 2025
7fa521d
Run air
Sep 11, 2025
b53856f
Run air on tests
Sep 11, 2025
4c45658
Fix piping error
Sep 11, 2025
87c9e3b
Improve docs
Sep 12, 2025
b6d7836
Improve docs and actually merge the file...
Sep 12, 2025
c397d67
Merge branch 'main' into tjburch-variable-fold-weights
topepo Oct 17, 2025
5a2f855
use statistics based on functions from the stats package
topepo Oct 17, 2025
ba3b420
move away from glmnet
topepo Oct 17, 2025
44551b4
spelling update
topepo Oct 17, 2025
0a87de7
redoc
topepo Oct 17, 2025
a38c55e
news entry
topepo Oct 17, 2025
a053787
global data false positives
topepo Oct 17, 2025
79f8c65
change to base R pipe
topepo Oct 17, 2025
5cd334d
fix test case
topepo Oct 17, 2025
ac04e3f
remove print methods - made an issue in rsample for that
topepo Oct 17, 2025
612325e
fold_weight -> resample_weight
topepo Oct 17, 2025
8349ded
name user-facing function with "extract" instead of "get"
topepo Oct 17, 2025
8aea0ea
Make weights NULL when equal
Oct 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ S3method(pretty,tune_results)
S3method(print,control_bayes)
S3method(print,control_grid)
S3method(print,control_last_fit)
S3method(print,manual_rset)
S3method(print,prob_improve)
S3method(print,rset)
S3method(print,tune_results)
S3method(select_best,default)
S3method(select_best,tune_results)
Expand Down Expand Up @@ -149,8 +151,10 @@ export(.get_tune_workflow)
export(.load_namespace)
export(.stash_last_result)
export(.use_case_weights_with_yardstick)
export(add_fold_weights)
export(augment)
export(autoplot)
export(calculate_fold_weights)
export(check_eval_time_arg)
export(check_initial)
export(check_metric_in_tune_results)
Expand Down Expand Up @@ -203,6 +207,7 @@ export(fit_resamples)
export(forge_from_workflow)
export(future_installed)
export(get_future_workers)
export(get_fold_weights)
export(get_metric_time)
export(get_mirai_workers)
export(get_parallel_seeds)
Expand Down
20 changes: 20 additions & 0 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,26 @@ check_rset <- function(x) {
if (inherits(x, "permutations")) {
cli::cli_abort("Permutation samples are not suitable for tuning.")
}

# Check fold weights if present
check_fold_weights(x)

invisible(NULL)
}

#' Check fold weights in rset objects
#'
#' @param x An rset object.
#' @return `NULL` invisibly, or error if weights are invalid.
#' @keywords internal
check_fold_weights <- function(x) {
weights <- attr(x, ".fold_weights")
if (is.null(weights)) {
return(invisible(NULL))
}

.validate_fold_weights(weights, nrow(x))

invisible(NULL)
}

Expand Down
63 changes: 49 additions & 14 deletions R/collect.R
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,8 @@ estimate_tune_results <- function(x, ..., col_name = ".metrics") {
)
}

fold_weights <- .get_fold_weights(x)

# The mapping of tuning parameters and .config.
config_key <- .config_key_from_metrics(x)

Expand Down Expand Up @@ -602,21 +604,54 @@ estimate_tune_results <- function(x, ..., col_name = ".metrics") {
x <- dplyr::distinct(x)
}

x <- tibble::as_tibble(x)
x <- vctrs::vec_slice(x, x$id != "Apparent")
x <- x |>
dplyr::group_by(
!!!rlang::syms(param_names),
.metric,
.estimator,
!!!rlang::syms(group_cols)
) |>
dplyr::summarize(
mean = mean(.estimate, na.rm = TRUE),
n = sum(!is.na(.estimate)),
std_err = sd(.estimate, na.rm = TRUE) / sqrt(n),
.groups = "drop"
)
tibble::as_tibble() |>
vctrs::vec_slice(., .$id != "Apparent")

# Join weights to the data if available
if (!is.null(fold_weights)) {
weight_data <- .create_weight_mapping(fold_weights, id_names, x)
if (!is.null(weight_data)) {
x <- dplyr::left_join(x, weight_data, by = id_names)
} else {
# If weight mapping failed, fall back to unweighted
fold_weights <- NULL
}
}

if (!is.null(fold_weights)) {
# Use weighted aggregation
x <- x |>
dplyr::group_by(
!!!rlang::syms(param_names),
.metric,
.estimator,
!!!rlang::syms(group_cols)
) |>
dplyr::summarize(
mean = .weighted_mean(.estimate, .fold_weight),
n = sum(!is.na(.estimate)),
effective_n = .effective_sample_size(.fold_weight[!is.na(.estimate)]),
std_err = .weighted_sd(.estimate, .fold_weight) /
sqrt(pmax(effective_n, 1)),
.groups = "drop"
) |>
dplyr::select(-effective_n)
} else {
x <- x |>
dplyr::group_by(
!!!rlang::syms(param_names),
.metric,
.estimator,
!!!rlang::syms(group_cols)
) |>
dplyr::summarize(
mean = mean(.estimate, na.rm = TRUE),
n = sum(!is.na(.estimate)),
std_err = sd(.estimate, na.rm = TRUE) / sqrt(n),
.groups = "drop"
)
}

# only join when parameters are being tuned (#600)
if (length(param_names) == 0) {
Expand Down
4 changes: 4 additions & 0 deletions R/tune_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,10 @@ pull_rset_attributes <- function(x) {
att$class <- setdiff(class(x), class(tibble::new_tibble(list())))
att$class <- att$class[att$class != "rset"]

if (!is.null(attr(x, ".fold_weights"))) {
att[[".fold_weights"]] <- attr(x, ".fold_weights")
}

lab <- try(pretty(x), silent = TRUE)
if (inherits(lab, "try-error")) {
lab <- NA_character_
Expand Down
242 changes: 242 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,248 @@ pretty.tune_results <- function(x, ...) {
attr(x, "rset_info")$label
}

#' Fold weights utility functions
#'
#' These are internal functions for handling variable fold weights in
#' hyperparameter tuning.
#'
#' @param x A tune_results object.
#' @param weights Numeric vector of weights.
#' @param id_names Character vector of ID column names.
#' @param metrics_data The metrics data frame.
#' @param w Numeric vector of weights.
#' @param n_folds Integer number of folds.
#'
#' @return Various return values depending on the function.
#' @keywords internal
#' @name fold_weights_utils
#' @aliases .create_weight_mapping .weighted_mean .weighted_sd .effective_sample_size .validate_fold_weights
#' @export
#' @rdname fold_weights_utils
.get_fold_weights <- function(x) {
rset_info <- attr(x, "rset_info")
if (is.null(rset_info)) {
return(NULL)
}

# Access weights from rset_info attributes using correct path
weights <- rset_info$att[[".fold_weights"]]

weights
}

#' @export
#' @rdname fold_weights_utils
.create_weight_mapping <- function(weights, id_names, metrics_data) {
# Get unique combinations of ID columns from the metrics data
unique_ids <- dplyr::distinct(metrics_data, !!!rlang::syms(id_names))

if (nrow(unique_ids) != length(weights)) {
cli::cli_warn(
"Number of weights ({length(weights)}) does not match number of resamples ({nrow(unique_ids)}). Weights will be ignored."
)
return(NULL)
}

# Add weights to the unique ID combinations
unique_ids$.fold_weight <- weights
unique_ids
}

#' @export
#' @rdname fold_weights_utils
.weighted_mean <- function(x, w) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use stats ::weighted.mean()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So looking further down, we do the same with the variance calculation. We do have an unexported API in recipes wt_calcs() that does a lot of this.

We can copy it over here (to avoid duplication) and think about moving that core code to hardat so that we get it everywhere from one source.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Erm actually, we should probably put those weighting functions in a "standalone file" here, and then recipes and anyone else can import that.

@EmilHvitfeldt do you have any thoughts on that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be happy to have a standalone file for weighted functions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll amend my comments a third time. We don't export wt_calcs() but we do export the single statistic versions like recipes::averages() and recipes::variances() so maybe it's that simple.

if (all(is.na(x))) {
return(NA_real_)
}

# Remove NA values and corresponding weights
valid <- !is.na(x)
x_valid <- x[valid]
w_valid <- w[valid]

if (length(x_valid) == 0) {
return(NA_real_)
}

# Normalize weights
w_valid <- w_valid / sum(w_valid)

sum(x_valid * w_valid)
}

#' @export
#' @rdname fold_weights_utils
.weighted_sd <- function(x, w) {
if (all(is.na(x))) {
return(NA_real_)
}

# Remove NA values and corresponding weights
valid <- !is.na(x)
x_valid <- x[valid]
w_valid <- w[valid]

if (length(x_valid) <= 1) {
return(NA_real_)
}

# Normalize weights
w_valid <- w_valid / sum(w_valid)

# Calculate weighted mean
weighted_mean <- sum(x_valid * w_valid)

# Calculate weighted variance
weighted_var <- sum(w_valid * (x_valid - weighted_mean)^2)

sqrt(weighted_var)
}

#' @export
#' @rdname fold_weights_utils
.effective_sample_size <- function(w) {
# Remove NA weights
w <- w[!is.na(w)]

if (length(w) == 0) {
return(0)
}

# Calculate effective sample size: (sum of weights)^2 / sum of squared weights
sum_w <- sum(w)
sum_w_sq <- sum(w^2)

if (sum_w_sq == 0) {
return(0)
}

sum_w^2 / sum_w_sq
}

#' @export
#' @rdname fold_weights_utils
.validate_fold_weights <- function(weights, n_folds) {
if (is.null(weights)) {
return(NULL)
}

if (!is.numeric(weights)) {
cli::cli_abort("{.arg weights} must be numeric.")
}

if (length(weights) != n_folds) {
cli::cli_abort(
"Length of {.arg weights} ({length(weights)}) must equal number of folds ({n_folds})."
)
}

if (any(weights < 0)) {
cli::cli_abort("{.arg weights} must be non-negative.")
}

if (all(weights == 0)) {
cli::cli_abort("At least one weight must be positive.")
}

# Return normalized weights
weights / sum(weights)
}

#' Add fold weights to an rset object
#'
#' @param rset An rset object.
#' @param weights A numeric vector of weights.
#' @return The rset object with weights added as an attribute.
#' @export
add_fold_weights <- function(rset, weights) {
if (!inherits(rset, "rset")) {
cli::cli_abort("{.arg rset} must be an rset object.")
}

# Validate weights
weights <- .validate_fold_weights(weights, nrow(rset))

# Add weights as an attribute
attr(rset, ".fold_weights") <- weights

rset
}

#' Calculate fold weights from fold sizes
#'
#' @param rset An rset object.
#' @return A numeric vector of weights proportional to fold sizes.
#' @export
calculate_fold_weights <- function(rset) {
if (!inherits(rset, "rset")) {
cli::cli_abort("{.arg rset} must be an rset object.")
}

# Calculate the size of each analysis set
fold_sizes <- purrr::map_int(rset$splits, ~ nrow(rsample::analysis(.x)))

# Return weights proportional to fold sizes
fold_sizes / sum(fold_sizes)
}

#' Extract fold weights from rset or tune_results objects
#'
#' This function provides a consistent interface to access fold weights
#' regardless of whether they were added to an rset object or are stored
#' in tune_results after tuning.
#'
#' @param x An rset object with fold weights, or a tune_results object.
#' @return A numeric vector of fold weights, or NULL if no weights are present.
#' @export
#' @examples
#' \dontrun{
#' library(rsample)
#' folds <- vfold_cv(mtcars, v = 3)
#' weighted_folds <- add_fold_weights(folds, c(0.2, 0.3, 0.5))
#' get_fold_weights(weighted_folds)
#' }
get_fold_weights <- function(x) {
if (inherits(x, "rset")) {
# For rset objects, weights are stored as an attribute
return(attr(x, ".fold_weights"))
} else if (inherits(x, c("tune_results", "resample_results"))) {
# For tune results, use the internal function
return(.get_fold_weights(x))
} else {
cli::cli_abort("{.arg x} must be an rset or tune_results object.")
}
}

#' @export
print.rset <- function(x, ...) {
fold_weights <- attr(x, ".fold_weights")

if (!is.null(fold_weights)) {
# Create a tibble with fold weights as a column
x_tbl <- tibble::as_tibble(x)
x_tbl$fold_weight <- fold_weights
print(x_tbl, ...)
} else {
# Use default behavior
NextMethod("print")
}
}

#' @export
print.manual_rset <- function(x, ...) {
fold_weights <- attr(x, ".fold_weights")

if (!is.null(fold_weights)) {
# Create a tibble with fold weights as a column
x_tbl <- tibble::as_tibble(x)
x_tbl$fold_weight <- fold_weights
print(x_tbl, ...)
} else {
# Use default behavior for manual_rset
NextMethod("print")
}
}

# ------------------------------------------------------------------------------

Expand Down
Loading
Loading