Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
59176dc
Add fold weights functionality
tjburch May 23, 2025
848a30a
Remove manually added .Rd files - these are generated by roxygen2
tjburch May 23, 2025
cc8bad1
Add print options for fold weights
Jul 22, 2025
24caeea
Merge branch 'tidymodels:main' into variable-fold-weights
tjburch Jul 22, 2025
5d97e2a
Export exposed functions
Jul 22, 2025
22754b8
printing manual_rset too?
Jul 22, 2025
03c038f
Merge branch 'tidymodels:main' into variable-fold-weights
tjburch Jul 30, 2025
6da3d80
Merge branch 'tidymodels:main' into variable-fold-weights
tjburch Jul 30, 2025
27a5ab5
Merge branch 'main' into variable-fold-weights
Sep 11, 2025
7fa521d
Run air
Sep 11, 2025
b53856f
Run air on tests
Sep 11, 2025
4c45658
Fix piping error
Sep 11, 2025
87c9e3b
Improve docs
Sep 12, 2025
b6d7836
Improve docs and actually merge the file...
Sep 12, 2025
c397d67
Merge branch 'main' into tjburch-variable-fold-weights
topepo Oct 17, 2025
5a2f855
use statistics based on functions from the stats package
topepo Oct 17, 2025
ba3b420
move away from glmnet
topepo Oct 17, 2025
44551b4
spelling update
topepo Oct 17, 2025
0a87de7
redoc
topepo Oct 17, 2025
a38c55e
news entry
topepo Oct 17, 2025
a053787
global data false positives
topepo Oct 17, 2025
79f8c65
change to base R pipe
topepo Oct 17, 2025
5cd334d
fix test case
topepo Oct 17, 2025
ac04e3f
remove print methods - made an issue in rsample for that
topepo Oct 17, 2025
612325e
fold_weight -> resample_weight
topepo Oct 17, 2025
8349ded
name user-facing function with "extract" instead of "get"
topepo Oct 17, 2025
8aea0ea
Make weights NULL when equal
Oct 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ S3method(pretty,tune_results)
S3method(print,control_bayes)
S3method(print,control_grid)
S3method(print,control_last_fit)
S3method(print,manual_rset)
S3method(print,prob_improve)
S3method(print,rset)
S3method(print,tune_results)
S3method(select_best,default)
S3method(select_best,tune_results)
Expand Down Expand Up @@ -149,8 +151,10 @@ export(.get_tune_workflow)
export(.load_namespace)
export(.stash_last_result)
export(.use_case_weights_with_yardstick)
export(add_fold_weights)
export(augment)
export(autoplot)
export(calculate_fold_weights)
export(check_eval_time_arg)
export(check_initial)
export(check_metric_in_tune_results)
Expand Down Expand Up @@ -203,6 +207,7 @@ export(fit_resamples)
export(forge_from_workflow)
export(future_installed)
export(get_future_workers)
export(get_fold_weights)
export(get_metric_time)
export(get_mirai_workers)
export(get_parallel_seeds)
Expand Down
20 changes: 20 additions & 0 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,26 @@ check_rset <- function(x) {
if (inherits(x, "permutations")) {
cli::cli_abort("Permutation samples are not suitable for tuning.")
}

# Check fold weights if present
check_fold_weights(x)

invisible(NULL)
}

#' Check fold weights in rset objects
#'
#' @param x An rset object.
#' @return `NULL` invisibly, or error if weights are invalid.
#' @keywords internal
check_fold_weights <- function(x) {
weights <- attr(x, ".fold_weights")
if (is.null(weights)) {
return(invisible(NULL))
}

.validate_fold_weights(weights, nrow(x))

invisible(NULL)
}

Expand Down
60 changes: 47 additions & 13 deletions R/collect.R
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,8 @@ estimate_tune_results <- function(x, ..., col_name = ".metrics") {
)
}

fold_weights <- .get_fold_weights(x)

# The mapping of tuning parameters and .config.
config_key <- .config_key_from_metrics(x)

Expand Down Expand Up @@ -604,19 +606,51 @@ estimate_tune_results <- function(x, ..., col_name = ".metrics") {

x <- tibble::as_tibble(x)
x <- vctrs::vec_slice(x, x$id != "Apparent")
x <- x |>
dplyr::group_by(
!!!rlang::syms(param_names),
.metric,
.estimator,
!!!rlang::syms(group_cols)
) |>
dplyr::summarize(
mean = mean(.estimate, na.rm = TRUE),
n = sum(!is.na(.estimate)),
std_err = sd(.estimate, na.rm = TRUE) / sqrt(n),
.groups = "drop"
)

# Join weights to the data if available
if (!is.null(fold_weights)) {
weight_data <- .create_weight_mapping(fold_weights, id_names, x)
if (!is.null(weight_data)) {
x <- dplyr::left_join(x, weight_data, by = id_names)
} else {
# If weight mapping failed, fall back to unweighted
fold_weights <- NULL
}
}

if (!is.null(fold_weights)) {
# Use weighted aggregation
x <- x |>
dplyr::group_by(
!!!rlang::syms(param_names),
.metric,
.estimator,
!!!rlang::syms(group_cols)
) |>
dplyr::summarize(
mean = .weighted_mean(.estimate, .fold_weight),
n = sum(!is.na(.estimate)),
effective_n = .effective_sample_size(.fold_weight[!is.na(.estimate)]),
std_err = .weighted_sd(.estimate, .fold_weight) /
sqrt(pmax(effective_n, 1)),
.groups = "drop"
) |>
dplyr::select(-effective_n)
} else {
x <- x |>
dplyr::group_by(
!!!rlang::syms(param_names),
.metric,
.estimator,
!!!rlang::syms(group_cols)
) |>
dplyr::summarize(
mean = mean(.estimate, na.rm = TRUE),
n = sum(!is.na(.estimate)),
std_err = sd(.estimate, na.rm = TRUE) / sqrt(n),
.groups = "drop"
)
}

# only join when parameters are being tuned (#600)
if (length(param_names) == 0) {
Expand Down
4 changes: 4 additions & 0 deletions R/tune_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,10 @@ pull_rset_attributes <- function(x) {
att$class <- setdiff(class(x), class(tibble::new_tibble(list())))
att$class <- att$class[att$class != "rset"]

if (!is.null(attr(x, ".fold_weights"))) {
att[[".fold_weights"]] <- attr(x, ".fold_weights")
}

lab <- try(pretty(x), silent = TRUE)
if (inherits(lab, "try-error")) {
lab <- NA_character_
Expand Down
Loading
Loading