Skip to content

Commit 057b541

Browse files
authored
add LBFGS optimizer with mlp fitting (#57)
* add LBFGS optimizer with mlp fitting * update readme * update snapshots and skips
1 parent fe9ebff commit 057b541

17 files changed

+89
-51
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
* Several learning rate schedulers were added to the modeling functions (#12).
44

5+
* An `optimizer` was added to [brulee_mlp()], with a new default being LBFGS instead of stochastic gradient descent.
6+
57
# brulee 0.1.0
68

79
* Modeling functions gained a `mixture` argument for the proportion of L1 penalty that is used. (#50)

R/linear_reg-fit.R

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,6 @@
2424
#'
2525
#' @inheritParams brulee_mlp
2626
#'
27-
#' @param optimizer The method used in the optimization procedure. Possible choices
28-
#' are 'LBFGS' and 'SGD'. Default is 'LBFGS'.
29-
#' @param learn_rate A positive number that controls the rapidity that the model
30-
#' moves along the descent path. Values less that 0.1 are typical.
31-
#' (`optimizer = "SGD"` only)
32-
#' @param momentum A positive number usually on `[0.50, 0.99]` for the momentum
33-
#' parameter in gradient descent. (`optimizer = "SGD"` only)
3427
#' @details
3528
#'
3629
#' This function fits a linear combination of coefficients and predictors to

R/mlp-fit.R

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
#' "relu", "elu", "tanh", and "linear". If `hidden_units` is a vector, `activation`
3838
#' can be a character vector with length equals to `length(hidden_units)` specifying
3939
#' the activation for each hidden layer.
40+
#' @param optimizer The method used in the optimization procedure. Possible choices
41+
#' are 'LBFGS' and 'SGD'. Default is 'LBFGS'.
4042
#' @param learn_rate A positive number that controls the initial rapidity that
4143
#' the model moves along the descent path. Values around 0.1 or less are
4244
#' typical.
@@ -45,7 +47,7 @@
4547
#' `"none"` (the default), `"decay_time"`, `"decay_expo"`, `"cyclic"` and
4648
#' `"step"`. See [schedule_decay_time()] for more details.
4749
#' @param momentum A positive number usually on `[0.50, 0.99]` for the momentum
48-
#' parameter in gradient descent.
50+
#' parameter in gradient descent. (`optimizer = "SGD"` only)
4951
#' @param dropout The proportion of parameters set to zero.
5052
#' @param class_weights Numeric class weights (classification only). The value
5153
#' can be:
@@ -59,7 +61,7 @@
5961
#' @param validation The proportion of the data randomly assigned to a
6062
#' validation set.
6163
#' @param batch_size An integer for the number of training set points in each
62-
#' batch.
64+
#' batch. (`optimizer = "SGD"` only)
6365
#' @param stop_iter A non-negative integer for how many iterations with no
6466
#' improvement before stopping.
6567
#' @param verbose A logical that prints out the iteration history.
@@ -239,6 +241,7 @@ brulee_mlp.data.frame <-
239241
mixture = 0,
240242
dropout = 0,
241243
validation = 0.1,
244+
optimizer = "LBFGS",
242245
learn_rate = 0.01,
243246
rate_schedule = "none",
244247
momentum = 0.0,
@@ -260,6 +263,7 @@ brulee_mlp.data.frame <-
260263
mixture = mixture,
261264
dropout = dropout,
262265
validation = validation,
266+
optimizer = optimizer,
263267
momentum = momentum,
264268
batch_size = batch_size,
265269
class_weights = class_weights,
@@ -282,6 +286,7 @@ brulee_mlp.matrix <- function(x,
282286
mixture = 0,
283287
dropout = 0,
284288
validation = 0.1,
289+
optimizer = "LBFGS",
285290
learn_rate = 0.01,
286291
rate_schedule = "none",
287292
momentum = 0.0,
@@ -304,6 +309,7 @@ brulee_mlp.matrix <- function(x,
304309
mixture = mixture,
305310
dropout = dropout,
306311
validation = validation,
312+
optimizer = optimizer,
307313
batch_size = batch_size,
308314
class_weights = class_weights,
309315
stop_iter = stop_iter,
@@ -326,6 +332,7 @@ brulee_mlp.formula <-
326332
mixture = 0,
327333
dropout = 0,
328334
validation = 0.1,
335+
optimizer = "LBFGS",
329336
learn_rate = 0.01,
330337
rate_schedule = "none",
331338
momentum = 0.0,
@@ -348,6 +355,7 @@ brulee_mlp.formula <-
348355
mixture = mixture,
349356
dropout = dropout,
350357
validation = validation,
358+
optimizer = optimizer,
351359
batch_size = batch_size,
352360
class_weights = class_weights,
353361
stop_iter = stop_iter,
@@ -370,6 +378,7 @@ brulee_mlp.recipe <-
370378
mixture = 0,
371379
dropout = 0,
372380
validation = 0.1,
381+
optimizer = "LBFGS",
373382
learn_rate = 0.01,
374383
rate_schedule = "none",
375384
momentum = 0.0,
@@ -392,6 +401,7 @@ brulee_mlp.recipe <-
392401
mixture = mixture,
393402
dropout = dropout,
394403
validation = validation,
404+
optimizer = optimizer,
395405
batch_size = batch_size,
396406
class_weights = class_weights,
397407
stop_iter = stop_iter,
@@ -405,7 +415,7 @@ brulee_mlp.recipe <-
405415

406416
brulee_mlp_bridge <- function(processed, epochs, hidden_units, activation,
407417
learn_rate, rate_schedule, momentum, penalty,
408-
mixture, dropout, class_weights, validation,
418+
mixture, dropout, class_weights, validation, optimizer,
409419
batch_size, stop_iter, verbose, ...) {
410420
if(!torch::torch_is_installed()) {
411421
rlang::abort("The torch backend has not been installed; use `torch::install_torch()`.")
@@ -426,6 +436,10 @@ brulee_mlp_bridge <- function(processed, epochs, hidden_units, activation,
426436
rlang::abort("'activation' must be a single value or a vector with the same length as 'hidden_units'")
427437
}
428438

439+
if (optimizer == "LBFGS" & !is.null(batch_size)) {
440+
rlang::warn("'batch_size' is only use for the SGD optimizer.")
441+
}
442+
429443
check_integer(epochs, single = TRUE, 1, fn = f_nm)
430444
if (!is.null(batch_size)) {
431445
if (is.numeric(batch_size) & !is.integer(batch_size)) {
@@ -487,6 +501,7 @@ brulee_mlp_bridge <- function(processed, epochs, hidden_units, activation,
487501
mixture = mixture,
488502
dropout = dropout,
489503
validation = validation,
504+
optimizer = optimizer,
490505
batch_size = batch_size,
491506
class_weights = class_weights,
492507
stop_iter = stop_iter,
@@ -555,6 +570,7 @@ mlp_fit_imp <-
555570
mixture = 0,
556571
dropout = 0,
557572
validation = 0.1,
573+
optimizer = "LBFGS",
558574
learn_rate = 0.01,
559575
rate_schedule = "none",
560576
momentum = 0.0,
@@ -640,6 +656,17 @@ mlp_fit_imp <-
640656
model <- mlp_module(ncol(x), hidden_units, activation, dropout, y_dim)
641657
loss_fn <- make_penalized_loss(loss_fn, model, penalty, mixture)
642658

659+
# Set the optimizer
660+
if (optimizer == "LBFGS") {
661+
optimizer <- torch::optim_lbfgs(model$parameters, lr = learn_rate,
662+
history_size = 5)
663+
} else if (optimizer == "SGD") {
664+
optimizer <-
665+
torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum)
666+
} else {
667+
rlang::abort(paste0("Unknown optimizer '", optimizer, "'"))
668+
}
669+
643670
## ---------------------------------------------------------------------------
644671

645672
loss_prev <- 10^38
@@ -671,14 +698,16 @@ mlp_fit_imp <-
671698

672699
# training loop
673700
coro::loop(
674-
for (batch in dl) {
675-
pred <- model(batch$x)
676-
loss <- loss_fn(pred, batch$y, class_weights)
677-
678-
optimizer$zero_grad()
679-
loss$backward()
680-
optimizer$step()
701+
for (batch in dl) {
702+
cl <- function() {
703+
optimizer$zero_grad()
704+
pred <- model(batch$x)
705+
loss <- loss_fn(pred, batch$y, class_weights)
706+
loss$backward()
707+
loss
681708
}
709+
optimizer$step(cl)
710+
}
682711
)
683712

684713
# calculate loss on the full datasets
@@ -750,6 +779,7 @@ mlp_fit_imp <-
750779
mixture = mixture,
751780
dropout = dropout,
752781
validation = validation,
782+
optimizer = optimizer,
753783
batch_size = batch_size,
754784
momentum = momentum,
755785
sched = rate_schedule,

R/schedulers.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
#' The details for how the schedulers change the rates:
2525
#'
2626
#' * `schedule_decay_time()`: \eqn{rate(epoch) = initial/(1 + decay \times epoch)}
27-
#' * `schedule_decay_expo()`: \eqn{initial\exp(-decay \times epoch)}
28-
#' * `schedule_step()`: \eqn{initial \times reduction^{floor(epoch / steps)}}
27+
#' * `schedule_decay_expo()`: \eqn{rate(epoch) = initial\exp(-decay \times epoch)}
28+
#' * `schedule_step()`: \eqn{rate(epoch) = initial \times reduction^{floor(epoch / steps)}}
2929
#' * `schedule_cyclic()`: \eqn{cycle = floor( 1 + (epoch / 2 / step size) )},
3030
#' \eqn{x = abs( ( epoch / step size ) - ( 2 * cycle) + 1 )}, and
31-
#' \eqn{rate = initial + ( largest - initial ) * \max( 0, 1 - x)}
31+
#' \eqn{rate(epoch) = initial + ( largest - initial ) * \max( 0, 1 - x)}
3232
#'
3333
#'
3434
#' @seealso [brulee_mlp()]

README.Rmd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ library(yardstick)
6060
data(bivariate, package = "modeldata")
6161
set.seed(20)
6262
nn_log_biv <- brulee_mlp(Class ~ log(A) + log(B), data = bivariate_train,
63-
epochs = 150, hidden_units = 3, batch_size = 64)
63+
epochs = 150, hidden_units = 3)
6464
6565
# We use the tidymodels semantics to always return a tibble when predicting
6666
predict(nn_log_biv, bivariate_test, type = "prob") %>%
@@ -80,7 +80,7 @@ rec <-
8080
8181
set.seed(20)
8282
nn_rec_biv <- brulee_mlp(rec, data = bivariate_train,
83-
epochs = 150, hidden_units = 3, batch_size = 64)
83+
epochs = 150, hidden_units = 3)
8484
8585
# A little better
8686
predict(nn_rec_biv, bivariate_test, type = "prob") %>%

README.md

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ experimental](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](h
1515
The R `brulee` package contains several basic modeling functions that
1616
use the `torch` package infrastructure, such as:
1717

18-
- [neural
19-
networks](https://tidymodels.github.io/brulee/reference/brulee_mlp.html)
20-
- [linear
21-
regression](https://tidymodels.github.io/brulee/reference/brulee_linear_reg.html)
22-
- [logistic
23-
regression](https://tidymodels.github.io/brulee/reference/brulee_logistic_reg.html)
24-
- [multinomial
25-
regression](https://tidymodels.github.io/brulee/reference/brulee_multinomial_reg.html)
18+
- [neural
19+
networks](https://tidymodels.github.io/brulee/reference/brulee_mlp.html)
20+
- [linear
21+
regression](https://tidymodels.github.io/brulee/reference/brulee_linear_reg.html)
22+
- [logistic
23+
regression](https://tidymodels.github.io/brulee/reference/brulee_logistic_reg.html)
24+
- [multinomial
25+
regression](https://tidymodels.github.io/brulee/reference/brulee_multinomial_reg.html)
2626

2727
## Installation
2828

@@ -54,7 +54,7 @@ library(yardstick)
5454
data(bivariate, package = "modeldata")
5555
set.seed(20)
5656
nn_log_biv <- brulee_mlp(Class ~ log(A) + log(B), data = bivariate_train,
57-
epochs = 150, hidden_units = 3, batch_size = 64)
57+
epochs = 150, hidden_units = 3)
5858

5959
# We use the tidymodels semantics to always return a tibble when predicting
6060
predict(nn_log_biv, bivariate_test, type = "prob") %>%
@@ -63,7 +63,7 @@ predict(nn_log_biv, bivariate_test, type = "prob") %>%
6363
#> # A tibble: 1 × 3
6464
#> .metric .estimator .estimate
6565
#> <chr> <chr> <dbl>
66-
#> 1 roc_auc binary 0.608
66+
#> 1 roc_auc binary 0.839
6767
```
6868

6969
A recipe can also be used if the data require some sort of preprocessing
@@ -79,7 +79,7 @@ rec <-
7979

8080
set.seed(20)
8181
nn_rec_biv <- brulee_mlp(rec, data = bivariate_train,
82-
epochs = 150, hidden_units = 3, batch_size = 64)
82+
epochs = 150, hidden_units = 3)
8383

8484
# A little better
8585
predict(nn_rec_biv, bivariate_test, type = "prob") %>%
@@ -88,7 +88,7 @@ predict(nn_rec_biv, bivariate_test, type = "prob") %>%
8888
#> # A tibble: 1 × 3
8989
#> .metric .estimator .estimate
9090
#> <chr> <chr> <dbl>
91-
#> 1 roc_auc binary 0.865
91+
#> 1 roc_auc binary 0.866
9292
```
9393

9494
## Code of Conduct

man/brulee-package.Rd

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/brulee_linear_reg.Rd

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/brulee_logistic_reg.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/brulee_mlp.Rd

Lines changed: 9 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)