3737# ' "relu", "elu", "tanh", and "linear". If `hidden_units` is a vector, `activation`
3838# ' can be a character vector with length equals to `length(hidden_units)` specifying
3939# ' the activation for each hidden layer.
40+ # ' @param optimizer The method used in the optimization procedure. Possible choices
41+ # ' are 'LBFGS' and 'SGD'. Default is 'LBFGS'.
4042# ' @param learn_rate A positive number that controls the initial rapidity that
4143# ' the model moves along the descent path. Values around 0.1 or less are
4244# ' typical.
4547# ' `"none"` (the default), `"decay_time"`, `"decay_expo"`, `"cyclic"` and
4648# ' `"step"`. See [schedule_decay_time()] for more details.
4749# ' @param momentum A positive number usually on `[0.50, 0.99]` for the momentum
48- # ' parameter in gradient descent.
50+ # ' parameter in gradient descent. (`optimizer = "SGD"` only)
4951# ' @param dropout The proportion of parameters set to zero.
5052# ' @param class_weights Numeric class weights (classification only). The value
5153# ' can be:
5961# ' @param validation The proportion of the data randomly assigned to a
6062# ' validation set.
6163# ' @param batch_size An integer for the number of training set points in each
62- # ' batch.
64+ # ' batch. (`optimizer = "SGD"` only)
6365# ' @param stop_iter A non-negative integer for how many iterations with no
6466# ' improvement before stopping.
6567# ' @param verbose A logical that prints out the iteration history.
@@ -239,6 +241,7 @@ brulee_mlp.data.frame <-
239241 mixture = 0 ,
240242 dropout = 0 ,
241243 validation = 0.1 ,
244+ optimizer = " LBFGS" ,
242245 learn_rate = 0.01 ,
243246 rate_schedule = " none" ,
244247 momentum = 0.0 ,
@@ -260,6 +263,7 @@ brulee_mlp.data.frame <-
260263 mixture = mixture ,
261264 dropout = dropout ,
262265 validation = validation ,
266+ optimizer = optimizer ,
263267 momentum = momentum ,
264268 batch_size = batch_size ,
265269 class_weights = class_weights ,
@@ -282,6 +286,7 @@ brulee_mlp.matrix <- function(x,
282286 mixture = 0 ,
283287 dropout = 0 ,
284288 validation = 0.1 ,
289+ optimizer = " LBFGS" ,
285290 learn_rate = 0.01 ,
286291 rate_schedule = " none" ,
287292 momentum = 0.0 ,
@@ -304,6 +309,7 @@ brulee_mlp.matrix <- function(x,
304309 mixture = mixture ,
305310 dropout = dropout ,
306311 validation = validation ,
312+ optimizer = optimizer ,
307313 batch_size = batch_size ,
308314 class_weights = class_weights ,
309315 stop_iter = stop_iter ,
@@ -326,6 +332,7 @@ brulee_mlp.formula <-
326332 mixture = 0 ,
327333 dropout = 0 ,
328334 validation = 0.1 ,
335+ optimizer = " LBFGS" ,
329336 learn_rate = 0.01 ,
330337 rate_schedule = " none" ,
331338 momentum = 0.0 ,
@@ -348,6 +355,7 @@ brulee_mlp.formula <-
348355 mixture = mixture ,
349356 dropout = dropout ,
350357 validation = validation ,
358+ optimizer = optimizer ,
351359 batch_size = batch_size ,
352360 class_weights = class_weights ,
353361 stop_iter = stop_iter ,
@@ -370,6 +378,7 @@ brulee_mlp.recipe <-
370378 mixture = 0 ,
371379 dropout = 0 ,
372380 validation = 0.1 ,
381+ optimizer = " LBFGS" ,
373382 learn_rate = 0.01 ,
374383 rate_schedule = " none" ,
375384 momentum = 0.0 ,
@@ -392,6 +401,7 @@ brulee_mlp.recipe <-
392401 mixture = mixture ,
393402 dropout = dropout ,
394403 validation = validation ,
404+ optimizer = optimizer ,
395405 batch_size = batch_size ,
396406 class_weights = class_weights ,
397407 stop_iter = stop_iter ,
@@ -405,7 +415,7 @@ brulee_mlp.recipe <-
405415
406416brulee_mlp_bridge <- function (processed , epochs , hidden_units , activation ,
407417 learn_rate , rate_schedule , momentum , penalty ,
408- mixture , dropout , class_weights , validation ,
418+ mixture , dropout , class_weights , validation , optimizer ,
409419 batch_size , stop_iter , verbose , ... ) {
410420 if (! torch :: torch_is_installed()) {
411421 rlang :: abort(" The torch backend has not been installed; use `torch::install_torch()`." )
@@ -426,6 +436,10 @@ brulee_mlp_bridge <- function(processed, epochs, hidden_units, activation,
426436 rlang :: abort(" 'activation' must be a single value or a vector with the same length as 'hidden_units'" )
427437 }
428438
439+ if (optimizer == " LBFGS" & ! is.null(batch_size )) {
440+ rlang :: warn(" 'batch_size' is only use for the SGD optimizer." )
441+ }
442+
429443 check_integer(epochs , single = TRUE , 1 , fn = f_nm )
430444 if (! is.null(batch_size )) {
431445 if (is.numeric(batch_size ) & ! is.integer(batch_size )) {
@@ -487,6 +501,7 @@ brulee_mlp_bridge <- function(processed, epochs, hidden_units, activation,
487501 mixture = mixture ,
488502 dropout = dropout ,
489503 validation = validation ,
504+ optimizer = optimizer ,
490505 batch_size = batch_size ,
491506 class_weights = class_weights ,
492507 stop_iter = stop_iter ,
@@ -555,6 +570,7 @@ mlp_fit_imp <-
555570 mixture = 0 ,
556571 dropout = 0 ,
557572 validation = 0.1 ,
573+ optimizer = " LBFGS" ,
558574 learn_rate = 0.01 ,
559575 rate_schedule = " none" ,
560576 momentum = 0.0 ,
@@ -640,6 +656,17 @@ mlp_fit_imp <-
640656 model <- mlp_module(ncol(x ), hidden_units , activation , dropout , y_dim )
641657 loss_fn <- make_penalized_loss(loss_fn , model , penalty , mixture )
642658
659+ # Set the optimizer
660+ if (optimizer == " LBFGS" ) {
661+ optimizer <- torch :: optim_lbfgs(model $ parameters , lr = learn_rate ,
662+ history_size = 5 )
663+ } else if (optimizer == " SGD" ) {
664+ optimizer <-
665+ torch :: optim_sgd(model $ parameters , lr = learn_rate , momentum = momentum )
666+ } else {
667+ rlang :: abort(paste0(" Unknown optimizer '" , optimizer , " '" ))
668+ }
669+
643670 # # ---------------------------------------------------------------------------
644671
645672 loss_prev <- 10 ^ 38
@@ -671,14 +698,16 @@ mlp_fit_imp <-
671698
672699 # training loop
673700 coro :: loop(
674- for (batch in dl ) {
675- pred <- model( batch $ x )
676- loss <- loss_fn( pred , batch $ y , class_weights )
677-
678- optimizer $ zero_grad( )
679- loss $ backward()
680- optimizer $ step()
701+ for (batch in dl ) {
702+ cl <- function () {
703+ optimizer $ zero_grad( )
704+ pred <- model( batch $ x )
705+ loss <- loss_fn( pred , batch $ y , class_weights )
706+ loss $ backward()
707+ loss
681708 }
709+ optimizer $ step(cl )
710+ }
682711 )
683712
684713 # calculate loss on the full datasets
@@ -750,6 +779,7 @@ mlp_fit_imp <-
750779 mixture = mixture ,
751780 dropout = dropout ,
752781 validation = validation ,
782+ optimizer = optimizer ,
753783 batch_size = batch_size ,
754784 momentum = momentum ,
755785 sched = rate_schedule ,
0 commit comments