@@ -169,3 +169,160 @@ test_that("compute_grid_info - recipe and model (with submodels)", {
169
169
)
170
170
expect_equal(nrow(res ), 3 )
171
171
})
172
+ test_that(" compute_grid_info - recipe and model (with and without submodels)" , {
173
+ library(workflows )
174
+ library(parsnip )
175
+ library(recipes )
176
+ library(dials )
177
+
178
+ rec <- recipe(mpg ~ . , mtcars ) %> % step_spline_natural(deg_free = tune())
179
+ spec <- boost_tree(mode = " regression" , trees = tune(), loss_reduction = tune())
180
+
181
+ wflow <- workflow()
182
+ wflow <- add_model(wflow , spec )
183
+ wflow <- add_recipe(wflow , rec )
184
+
185
+ # use grid_regular to (partially) trigger submodel trick
186
+ set.seed(1 )
187
+ param_set <- extract_parameter_set_dials(wflow )
188
+ grid <- bind_rows(grid_regular(param_set ), grid_space_filling(param_set ))
189
+ res <- compute_grid_info(wflow , grid )
190
+
191
+ expect_equal(length(unique(res $ .iter_preprocessor )), 5 )
192
+ expect_equal(
193
+ unique(res $ .msg_preprocessor ),
194
+ paste0(" preprocessor " , 1 : 5 , " /5" )
195
+ )
196
+ expect_equal(res $ trees , c(rep(max(grid $ trees ), 10 ), 1 ))
197
+ expect_equal(unique(res $ .iter_model ), 1 : 3 )
198
+ expect_equal(
199
+ res $ .iter_config [1 : 3 ],
200
+ list (
201
+ c(" Preprocessor1_Model1" , " Preprocessor1_Model2" , " Preprocessor1_Model3" , " Preprocessor1_Model4" ),
202
+ c(" Preprocessor2_Model1" , " Preprocessor2_Model2" , " Preprocessor2_Model3" ),
203
+ c(" Preprocessor3_Model1" , " Preprocessor3_Model2" , " Preprocessor3_Model3" )
204
+ )
205
+ )
206
+ expect_equal(res $ .msg_model [1 : 3 ], paste0(" preprocessor " , 1 : 3 , " /5, model 1/3" ))
207
+ expect_equal(
208
+ res $ .submodels [1 : 3 ],
209
+ list (
210
+ list (trees = c(1L , 1000L , 1000L )),
211
+ list (trees = c(1L , 1000L )),
212
+ list (trees = c(1L , 1000L ))
213
+ )
214
+ )
215
+ expect_named(
216
+ res ,
217
+ c(" .iter_preprocessor" , " .msg_preprocessor" , " deg_free" , " trees" ,
218
+ " loss_reduction" , " .iter_model" , " .iter_config" , " .msg_model" , " .submodels" ),
219
+ ignore.order = TRUE
220
+ )
221
+ expect_equal(nrow(res ), 11 )
222
+ })
223
+
224
+ test_that(" compute_grid_info - model (with and without submodels)" , {
225
+ library(workflows )
226
+ library(parsnip )
227
+ library(recipes )
228
+ library(dials )
229
+
230
+ rec <- recipe(mpg ~ . , mtcars )
231
+ spec <- mars(num_terms = tune(), prod_degree = tune(), prune_method = tune()) %> %
232
+ set_mode(" classification" ) %> %
233
+ set_engine(" earth" )
234
+
235
+ wflow <- workflow()
236
+ wflow <- add_model(wflow , spec )
237
+ wflow <- add_recipe(wflow , rec )
238
+
239
+ set.seed(123 )
240
+ params_grid <- grid_space_filling(
241
+ num_terms() %> % range_set(c(1L , 12L )),
242
+ prod_degree(),
243
+ prune_method(values = c(" backward" , " none" , " forward" )),
244
+ size = 7 ,
245
+ type = " latin_hypercube"
246
+ )
247
+
248
+ res <- compute_grid_info(wflow , params_grid )
249
+
250
+ expect_equal(res $ .iter_preprocessor , rep(1 , 5 ))
251
+ expect_equal(res $ .msg_preprocessor , rep(" preprocessor 1/1" , 5 ))
252
+ expect_equal(length(unique(res $ num_terms )), 5 )
253
+ expect_equal(res $ .iter_model , 1 : 5 )
254
+ expect_equal(
255
+ res $ .iter_config ,
256
+ list (
257
+ c(" Preprocessor1_Model1" , " Preprocessor1_Model2" ),
258
+ c(" Preprocessor1_Model3" , " Preprocessor1_Model4" ),
259
+ " Preprocessor1_Model5" , " Preprocessor1_Model6" , " Preprocessor1_Model7"
260
+ )
261
+ )
262
+ expect_equal(
263
+ unique(res $ .msg_model ),
264
+ paste0(" preprocessor 1/1, model " , 1 : 5 ," /5" )
265
+ )
266
+ expect_equal(
267
+ res $ .submodels ,
268
+ list (
269
+ list (num_terms = c(1 )),
270
+ list (num_terms = c(3 )),
271
+ list (), list (), list ()
272
+ )
273
+ )
274
+ expect_named(
275
+ res ,
276
+ c(" .iter_preprocessor" , " .msg_preprocessor" , " num_terms" , " prod_degree" ,
277
+ " prune_method" , " .iter_model" , " .iter_config" , " .msg_model" , " .submodels" ),
278
+ ignore.order = TRUE
279
+ )
280
+ expect_equal(nrow(res ), 5 )
281
+ })
282
+
283
+ test_that(" compute_grid_info - recipe and model (no submodels but has inner grid)" , {
284
+ library(workflows )
285
+ library(parsnip )
286
+ library(recipes )
287
+ library(dials )
288
+
289
+ set.seed(1 )
290
+
291
+ helper_objects <- helper_objects_tune()
292
+
293
+ wflow <- workflow() %> %
294
+ add_recipe(helper_objects $ rec_tune_1 ) %> %
295
+ add_model(helper_objects $ svm_mod )
296
+
297
+ pset <- extract_parameter_set_dials(wflow ) %> %
298
+ update(num_comp = dials :: num_comp(c(1 , 3 )))
299
+
300
+ grid <- dials :: grid_regular(pset , levels = 3 )
301
+
302
+ res <- compute_grid_info(wflow , grid )
303
+
304
+ expect_equal(res $ .iter_preprocessor , rep(1 : 3 , each = 3 ))
305
+ expect_equal(res $ .msg_preprocessor , rep(paste0(" preprocessor " , 1 : 3 , " /3" ), each = 3 ))
306
+ expect_equal(res $ .iter_model , rep(1 : 3 , times = 3 ))
307
+ expect_equal(
308
+ res $ .iter_config ,
309
+ as.list(paste0(
310
+ rep(paste0(" Preprocessor" , 1 : 3 , " _Model" ), each = 3 ),
311
+ rep(1 : 3 , times = 3 )
312
+ ))
313
+ )
314
+ expect_equal(
315
+ unique(res $ .msg_model ),
316
+ paste0(
317
+ rep(paste0(" preprocessor " , 1 : 3 , " /3, model " ), each = 3 ),
318
+ paste0(rep(1 : 3 , times = 3 ), " /3" )
319
+ )
320
+ )
321
+ expect_named(
322
+ res ,
323
+ c(" cost" , " num_comp" , " .submodels" , " .iter_preprocessor" , " .msg_preprocessor" ,
324
+ " .iter_model" , " .iter_config" , " .msg_model" ),
325
+ ignore.order = TRUE
326
+ )
327
+ expect_equal(nrow(res ), 9 )
328
+ })
0 commit comments