1
-
2
1
# Unit tests are in extratests
3
2
# nocov start
4
3
5
4
# ' @export
6
5
tunable.model_spec <- function (x , ... ) {
7
-
8
6
mod_env <- get_model_env()
9
7
10
8
if (is.null(x $ engine )) {
@@ -13,9 +11,14 @@ tunable.model_spec <- function(x, ...) {
13
11
14
12
arg_name <- paste0(mod_type(x ), " _args" )
15
13
if (! (any(arg_name == names(mod_env )))) {
16
- stop(" The `parsnip` model database doesn't know about the arguments for " ,
17
- " model `" , mod_type(x ), " `. Was it registered?" ,
18
- sep = " " , call. = FALSE )
14
+ stop(
15
+ " The `parsnip` model database doesn't know about the arguments for " ,
16
+ " model `" ,
17
+ mod_type(x ),
18
+ " `. Was it registered?" ,
19
+ sep = " " ,
20
+ call. = FALSE
21
+ )
19
22
}
20
23
21
24
arg_vals <- mod_env [[arg_name ]]
@@ -28,7 +31,10 @@ tunable.model_spec <- function(x, ...) {
28
31
29
32
extra_args_tbl <-
30
33
tibble :: new_tibble(
31
- list (name = extra_args , call_info = vector(" list" , vctrs :: vec_size(extra_args ))),
34
+ list (
35
+ name = extra_args ,
36
+ call_info = vector(" list" , vctrs :: vec_size(extra_args ))
37
+ ),
32
38
nrow = vctrs :: vec_size(extra_args )
33
39
)
34
40
@@ -57,7 +63,7 @@ add_engine_parameters <- function(pset, engines) {
57
63
is_engine_param <- pset $ name %in% engines $ name
58
64
if (any(is_engine_param )) {
59
65
engine_names <- pset $ name [is_engine_param ]
60
- pset <- pset [! is_engine_param ,]
66
+ pset <- pset [! is_engine_param , ]
61
67
pset <-
62
68
dplyr :: bind_rows(pset , engines | > dplyr :: filter(name %in% engines $ name ))
63
69
}
@@ -213,9 +219,22 @@ tune_sched <- c("none", "decay_time", "decay_expo", "cyclic", "step")
213
219
214
220
brulee_mlp_args <-
215
221
tibble :: tibble(
216
- name = c(' epochs' , ' hidden_units' , ' hidden_units_2' , ' activation' , ' activation_2' ,
217
- ' penalty' , ' mixture' , ' dropout' , ' learn_rate' , ' momentum' , ' batch_size' ,
218
- ' class_weights' , ' stop_iter' , ' rate_schedule' ),
222
+ name = c(
223
+ ' epochs' ,
224
+ ' hidden_units' ,
225
+ ' hidden_units_2' ,
226
+ ' activation' ,
227
+ ' activation_2' ,
228
+ ' penalty' ,
229
+ ' mixture' ,
230
+ ' dropout' ,
231
+ ' learn_rate' ,
232
+ ' momentum' ,
233
+ ' batch_size' ,
234
+ ' class_weights' ,
235
+ ' stop_iter' ,
236
+ ' rate_schedule'
237
+ ),
219
238
call_info = list (
220
239
list (pkg = " dials" , fun = " epochs" , range = c(5L , 500L )),
221
240
list (pkg = " dials" , fun = " hidden_units" , range = c(2L , 50L )),
@@ -225,9 +244,9 @@ brulee_mlp_args <-
225
244
list (pkg = " dials" , fun = " penalty" ),
226
245
list (pkg = " dials" , fun = " mixture" ),
227
246
list (pkg = " dials" , fun = " dropout" ),
228
- list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 5 )),
229
- list (pkg = " dials" , fun = " momentum" , range = c(0.50 , 0.95 )),
230
- list (pkg = " dials" , fun = " batch_size" ),
247
+ list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 5 )),
248
+ list (pkg = " dials" , fun = " momentum" , range = c(0.00 , 0.99 )),
249
+ list (pkg = " dials" , fun = " batch_size" , range = c( 3L , 8L ) ),
231
250
list (pkg = " dials" , fun = " class_weights" ),
232
251
list (pkg = " dials" , fun = " stop_iter" ),
233
252
list (pkg = " dials" , fun = " rate_schedule" , values = tune_sched )
@@ -237,8 +256,13 @@ brulee_mlp_args <-
237
256
238
257
brulee_mlp_only_args <-
239
258
tibble :: tibble(
240
- name =
241
- c(' hidden_units' , ' hidden_units_2' , ' activation' , ' activation_2' , ' dropout' )
259
+ name = c(
260
+ ' hidden_units' ,
261
+ ' hidden_units_2' ,
262
+ ' activation' ,
263
+ ' activation_2' ,
264
+ ' dropout'
265
+ )
242
266
)
243
267
244
268
# ------------------------------------------------------------------------------
@@ -256,7 +280,11 @@ tunable.linear_reg <- function(x, ...) {
256
280
dplyr :: filter(name != " class_weights" ) | >
257
281
dplyr :: mutate(
258
282
component = " linear_reg" ,
259
- component_id = ifelse(name %in% names(formals(" linear_reg" )), " main" , " engine" )
283
+ component_id = ifelse(
284
+ name %in% names(formals(" linear_reg" )),
285
+ " main" ,
286
+ " engine"
287
+ )
260
288
) | >
261
289
dplyr :: select(name , call_info , source , component , component_id )
262
290
}
@@ -277,7 +305,11 @@ tunable.logistic_reg <- function(x, ...) {
277
305
dplyr :: anti_join(brulee_mlp_only_args , by = " name" ) | >
278
306
dplyr :: mutate(
279
307
component = " logistic_reg" ,
280
- component_id = ifelse(name %in% names(formals(" logistic_reg" )), " main" , " engine" )
308
+ component_id = ifelse(
309
+ name %in% names(formals(" logistic_reg" )),
310
+ " main" ,
311
+ " engine"
312
+ )
281
313
) | >
282
314
dplyr :: select(name , call_info , source , component , component_id )
283
315
}
@@ -296,7 +328,11 @@ tunable.multinom_reg <- function(x, ...) {
296
328
dplyr :: anti_join(brulee_mlp_only_args , by = " name" ) | >
297
329
dplyr :: mutate(
298
330
component = " multinom_reg" ,
299
- component_id = ifelse(name %in% names(formals(" multinom_reg" )), " main" , " engine" )
331
+ component_id = ifelse(
332
+ name %in% names(formals(" multinom_reg" )),
333
+ " main" ,
334
+ " engine"
335
+ )
300
336
) | >
301
337
dplyr :: select(name , call_info , source , component , component_id )
302
338
}
@@ -311,7 +347,7 @@ tunable.boost_tree <- function(x, ...) {
311
347
res $ call_info [res $ name == " sample_size" ] <-
312
348
list (list (pkg = " dials" , fun = " sample_prop" ))
313
349
res $ call_info [res $ name == " learn_rate" ] <-
314
- list (list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 2 )))
350
+ list (list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 2 )))
315
351
} else if (x $ engine == " C5.0" ) {
316
352
res <- add_engine_parameters(res , c5_boost_engine_args )
317
353
res $ call_info [res $ name == " trees" ] <-
@@ -357,9 +393,11 @@ tunable.decision_tree <- function(x, ...) {
357
393
res <- add_engine_parameters(res , c5_tree_engine_args )
358
394
} else if (x $ engine == " partykit" ) {
359
395
res <-
360
- add_engine_parameters(res ,
361
- partykit_engine_args | >
362
- dplyr :: mutate(component = " decision_tree" ))
396
+ add_engine_parameters(
397
+ res ,
398
+ partykit_engine_args | >
399
+ dplyr :: mutate(component = " decision_tree" )
400
+ )
363
401
}
364
402
res
365
403
}
@@ -386,7 +424,7 @@ tunable.mlp <- function(x, ...) {
386
424
) | >
387
425
dplyr :: select(name , call_info , source , component , component_id )
388
426
if (x $ engine == " brulee" ) {
389
- res <- res [! grepl(" _2" , res $ name ),]
427
+ res <- res [! grepl(" _2" , res $ name ), ]
390
428
}
391
429
}
392
430
res
@@ -402,4 +440,3 @@ tunable.survival_reg <- function(x, ...) {
402
440
}
403
441
404
442
# nocov end
405
-
0 commit comments