Skip to content

Commit 88d6002

Browse files
authored
More activation functions (#74)
* add more activation functions for #69 * udpate error checking and tests * fix unit test * unit tests * make a function to get possible values * small updates * update snapshot * redoc with function link * add skips; will re-write tests in next PR
1 parent ec4756c commit 88d6002

File tree

11 files changed

+424
-89
lines changed

11 files changed

+424
-89
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ S3method(tunable,brulee_mlp)
4242
S3method(tunable,brulee_multinomial_reg)
4343
export("%>%")
4444
export(autoplot)
45+
export(brulee_activations)
4546
export(brulee_linear_reg)
4647
export(brulee_logistic_reg)
4748
export(brulee_mlp)

R/activation.R

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
allowed_activation <-
2+
c("celu", "elu", "gelu", "hardshrink", "hardsigmoid",
3+
"hardtanh", "leaky_relu", "linear", "log_sigmoid", "relu", "relu6",
4+
"rrelu", "selu", "sigmoid", "silu", "softplus", "softshrink",
5+
"softsign", "tanh", "tanhshrink")
6+
7+
#' Activation functions for neural networks in brulee
8+
#'
9+
#' @return A character vector of values.
10+
#' @export
11+
brulee_activations <- function() {
12+
allowed_activation
13+
}
14+
15+
get_activation_fn <- function(arg, ...) {
16+
17+
if (arg == "linear") {
18+
res <- identity
19+
} else {
20+
cl <- rlang::call2(paste0("nn_", arg), .ns = "torch")
21+
res <- rlang::eval_bare(cl)
22+
}
23+
24+
res
25+
}

R/mlp-fit.R

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@
3333
#' @param hidden_units An integer for the number of hidden units, or a vector
3434
#' of integers. If a vector of integers, the model will have `length(hidden_units)`
3535
#' layers each with `hidden_units[i]` hidden units.
36-
#' @param activation A string for the activation function. Possible values are
37-
#' "relu", "elu", "tanh", and "linear". If `hidden_units` is a vector, `activation`
38-
#' can be a character vector with length equals to `length(hidden_units)` specifying
39-
#' the activation for each hidden layer.
36+
#' @param activation A character vector for the activation function )such as
37+
#' "relu", "tanh", "sigmoid", and so on). See [brulee_activations()] for
38+
#' a list of possible values. If `hidden_units` is a vector, `activation`
39+
#' can be a character vector with length equals to `length(hidden_units)`
40+
#' specifying the activation for each hidden layer.
4041
#' @param optimizer The method used in the optimization procedure. Possible choices
4142
#' are 'LBFGS' and 'SGD'. Default is 'LBFGS'.
4243
#' @param learn_rate A positive number that controls the initial rapidity that
@@ -435,18 +436,26 @@ brulee_mlp_bridge <- function(processed, epochs, hidden_units, activation,
435436
if (length(hidden_units) != length(activation)) {
436437
rlang::abort("'activation' must be a single value or a vector with the same length as 'hidden_units'")
437438
}
439+
440+
allowed_activation <- brulee_activations()
441+
good_activation <- activation %in% allowed_activation
442+
if (!all(good_activation)) {
443+
rlang::abort(paste("'activation' should be one of: ", paste0(allowed_activation, collapse = ", ")))
444+
}
445+
438446
if (optimizer == "LBFGS" & !is.null(batch_size)) {
439447
rlang::warn("'batch_size' is only used for the SGD optimizer.")
440448
batch_size <- NULL
441449
}
442450

443-
check_integer(epochs, single = TRUE, 1, fn = f_nm)
444-
if (!is.null(batch_size)) {
451+
if (!is.null(batch_size) & optimizer == "SGD") {
445452
if (is.numeric(batch_size) & !is.integer(batch_size)) {
446453
batch_size <- as.integer(batch_size)
447454
}
448455
check_integer(batch_size, single = TRUE, 1, fn = f_nm)
449456
}
457+
458+
check_integer(epochs, single = TRUE, 1, fn = f_nm)
450459
check_integer(hidden_units, single = FALSE, 1, fn = f_nm)
451460
check_double(penalty, single = TRUE, 0, incl = c(TRUE, TRUE), fn = f_nm)
452461
check_double(mixture, single = TRUE, 0, 1, incl = c(TRUE, TRUE), fn = f_nm)
@@ -457,8 +466,6 @@ brulee_mlp_bridge <- function(processed, epochs, hidden_units, activation,
457466
check_logical(verbose, single = TRUE, fn = f_nm)
458467
check_character(activation, single = FALSE, fn = f_nm)
459468

460-
461-
462469
## -----------------------------------------------------------------------------
463470

464471
predictors <- processed$predictors
@@ -635,7 +642,7 @@ mlp_fit_imp <-
635642
loss_label <- "\tLoss:"
636643
}
637644

638-
if (is.null(batch_size)) {
645+
if (is.null(batch_size) & optimizer == "SGD") {
639646
batch_size <- nrow(x)
640647
} else {
641648
batch_size <- min(batch_size, nrow(x))
@@ -854,19 +861,6 @@ print.brulee_mlp <- function(x, ...) {
854861

855862
## -----------------------------------------------------------------------------
856863

857-
get_activation_fn <- function(arg, ...) {
858-
if (arg == "relu") {
859-
res <- torch::nn_relu(...)
860-
} else if (arg == "elu") {
861-
res <- torch::nn_elu(...)
862-
} else if (arg == "tanh") {
863-
res <- torch::nn_tanh(...)
864-
} else {
865-
res <- identity
866-
}
867-
res
868-
}
869-
870864
set_optimizer <- function(optimizer, model, learn_rate, momentum) {
871865
if (optimizer == "LBFGS") {
872866
res <- torch::optim_lbfgs(model$parameters, lr = learn_rate, history_size = 5)

inst/WORDLIST

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@ mlp
1212
multilayer
1313
perceptrons
1414
relu
15+
sigmoid
1516
tanh
1617
tibble

man/brulee_activations.Rd

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/brulee_mlp.Rd

Lines changed: 5 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)