Skip to content

Commit 8801942

Browse files
committed
update tunables and tests
1 parent 2390375 commit 8801942

File tree

13 files changed

+452
-38
lines changed

13 files changed

+452
-38
lines changed

R/tunable.R

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ flexsurvspline_engine_args <-
211211
tune_activations <- c("relu", "tanh", "elu", "log_sigmoid", "tanhshrink")
212212
tune_sched <- c("none", "decay_time", "decay_expo", "cyclic", "step")
213213

214-
brulee_args <-
214+
brulee_mlp_args <-
215215
tibble::tibble(
216216
name = c('epochs', 'hidden_units', 'hidden_units_2', 'activation', 'activation_2',
217217
'penalty', 'mixture', 'dropout', 'learn_rate', 'momentum', 'batch_size',
@@ -232,12 +232,17 @@ brulee_args <-
232232
list(pkg = "dials", fun = "class_weights"),
233233
list(pkg = "dials", fun = "rate_schedule", values = tune_sched)
234234
)
235+
) %>%
236+
dplyr::mutate(source = "model_spec")
237+
238+
brulee_mlp_only_args <-
239+
tibble::tibble(
240+
name =
241+
c('hidden_units', 'hidden_units_2', 'activation', 'activation_2', 'dropout')
235242
)
236243

237244
# ------------------------------------------------------------------------------
238245

239-
240-
241246
#' @export
242247
tunable.linear_reg <- function(x, ...) {
243248
res <- NextMethod()
@@ -246,18 +251,57 @@ tunable.linear_reg <- function(x, ...) {
246251
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
247252
} else if (x$engine == "brulee") {
248253
res <-
249-
brulee_args %>%
250-
dplyr::filter(name %in% tune_args(x)$name) %>%
251-
dplyr::full_join(res %>% dplyr::select(-call_info), by = "name")
254+
brulee_mlp_args %>%
255+
dplyr::anti_join(brulee_mlp_only_args, by = "name") %>%
256+
dplyr::filter(name != "class_weights") %>%
257+
dplyr::mutate(
258+
component = "linear_reg",
259+
component_id = ifelse(name %in% names(formals("linear_reg")), "main", "engine")
260+
) %>%
261+
dplyr::select(name, call_info, source, component, component_id)
252262
}
253263
res
254264
}
255265

256266
#' @export
257-
tunable.logistic_reg <- tunable.linear_reg
258267

259268
#' @export
260-
tunable.multinom_reg <- tunable.linear_reg
269+
tunable.logistic_reg <- function(x, ...) {
270+
res <- NextMethod()
271+
if (x$engine == "glmnet") {
272+
res$call_info[res$name == "mixture"] <-
273+
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
274+
} else if (x$engine == "brulee") {
275+
res <-
276+
brulee_mlp_args %>%
277+
dplyr::anti_join(brulee_mlp_only_args, by = "name") %>%
278+
dplyr::mutate(
279+
component = "logistic_reg",
280+
component_id = ifelse(name %in% names(formals("logistic_reg")), "main", "engine")
281+
) %>%
282+
dplyr::select(name, call_info, source, component, component_id)
283+
}
284+
res
285+
}
286+
287+
#' @export
288+
tunable.multinom_reg <- function(x, ...) {
289+
res <- NextMethod()
290+
if (x$engine == "glmnet") {
291+
res$call_info[res$name == "mixture"] <-
292+
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
293+
} else if (x$engine == "brulee") {
294+
res <-
295+
brulee_mlp_args %>%
296+
dplyr::anti_join(brulee_mlp_only_args, by = "name") %>%
297+
dplyr::mutate(
298+
component = "multinom_reg",
299+
component_id = ifelse(name %in% names(formals("multinom_reg")), "main", "engine")
300+
) %>%
301+
dplyr::select(name, call_info, source, component, component_id)
302+
}
303+
res
304+
}
261305

