Skip to content

Commit e16bb0d

Browse files
authored
Compute grid info dplyr (#961)
* refactored compute_grid_info() using dplyr, purrr, and tidyr * remove padding in .config * sort values for tests * update test specification for different sorting * fix bug in the messages * update snapshots with new remotes * added padding back
1 parent f85eac9 commit e16bb0d

File tree

7 files changed

+110
-57
lines changed

7 files changed

+110
-57
lines changed

R/0_imports.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ utils::globalVariables(
4747
"rowwise", ".best", "location", "msg", "..object", ".eval_time",
4848
".pred_survival", ".pred_time", ".weight_censored", "nice_time",
4949
"time_metric", ".lower", ".upper", "i", "results", "term", ".alpha",
50-
".method", "old_term"
50+
".method", "old_term", ".lab_pre", ".model", ".num_models"
5151
)
5252
)
5353

R/grid_helpers.R

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -318,50 +318,71 @@ compute_grid_info <- function(workflow, grid) {
318318

319319
res <- min_grid(extract_spec_parsnip(workflow), grid)
320320

321+
syms_pre <- rlang::syms(parameters_preprocessor$id)
322+
syms_mod <- rlang::syms(parameters_model$id)
323+
324+
# ----------------------------------------------------------------------------
325+
# Create an order of execution to train the preprocessor (if any). This will
326+
# define a loop over any preprocessing tuning parameter combinations.
321327
if (any_parameters_preprocessor) {
322-
res$.iter_preprocessor <- seq_len(nrow(res))
328+
pp_df <-
329+
dplyr::distinct(res, !!!syms_pre) %>%
330+
dplyr::arrange(!!!syms_pre) %>%
331+
dplyr::mutate(
332+
.iter_preprocessor = dplyr::row_number(),
333+
.lab_pre = recipes::names0(max(dplyr::n()), "Preprocessor")
334+
)
335+
res <-
336+
dplyr::full_join(res, pp_df, by = parameters_preprocessor$id) %>%
337+
dplyr::arrange(.iter_preprocessor)
323338
} else {
324339
res$.iter_preprocessor <- 1L
340+
res$.lab_pre <- "Preprocessor1"
325341
}
326342

343+
# Make the label shown in the grid and in loggining
327344
res$.msg_preprocessor <-
328345
new_msgs_preprocessor(
329-
seq_len(max(res$.iter_preprocessor)),
346+
res$.iter_preprocessor,
330347
max(res$.iter_preprocessor)
331348
)
332349

333-
if (nrow(res) != nrow(grid) ||
334-
(any_parameters_model && !any_parameters_preprocessor)) {
335-
res$.iter_model <- seq_len(dplyr::n_distinct(res[parameters_model$id]))
336-
} else {
337-
res$.iter_model <- 1L
338-
}
339-
340-
res$.iter_config <- list(list())
341-
for (row in seq_len(nrow(res))) {
342-
res$.iter_config[row] <- list(iter_config(res[row, ]))
343-
}
350+
# ----------------------------------------------------------------------------
351+
# Now make a similar iterator across models. Conditioning on each unique
352+
# preprocessing candidate set, make an iterator for the model candidate sets
353+
# (if any)
354+
355+
res <-
356+
res %>%
357+
dplyr::group_nest(.iter_preprocessor, keep = TRUE) %>%
358+
dplyr::mutate(
359+
.iter_config = purrr::map(data, make_iter_config),
360+
.model = purrr::map(data, ~ tibble::tibble(.iter_model = seq_len(nrow(.x)))),
361+
.num_models = purrr::map_int(.model, nrow)
362+
) %>%
363+
dplyr::select(-.iter_preprocessor) %>%
364+
tidyr::unnest(cols = c(data, .model, .iter_config)) %>%
365+
dplyr::select(-.lab_pre) %>%
366+
dplyr::relocate(dplyr::starts_with(".iter"))
344367

345368
res$.msg_model <-
346-
new_msgs_model(i = res$.iter_model, n = max(res$.iter_model), res$.msg_preprocessor)
369+
new_msgs_model(i = res$.iter_model,
370+
n = res$.num_models,
371+
res$.msg_preprocessor)
347372

348-
res
373+
res %>%
374+
dplyr::select(-.num_models) %>%
375+
dplyr::relocate(dplyr::starts_with(".msg"))
349376
}
350377

