Skip to content

Commit 7f7a421

Browse files
authored
fix optimizer bug in #61 (#70)
* fix optimizer bug in #61 * regenerate snapshots on intel hardware * rephrase
1 parent 087129b commit 7f7a421

File tree

12 files changed

+153
-293
lines changed

12 files changed

+153
-293
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# brulee (development version)
22

3+
* Fixed a bug where SGD always being used as the optimizer (#61).
4+
35
# brulee 0.2.0
46

57
* Several learning rate schedulers were added to the modeling functions (#12).

R/linear_reg-fit.R

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -452,12 +452,17 @@ linear_reg_fit_imp <-
452452
}
453453
loss_label <- "\tLoss (scaled):"
454454

455+
if (optimizer == "LBFGS" & !is.null(batch_size)) {
456+
rlang::warn("'batch_size' is only used for the SGD optimizer.")
457+
batch_size <- NULL
458+
}
455459
if (is.null(batch_size)) {
456460
batch_size <- nrow(x)
457461
} else {
458462
batch_size <- min(batch_size, nrow(x))
459463
}
460464

465+
461466
## ---------------------------------------------------------------------------
462467
# Convert to index sampler and data loader
463468
ds <- brulee::matrix_to_dataset(x, y)
@@ -472,17 +477,7 @@ linear_reg_fit_imp <-
472477
# Initialize model and optimizer
473478
model <- linear_reg_module(ncol(x))
474479
loss_fn <- make_penalized_loss(loss_fn, model, penalty, mixture)
475-
476-
# Write a optim wrapper
477-
if (optimizer == "LBFGS") {
478-
optimizer <- torch::optim_lbfgs(model$parameters, lr = learn_rate,
479-
history_size = 5)
480-
} else if (optimizer == "SGD") {
481-
optimizer <-
482-
torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum)
483-
} else {
484-
rlang::abort(paste0("Unknown optimizer '", optimizer, "'"))
485-
}
480+
optimizer_obj <- set_optimizer(optimizer, model, learn_rate, momentum)
486481

487482
## ---------------------------------------------------------------------------
488483

@@ -505,13 +500,13 @@ linear_reg_fit_imp <-
505500
coro::loop(
506501
for (batch in dl) {
507502
cl <- function() {
508-
optimizer$zero_grad()
503+
optimizer_obj$zero_grad()
509504
pred <- model(batch$x)
510505
loss <- loss_fn(pred, batch$y)
511506
loss$backward()
512507
loss
513508
}
514-
optimizer$step(cl)
509+
optimizer_obj$step(cl)
515510
}
516511
)
517512

R/logistic_reg-fit.R

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,10 @@ logistic_reg_fit_imp <-
463463
y_stats <- list(mean = NA_real_, sd = NA_real_)
464464
loss_label <- "\tLoss:"
465465

466+
if (optimizer == "LBFGS" & !is.null(batch_size)) {
467+
rlang::warn("'batch_size' is only used for the SGD optimizer.")
468+
batch_size <- NULL
469+
}
466470
if (is.null(batch_size)) {
467471
batch_size <- nrow(x)
468472
} else {
@@ -483,17 +487,7 @@ logistic_reg_fit_imp <-
483487
# Initialize model and optimizer
484488
model <- logistic_module(ncol(x), y_dim)
485489
loss_fn <- make_penalized_loss(loss_fn, model, penalty, mixture)
486-
487-
# Write a optim wrapper
488-
if (optimizer == "LBFGS") {
489-
optimizer <- torch::optim_lbfgs(model$parameters, lr = learn_rate,
490-
history_size = 5)
491-
} else if (optimizer == "SGD") {
492-
optimizer <-
493-
torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum)
494-
} else {
495-
rlang::abort(paste0("Unknown optimizer '", optimizer, "'"))
496-
}
490+
optimizer_obj <- set_optimizer(optimizer, model, learn_rate, momentum)
497491

498492
## ---------------------------------------------------------------------------
499493