262306
#' @export
263307
tunable.boost_tree <- function(x, ...) {
@@ -335,9 +379,15 @@ tunable.mlp <- function(x, ...) {
335379
res <- NextMethod()
336380
if (grepl("brulee", x$engine)) {
337381
res <-
338-
brulee_args %>%
339-
dplyr::filter(name %in% tune_args(x)$name) %>%
340-
dplyr::full_join(res %>% dplyr::select(-call_info), by = "name")
382+
brulee_mlp_args %>%
383+
dplyr::mutate(
384+
component = "mlp",
385+
component_id = ifelse(name %in% names(formals("mlp")), "main", "engine")
386+
) %>%
387+
dplyr::select(name, call_info, source, component, component_id)
388+
if (x$engine == "brulee") {
389+
res <- res[!grepl("_2", res$name),]
390+
}
341391
}
342392
res
343393
}

man/details_mlp_brulee.Rd

Lines changed: 8 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/parsnip-package.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/rmd/mlp_brulee.md

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ For this engine, there are multiple modes: classification and regression
77

88

99

10-
This model has 6 tuning parameters:
10+
This model has 7 tuning parameters:
1111

1212
- `epochs`: # Epochs (type: integer, default: 100L)
1313

@@ -17,6 +17,8 @@ This model has 6 tuning parameters:
1717

1818
- `penalty`: Amount of Regularization (type: double, default: 0.001)
1919

20+
- `mixture`: Proportion of Lasso Penalty (type: double, default: 0.0)
21+
2022
- `dropout`: Dropout Rate (type: double, default: 0.0)
2123

2224
- `learn_rate`: Learning Rate (type: double, default: 0.01)
@@ -27,16 +29,16 @@ Both `penalty` and `dropout` should be not be used in the same model.
2729

2830
Other engine arguments of interest:
2931

30-
- `momentum()`: A number used to use historical gradient infomration during optimization.
31-
- `batch_size()`: An integer for the number of training set points in each batch.
32-
- `class_weights()`: Numeric class weights. See [brulee::brulee_mlp()].
33-
- `stop_iter()`: A non-negative integer for how many iterations with no improvement before stopping. (default: 5L).
34-
- `rate_schedule()`: A function to change the learning rate over epochs. See [brulee::schedule_decay_time()] for details.
32+
- `momentum`: A number used to use historical gradient infomration during optimization.
33+
- `batch_size`: An integer for the number of training set points in each batch.
34+
- `class_weights`: Numeric class weights. See [brulee::brulee_mlp()].
35+
- `stop_iter`: A non-negative integer for how many iterations with no improvement before stopping. (default: 5L).
36+
- `rate_schedule`: A function to change the learning rate over epochs. See [brulee::schedule_decay_time()] for details.
3537

3638
## Translation from parsnip to the original package (regression)
3739

3840

39-
```r
41+
``` r
4042
mlp(
4143
hidden_units = integer(1),
4244
penalty = double(1),
@@ -74,7 +76,7 @@ Note that parsnip automatically sets linear activation in the last layer.
7476
## Translation from parsnip to the original package (classification)
7577

7678

77-
```r
79+
``` r
7880
mlp(
7981
hidden_units = integer(1),
8082
penalty = double(1),

man/rmd/mlp_brulee_two_layer.md

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ For this engine, there are multiple modes: classification and regression
77

88

99

10-
This model has 6 tuning parameters:
10+
This model has 7 tuning parameters:
1111

1212
- `epochs`: # Epochs (type: integer, default: 100L)
1313

@@ -17,6 +17,8 @@ This model has 6 tuning parameters:
1717

1818
- `penalty`: Amount of Regularization (type: double, default: 0.001)
1919

20+
- `mixture`: Proportion of Lasso Penalty (type: double, default: 0.0)
21+
2022
- `dropout`: Dropout Rate (type: double, default: 0.0)
2123

2224
- `learn_rate`: Learning Rate (type: double, default: 0.01)
@@ -28,17 +30,17 @@ Both `penalty` and `dropout` should be not be used in the same model.
2830
Other engine arguments of interest:
2931

3032
- `hidden_layer_2` and `activation_2` control the format of the second layer.
31-
- `momentum()`: A number used to use historical gradient information during optimization.
32-
- `batch_size()`: An integer for the number of training set points in each batch.
33-
- `class_weights()`: Numeric class weights. See [brulee::brulee_mlp()].
34-
- `stop_iter()`: A non-negative integer for how many iterations with no improvement before stopping. (default: 5L).
35-
- `rate_schedule()`: A function to change the learning rate over epochs. See [brulee::schedule_decay_time()] for details.
33+
- `momentum`: A number used to use historical gradient information during optimization.
34+
- `batch_size`: An integer for the number of training set points in each batch.
35+
- `class_weights`: Numeric class weights. See [brulee::brulee_mlp()].
36+
- `stop_iter`: A non-negative integer for how many iterations with no improvement before stopping. (default: 5L).
37+
- `rate_schedule`: A function to change the learning rate over epochs. See [brulee::schedule_decay_time()] for details.
3638

3739

3840
## Translation from parsnip to the original package (regression)
3941

4042

41-
```r
43+
``` r
4244
mlp(
4345
hidden_units = integer(1),
4446
penalty = double(1),
@@ -78,12 +80,12 @@ mlp(
7880
## hidden_units_2 = integer(1), activation_2 = character(1))
7981
```
8082

81-
Note that parsnip automatically sets linear activation in the last layer.
83+
Note that parsnip automatically sets the linear activation in the last layer.
8284

8385
## Translation from parsnip to the original package (classification)
8486

8587

86-
```r
88+
``` r
8789
mlp(
8890
hidden_units = integer(1),
8991
penalty = double(1),

tests/testthat/_snaps/linear_reg.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,59 @@
175175
Error in `fit()`:
176176
! Please install the glmnet package to use this engine.
177177

178+
# tunables
179+
180+
Code
181+
linear_reg() %>% tunable()
182+
Output
183+
# A tibble: 0 x 5
184+
# i 5 variables: name <chr>, call_info <list>, source <chr>, component <chr>,
185+
# component_id <chr>
186+
187+
---
188+
189+
Code
190+
linear_reg() %>% set_engine("brulee") %>% tunable()
191+
Output
192+
# A tibble: 8 x 5
193+
name call_info source component component_id
194+
<chr> <list> <chr> <chr> <chr>
195+
1 epochs <named list [3]> model_spec linear_reg engine
196+
2 penalty <named list [2]> model_spec linear_reg main
197+
3 mixture <named list [2]> model_spec linear_reg main
198+
4 learn_rate <named list [3]> model_spec linear_reg engine
199+
5 momentum <named list [3]> model_spec linear_reg engine
200+
6 batch_size <named list [2]> model_spec linear_reg engine
201+
7 stop_iter <named list [2]> model_spec linear_reg engine
202+
8 rate_schedule <named list [3]> model_spec linear_reg engine
203+
204+
---
205+
206+
Code
207+
linear_reg() %>% set_engine("glmnet") %>% tunable()
208+
Output
209+
# A tibble: 2 x 5
210+
name call_info source component component_id
211+
<chr> <list> <chr> <chr> <chr>
212+
1 penalty <named list [2]> model_spec linear_reg main
213+
2 mixture <named list [3]> model_spec linear_reg main
214+
215+
---
216+
217+
Code
218+
linear_reg() %>% set_engine("quantreg") %>% tunable()
219+
Output
220+
# A tibble: 0 x 5
221+
# i 5 variables: name <chr>, call_info <list>, source <chr>, component <chr>,
222+
# component_id <chr>
223+
224+
---
225+
226+
Code
227+
linear_reg() %>% set_engine("keras") %>% tunable()
228+
Output
229+
# A tibble: 1 x 5
230+
name call_info source component component_id
231+
<chr> <list> <chr> <chr> <chr>
232+
1 penalty <named list [2]> model_spec linear_reg main
233+

tests/testthat/_snaps/logistic_reg.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,51 @@
139139
Error in `fit()`:
140140
! For the LiblineaR engine, `penalty` must be `> 0`, not 0.
141141

142+
# tunables
143+
144+
Code
145+
logistic_reg() %>% tunable()
146+
Output
147+
# A tibble: 0 x 5
148+
# i 5 variables: name <chr>, call_info <list>, source <chr>, component <chr>,
149+
# component_id <chr>
150+
151+
---
152+
153+
Code
154+
logistic_reg() %>% set_engine("brulee") %>% tunable()
155+
Output
156+
# A tibble: 9 x 5
157+
name call_info source component component_id
158+
<chr> <list> <chr> <chr> <chr>
159+
1 epochs <named list [3]> model_spec logistic_reg engine
160+
2 penalty <named list [2]> model_spec logistic_reg main
161+
3 mixture <named list [2]> model_spec logistic_reg main
162+
4 learn_rate <named list [3]> model_spec logistic_reg engine
163+
5 momentum <named list [3]> model_spec logistic_reg engine
164+
6 batch_size <named list [2]> model_spec logistic_reg engine
165+
7 class_weights <named list [2]> model_spec logistic_reg engine
166+
8 stop_iter <named list [2]> model_spec logistic_reg engine
167+
9 rate_schedule <named list [3]> model_spec logistic_reg engine
168+
169+
---
170+
171+
Code
172+
logistic_reg() %>% set_engine("glmnet") %>% tunable()
173+
Output
174+
# A tibble: 2 x 5
175+
name call_info source component component_id
176+
<chr> <list> <chr> <chr> <chr>
177+
1 penalty <named list [2]> model_spec logistic_reg main
178+
2 mixture <named list [3]> model_spec logistic_reg main
179+
180+
---
181+
182+
Code
183+
logistic_reg() %>% set_engine("keras") %>% tunable()
184+
Output
185+
# A tibble: 1 x 5
186+
name call_info source component component_id
187+
<chr> <list> <chr> <chr> <chr>
188+
1 penalty <named list [2]> model_spec logistic_reg main
189+

0 commit comments

Comments
 (0)