351-
iter_config <- function(res_row) {
352-
submodels <- res_row$.submodels[[1]]
353-
if (identical(submodels, list())) {
354-
models <- res_row$.iter_model
355-
} else {
356-
models <- seq_len(length(submodels[[1]]) + 1)
357-
}
358-
359-
paste0(
360-
"Preprocessor",
361-
res_row$.iter_preprocessor,
362-
"_Model",
363-
format_with_padding(models)
364-
)
378+
make_iter_config <- function(dat) {
379+
# Compute labels for the models *within* each preprocessing loop.
380+
num_submodels <- purrr::map_int(dat$.submodels, ~ length(unlist(.x)))
381+
num_models <- sum(num_submodels + 1) # +1 for the model being trained
382+
.mod_label <- recipes::names0(num_models, "Model")
383+
.iter_config <- paste(dat$.lab_pre[1], .mod_label, sep = "_")
384+
.iter_config <- vctrs::vec_chop(.iter_config, sizes = num_submodels + 1)
385+
tibble::tibble(.iter_config = .iter_config)
365386
}
366387

367388
# This generates a "dummy" grid_info object that has the same

tests/testthat/_snaps/bayes.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -393,12 +393,12 @@
393393
Message
394394
x Fold1: preprocessor 1/1:
395395
Error in `step_spline_b()`:
396-
Caused by error in `spline_msg()`:
397-
! Error in if (df < 0) : missing value where TRUE/FALSE needed
396+
Caused by error in `prep()`:
397+
! `deg_free` must be a whole number, not a numeric `NA`.
398398
x Fold2: preprocessor 1/1:
399399
Error in `step_spline_b()`:
400-
Caused by error in `spline_msg()`:
401-
! Error in if (df < 0) : missing value where TRUE/FALSE needed
400+
Caused by error in `prep()`:
401+
! `deg_free` must be a whole number, not a numeric `NA`.
402402
Condition
403403
Warning:
404404
All models failed. Run `show_notes(.Last.tune.result)` for more information.
@@ -415,10 +415,10 @@
415415
Message
416416
x Fold1: preprocessor 1/1:
417417
Error in `get_all_predictors()`:
418-
! The following predictors were not found in `data`: 'z'.
418+
! The following predictor was not found in `data`: "z".
419419
x Fold2: preprocessor 1/1:
420420
Error in `get_all_predictors()`:
421-
! The following predictors were not found in `data`: 'z'.
421+
! The following predictor was not found in `data`: "z".
422422
Condition
423423
Warning:
424424
All models failed. Run `show_notes(.Last.tune.result)` for more information.

tests/testthat/_snaps/checks.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,22 @@
9292
Error in `tune:::check_workflow()`:
9393
! A parsnip model is required.
9494

95+
# errors informatively when needed package isn't installed
96+
97+
Code
98+
check_workflow(stan_wflow)
99+
Condition
100+
Error:
101+
! Package install is required for rstanarm.
102+
103+
---
104+
105+
Code
106+
fit_resamples(stan_wflow, rsample::bootstraps(mtcars))
107+
Condition
108+
Error in `fit_resamples()`:
109+
! Package install is required for rstanarm.
110+
95111
# workflow objects (will not tune, tidymodels/tune#548)
96112

97113
Code

tests/testthat/_snaps/grid.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
Message
99
x Fold1: preprocessor 1/1:
1010
Error in `step_spline_b()`:
11-
Caused by error in `spline_msg()`:
12-
! Error in if (df < 0) : missing value where TRUE/FALSE needed
11+
Caused by error in `prep()`:
12+
! `deg_free` must be a whole number, not a numeric `NA`.
1313
x Fold2: preprocessor 1/1:
1414
Error in `step_spline_b()`:
15-
Caused by error in `spline_msg()`:
16-
! Error in if (df < 0) : missing value where TRUE/FALSE needed
15+
Caused by error in `prep()`:
16+
! `deg_free` must be a whole number, not a numeric `NA`.
1717
Condition
1818
Warning:
1919
All models failed. Run `show_notes(.Last.tune.result)` for more information.
@@ -28,10 +28,10 @@
2828
Message
2929
x Fold1: preprocessor 1/1:
3030
Error in `get_all_predictors()`:
31-
! The following predictors were not found in `data`: 'z'.
31+
! The following predictor was not found in `data`: "z".
3232
x Fold2: preprocessor 1/1:
3333
Error in `get_all_predictors()`:
34-
! The following predictors were not found in `data`: 'z'.
34+
! The following predictor was not found in `data`: "z".
3535
Condition
3636
Warning:
3737
All models failed. Run `show_notes(.Last.tune.result)` for more information.

tests/testthat/_snaps/resample.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
Message
66
x Fold1: preprocessor 1/1:
77
Error in `step_spline_natural()`:
8-
Caused by error in `spline_msg()`:
9-
! Error in if (df < 2) : missing value where TRUE/FALSE needed
8+
Caused by error in `prep()`:
9+
! `deg_free` must be a whole number, not a numeric `NA`.
1010
x Fold2: preprocessor 1/1:
1111
Error in `step_spline_natural()`:
12-
Caused by error in `spline_msg()`:
13-
! Error in if (df < 2) : missing value where TRUE/FALSE needed
12+
Caused by error in `prep()`:
13+
! `deg_free` must be a whole number, not a numeric `NA`.
1414
Condition
1515
Warning:
1616
All models failed. Run `show_notes(.Last.tune.result)` for more information.
@@ -20,7 +20,7 @@
2020
Code
2121
note
2222
Output
23-
[1] "Error in `step_spline_natural()`:\nCaused by error in `spline_msg()`:\n! Error in if (df < 2) { : missing value where TRUE/FALSE needed"
23+
[1] "Error in `step_spline_natural()`:\nCaused by error in `prep()`:\n! `deg_free` must be a whole number, not a numeric `NA`."
2424

2525
# failure in variables tidyselect specification is caught elegantly
2626

tests/testthat/test-grid_helpers.R

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ test_that("compute_grid_info - recipe only", {
1515

1616
expect_equal(res$.iter_preprocessor, 1:5)
1717
expect_equal(res$.msg_preprocessor, paste0("preprocessor ", 1:5, "/5"))
18-
expect_equal(res$deg_free, grid$deg_free)
18+
expect_equal(sort(res$deg_free), sort(grid$deg_free))
1919
expect_equal(res$.iter_model, rep(1, 5))
2020
expect_equal(res$.iter_config, as.list(paste0("Preprocessor", 1:5, "_Model1")))
2121
expect_equal(res$.msg_model, paste0("preprocessor ", 1:5, "/5, model 1/1"))
@@ -27,6 +27,7 @@ test_that("compute_grid_info - recipe only", {
2727
ignore.order = TRUE
2828
)
2929
expect_equal(nrow(res), 5)
30+
expect_equal(vctrs::vec_unique_count(res$.iter_config), nrow(grid))
3031
})
3132

3233
test_that("compute_grid_info - model only (no submodels)", {
@@ -57,6 +58,7 @@ test_that("compute_grid_info - model only (no submodels)", {
5758
ignore.order = TRUE
5859
)
5960
expect_equal(nrow(res), 5)
61+
expect_equal(vctrs::vec_unique_count(res$.iter_config), nrow(grid))
6062
})
6163

6264
test_that("compute_grid_info - model only (with submodels)", {
@@ -107,8 +109,8 @@ test_that("compute_grid_info - recipe and model (no submodels)", {
107109

108110
expect_equal(res$.iter_preprocessor, 1:5)
109111
expect_equal(res$.msg_preprocessor, paste0("preprocessor ", 1:5, "/5"))
110-
expect_equal(res$learn_rate, grid$learn_rate)
111-
expect_equal(res$deg_free, grid$deg_free)
112+
expect_equal(sort(res$learn_rate), sort(grid$learn_rate))
113+
expect_equal(sort(res$deg_free), sort(grid$deg_free))
112114
expect_equal(res$.iter_model, rep(1, 5))
113115
expect_equal(res$.iter_config, as.list(paste0("Preprocessor", 1:5, "_Model1")))
114116
expect_equal(res$.msg_model, paste0("preprocessor ", 1:5, "/5, model 1/1"))
@@ -120,6 +122,7 @@ test_that("compute_grid_info - recipe and model (no submodels)", {
120122
ignore.order = TRUE
121123
)
122124
expect_equal(nrow(res), 5)
125+
expect_equal(vctrs::vec_unique_count(res$.iter_config), nrow(grid))
123126
})
124127

125128
test_that("compute_grid_info - recipe and model (with submodels)", {
@@ -169,6 +172,7 @@ test_that("compute_grid_info - recipe and model (with submodels)", {
169172
)
170173
expect_equal(nrow(res), 3)
171174
})
175+
172176
test_that("compute_grid_info - recipe and model (with and without submodels)", {
173177
library(workflows)
174178
library(parsnip)
@@ -185,25 +189,30 @@ test_that("compute_grid_info - recipe and model (with and without submodels)", {
185189
# use grid_regular to (partially) trigger submodel trick
186190
set.seed(1)
187191
param_set <- extract_parameter_set_dials(wflow)
188-
grid <- bind_rows(grid_regular(param_set), grid_space_filling(param_set))
192+
grid <-
193+
bind_rows(grid_regular(param_set), grid_space_filling(param_set)) %>%
194+
arrange(deg_free, loss_reduction, trees)
189195
res <- compute_grid_info(wflow, grid)
190196

191197
expect_equal(length(unique(res$.iter_preprocessor)), 5)
192198
expect_equal(
193199
unique(res$.msg_preprocessor),
194200
paste0("preprocessor ", 1:5, "/5")
195201
)
196-
expect_equal(res$trees, c(rep(max(grid$trees), 10), 1))
202+
expect_equal(sort(res$trees), sort(c(rep(max(grid$trees), 10), 1)))
197203
expect_equal(unique(res$.iter_model), 1:3)
198204
expect_equal(
199-
res$.iter_config[1:3],
205+
res$.iter_config[res$.iter_preprocessor == 1],
200206
list(
201-
c("Preprocessor1_Model1", "Preprocessor1_Model2", "Preprocessor1_Model3", "Preprocessor1_Model4"),
202-
c("Preprocessor2_Model1", "Preprocessor2_Model2", "Preprocessor2_Model3"),
203-
c("Preprocessor3_Model1", "Preprocessor3_Model2", "Preprocessor3_Model3")
207+
c("Preprocessor1_Model01", "Preprocessor1_Model02", "Preprocessor1_Model03", "Preprocessor1_Model04"),
208+
c("Preprocessor1_Model05", "Preprocessor1_Model06", "Preprocessor1_Model07"),
209+
c("Preprocessor1_Model08", "Preprocessor1_Model09", "Preprocessor1_Model10")
204210
)
205211
)
206-
expect_equal(res$.msg_model[1:3], paste0("preprocessor ", 1:3, "/5, model 1/3"))
212+
expect_equal(
213+
res$.msg_model[res$.iter_preprocessor == 1],
214+
paste0("preprocessor 1/5, model ", 1:3, "/3")
215+
)
207216
expect_equal(
208217
res$.submodels[1:3],
209218
list(
@@ -212,6 +221,12 @@ test_that("compute_grid_info - recipe and model (with and without submodels)", {
212221
list(trees = c(1L, 1000L))
213222
)
214223
)
224+
expect_equal(
225+
res %>%
226+
mutate(num_models = purrr::map_int(.iter_config, length)) %>%
227+
summarize(n = sum(num_models), .by = c(deg_free)),
228+
grid %>% count(deg_free)
229+
)
215230
expect_named(
216231
res,
217232
c(".iter_preprocessor", ".msg_preprocessor", "deg_free", "trees",
@@ -325,4 +340,5 @@ test_that("compute_grid_info - recipe and model (no submodels but has inner grid
325340
ignore.order = TRUE
326341
)
327342
expect_equal(nrow(res), 9)
343+
expect_equal(vctrs::vec_unique_count(res$.iter_config), nrow(grid))
328344
})

0 commit comments

Comments
 (0)