@@ -517,13 +511,13 @@ logistic_reg_fit_imp <-
517511
coro::loop(
518512
for (batch in dl) {
519513
cl <- function() {
520-
optimizer$zero_grad()
514+
optimizer_obj$zero_grad()
521515
pred <- model(batch$x)
522516
loss <- loss_fn(pred, batch$y, class_weights)
523517
loss$backward()
524518
loss
525519
}
526-
optimizer$step(cl)
520+
optimizer_obj$step(cl)
527521
}
528522
)
529523

R/mlp-fit.R

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@
201201
#' set.seed(2)
202202
#' cls_fit <- brulee_mlp(class ~ ., data = parabolic_tr, hidden_units = 2,
203203
#' epochs = 200L, learn_rate = 0.1, activation = "elu",
204-
#' penalty = 0.1, batch_size = 2^8)
204+
#' penalty = 0.1, batch_size = 2^8, optimizer = "SGD")
205205
#' autoplot(cls_fit)
206206
#'
207207
#' grid_points <- seq(-4, 4, length.out = 100)
@@ -435,9 +435,9 @@ brulee_mlp_bridge <- function(processed, epochs, hidden_units, activation,
435435
if (length(hidden_units) != length(activation)) {
436436
rlang::abort("'activation' must be a single value or a vector with the same length as 'hidden_units'")
437437
}
438-
439438
if (optimizer == "LBFGS" & !is.null(batch_size)) {
440-
rlang::warn("'batch_size' is only use for the SGD optimizer.")
439+
rlang::warn("'batch_size' is only used for the SGD optimizer.")
440+
batch_size <- NULL
441441
}
442442

443443
check_integer(epochs, single = TRUE, 1, fn = f_nm)
@@ -656,16 +656,8 @@ mlp_fit_imp <-
656656
model <- mlp_module(ncol(x), hidden_units, activation, dropout, y_dim)
657657
loss_fn <- make_penalized_loss(loss_fn, model, penalty, mixture)
658658

659-
# Set the optimizer
660-
if (optimizer == "LBFGS") {
661-
optimizer <- torch::optim_lbfgs(model$parameters, lr = learn_rate,
662-
history_size = 5)
663-
} else if (optimizer == "SGD") {
664-
optimizer <-
665-
torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum)
666-
} else {
667-
rlang::abort(paste0("Unknown optimizer '", optimizer, "'"))
668-
}
659+
# Set the optimizer (will be set again below)
660+
optimizer_obj <- set_optimizer(optimizer, model, learn_rate, momentum)
669661

670662
## ---------------------------------------------------------------------------
671663

@@ -694,19 +686,19 @@ mlp_fit_imp <-
694686
# resetting them can interfere in training."
695687

696688
learn_rate <- set_learn_rate(epoch - 1, learn_rate, type = rate_schedule, ...)
697-
optimizer <- torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum)
689+
optimizer_obj <- set_optimizer(optimizer, model, learn_rate, momentum)
698690

699691
# training loop
700692
coro::loop(
701693
for (batch in dl) {
702694
cl <- function() {
703-
optimizer$zero_grad()
695+
optimizer_obj$zero_grad()
704696
pred <- model(batch$x)
705697
loss <- loss_fn(pred, batch$y, class_weights)
706698
loss$backward()
707699
loss
708700
}
709-
optimizer$step(cl)
701+
optimizer_obj$step(cl)
710702
}
711703
)
712704

@@ -874,3 +866,14 @@ get_activation_fn <- function(arg, ...) {
874866
}
875867
res
876868
}
869+
870+
set_optimizer <- function(optimizer, model, learn_rate, momentum) {
871+
if (optimizer == "LBFGS") {
872+
res <- torch::optim_lbfgs(model$parameters, lr = learn_rate, history_size = 5)
873+
} else if (optimizer == "SGD") {
874+
res <- torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum)
875+
} else {
876+
rlang::abort(paste0("Unknown optimizer '", optimizer, "'"))
877+
}
878+
res
879+
}

