3333# ' @param hidden_units An integer for the number of hidden units, or a vector
3434# ' of integers. If a vector of integers, the model will have `length(hidden_units)`
3535# ' layers each with `hidden_units[i]` hidden units.
36- # ' @param activation A string for the activation function. Possible values are
37- # ' "relu", "elu", "tanh", and "linear". If `hidden_units` is a vector, `activation`
38- # ' can be a character vector with length equals to `length(hidden_units)` specifying
39- # ' the activation for each hidden layer.
36+ # ' @param activation A character vector for the activation function )such as
37+ # ' "relu", "tanh", "sigmoid", and so on). See [brulee_activations()] for
38+ # ' a list of possible values. If `hidden_units` is a vector, `activation`
39+ # ' can be a character vector with length equals to `length(hidden_units)`
40+ # ' specifying the activation for each hidden layer.
4041# ' @param optimizer The method used in the optimization procedure. Possible choices
4142# ' are 'LBFGS' and 'SGD'. Default is 'LBFGS'.
4243# ' @param learn_rate A positive number that controls the initial rapidity that
@@ -435,18 +436,26 @@ brulee_mlp_bridge <- function(processed, epochs, hidden_units, activation,
435436 if (length(hidden_units ) != length(activation )) {
436437 rlang :: abort(" 'activation' must be a single value or a vector with the same length as 'hidden_units'" )
437438 }
439+
440+ allowed_activation <- brulee_activations()
441+ good_activation <- activation %in% allowed_activation
442+ if (! all(good_activation )) {
443+ rlang :: abort(paste(" 'activation' should be one of: " , paste0(allowed_activation , collapse = " , " )))
444+ }
445+
438446 if (optimizer == " LBFGS" & ! is.null(batch_size )) {
439447 rlang :: warn(" 'batch_size' is only used for the SGD optimizer." )
440448 batch_size <- NULL
441449 }
442450
443- check_integer(epochs , single = TRUE , 1 , fn = f_nm )
444- if (! is.null(batch_size )) {
451+ if (! is.null(batch_size ) & optimizer == " SGD" ) {
445452 if (is.numeric(batch_size ) & ! is.integer(batch_size )) {
446453 batch_size <- as.integer(batch_size )
447454 }
448455 check_integer(batch_size , single = TRUE , 1 , fn = f_nm )
449456 }
457+
458+ check_integer(epochs , single = TRUE , 1 , fn = f_nm )
450459 check_integer(hidden_units , single = FALSE , 1 , fn = f_nm )
451460 check_double(penalty , single = TRUE , 0 , incl = c(TRUE , TRUE ), fn = f_nm )
452461 check_double(mixture , single = TRUE , 0 , 1 , incl = c(TRUE , TRUE ), fn = f_nm )
@@ -457,8 +466,6 @@ brulee_mlp_bridge <- function(processed, epochs, hidden_units, activation,
457466 check_logical(verbose , single = TRUE , fn = f_nm )
458467 check_character(activation , single = FALSE , fn = f_nm )
459468
460-
461-
462469 # # -----------------------------------------------------------------------------
463470
464471 predictors <- processed $ predictors
@@ -635,7 +642,7 @@ mlp_fit_imp <-
635642 loss_label <- " \t Loss:"
636643 }
637644
638- if (is.null(batch_size )) {
645+ if (is.null(batch_size ) & optimizer == " SGD " ) {
639646 batch_size <- nrow(x )
640647 } else {
641648 batch_size <- min(batch_size , nrow(x ))
@@ -854,19 +861,6 @@ print.brulee_mlp <- function(x, ...) {
854861
855862# # -----------------------------------------------------------------------------
856863
857- get_activation_fn <- function (arg , ... ) {
858- if (arg == " relu" ) {
859- res <- torch :: nn_relu(... )
860- } else if (arg == " elu" ) {
861- res <- torch :: nn_elu(... )
862- } else if (arg == " tanh" ) {
863- res <- torch :: nn_tanh(... )
864- } else {
865- res <- identity
866- }
867- res
868- }
869-
870864set_optimizer <- function (optimizer , model , learn_rate , momentum ) {
871865 if (optimizer == " LBFGS" ) {
872866 res <- torch :: optim_lbfgs(model $ parameters , lr = learn_rate , history_size = 5 )
0 commit comments