|
201 | 201 | #' set.seed(2) |
202 | 202 | #' cls_fit <- brulee_mlp(class ~ ., data = parabolic_tr, hidden_units = 2, |
203 | 203 | #' epochs = 200L, learn_rate = 0.1, activation = "elu", |
204 | | -#' penalty = 0.1, batch_size = 2^8) |
| 204 | +#' penalty = 0.1, batch_size = 2^8, optimizer = "SGD") |
205 | 205 | #' autoplot(cls_fit) |
206 | 206 | #' |
207 | 207 | #' grid_points <- seq(-4, 4, length.out = 100) |
@@ -435,9 +435,9 @@ brulee_mlp_bridge <- function(processed, epochs, hidden_units, activation, |
435 | 435 | if (length(hidden_units) != length(activation)) { |
436 | 436 | rlang::abort("'activation' must be a single value or a vector with the same length as 'hidden_units'") |
437 | 437 | } |
438 | | - |
439 | 438 | if (optimizer == "LBFGS" & !is.null(batch_size)) { |
440 | | - rlang::warn("'batch_size' is only use for the SGD optimizer.") |
| 439 | + rlang::warn("'batch_size' is only used for the SGD optimizer.") |
| 440 | + batch_size <- NULL |
441 | 441 | } |
442 | 442 |
|
443 | 443 | check_integer(epochs, single = TRUE, 1, fn = f_nm) |
@@ -656,16 +656,8 @@ mlp_fit_imp <- |
656 | 656 | model <- mlp_module(ncol(x), hidden_units, activation, dropout, y_dim) |
657 | 657 | loss_fn <- make_penalized_loss(loss_fn, model, penalty, mixture) |
658 | 658 |
|
659 | | - # Set the optimizer |
660 | | - if (optimizer == "LBFGS") { |
661 | | - optimizer <- torch::optim_lbfgs(model$parameters, lr = learn_rate, |
662 | | - history_size = 5) |
663 | | - } else if (optimizer == "SGD") { |
664 | | - optimizer <- |
665 | | - torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum) |
666 | | - } else { |
667 | | - rlang::abort(paste0("Unknown optimizer '", optimizer, "'")) |
668 | | - } |
| 659 | + # Set the optimizer (will be set again below) |
| 660 | + optimizer_obj <- set_optimizer(optimizer, model, learn_rate, momentum) |
669 | 661 |
|
670 | 662 | ## --------------------------------------------------------------------------- |
671 | 663 |
|
@@ -694,19 +686,19 @@ mlp_fit_imp <- |
694 | 686 | # resetting them can interfere in training." |
695 | 687 |
|
696 | 688 | learn_rate <- set_learn_rate(epoch - 1, learn_rate, type = rate_schedule, ...) |
697 | | - optimizer <- torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum) |
| 689 | + optimizer_obj <- set_optimizer(optimizer, model, learn_rate, momentum) |
698 | 690 |
|
699 | 691 | # training loop |
700 | 692 | coro::loop( |
701 | 693 | for (batch in dl) { |
702 | 694 | cl <- function() { |
703 | | - optimizer$zero_grad() |
| 695 | + optimizer_obj$zero_grad() |
704 | 696 | pred <- model(batch$x) |
705 | 697 | loss <- loss_fn(pred, batch$y, class_weights) |
706 | 698 | loss$backward() |
707 | 699 | loss |
708 | 700 | } |
709 | | - optimizer$step(cl) |
| 701 | + optimizer_obj$step(cl) |
710 | 702 | } |
711 | 703 | ) |
712 | 704 |
|
@@ -874,3 +866,14 @@ get_activation_fn <- function(arg, ...) { |
874 | 866 | } |
875 | 867 | res |
876 | 868 | } |
| 869 | + |
| 870 | +set_optimizer <- function(optimizer, model, learn_rate, momentum) { |
| 871 | + if (optimizer == "LBFGS") { |
| 872 | + res <- torch::optim_lbfgs(model$parameters, lr = learn_rate, history_size = 5) |
| 873 | + } else if (optimizer == "SGD") { |
| 874 | + res <- torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum) |
| 875 | + } else { |
| 876 | + rlang::abort(paste0("Unknown optimizer '", optimizer, "'")) |
| 877 | + } |
| 878 | + res |
| 879 | +} |
0 commit comments