Skip to content

Commit f85eac9

Browse files
committed
add additional compute_grid_info() tests [no ci]
These tests will fail with current `main` but pass with #960. Max will file a PR today with a new draft of the helper, and that helper ought to pass these tests.
1 parent abe9fa6 commit f85eac9

File tree

1 file changed

+157
-0
lines changed

1 file changed

+157
-0
lines changed

tests/testthat/test-grid_helpers.R

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,160 @@ test_that("compute_grid_info - recipe and model (with submodels)", {
169169
)
170170
expect_equal(nrow(res), 3)
171171
})
172+
test_that("compute_grid_info - recipe and model (with and without submodels)", {
173+
library(workflows)
174+
library(parsnip)
175+
library(recipes)
176+
library(dials)
177+
178+
rec <- recipe(mpg ~ ., mtcars) %>% step_spline_natural(deg_free = tune())
179+
spec <- boost_tree(mode = "regression", trees = tune(), loss_reduction = tune())
180+
181+
wflow <- workflow()
182+
wflow <- add_model(wflow, spec)
183+
wflow <- add_recipe(wflow, rec)
184+
185+
# use grid_regular to (partially) trigger submodel trick
186+
set.seed(1)
187+
param_set <- extract_parameter_set_dials(wflow)
188+
grid <- bind_rows(grid_regular(param_set), grid_space_filling(param_set))
189+
res <- compute_grid_info(wflow, grid)
190+
191+
expect_equal(length(unique(res$.iter_preprocessor)), 5)
192+
expect_equal(
193+
unique(res$.msg_preprocessor),
194+
paste0("preprocessor ", 1:5, "/5")
195+
)
196+
expect_equal(res$trees, c(rep(max(grid$trees), 10), 1))
197+
expect_equal(unique(res$.iter_model), 1:3)
198+
expect_equal(
199+
res$.iter_config[1:3],
200+
list(
201+
c("Preprocessor1_Model1", "Preprocessor1_Model2", "Preprocessor1_Model3", "Preprocessor1_Model4"),
202+
c("Preprocessor2_Model1", "Preprocessor2_Model2", "Preprocessor2_Model3"),
203+
c("Preprocessor3_Model1", "Preprocessor3_Model2", "Preprocessor3_Model3")
204+
)
205+
)
206+
expect_equal(res$.msg_model[1:3], paste0("preprocessor ", 1:3, "/5, model 1/3"))
207+
expect_equal(
208+
res$.submodels[1:3],
209+
list(
210+
list(trees = c(1L, 1000L, 1000L)),
211+
list(trees = c(1L, 1000L)),
212+
list(trees = c(1L, 1000L))
213+
)
214+
)
215+
expect_named(
216+
res,
217+
c(".iter_preprocessor", ".msg_preprocessor", "deg_free", "trees",
218+
"loss_reduction", ".iter_model", ".iter_config", ".msg_model", ".submodels"),
219+
ignore.order = TRUE
220+
)
221+
expect_equal(nrow(res), 11)
222+
})
223+
224+
test_that("compute_grid_info - model (with and without submodels)", {
225+
library(workflows)
226+
library(parsnip)
227+
library(recipes)
228+
library(dials)
229+
230+
rec <- recipe(mpg ~ ., mtcars)
231+
spec <- mars(num_terms = tune(), prod_degree = tune(), prune_method = tune()) %>%
232+
set_mode("classification") %>%
233+
set_engine("earth")
234+
235+
wflow <- workflow()
236+
wflow <- add_model(wflow, spec)
237+
wflow <- add_recipe(wflow, rec)
238+
239+
set.seed(123)
240+
params_grid <- grid_space_filling(
241+
num_terms() %>% range_set(c(1L, 12L)),
242+
prod_degree(),
243+
prune_method(values = c("backward", "none", "forward")),
244+
size = 7,
245+
type = "latin_hypercube"
246+
)
247+
248+
res <- compute_grid_info(wflow, params_grid)
249+
250+
expect_equal(res$.iter_preprocessor, rep(1, 5))
251+
expect_equal(res$.msg_preprocessor, rep("preprocessor 1/1", 5))
252+
expect_equal(length(unique(res$num_terms)), 5)
253+
expect_equal(res$.iter_model, 1:5)
254+
expect_equal(
255+
res$.iter_config,
256+
list(
257+
c("Preprocessor1_Model1", "Preprocessor1_Model2"),
258+
c("Preprocessor1_Model3", "Preprocessor1_Model4"),
259+
"Preprocessor1_Model5", "Preprocessor1_Model6", "Preprocessor1_Model7"
260+
)
261+
)
262+
expect_equal(
263+
unique(res$.msg_model),
264+
paste0("preprocessor 1/1, model ", 1:5,"/5")
265+
)
266+
expect_equal(
267+
res$.submodels,
268+
list(
269+
list(num_terms = c(1)),
270+
list(num_terms = c(3)),
271+
list(), list(), list()
272+
)
273+
)
274+
expect_named(
275+
res,
276+
c(".iter_preprocessor", ".msg_preprocessor", "num_terms", "prod_degree",
277+
"prune_method", ".iter_model", ".iter_config", ".msg_model", ".submodels"),
278+
ignore.order = TRUE
279+
)
280+
expect_equal(nrow(res), 5)
281+
})
282+
283+
test_that("compute_grid_info - recipe and model (no submodels but has inner grid)", {
284+
library(workflows)
285+
library(parsnip)
286+
library(recipes)
287+
library(dials)
288+
289+
set.seed(1)
290+
291+
helper_objects <- helper_objects_tune()
292+
293+
wflow <- workflow() %>%
294+
add_recipe(helper_objects$rec_tune_1) %>%
295+
add_model(helper_objects$svm_mod)
296+
297+
pset <- extract_parameter_set_dials(wflow) %>%
298+
update(num_comp = dials::num_comp(c(1, 3)))
299+
300+
grid <- dials::grid_regular(pset, levels = 3)
301+
302+
res <- compute_grid_info(wflow, grid)
303+
304+
expect_equal(res$.iter_preprocessor, rep(1:3, each = 3))
305+
expect_equal(res$.msg_preprocessor, rep(paste0("preprocessor ", 1:3, "/3"), each = 3))
306+
expect_equal(res$.iter_model, rep(1:3, times = 3))
307+
expect_equal(
308+
res$.iter_config,
309+
as.list(paste0(
310+
rep(paste0("Preprocessor", 1:3, "_Model"), each = 3),
311+
rep(1:3, times = 3)
312+
))
313+
)
314+
expect_equal(
315+
unique(res$.msg_model),
316+
paste0(
317+
rep(paste0("preprocessor ", 1:3, "/3, model "), each = 3),
318+
paste0(rep(1:3, times = 3), "/3")
319+
)
320+
)
321+
expect_named(
322+
res,
323+
c("cost", "num_comp", ".submodels", ".iter_preprocessor", ".msg_preprocessor",
324+
".iter_model", ".iter_config", ".msg_model"),
325+
ignore.order = TRUE
326+
)
327+
expect_equal(nrow(res), 9)
328+
})

0 commit comments

Comments
 (0)