-
Notifications
You must be signed in to change notification settings - Fork 44
Compute grid info dplyr #961
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
38fe8a0
2a78da3
a5d94aa
d719360
7450c12
3935987
cd1a79f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -318,50 +318,71 @@ compute_grid_info <- function(workflow, grid) { | |
|
||
res <- min_grid(extract_spec_parsnip(workflow), grid) | ||
|
||
syms_pre <- rlang::syms(parameters_preprocessor$id) | ||
syms_mod <- rlang::syms(parameters_model$id) | ||
|
||
# ---------------------------------------------------------------------------- | ||
# Create an order of execution to train the preprocessor (if any). This will | ||
# define a loop over any preprocessing tuning parameter combinations. | ||
if (any_parameters_preprocessor) { | ||
res$.iter_preprocessor <- seq_len(nrow(res)) | ||
pp_df <- | ||
dplyr::distinct(res, !!!syms_pre) %>% | ||
dplyr::arrange(!!!syms_pre) %>% | ||
dplyr::mutate( | ||
.iter_preprocessor = dplyr::row_number(), | ||
.lab_pre = recipes::names0(max(dplyr::n()), "Preprocessor") | ||
) | ||
res <- | ||
dplyr::full_join(res, pp_df, by = parameters_preprocessor$id) %>% | ||
dplyr::arrange(.iter_preprocessor) | ||
} else { | ||
res$.iter_preprocessor <- 1L | ||
res$.lab_pre <- "Preprocessor1" | ||
} | ||
|
||
# Make the label shown in the grid and in loggining | ||
res$.msg_preprocessor <- | ||
new_msgs_preprocessor( | ||
seq_len(max(res$.iter_preprocessor)), | ||
res$.iter_preprocessor, | ||
max(res$.iter_preprocessor) | ||
) | ||
|
||
if (nrow(res) != nrow(grid) || | ||
(any_parameters_model && !any_parameters_preprocessor)) { | ||
res$.iter_model <- seq_len(dplyr::n_distinct(res[parameters_model$id])) | ||
} else { | ||
res$.iter_model <- 1L | ||
} | ||
|
||
res$.iter_config <- list(list()) | ||
for (row in seq_len(nrow(res))) { | ||
res$.iter_config[row] <- list(iter_config(res[row, ])) | ||
} | ||
# ---------------------------------------------------------------------------- | ||
# Now make a similar iterator across models. Conditioning on each unique | ||
# preprocessing candidate set, make an iterator for the model candidate sets | ||
# (if any) | ||
|
||
res <- | ||
res %>% | ||
dplyr::group_nest(.iter_preprocessor, keep = TRUE) %>% | ||
dplyr::mutate( | ||
.iter_config = purrr::map(data, make_iter_config), | ||
.model = purrr::map(data, ~ tibble::tibble(.iter_model = seq_len(nrow(.x)))), | ||
.num_models = purrr::map_int(.model, nrow) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is needed for |
||
) %>% | ||
dplyr::select(-.iter_preprocessor) %>% | ||
tidyr::unnest(cols = c(data, .model, .iter_config)) %>% | ||
dplyr::select(-.lab_pre) %>% | ||
dplyr::relocate(dplyr::starts_with(".iter")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This (and another |
||
|
||
res$.msg_model <- | ||
new_msgs_model(i = res$.iter_model, n = max(res$.iter_model), res$.msg_preprocessor) | ||
new_msgs_model(i = res$.iter_model, | ||
n = res$.num_models, | ||
res$.msg_preprocessor) | ||
|
||
res | ||
res %>% | ||
dplyr::select(-.num_models) %>% | ||
dplyr::relocate(dplyr::starts_with(".msg")) | ||
} | ||
|
||
iter_config <- function(res_row) { | ||
submodels <- res_row$.submodels[[1]] | ||
if (identical(submodels, list())) { | ||
models <- res_row$.iter_model | ||
} else { | ||
models <- seq_len(length(submodels[[1]]) + 1) | ||
} | ||
|
||
paste0( | ||
"Preprocessor", | ||
res_row$.iter_preprocessor, | ||
"_Model", | ||
format_with_padding(models) | ||
) | ||
make_iter_config <- function(dat) { | ||
# Compute labels for the models *within* each preprocessing loop. | ||
num_submodels <- purrr::map_int(dat$.submodels, ~ length(unlist(.x))) | ||
num_models <- sum(num_submodels + 1) # +1 for the model being trained | ||
.mod_label <- recipes::names0(num_models, "Model") | ||
.iter_config <- paste(dat$.lab_pre[1], .mod_label, sep = "_") | ||
.iter_config <- vctrs::vec_chop(.iter_config, sizes = num_submodels + 1) | ||
tibble::tibble(.iter_config = .iter_config) | ||
} | ||
|
||
# This generates a "dummy" grid_info object that has the same | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ test_that("compute_grid_info - recipe only", { | |
|
||
expect_equal(res$.iter_preprocessor, 1:5) | ||
expect_equal(res$.msg_preprocessor, paste0("preprocessor ", 1:5, "/5")) | ||
expect_equal(res$deg_free, grid$deg_free) | ||
expect_equal(sort(res$deg_free), sort(grid$deg_free)) | ||
expect_equal(res$.iter_model, rep(1, 5)) | ||
expect_equal(res$.iter_config, as.list(paste0("Preprocessor", 1:5, "_Model1"))) | ||
expect_equal(res$.msg_model, paste0("preprocessor ", 1:5, "/5, model 1/1")) | ||
|
@@ -27,6 +27,7 @@ test_that("compute_grid_info - recipe only", { | |
ignore.order = TRUE | ||
) | ||
expect_equal(nrow(res), 5) | ||
expect_equal(vctrs::vec_unique_count(res$.iter_config), nrow(grid)) | ||
}) | ||
|
||
test_that("compute_grid_info - model only (no submodels)", { | ||
|
@@ -57,6 +58,7 @@ test_that("compute_grid_info - model only (no submodels)", { | |
ignore.order = TRUE | ||
) | ||
expect_equal(nrow(res), 5) | ||
expect_equal(vctrs::vec_unique_count(res$.iter_config), nrow(grid)) | ||
}) | ||
|
||
test_that("compute_grid_info - model only (with submodels)", { | ||
|
@@ -107,8 +109,8 @@ test_that("compute_grid_info - recipe and model (no submodels)", { | |
|
||
expect_equal(res$.iter_preprocessor, 1:5) | ||
expect_equal(res$.msg_preprocessor, paste0("preprocessor ", 1:5, "/5")) | ||
expect_equal(res$learn_rate, grid$learn_rate) | ||
expect_equal(res$deg_free, grid$deg_free) | ||
expect_equal(sort(res$learn_rate), sort(grid$learn_rate)) | ||
expect_equal(sort(res$deg_free), sort(grid$deg_free)) | ||
expect_equal(res$.iter_model, rep(1, 5)) | ||
expect_equal(res$.iter_config, as.list(paste0("Preprocessor", 1:5, "_Model1"))) | ||
expect_equal(res$.msg_model, paste0("preprocessor ", 1:5, "/5, model 1/1")) | ||
|
@@ -120,6 +122,7 @@ test_that("compute_grid_info - recipe and model (no submodels)", { | |
ignore.order = TRUE | ||
) | ||
expect_equal(nrow(res), 5) | ||
expect_equal(vctrs::vec_unique_count(res$.iter_config), nrow(grid)) | ||
}) | ||
|
||
test_that("compute_grid_info - recipe and model (with submodels)", { | ||
|
@@ -169,6 +172,7 @@ test_that("compute_grid_info - recipe and model (with submodels)", { | |
) | ||
expect_equal(nrow(res), 3) | ||
}) | ||
|
||
test_that("compute_grid_info - recipe and model (with and without submodels)", { | ||
library(workflows) | ||
library(parsnip) | ||
|
@@ -185,25 +189,30 @@ test_that("compute_grid_info - recipe and model (with and without submodels)", { | |
# use grid_regular to (partially) trigger submodel trick | ||
set.seed(1) | ||
param_set <- extract_parameter_set_dials(wflow) | ||
grid <- bind_rows(grid_regular(param_set), grid_space_filling(param_set)) | ||
grid <- | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was going to add more tests for unbalanced grids but this one covers it well. |
||
bind_rows(grid_regular(param_set), grid_space_filling(param_set)) %>% | ||
arrange(deg_free, loss_reduction, trees) | ||
res <- compute_grid_info(wflow, grid) | ||
|
||
expect_equal(length(unique(res$.iter_preprocessor)), 5) | ||
expect_equal( | ||
unique(res$.msg_preprocessor), | ||
paste0("preprocessor ", 1:5, "/5") | ||
) | ||
expect_equal(res$trees, c(rep(max(grid$trees), 10), 1)) | ||
expect_equal(sort(res$trees), sort(c(rep(max(grid$trees), 10), 1))) | ||
expect_equal(unique(res$.iter_model), 1:3) | ||
expect_equal( | ||
res$.iter_config[1:3], | ||
res$.iter_config[res$.iter_preprocessor == 1], | ||
list( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Your new enumeration is correct (and fixes the case where I had broken previously), but it does look like we're missing a
|
||
c("Preprocessor1_Model1", "Preprocessor1_Model2", "Preprocessor1_Model3", "Preprocessor1_Model4"), | ||
c("Preprocessor2_Model1", "Preprocessor2_Model2", "Preprocessor2_Model3"), | ||
c("Preprocessor3_Model1", "Preprocessor3_Model2", "Preprocessor3_Model3") | ||
c("Preprocessor1_Model01", "Preprocessor1_Model02", "Preprocessor1_Model03", "Preprocessor1_Model04"), | ||
c("Preprocessor1_Model05", "Preprocessor1_Model06", "Preprocessor1_Model07"), | ||
c("Preprocessor1_Model08", "Preprocessor1_Model09", "Preprocessor1_Model10") | ||
) | ||
) | ||
expect_equal(res$.msg_model[1:3], paste0("preprocessor ", 1:3, "/5, model 1/3")) | ||
expect_equal( | ||
res$.msg_model[res$.iter_preprocessor == 1], | ||
paste0("preprocessor 1/5, model ", 1:3, "/3") | ||
) | ||
expect_equal( | ||
res$.submodels[1:3], | ||
list( | ||
|
@@ -212,6 +221,12 @@ test_that("compute_grid_info - recipe and model (with and without submodels)", { | |
list(trees = c(1L, 1000L)) | ||
) | ||
) | ||
expect_equal( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can't be too paranoid about this. |
||
res %>% | ||
mutate(num_models = purrr::map_int(.iter_config, length)) %>% | ||
summarize(n = sum(num_models), .by = c(deg_free)), | ||
grid %>% count(deg_free) | ||
) | ||
expect_named( | ||
res, | ||
c(".iter_preprocessor", ".msg_preprocessor", "deg_free", "trees", | ||
|
@@ -325,4 +340,5 @@ test_that("compute_grid_info - recipe and model (no submodels but has inner grid | |
ignore.order = TRUE | ||
) | ||
expect_equal(nrow(res), 9) | ||
expect_equal(vctrs::vec_unique_count(res$.iter_config), nrow(grid)) | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will make the ordering of the preprocessors predictable. Previously, it would order them as-is. It's no big deal, but someone might wonder why
deg_free = 10
is executed beforedeg_free=2
.This is why there are several
sort()
calls in the unit tests.