diff --git a/NAMESPACE b/NAMESPACE index 051f192e..1f37bc36 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -134,10 +134,13 @@ S3method(vec_restore,resample_results) S3method(vec_restore,tune_results) export(.catch_and_log) export(.config_key_from_metrics) +export(.create_weight_mapping) +export(.effective_sample_size) export(.estimate_metrics) export(.filter_perf_metrics) export(.get_extra_col_names) export(.get_fingerprint) +export(.get_resample_weights) export(.get_tune_eval_time_target) export(.get_tune_eval_times) export(.get_tune_metric_names) @@ -149,8 +152,12 @@ export(.get_tune_workflow) export(.load_namespace) export(.stash_last_result) export(.use_case_weights_with_yardstick) +export(.validate_resample_weights) +export(.weighted_sd) +export(add_resample_weights) export(augment) export(autoplot) +export(calculate_resample_weights) export(check_eval_time_arg) export(check_initial) export(check_metric_in_tune_results) @@ -187,6 +194,7 @@ export(extract_mold) export(extract_parameter_set_dials) export(extract_preprocessor) export(extract_recipe) +export(extract_resample_weights) export(extract_spec_parsnip) export(extract_workflow) export(filter_parameters) diff --git a/NEWS.md b/NEWS.md index 09833b69..2c819235 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # tune (development version) +When calculating resampling estimates, we can now use a weighted mean based on the number of rows in the assessment set. You can opt-in to this using the new `add_resample_weights()` function. See `?calculate_resample_weights` (#990) + # tune 2.0.1 * Fixed a bug where `int_pctl()` wouldn't work on `last_fit()` outcomes when future parallelism was enabled. (#1099) diff --git a/R/0_imports.R b/R/0_imports.R index 54409ce4..6aa46594 100644 --- a/R/0_imports.R +++ b/R/0_imports.R @@ -125,7 +125,9 @@ utils::globalVariables( "model_stage", "predict_stage", "user", - "num" + "num", + ".resample_weight", + "effective_n" ) ) diff --git a/R/checks.R b/R/checks.R index fb76d712..193d1548 100644 --- a/R/checks.R +++ b/R/checks.R @@ -22,6 +22,26 @@ check_rset <- function(x) { if (inherits(x, "permutations")) { cli::cli_abort("Permutation samples are not suitable for tuning.") } + + # Check fold weights if present + check_resample_weights(x) + + invisible(NULL) +} + +#' Check fold weights in rset objects +#' +#' @param x An rset object. +#' @return `NULL` invisibly, or error if weights are invalid. +#' @keywords internal +check_resample_weights <- function(x) { + weights <- attr(x, ".resample_weights") + if (is.null(weights)) { + return(invisible(NULL)) + } + + .validate_resample_weights(weights, nrow(x)) + invisible(NULL) } diff --git a/R/collect.R b/R/collect.R index f9eb631a..5e6b004c 100644 --- a/R/collect.R +++ b/R/collect.R @@ -565,6 +565,8 @@ estimate_tune_results <- function(x, ..., col_name = ".metrics") { ) } + resample_weights <- .get_resample_weights(x) + # The mapping of tuning parameters and .config. config_key <- .config_key_from_metrics(x) @@ -604,19 +606,53 @@ estimate_tune_results <- function(x, ..., col_name = ".metrics") { x <- tibble::as_tibble(x) x <- vctrs::vec_slice(x, x$id != "Apparent") - x <- x |> - dplyr::group_by( - !!!rlang::syms(param_names), - .metric, - .estimator, - !!!rlang::syms(group_cols) - ) |> - dplyr::summarize( - mean = mean(.estimate, na.rm = TRUE), - n = sum(!is.na(.estimate)), - std_err = sd(.estimate, na.rm = TRUE) / sqrt(n), - .groups = "drop" - ) + + # Join weights to the data if available + if (!is.null(resample_weights)) { + weight_data <- .create_weight_mapping(resample_weights, id_names, x) + if (!is.null(weight_data)) { + x <- dplyr::left_join(x, weight_data, by = id_names) + } else { + # If weight mapping failed, fall back to unweighted + resample_weights <- NULL + } + } + + if (!is.null(resample_weights)) { + # Use weighted aggregation + x <- x |> + dplyr::group_by( + !!!rlang::syms(param_names), + .metric, + .estimator, + !!!rlang::syms(group_cols) + ) |> + dplyr::summarize( + mean = stats::weighted.mean(.estimate, .resample_weight), + n = sum(!is.na(.estimate)), + effective_n = .effective_sample_size(.resample_weight[ + !is.na(.estimate) + ]), + std_err = .weighted_sd(.estimate, .resample_weight) / + sqrt(pmax(effective_n, 1)), + .groups = "drop" + ) |> + dplyr::select(-effective_n) + } else { + x <- x |> + dplyr::group_by( + !!!rlang::syms(param_names), + .metric, + .estimator, + !!!rlang::syms(group_cols) + ) |> + dplyr::summarize( + mean = mean(.estimate, na.rm = TRUE), + n = sum(!is.na(.estimate)), + std_err = sd(.estimate, na.rm = TRUE) / sqrt(n), + .groups = "drop" + ) + } # only join when parameters are being tuned (#600) if (length(param_names) == 0) { diff --git a/R/tune_grid.R b/R/tune_grid.R index 5080c5a9..d4b7bafe 100644 --- a/R/tune_grid.R +++ b/R/tune_grid.R @@ -425,6 +425,10 @@ pull_rset_attributes <- function(x) { att$class <- setdiff(class(x), class(tibble::new_tibble(list()))) att$class <- att$class[att$class != "rset"] + if (!is.null(attr(x, ".resample_weights"))) { + att[[".resample_weights"]] <- attr(x, ".resample_weights") + } + lab <- try(pretty(x), silent = TRUE) if (inherits(lab, "try-error")) { lab <- NA_character_ diff --git a/R/utils.R b/R/utils.R index 30f6f23f..5f637411 100644 --- a/R/utils.R +++ b/R/utils.R @@ -247,6 +247,241 @@ pretty.tune_results <- function(x, ...) { attr(x, "rset_info")$label } +#' Resampling weights utility functions +#' +#' These are internal functions for handling variable resampling weights in +#' hyperparameter tuning. +#' +#' @param x A tune_results object. +#' @param weights Numeric vector of weights. +#' @param id_names Character vector of ID column names. +#' @param metrics_data The metrics data frame. +#' @param w Numeric vector of weights. +#' @param num_resamples Integer number of resamples. +#' +#' @return Various return values depending on the function. +#' @keywords internal +#' @name resample_weights_utils +#' @aliases .create_weight_mapping .weighted_sd .effective_sample_size .validate_resample_weights +#' @export +#' @rdname resample_weights_utils +.get_resample_weights <- function(x) { + rset_info <- attr(x, "rset_info") + if (is.null(rset_info)) { + return(NULL) + } + + # Access weights from rset_info attributes using correct path + weights <- rset_info$att[[".resample_weights"]] + + weights +} + +#' @export +#' @rdname resample_weights_utils +.create_weight_mapping <- function(weights, id_names, metrics_data) { + # Get unique combinations of ID columns from the metrics data + unique_ids <- dplyr::distinct(metrics_data, !!!rlang::syms(id_names)) + + if (nrow(unique_ids) != length(weights)) { + cli::cli_warn( + c( + "Number of weights ({length(weights)}) does not match number of resamples ({nrow(unique_ids)}).", + "Weights will be ignored." + ) + ) + return(NULL) + } + + # Add weights to the unique ID combinations + unique_ids$.resample_weight <- weights + unique_ids +} + +#' @export +#' @rdname resample_weights_utils +.weighted_sd <- function(x, w) { + if (all(is.na(x))) { + return(NA_real_) + } + + # Remove NA values and corresponding weights + valid <- !is.na(x) + x_valid <- x[valid] + w_valid <- w[valid] + + if (length(x_valid) <= 1) { + return(NA_real_) + } + + # Calculate weighted variance + weighted_var <- + tibble::as_tibble_col(x) |> + stats::cov.wt(wt = w, cor = FALSE) + + weighted_var <- weighted_var$cov[1, 1] + + sqrt(weighted_var) +} + +#' @export +#' @rdname resample_weights_utils +.effective_sample_size <- function(w) { + # Remove NA weights + w <- w[!is.na(w)] + + if (length(w) == 0) { + return(0) + } + + # Calculate effective sample size: (sum of weights)^2 / sum of squared weights + sum_w <- sum(w) + sum_w_sq <- sum(w^2) + + if (sum_w_sq == 0) { + return(0) + } + + sum_w^2 / sum_w_sq +} + +#' @export +#' @rdname resample_weights_utils +.validate_resample_weights <- function(weights, num_resamples) { + if (is.null(weights)) { + return(NULL) + } + + if (!is.numeric(weights)) { + cli::cli_abort("{.arg weights} must be numeric.") + } + + if (length(weights) != num_resamples) { + cli::cli_abort( + "Length of {.arg weights} ({length(weights)}) must equal number of resamples ({num_resamples})." + ) + } + + if (any(weights < 0)) { + cli::cli_abort("{.arg weights} must be non-negative.") + } + + if (all(weights == 0)) { + cli::cli_abort("At least one weight must be positive.") + } + + # Return normalized weights + normalized_weights <- weights / sum(weights) + + # If equal, equivalent to not weighting + expected_equal <- 1 / num_resamples + if ( + isTRUE(all.equal(normalized_weights, rep(expected_equal, num_resamples))) + ) { + return(NULL) + } + + return(normalized_weights) +} + +#' Add resample weights to an rset object +#' +#' This function allows you to specify custom weights for resamples. Weights +#' are automatically normalized to sum to 1. +#' +#' @param rset An rset object from \pkg{rsample}. +#' @param weights A numeric vector of weights, one per resample. Weights will be +#' normalized. +#' @return The rset object with weights added as an attribute. +#' @details +#' Resampling weights are useful when assessment sets (i.e., held out data) have +#' different sizes or when you want to upweight certain resamples in the evaluation. +#' The weights are stored as an attribute and used automatically during +#' metric aggregation. +#' @seealso [calculate_resample_weights()], [extract_resample_weights()] +#' @examples +#' library(rsample) +#' folds <- vfold_cv(mtcars, v = 3) +#' # Give equal weight to all folds +#' weighted_folds <- add_resample_weights(folds, c(1, 1, 1)) +#' # Emphasize the first fold +#' weighted_folds <- add_resample_weights(folds, c(0.5, 0.25, 0.25)) +#' @export +add_resample_weights <- function(rset, weights) { + if (!inherits(rset, "rset")) { + cli::cli_abort("{.arg rset} must be an rset object.") + } + + # Validate weights + weights <- .validate_resample_weights(weights, nrow(rset)) + + # Add weights as an attribute + attr(rset, ".resample_weights") <- weights + + rset +} + +#' Calculate resample weights from resample sizes +#' +#' This convenience function calculates weights proportional to the number of +#' observations in each resample's analysis set. Larger resamples get higher weights. +#' This ensures that resamples with more data have proportionally more influence +#' on the final aggregated metrics. +#' +#' @param rset An rset object from \pkg{rsample}. +#' @return A numeric vector of weights proportional to resample sizes, normalized +#' to sum to 1. +#' @details +#' This is particularly useful for time-based resamples (e.g., expanding window CV) +#' or stratified sampling where resamples might have slightly different sizes, in +#' which resamples are imbalanced. +#' @seealso [add_resample_weights()], [extract_resample_weights()] +#' @examples +#' library(rsample) +#' folds <- vfold_cv(mtcars, v = 3) +#' weights <- calculate_resample_weights(folds) +#' weighted_folds <- add_resample_weights(folds, weights) +#' @export +calculate_resample_weights <- function(rset) { + if (!inherits(rset, "rset")) { + cli::cli_abort("{.arg rset} must be an rset object.") + } + + # Calculate the size of each analysis set + resample_sizes <- purrr::map_int(rset$splits, ~ nrow(rsample::analysis(.x))) + + # Return weights proportional to resample sizes + resample_sizes / sum(resample_sizes) +} + +#' Extract resample weights from rset or tuning objects +#' +#' This function provides a consistent interface to access resample weights +#' regardless of whether they were added to an rset object or are stored +#' in `tune_results` after tuning. +#' +#' @param x An rset object with resample weights, or a `tune_results` object. +#' @return A numeric vector of resample weights, or NULL if no weights are present. +#' @export +#' @examples +#' \dontrun{ +#' library(rsample) +#' folds <- vfold_cv(mtcars, v = 3) +#' weighted_folds <- add_resample_weights(folds, c(0.2, 0.3, 0.5)) +#' extract_resample_weights(weighted_folds) +#' } +extract_resample_weights <- function(x) { + if (inherits(x, "rset")) { + # For rset objects, weights are stored as an attribute + res <- attr(x, ".resample_weights") + } else if (inherits(x, c("tune_results", "resample_results"))) { + # For tune results, use the internal function + res <- .get_resample_weights(x) + } else { + cli::cli_abort("{.arg x} must be an rset or tune_results object.") + } + res +} # ------------------------------------------------------------------------------ diff --git a/_pkgdown.yml b/_pkgdown.yml index 58936fac..7999cdf5 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -66,6 +66,7 @@ reference: - starts_with("compute") - augment.tune_results - example_ames_knn + - contains("resample_weight") - title: Developer functions contents: - merge.recipe diff --git a/inst/WORDLIST b/inst/WORDLIST index 1e048a29..22f83db4 100644 --- a/inst/WORDLIST +++ b/inst/WORDLIST @@ -49,12 +49,15 @@ pre preprocessor preprocessors reprex +resample's rsample's +rset tibble tibbles tidymodels tunable unsummarized +upweight urations vectorization wiggliness diff --git a/man/add_resample_weights.Rd b/man/add_resample_weights.Rd new file mode 100644 index 00000000..44e02627 --- /dev/null +++ b/man/add_resample_weights.Rd @@ -0,0 +1,38 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{add_resample_weights} +\alias{add_resample_weights} +\title{Add resample weights to an rset object} +\usage{ +add_resample_weights(rset, weights) +} +\arguments{ +\item{rset}{An rset object from \pkg{rsample}.} + +\item{weights}{A numeric vector of weights, one per resample. Weights will be +normalized.} +} +\value{ +The rset object with weights added as an attribute. +} +\description{ +This function allows you to specify custom weights for resamples. Weights +are automatically normalized to sum to 1. +} +\details{ +Resampling weights are useful when assessment sets (i.e., held out data) have +different sizes or when you want to upweight certain resamples in the evaluation. +The weights are stored as an attribute and used automatically during +metric aggregation. +} +\examples{ +library(rsample) +folds <- vfold_cv(mtcars, v = 3) +# Give equal weight to all folds +weighted_folds <- add_resample_weights(folds, c(1, 1, 1)) +# Emphasize the first fold +weighted_folds <- add_resample_weights(folds, c(0.5, 0.25, 0.25)) +} +\seealso{ +\code{\link[=calculate_resample_weights]{calculate_resample_weights()}}, \code{\link[=extract_resample_weights]{extract_resample_weights()}} +} diff --git a/man/calculate_resample_weights.Rd b/man/calculate_resample_weights.Rd new file mode 100644 index 00000000..3e971867 --- /dev/null +++ b/man/calculate_resample_weights.Rd @@ -0,0 +1,35 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{calculate_resample_weights} +\alias{calculate_resample_weights} +\title{Calculate resample weights from resample sizes} +\usage{ +calculate_resample_weights(rset) +} +\arguments{ +\item{rset}{An rset object from \pkg{rsample}.} +} +\value{ +A numeric vector of weights proportional to resample sizes, normalized +to sum to 1. +} +\description{ +This convenience function calculates weights proportional to the number of +observations in each resample's analysis set. Larger resamples get higher weights. +This ensures that resamples with more data have proportionally more influence +on the final aggregated metrics. +} +\details{ +This is particularly useful for time-based resamples (e.g., expanding window CV) +or stratified sampling where resamples might have slightly different sizes, in +which resamples are imbalanced. +} +\examples{ +library(rsample) +folds <- vfold_cv(mtcars, v = 3) +weights <- calculate_resample_weights(folds) +weighted_folds <- add_resample_weights(folds, weights) +} +\seealso{ +\code{\link[=add_resample_weights]{add_resample_weights()}}, \code{\link[=extract_resample_weights]{extract_resample_weights()}} +} diff --git a/man/check_resample_weights.Rd b/man/check_resample_weights.Rd new file mode 100644 index 00000000..adc1376a --- /dev/null +++ b/man/check_resample_weights.Rd @@ -0,0 +1,18 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/checks.R +\name{check_resample_weights} +\alias{check_resample_weights} +\title{Check fold weights in rset objects} +\usage{ +check_resample_weights(x) +} +\arguments{ +\item{x}{An rset object.} +} +\value{ +\code{NULL} invisibly, or error if weights are invalid. +} +\description{ +Check fold weights in rset objects +} +\keyword{internal} diff --git a/man/extract_resample_weights.Rd b/man/extract_resample_weights.Rd new file mode 100644 index 00000000..335648cf --- /dev/null +++ b/man/extract_resample_weights.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{extract_resample_weights} +\alias{extract_resample_weights} +\title{Extract resample weights from rset or tuning objects} +\usage{ +extract_resample_weights(x) +} +\arguments{ +\item{x}{An rset object with resample weights, or a \code{tune_results} object.} +} +\value{ +A numeric vector of resample weights, or NULL if no weights are present. +} +\description{ +This function provides a consistent interface to access resample weights +regardless of whether they were added to an rset object or are stored +in \code{tune_results} after tuning. +} +\examples{ +\dontrun{ +library(rsample) +folds <- vfold_cv(mtcars, v = 3) +weighted_folds <- add_resample_weights(folds, c(0.2, 0.3, 0.5)) +extract_resample_weights(weighted_folds) +} +} diff --git a/man/resample_weights_utils.Rd b/man/resample_weights_utils.Rd new file mode 100644 index 00000000..130b1926 --- /dev/null +++ b/man/resample_weights_utils.Rd @@ -0,0 +1,42 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{resample_weights_utils} +\alias{resample_weights_utils} +\alias{.get_resample_weights} +\alias{.create_weight_mapping} +\alias{.weighted_sd} +\alias{.effective_sample_size} +\alias{.validate_resample_weights} +\title{Resampling weights utility functions} +\usage{ +.get_resample_weights(x) + +.create_weight_mapping(weights, id_names, metrics_data) + +.weighted_sd(x, w) + +.effective_sample_size(w) + +.validate_resample_weights(weights, num_resamples) +} +\arguments{ +\item{x}{A tune_results object.} + +\item{weights}{Numeric vector of weights.} + +\item{id_names}{Character vector of ID column names.} + +\item{metrics_data}{The metrics data frame.} + +\item{w}{Numeric vector of weights.} + +\item{num_resamples}{Integer number of resamples.} +} +\value{ +Various return values depending on the function. +} +\description{ +These are internal functions for handling variable resampling weights in +hyperparameter tuning. +} +\keyword{internal} diff --git a/tests/testthat/_snaps/checks.md b/tests/testthat/_snaps/checks.md index 2eea7eb3..6d13c88b 100644 --- a/tests/testthat/_snaps/checks.md +++ b/tests/testthat/_snaps/checks.md @@ -335,3 +335,27 @@ ! Some model parameters require finalization but there are recipe parameters that require tuning. i Please use `extract_parameter_set_dials()` to set parameter ranges manually and supply the output to the `param_info` argument. +# check fold weights + + Code + add_resample_weights(folds, c("a", "b", "c")) + Condition + Error in `.validate_resample_weights()`: + ! `weights` must be numeric. + +--- + + Code + add_resample_weights(folds, c(0.5, 0.3)) + Condition + Error in `.validate_resample_weights()`: + ! Length of `weights` (2) must equal number of resamples (3). + +--- + + Code + add_resample_weights(folds, c(-0.1, 0.5, 0.6)) + Condition + Error in `.validate_resample_weights()`: + ! `weights` must be non-negative. + diff --git a/tests/testthat/_snaps/weights.md b/tests/testthat/_snaps/weights.md new file mode 100644 index 00000000..b4aa4f06 --- /dev/null +++ b/tests/testthat/_snaps/weights.md @@ -0,0 +1,56 @@ +# add_resample_weights() validates inputs correctly + + Code + add_resample_weights("not_an_rset", c(0.5, 0.3, 0.2)) + Condition + Error in `add_resample_weights()`: + ! `rset` must be an rset object. + +--- + + Code + add_resample_weights(folds, c("a", "b", "c")) + Condition + Error in `.validate_resample_weights()`: + ! `weights` must be numeric. + +--- + + Code + add_resample_weights(folds, c(0.5, 0.3)) + Condition + Error in `.validate_resample_weights()`: + ! Length of `weights` (2) must equal number of resamples (3). + +--- + + Code + add_resample_weights(folds, c(-0.1, 0.5, 0.6)) + Condition + Error in `.validate_resample_weights()`: + ! `weights` must be non-negative. + +--- + + Code + add_resample_weights(folds, c(0, 0, 0)) + Condition + Error in `.validate_resample_weights()`: + ! At least one weight must be positive. + +# extract_resample_weights() validates input types + + Code + extract_resample_weights("not_valid_input") + Condition + Error in `extract_resample_weights()`: + ! `x` must be an rset or tune_results object. + +--- + + Code + extract_resample_weights(data.frame(x = 1:3)) + Condition + Error in `extract_resample_weights()`: + ! `x` must be an rset or tune_results object. + diff --git a/tests/testthat/test-checks.R b/tests/testthat/test-checks.R index 5d887244..4b5eb456 100644 --- a/tests/testthat/test-checks.R +++ b/tests/testthat/test-checks.R @@ -16,7 +16,7 @@ test_that("grid objects", { skip_if_not_installed("splines2") skip_if_not_installed("kernlab") data("Chicago", package = "modeldata") - data("Chicago", package = "modeldata") + spline_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) |> recipes::step_date(date) |> @@ -541,3 +541,161 @@ test_that("check parameter finalization", { ) expect_true(inherits(p5, "parameters")) }) + +test_that("check fold weights", { + folds <- rsample::vfold_cv(mtcars, v = 3) + + # No weights should pass silently + expect_no_error(tune:::check_resample_weights(folds)) + + # Valid weights should pass + weights <- c(0.1, 0.5, 0.4) + weighted_folds <- add_resample_weights(folds, weights) + expect_no_error(tune:::check_resample_weights(weighted_folds)) + + # Invalid weights should error + expect_snapshot( + add_resample_weights(folds, c("a", "b", "c")), + error = TRUE + ) + + expect_snapshot( + add_resample_weights(folds, c(0.5, 0.3)), + error = TRUE + ) + + expect_snapshot( + add_resample_weights(folds, c(-0.1, 0.5, 0.6)), + error = TRUE + ) +}) + +test_that("fold weights integration test", { + # Create simple data and resamples + set.seed(1234) + data_small <- mtcars[1:20, ] + folds <- rsample::vfold_cv(data_small, v = 3) + + # Create simple model and recipe + simple_rec <- recipes::recipe(mpg ~ wt + hp, data = data_small) + simple_mod <- parsnip::linear_reg() |> parsnip::set_engine("lm") + simple_wflow <- workflows::workflow() |> + workflows::add_recipe(simple_rec) |> + workflows::add_model(simple_mod) + + # Test with equal weights (should match unweighted results) + equal_weights <- c(1, 1, 1) + weighted_folds_equal <- add_resample_weights(folds, equal_weights) + + # Fit both weighted and unweighted + unweighted_results <- fit_resamples( + simple_wflow, + folds, + control = control_resamples(save_pred = FALSE) + ) + weighted_results_equal <- fit_resamples( + simple_wflow, + weighted_folds_equal, + control = control_resamples(save_pred = FALSE) + ) + + # Extract metrics + unweighted_metrics <- collect_metrics(unweighted_results) + weighted_metrics_equal <- collect_metrics(weighted_results_equal) + + # Should be nearly identical (allowing for small numerical differences) + expect_equal( + unweighted_metrics$mean, + weighted_metrics_equal$mean, + tolerance = 1e-10 + ) + + # Test with unequal weights + unequal_weights <- c(0.1, 0.3, 0.6) # Higher weight on last fold + weighted_folds_unequal <- add_resample_weights(folds, unequal_weights) + + weighted_results_unequal <- fit_resamples( + simple_wflow, + weighted_folds_unequal, + control = control_resamples(save_pred = FALSE) + ) + weighted_metrics_unequal <- collect_metrics(weighted_results_unequal) + + # Should be different from unweighted results + expect_false(all( + abs(unweighted_metrics$mean - weighted_metrics_unequal$mean) < 1e-10 + )) + + # Verify that weights are properly stored and retrieved + expect_equal( + attr(weighted_folds_unequal, ".resample_weights"), + unequal_weights + ) + + # Test fold size calculation + calculated_weights <- calculate_resample_weights(folds) + expect_length(calculated_weights, nrow(folds)) + expect_true(all(calculated_weights > 0)) + expect_equal(sum(calculated_weights), 1) # Should sum to 1 now +}) + +test_that("fold weights with tune_grid", { + skip_if_not_installed("kernlab") + + # Create simple tuning scenario + set.seed(5678) + data_small <- mtcars[1:15, ] + folds <- rsample::vfold_cv(data_small, v = 3) + + # Create tunable workflow + tune_rec <- recipes::recipe(mpg ~ wt + hp, data = data_small) |> + recipes::step_normalize(recipes::all_predictors()) + tune_mod <- parsnip::svm_rbf( + cost = tune(), + rbf_sigma = 0.001, + mode = "regression" + ) + + tune_wflow <- workflows::workflow() |> + workflows::add_recipe(tune_rec) |> + workflows::add_model(tune_mod) + + # Create simple grid + simple_grid <- tibble::tibble(cost = c(1, 10, 100)) + + # Test with unequal weights + weights <- c(0.2, 0.3, 0.5) + weighted_folds <- add_resample_weights(folds, weights) + + # Tune with weights + weighted_tune_results <- tune_grid( + tune_wflow, + weighted_folds, + grid = simple_grid, + control = control_grid(save_pred = FALSE) + ) + + # Verify results structure + expect_s3_class(weighted_tune_results, "tune_results") + + # Extract metrics and verify they're computed + weighted_metrics <- collect_metrics(weighted_tune_results) + expect_true(nrow(weighted_metrics) > 0) + expect_true(all(c("mean", "std_err") %in% names(weighted_metrics))) + + # Compare with unweighted results + unweighted_tune_results <- tune_grid( + tune_wflow, + folds, + grid = simple_grid, + control = control_grid(save_pred = FALSE) + ) + unweighted_metrics <- collect_metrics(unweighted_tune_results) + + # Results should differ due to weighting + expect_false(all( + abs(weighted_metrics$mean - unweighted_metrics$mean) < 1e-10 + )) +}) + +# ------------------------------------------------------------------------------ diff --git a/tests/testthat/test-weights.R b/tests/testthat/test-weights.R new file mode 100644 index 00000000..eb9f62de --- /dev/null +++ b/tests/testthat/test-weights.R @@ -0,0 +1,430 @@ +# Test file for variable fold weights functionality + +# Setup test data +set.seed(42) +test_data <- data.frame( + x1 = rnorm(50), + x2 = rnorm(50), + x3 = rnorm(50) +) +test_data$y <- 2 * test_data$x1 + 3 * test_data$x2 + rnorm(50, sd = 0.5) + +set.seed(123) +folds <- rsample::vfold_cv(mtcars, v = 3) + +# Helper function to create a simple model +create_test_model <- function() { + parsnip::linear_reg() |> parsnip::set_engine("lm") +} + + +test_that("add_resample_weights() validates inputs correctly", { + expect_snapshot( + add_resample_weights("not_an_rset", c(0.5, 0.3, 0.2)), + error = TRUE + ) + + expect_snapshot( + add_resample_weights(folds, c("a", "b", "c")), + error = TRUE + ) + + expect_snapshot( + add_resample_weights(folds, c(0.5, 0.3)), + error = TRUE + ) + + expect_snapshot( + add_resample_weights(folds, c(-0.1, 0.5, 0.6)), + error = TRUE + ) + + expect_snapshot( + add_resample_weights(folds, c(0, 0, 0)), + error = TRUE + ) +}) + +test_that("add_resample_weights() adds weights correctly", { + weights <- c(0.1, 0.5, 0.4) + weighted_folds <- add_resample_weights(folds, weights) + + # Weights get normalized to sum to 1 + expected_weights <- weights / sum(weights) + + expect_s3_class(weighted_folds, "rset") + expect_equal(attr(weighted_folds, ".resample_weights"), expected_weights) + expect_equal(nrow(weighted_folds), nrow(folds)) +}) + +test_that("calculate_resample_weights() works correctly", { + auto_weights <- calculate_resample_weights(folds) + + expect_type(auto_weights, "double") + expect_length(auto_weights, nrow(folds)) + expect_true(all(auto_weights > 0)) + expect_true(abs(sum(auto_weights) - 1) < 1e-10) +}) + +test_that("weights are preserved through tuning pipeline", { + weights <- c(0.1, 0.5, 0.4) + weighted_folds <- add_resample_weights(folds, weights) + + mod <- create_test_model() + + suppressWarnings({ + res <- tune_grid( + mod, + mpg ~ ., + resamples = weighted_folds, + grid = 1, + metrics = yardstick::metric_set(yardstick::rmse), + control = control_grid(verbose = FALSE) + ) + }) + + metrics <- collect_metrics(res) + expect_equal(nrow(metrics), 1) + expect_true("mean" %in% names(metrics)) + expect_true(is.numeric(metrics$mean)) +}) + +test_that("weights affect metric aggregation", { + weights <- c(0.1, 0.5, 0.4) + weighted_folds <- add_resample_weights(folds, weights) + + mod <- create_test_model() + + suppressWarnings({ + # Unweighted results + res_unweighted <- tune_grid( + mod, + mpg ~ ., + resamples = folds, + grid = 1, + metrics = yardstick::metric_set(yardstick::rmse), + control = control_grid(verbose = FALSE) + ) + + # Weighted results + res_weighted <- tune_grid( + mod, + mpg ~ ., + resamples = weighted_folds, + grid = 1, + metrics = yardstick::metric_set(yardstick::rmse), + control = control_grid(verbose = FALSE) + ) + }) + + unweighted_rmse <- collect_metrics(res_unweighted)$mean[1] + weighted_rmse <- collect_metrics(res_weighted)$mean[1] + + expect_true(is.numeric(unweighted_rmse)) + expect_true(is.numeric(weighted_rmse)) + expect_false(is.na(unweighted_rmse)) + expect_false(is.na(weighted_rmse)) +}) + +test_that("extreme weights show larger effect", { + skip_if_not_installed("kknn") + + # Create folds for this specific test + set.seed(42) + test_folds <- rsample::vfold_cv(test_data, v = 3) + + # Regular weights + weights <- c(0.6, 0.2, 0.2) + weighted_folds <- add_resample_weights(test_folds, weights) + + # Extreme weights + extreme_weights <- c(0.95, 0.025, 0.025) + extreme_weighted_folds <- add_resample_weights(test_folds, extreme_weights) + + # Create a model with tuning parameter + knn_spec <- parsnip::nearest_neighbor(neighbors = tune()) |> + parsnip::set_engine("kknn") |> + parsnip::set_mode("regression") + + param_grid <- data.frame(neighbors = c(3, 5)) + + suppressWarnings({ + # Unweighted + res_unweighted <- tune_grid( + knn_spec, + y ~ ., + resamples = test_folds, + grid = param_grid, + metrics = yardstick::metric_set(yardstick::rmse), + control = control_grid(verbose = FALSE) + ) + + # Regular weights + res_weighted <- tune_grid( + knn_spec, + y ~ ., + resamples = weighted_folds, + grid = param_grid, + metrics = yardstick::metric_set(yardstick::rmse), + control = control_grid(verbose = FALSE) + ) + + # Extreme weights + res_extreme <- tune_grid( + knn_spec, + y ~ ., + resamples = extreme_weighted_folds, + grid = param_grid, + metrics = yardstick::metric_set(yardstick::rmse), + control = control_grid(verbose = FALSE) + ) + }) + + unweighted_metrics <- collect_metrics(res_unweighted) + weighted_metrics <- collect_metrics(res_weighted) + extreme_metrics <- collect_metrics(res_extreme) + + # Check that results exist and are sensible + expect_equal(nrow(unweighted_metrics), 2) + expect_equal(nrow(weighted_metrics), 2) + expect_equal(nrow(extreme_metrics), 2) + + # Calculate differences + regular_diff <- max(abs(unweighted_metrics$mean - weighted_metrics$mean)) + extreme_diff <- max(abs(unweighted_metrics$mean - extreme_metrics$mean)) + + expect_true(regular_diff >= 0) + expect_true(extreme_diff >= 0) + expect_true(all(is.finite(c(regular_diff, extreme_diff)))) +}) + +test_that("weight normalization works correctly", { + expect_equal( + tune:::.validate_resample_weights(c(3, 6, 9), 3), + c(1 / 6, 1 / 3, 1 / 2) # normalized to sum to 1 + ) + + expect_equal( + tune:::.validate_resample_weights(c(0.2, 0.3, 0.5), 3), + c(0.2, 0.3, 0.5) # already normalized to sum to 1 + ) +}) + +test_that("equal weights return NULL", { + # Simplest integer match + expect_null(tune:::.validate_resample_weights(c(2, 2, 2), 3)) + + # Fractional match + expect_null(tune:::.validate_resample_weights(c(1 / 3, 1 / 3, 1 / 3), 3)) + + # Check more reseampless + expect_null(tune:::.validate_resample_weights(c(1, 1, 1, 1, 1), 5)) +}) + +test_that("unequal weights do not return NULL", { + # Check non-null decimal values + result <- tune:::.validate_resample_weights(c(0.1, 0.5, 0.4), 3) + expect_false(is.null(result)) + expect_equal(result, c(0.1, 0.5, 0.4)) + + # Non-null fractional values + result2 <- tune:::.validate_resample_weights(c(1, 2, 3), 3) + expect_false(is.null(result2)) + expect_equal(result2, c(1 / 6, 2 / 6, 3 / 6)) +}) + +test_that("add_resample_weights with equal weights returns NULL attribute", { + # Adding equal weights should trigger NULL assignment + equal_weighted_folds <- add_resample_weights(folds, c(1, 1, 1)) + expect_null(attr(equal_weighted_folds, ".resample_weights")) + + # Verify it's still an rset object + expect_s3_class(equal_weighted_folds, "rset") +}) + +test_that("equal weights produce same results as no weights", { + mod <- create_test_model() + + suppressWarnings({ + # Results with no weights + res_no_weights <- tune_grid( + mod, + mpg ~ ., + resamples = folds, + grid = 1, + metrics = yardstick::metric_set(yardstick::rmse), + control = control_grid(verbose = FALSE) + ) + + # Results with equal weights + equal_weighted_folds <- add_resample_weights(folds, c(1, 1, 1)) + res_equal_weights <- tune_grid( + mod, + mpg ~ ., + resamples = equal_weighted_folds, + grid = 1, + metrics = yardstick::metric_set(yardstick::rmse), + control = control_grid(verbose = FALSE) + ) + }) + + metrics_no_weights <- collect_metrics(res_no_weights) + metrics_equal_weights <- collect_metrics(res_equal_weights) + + # Results should match + expect_equal(metrics_no_weights$mean, metrics_equal_weights$mean) + expect_equal(metrics_no_weights$std_err, metrics_equal_weights$std_err) +}) + +test_that("weighted statistics functions work correctly", { + x <- c(1, 2, 3, 4, 5) + w <- c(0.1, 0.2, 0.3, 0.2, 0.2) + + weighted_sd <- tune:::.weighted_sd(x, w) + + expect_true(is.numeric(weighted_sd)) + expect_false(is.na(weighted_sd)) + expect_true(weighted_sd >= 0) + + # Test with NA values + x_na <- c(1, 2, NA, 4, 5) + weighted_sd_na <- tune:::.weighted_sd(x_na[!is.na(x_na)], w[!is.na(x_na)]) + + expect_true(is.numeric(weighted_sd_na)) + + # Test edge cases + expect_true(is.na(tune:::.weighted_sd(c(1), c(1)))) # single value +}) + +test_that("fold weight extraction works", { + weights <- c(0.1, 0.5, 0.4) + weighted_folds <- add_resample_weights(folds, weights) + + # Weights get normalized to sum to 1 + expected_weights <- weights / sum(weights) + + mod <- create_test_model() + + suppressWarnings({ + res <- tune_grid( + mod, + mpg ~ ., + resamples = weighted_folds, + grid = 1, + metrics = yardstick::metric_set(yardstick::rmse), + control = control_grid(verbose = FALSE) + ) + }) + + extracted_weights <- tune:::.get_resample_weights(res) + expect_equal(extracted_weights, expected_weights) +}) + +test_that("individual fold metrics can be collected", { + weights <- c(0.1, 0.5, 0.4) + weighted_folds <- add_resample_weights(folds, weights) + + mod <- create_test_model() + + suppressWarnings({ + res <- tune_grid( + mod, + mpg ~ ., + resamples = weighted_folds, + grid = 1, + metrics = yardstick::metric_set(yardstick::rmse), + control = control_grid(verbose = FALSE) + ) + }) + + # Collect individual fold metrics + individual_metrics <- collect_metrics(res, summarize = FALSE) + + expect_true(nrow(individual_metrics) >= 3) # At least one metric per fold + expect_true("id" %in% names(individual_metrics)) + expect_true(".estimate" %in% names(individual_metrics)) + expect_true(all(is.finite(individual_metrics$.estimate))) +}) + +test_that("backwards compatibility - no weights", { + mod <- create_test_model() + + suppressWarnings({ + res <- tune_grid( + mod, + mpg ~ ., + resamples = folds, # No weights + grid = 1, + metrics = yardstick::metric_set(yardstick::rmse), + control = control_grid(verbose = FALSE) + ) + }) + + metrics <- collect_metrics(res) + expect_equal(nrow(metrics), 1) + expect_true("mean" %in% names(metrics)) + expect_true(is.numeric(metrics$mean)) + expect_false(is.na(metrics$mean)) +}) + +test_that("rset tibble conversion includes fold weights", { + weights <- c(0.1, 0.4, 0.5) + weighted_folds <- add_resample_weights(folds, weights) + + # Convert to tibble manually (this is what our print method does) + x_tbl <- tibble::as_tibble(weighted_folds) + x_tbl$resample_weight <- weights + + # Verify the structure + expect_true("resample_weight" %in% names(x_tbl)) + expect_equal(x_tbl$resample_weight, weights) + expect_equal(nrow(x_tbl), 3) +}) + +test_that("extract_resample_weights() works with rset objects", { + weights <- c(0.2, 0.3, 0.5) + weighted_folds <- add_resample_weights(folds, weights) + + # Should return the weights + extracted_weights <- extract_resample_weights(weighted_folds) + expect_equal(extracted_weights, weights) + + # Should return NULL for unweighted rsets + unweighted_result <- extract_resample_weights(folds) + expect_null(unweighted_result) +}) + +test_that("extract_resample_weights() works with tune_results objects", { + weights <- c(0.1, 0.5, 0.4) + weighted_folds <- add_resample_weights(folds, weights) + + mod <- create_test_model() + + suppressWarnings({ + res <- tune_grid( + mod, + mpg ~ ., + resamples = weighted_folds, + grid = 1, + metrics = yardstick::metric_set(yardstick::rmse), + control = control_grid(verbose = FALSE) + ) + }) + + # Should extract weights from tune results + extracted_weights <- extract_resample_weights(res) + expected_weights <- weights / sum(weights) # normalized + expect_equal(extracted_weights, expected_weights) +}) + +test_that("extract_resample_weights() validates input types", { + expect_snapshot( + extract_resample_weights("not_valid_input"), + error = TRUE + ) + + expect_snapshot( + extract_resample_weights(data.frame(x = 1:3)), + error = TRUE + ) +}) diff --git a/vignettes/extras/optimizations.Rmd b/vignettes/extras/optimizations.Rmd index cd5b7264..ee97957e 100644 --- a/vignettes/extras/optimizations.Rmd +++ b/vignettes/extras/optimizations.Rmd @@ -31,8 +31,8 @@ methods("multi_predict") # There are arguments for the parameter(s) that can create multiple predictions. # For xgboost, `trees` are cheap to evaluate: -parsnip:::multi_predict._xgb.Booster %>% - formals() %>% +parsnip:::multi_predict._xgb.Booster |> + formals() |> names() ``` @@ -49,13 +49,13 @@ For example, suppose that [Isomap multi-dimensional scaling](https://en.wikipedi #| results: hide data(Chicago) iso_rec <- - recipe(ridership ~ ., data = Chicago) %>% - step_dummy(all_nominal()) %>% + recipe(ridership ~ ., data = Chicago) |> + step_dummy(all_nominal()) |> step_isomap(all_predictors(), num_terms = tune()) knn_mod <- - nearest_neighbor(neighbors = tune(), weight_func = tune()) %>% - set_engine("kknn") %>% + nearest_neighbor(neighbors = tune(), weight_func = tune()) |> + set_engine("kknn") |> set_mode("regression") ``` @@ -63,8 +63,8 @@ With the following grid: ```{r} grid <- - parameters(num_terms(c(1, 9)), neighbors(), weight_func()) %>% - grid_regular(levels = c(5, 10, 7)) %>% + parameters(num_terms(c(1, 9)), neighbors(), weight_func()) |> + grid_regular(levels = c(5, 10, 7)) |> arrange(num_terms, neighbors, weight_func) grid ``` @@ -155,11 +155,11 @@ Some helpful advice to avoid errors in parallel processing is to not use variabl #| eval: false num_pcs <- 3 -recipe(mpg ~ ., data = mtcars) %>% +recipe(mpg ~ ., data = mtcars) |> # Bad since num_pcs might not be found by a worker process step_pca(all_predictors(), num_comp = num_pcs) -recipe(mpg ~ ., data = mtcars) %>% +recipe(mpg ~ ., data = mtcars) |> # Good since the value is injected into the object step_pca(all_predictors(), num_comp = !!num_pcs) ``` diff --git a/vignettes/tune.Rmd b/vignettes/tune.Rmd index 0fe8a59e..5767c534 100644 --- a/vignettes/tune.Rmd +++ b/vignettes/tune.Rmd @@ -45,8 +45,8 @@ library(tidymodels) data(ames) set.seed(4595) -data_split <- ames %>% - mutate(Sale_Price = log10(Sale_Price)) %>% +data_split <- ames |> + mutate(Sale_Price = log10(Sale_Price)) |> initial_split(strata = Sale_Price) ames_train <- training(data_split) ames_test <- testing(data_split) @@ -59,10 +59,10 @@ For simplicity, the sale price of a house will be modeled as a function of its g #| fig-alt: A ggplot2 scatterplot. x axes plot the latitude and longitude, in side-by-side #| facets, and the log sale price is on the y axis. The clouds of points follow highly #| non-linear trends, traced by a blue trend line. -ames_train %>% - dplyr::select(Sale_Price, Longitude, Latitude) %>% +ames_train |> + dplyr::select(Sale_Price, Longitude, Latitude) |> tidyr::pivot_longer(cols = c(Longitude, Latitude), - names_to = "predictor", values_to = "value") %>% + names_to = "predictor", values_to = "value") |> ggplot(aes(x = value, Sale_Price)) + geom_point(alpha = .2) + geom_smooth(se = FALSE) + @@ -76,8 +76,8 @@ We can tag these parameters for optimization using the `tune()` function: ```{r} #| label: tag-rec ames_rec <- - recipe(Sale_Price ~ Gr_Liv_Area + Longitude + Latitude, data = ames_train) %>% - step_log(Gr_Liv_Area, base = 10) %>% + recipe(Sale_Price ~ Gr_Liv_Area + Longitude + Latitude, data = ames_train) |> + step_log(Gr_Liv_Area, base = 10) |> step_spline_natural(Longitude, Latitude, deg_free = tune()) ``` @@ -92,9 +92,9 @@ To accomplish this, individual `step_spline_natural()` terms can be added to the ```{r} #| label: tag-rec-d ames_rec <- - recipe(Sale_Price ~ Gr_Liv_Area + Longitude + Latitude, data = ames_train) %>% - step_log(Gr_Liv_Area, base = 10) %>% - step_spline_natural(Longitude, deg_free = tune("long df")) %>% + recipe(Sale_Price ~ Gr_Liv_Area + Longitude + Latitude, data = ames_train) |> + step_log(Gr_Liv_Area, base = 10) |> + step_spline_natural(Longitude, deg_free = tune("long df")) |> step_spline_natural(Latitude, deg_free = tune("lat df")) ``` @@ -124,8 +124,8 @@ The parameter objects can be easily changed using the `update()` function: ```{r} #| label: updated ames_param <- - ames_rec %>% - extract_parameter_set_dials() %>% + ames_rec |> + extract_parameter_set_dials() |> update( `long df` = spline_degree(), `lat df` = spline_degree() @@ -168,7 +168,7 @@ First is a model specification. Using functions in parsnip, a basic linear model ```{r} #| label: mod -lm_mod <- linear_reg() %>% set_engine("lm") +lm_mod <- linear_reg() |> set_engine("lm") ``` No tuning parameters here. @@ -218,8 +218,8 @@ The values in the `mean` column are the averages of the `r nrow(cv_splits)` resa ```{r} #| label: best-rmse rmse_vals <- - estimates %>% - dplyr::filter(.metric == "rmse") %>% + estimates |> + dplyr::filter(.metric == "rmse") |> arrange(mean) rmse_vals ``` @@ -245,10 +245,10 @@ Let's plot these spline functions over the data for both good and bad values of #| fig-alt: A scatterplot much like the first one, except that a smoother, red line, #| representing a spline term with fewer degrees of freedom, is also plotted. The red #| line is much smoother but accounts for the less of the variation shown. -ames_train %>% - dplyr::select(Sale_Price, Longitude, Latitude) %>% +ames_train |> + dplyr::select(Sale_Price, Longitude, Latitude) |> tidyr::pivot_longer(cols = c(Longitude, Latitude), - names_to = "predictor", values_to = "value") %>% + names_to = "predictor", values_to = "value") |> ggplot(aes(x = value, Sale_Price)) + geom_point(alpha = .2) + geom_smooth(se = FALSE, method = lm, formula = y ~ splines::ns(x, df = 3), col = "red") + @@ -269,8 +269,8 @@ Instead of a linear regression, a nonlinear model might provide good performance #| label: knn # requires the kknn package knn_mod <- - nearest_neighbor(neighbors = tune(), weight_func = tune()) %>% - set_engine("kknn") %>% + nearest_neighbor(neighbors = tune(), weight_func = tune()) |> + set_engine("kknn") |> set_mode("regression") ``` @@ -280,8 +280,8 @@ The easiest approach to optimize the pre-processing and model parameters is to b #| label: knn-wflow library(workflows) knn_wflow <- - workflow() %>% - add_model(knn_mod) %>% + workflow() |> + add_model(knn_mod) |> add_recipe(ames_rec) ``` @@ -290,8 +290,8 @@ From this, the parameter set can be used to modify the range and values of param ```{r} #| label: knn-set knn_param <- - knn_wflow %>% - extract_parameter_set_dials() %>% + knn_wflow |> + extract_parameter_set_dials() |> update( `long df` = spline_degree(c(2, 18)), `lat df` = spline_degree(c(2, 18)), @@ -337,8 +337,8 @@ The best results here were: ```{r} #| label: bo-best -collect_metrics(knn_search) %>% - dplyr::filter(.metric == "rmse") %>% +collect_metrics(knn_search) |> + dplyr::filter(.metric == "rmse") |> arrange(mean) ```