Skip to content

Commit 9f0ba3f

Browse files
toggle mlp activation by engine (#1246)
* toggle mlp activation by engine * normalize hard sigmoid activation name with brulee * add skip_if_not_installed
1 parent d2d8014 commit 9f0ba3f

File tree

4 files changed

+30
-2
lines changed

4 files changed

+30
-2
lines changed

R/mlp.R

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ keras_mlp <-
200200
{.val {activation}}."
201201
)
202202
}
203+
activation <- get_activation_fn(activation)
203204

204205
if (penalty > 0 & dropout > 0) {
205206
cli::cli_abort("Please use either dropout or weight decay.", call = NULL)
@@ -351,7 +352,7 @@ mlp_num_weights <- function(p, hidden_units, classes) {
351352
}
352353

353354
allowed_keras_activation <-
354-
c("elu", "exponential", "gelu", "hard_sigmoid", "linear", "relu", "selu",
355+
c("elu", "exponential", "gelu", "hardsigmoid", "linear", "relu", "selu",
355356
"sigmoid", "softmax", "softplus", "softsign", "swish", "tanh")
356357

357358
#' Activation functions for neural networks in keras
@@ -363,6 +364,13 @@ keras_activations <- function() {
363364
allowed_keras_activation
364365
}
365366

367+
get_activation_fn <- function(arg, ...) {
368+
if (arg == "hardsigmoid") {
369+
arg <- "hard_sigmoid"
370+
}
371+
arg
372+
}
373+
366374
## -----------------------------------------------------------------------------
367375

368376
#' @importFrom purrr map

R/tunable.R

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,9 +355,18 @@ tunable.mlp <- function(x, ...) {
355355
list(list(pkg = "dials", fun = "learn_rate", range = c(-3, -1/2)))
356356
res$call_info[res$name == "epochs"] <-
357357
list(list(pkg = "dials", fun = "epochs", range = c(5L, 500L)))
358+
activation_values <- rlang::eval_tidy(
359+
rlang::call2("brulee_activations", .ns = "brulee")
360+
)
361+
res$call_info[res$name == "activation"] <-
362+
list(list(pkg = "dials", fun = "activation", values = activation_values))
363+
} else if (x$engine == "keras") {
364+
activation_values <- parsnip::keras_activations()
365+
res$call_info[res$name == "activation"] <-
366+
list(list(pkg = "dials", fun = "activation", values = activation_values))
358367
}
359368
res
360-
}
369+
}
361370

362371
#' @export
363372
tunable.survival_reg <- function(x, ...) {

tests/testthat/_snaps/mlp_keras.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,13 @@
66
Error:
77
! object 'novar' not found
88

9+
# all keras activation functions
10+
11+
Code
12+
mlp(mode = "classification", hidden_units = 2, penalty = 0.01, epochs = 2,
13+
activation = "invalid") %>% set_engine("keras", verbose = 0) %>% parsnip::fit(
14+
Class ~ A + B, data = modeldata::two_class_dat)
15+
Condition
16+
Error in `parsnip::keras_mlp()`:
17+
! `activation` should be one of: elu, exponential, gelu, hardsigmoid, linear, relu, selu, sigmoid, softmax, softplus, softsign, swish, and tanh, not "invalid".
18+

tests/testthat/test-tunable.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
test_that('brulee has mixture object', {
2+
skip_if_not_installed("brulee")
23
# for issue 1236
34
mlp_spec <-
45
mlp(

0 commit comments

Comments
 (0)