R/multinomial_reg-fit.R

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,10 @@ multinomial_reg_fit_imp <-
444444
y_stats <- list(mean = NA_real_, sd = NA_real_)
445445
loss_label <- "\tLoss:"
446446

447+
if (optimizer == "LBFGS" & !is.null(batch_size)) {
448+
rlang::warn("'batch_size' is only used for the SGD optimizer.")
449+
batch_size <- NULL
450+
}
447451
if (is.null(batch_size)) {
448452
batch_size <- nrow(x)
449453
} else {
@@ -464,17 +468,7 @@ multinomial_reg_fit_imp <-
464468
# Initialize model and optimizer
465469
model <- multinomial_module(ncol(x), y_dim)
466470
loss_fn <- make_penalized_loss(loss_fn, model, penalty, mixture)
467-
468-
# Write a optim wrapper
469-
if (optimizer == "LBFGS") {
470-
optimizer <- torch::optim_lbfgs(model$parameters, lr = learn_rate,
471-
history_size = 5)
472-
} else if (optimizer == "SGD") {
473-
optimizer <-
474-
torch::optim_sgd(model$parameters, lr = learn_rate, momentum = momentum)
475-
} else {
476-
rlang::abort(paste0("Unknown optimizer '", optimizer, "'"))
477-
}
471+
optimizer_obj <- set_optimizer(optimizer, model, learn_rate, momentum)
478472

479473
## ---------------------------------------------------------------------------
480474

@@ -498,13 +492,13 @@ multinomial_reg_fit_imp <-
498492
coro::loop(
499493
for (batch in dl) {
500494
cl <- function() {
501-
optimizer$zero_grad()
495+
optimizer_obj$zero_grad()
502496
pred <- model(batch$x)
503497
loss <- loss_fn(pred, batch$y, class_weights)
504498
loss$backward()
505499
loss
506500
}
507-
optimizer$step(cl)
501+
optimizer_obj$step(cl)
508502
}
509503
)
510504

inst/WORDLIST

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,16 @@
11
CMD
22
Codecov
3-
Hoerl
4-
Kennard
53
LBFGS
64
Lifecycle
7-
Multilayer
8-
Nonorthogonal
95
ORCID
6+
PBC
107
SGD
11-
Springer
12-
Technometrics
138
elu
149
extensibility
15-
optimizers
16-
perceptron
10+
funder
11+
mlp
12+
multilayer
13+
perceptrons
1714
relu
18-
relu’
1915
tanh
20-
CMD
21-
Lifecycle
22-
LBFGS
23-
SGD
24-
optimizers
25-
ggplot
26-
mlp
2716
tibble
28-
multilayer

man/brulee_mlp.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/_snaps/logistic_reg-fit.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Code
44
set.seed(1)
55
fit <- brulee_logistic_reg(y ~ ., df, epochs = 2, verbose = TRUE, penalty = 0)
6-
Message <rlang_message>
6+
Message
77
epoch: 1 Loss: 0.173
88
epoch: 2 Loss: 0.173 x
99

@@ -25,7 +25,7 @@
2525
set.seed(1)
2626
fit_imbal <- brulee_logistic_reg(y ~ ., df_imbal, verbose = TRUE,
2727
class_weights = 20, optimizer = "SGD", penalty = 0)
28-
Message <rlang_message>
28+
Message
2929
epoch: 1 Loss: 0.329
3030
epoch: 2 Loss: 0.302
3131
epoch: 3 Loss: 0.282
@@ -53,7 +53,7 @@
5353
set.seed(1)
5454
fit <- brulee_logistic_reg(y ~ ., df_imbal, epochs = 2, verbose = TRUE,
5555
class_weights = c(a = 12, b = 1), penalty = 0)
56-
Message <rlang_message>
56+
Message
5757
epoch: 1 Loss: 0.113
5858
epoch: 2 Loss: 0.113 x
5959

0 commit comments

Comments
 (0)