Skip to content

Commit fe9ebff

Browse files
authored
Learning rate schedulers (#56)
* initial work for mlp models * update tests * unit tests * rename _learn_rate_* to schedule_* * Update news * begin to harmonize learn_rate options and schedulers * change "constant" scheduler to "none" * fix some documentation issues * GHA update * differences in OS results; skipping some snapshots on linux * donttest and fewer significant digits * update snapshots * another skip due to OS differences * dont check for on windows * added back args: 'c("--no-multiarch", "--no-manual")' * note on optimizers
1 parent 95acab6 commit fe9ebff

35 files changed

+845
-352
lines changed

.github/workflows/R-CMD-check.yaml

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Workflow derived from https://github.com/r-lib/actions/tree/master/examples
1+
# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples
22
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
33
#
44
# NOTE: This workflow is overkill for most R packages and
@@ -25,14 +25,11 @@ jobs:
2525
- {os: macOS-latest, r: 'release'}
2626

2727
- {os: windows-latest, r: 'release'}
28-
# Use 3.6 to trigger usage of RTools35
29-
- {os: windows-latest, r: '3.6'}
3028

31-
# Use older ubuntu to maximise backward compatibility
32-
- {os: ubuntu-18.04, r: 'devel', http-user-agent: 'release'}
33-
- {os: ubuntu-18.04, r: 'release'}
34-
- {os: ubuntu-18.04, r: 'oldrel-1'}
35-
- {os: ubuntu-18.04, r: 'oldrel-2'}
29+
- {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'}
30+
- {os: ubuntu-latest, r: 'release'}
31+
- {os: ubuntu-latest, r: 'oldrel-1'}
32+
- {os: ubuntu-latest, r: 'oldrel-2'}
3633

3734
env:
3835
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
@@ -42,18 +39,20 @@ jobs:
4239
steps:
4340
- uses: actions/checkout@v2
4441

45-
- uses: r-lib/actions/setup-pandoc@v1
42+
- uses: r-lib/actions/setup-pandoc@v2
4643

47-
- uses: r-lib/actions/setup-r@v1
44+
- uses: r-lib/actions/setup-r@v2
4845
with:
4946
r-version: ${{ matrix.config.r }}
5047
http-user-agent: ${{ matrix.config.http-user-agent }}
5148
use-public-rspm: true
5249

53-
- uses: r-lib/actions/setup-r-dependencies@v1
50+
- uses: r-lib/actions/setup-r-dependencies@v2
5451
with:
55-
extra-packages: rcmdcheck
52+
extra-packages: any::rcmdcheck
53+
needs: check
5654

57-
- uses: r-lib/actions/check-r-package@v1
55+
- uses: r-lib/actions/check-r-package@v2
5856
with:
57+
upload-snapshots: true
5958
args: 'c("--no-multiarch", "--no-manual")'

.github/workflows/pkgdown.yaml

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
# Workflow derived from https://github.com/r-lib/actions/tree/master/examples
1+
# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples
22
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
33
on:
44
push:
55
branches: [main, master]
6+
pull_request:
7+
branches: [main, master]
68
release:
79
types: [published]
810
workflow_dispatch:
@@ -12,25 +14,34 @@ name: pkgdown
1214
jobs:
1315
pkgdown:
1416
runs-on: ubuntu-latest
17+
# Only restrict concurrency for non-PR jobs
18+
concurrency:
19+
group: pkgdown-${{ github.event_name != 'pull_request' || github.run_id }}
1520
env:
1621
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
1722
TORCH_INSTALL: 1
1823
steps:
1924
- uses: actions/checkout@v2
2025

21-
- uses: r-lib/actions/setup-pandoc@v1
26+
- uses: r-lib/actions/setup-pandoc@v2
2227

23-
- uses: r-lib/actions/setup-r@v1
28+
- uses: r-lib/actions/setup-r@v2
2429
with:
2530
use-public-rspm: true
2631

27-
- uses: r-lib/actions/setup-r-dependencies@v1
32+
- uses: r-lib/actions/setup-r-dependencies@v2
2833
with:
29-
extra-packages: pkgdown
34+
extra-packages: any::pkgdown, local::.
3035
needs: website
3136

32-
- name: Deploy package
33-
run: |
34-
git config --local user.name "$GITHUB_ACTOR"
35-
git config --local user.email "[email protected]"
36-
Rscript -e 'pkgdown::deploy_to_branch(new_process = FALSE)'
37+
- name: Build site
38+
run: pkgdown::build_site_github_pages(new_process = FALSE, install = FALSE)
39+
shell: Rscript {0}
40+
41+
- name: Deploy to GitHub pages 🚀
42+
if: github.event_name != 'pull_request'
43+
uses: JamesIves/[email protected]
44+
with:
45+
clean: false
46+
branch: gh-pages
47+
folder: docs

.github/workflows/pr-commands.yaml

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Workflow derived from https://github.com/r-lib/actions/tree/master/examples
1+
# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples
22
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
33
on:
44
issue_comment:
@@ -16,20 +16,22 @@ jobs:
1616
steps:
1717
- uses: actions/checkout@v2
1818

19-
- uses: r-lib/actions/pr-fetch@v1
19+
- uses: r-lib/actions/pr-fetch@v2
2020
with:
2121
repo-token: ${{ secrets.GITHUB_TOKEN }}
2222

23-
- uses: r-lib/actions/setup-r@v1
23+
- uses: r-lib/actions/setup-r@v2
2424
with:
2525
use-public-rspm: true
2626

27-
- uses: r-lib/actions/setup-r-dependencies@v1
27+
- uses: r-lib/actions/setup-r-dependencies@v2
2828
with:
29-
extra-packages: roxygen2
29+
extra-packages: any::roxygen2
30+
needs: pr-document
3031

3132
- name: Document
32-
run: Rscript -e 'roxygen2::roxygenise()'
33+
run: roxygen2::roxygenise()
34+
shell: Rscript {0}
3335

3436
- name: commit
3537
run: |
@@ -38,7 +40,7 @@ jobs:
3840
git add man/\* NAMESPACE
3941
git commit -m 'Document'
4042
41-
- uses: r-lib/actions/pr-push@v1
43+
- uses: r-lib/actions/pr-push@v2
4244
with:
4345
repo-token: ${{ secrets.GITHUB_TOKEN }}
4446

@@ -51,17 +53,19 @@ jobs:
5153
steps:
5254
- uses: actions/checkout@v2
5355

54-
- uses: r-lib/actions/pr-fetch@v1
56+
- uses: r-lib/actions/pr-fetch@v2
5557
with:
5658
repo-token: ${{ secrets.GITHUB_TOKEN }}
5759

58-
- uses: r-lib/actions/setup-r@v1
60+
- uses: r-lib/actions/setup-r@v2
5961

6062
- name: Install dependencies
61-
run: Rscript -e 'install.packages("styler")'
63+
run: install.packages("styler")
64+
shell: Rscript {0}
6265

6366
- name: Style
64-
run: Rscript -e 'styler::style_pkg()'
67+
run: styler::style_pkg()
68+
shell: Rscript {0}
6569

6670
- name: commit
6771
run: |
@@ -70,6 +74,6 @@ jobs:
7074
git add \*.R
7175
git commit -m 'Style'
7276
73-
- uses: r-lib/actions/pr-push@v1
77+
- uses: r-lib/actions/pr-push@v2
7478
with:
7579
repo-token: ${{ secrets.GITHUB_TOKEN }}

.github/workflows/test-coverage.yaml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Workflow derived from https://github.com/r-lib/actions/tree/master/examples
1+
# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples
22
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
33
on:
44
push:
@@ -18,14 +18,15 @@ jobs:
1818
steps:
1919
- uses: actions/checkout@v2
2020

21-
- uses: r-lib/actions/setup-r@v1
21+
- uses: r-lib/actions/setup-r@v2
2222
with:
2323
use-public-rspm: true
2424

25-
- uses: r-lib/actions/setup-r-dependencies@v1
25+
- uses: r-lib/actions/setup-r-dependencies@v2
2626
with:
27-
extra-packages: covr
27+
extra-packages: any::covr
28+
needs: coverage
2829

2930
- name: Test coverage
30-
run: covr::codecov()
31+
run: covr::codecov(quiet = FALSE)
3132
shell: Rscript {0}

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Imports:
3030
Suggests:
3131
covr,
3232
modeldata,
33+
purrr,
3334
recipes,
3435
spelling,
3536
testthat,
@@ -39,4 +40,4 @@ Config/testthat/edition: 3
3940
Encoding: UTF-8
4041
Language: en-US
4142
Roxygen: list(markdown = TRUE)
42-
RoxygenNote: 7.1.2
43+
RoxygenNote: 7.2.1.9000

NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ export(brulee_mlp)
4848
export(brulee_multinomial_reg)
4949
export(coef)
5050
export(matrix_to_dataset)
51+
export(schedule_cyclic)
52+
export(schedule_decay_expo)
53+
export(schedule_decay_time)
54+
export(schedule_step)
55+
export(set_learn_rate)
5156
export(tunable)
5257
import(torch)
5358
importFrom(dplyr,"%>%")

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# brulee (development version)
22

3+
* Several learning rate schedulers were added to the modeling functions (#12).
4+
35
# brulee 0.1.0
46

57
* Modeling functions gained a `mixture` argument for the proportion of L1 penalty that is used. (#50)

R/0_utils.R

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,28 @@ brulee_print <- function(x, ...) {
3232
}
3333
cat("batch size:", x$parameters$batch_size, "\n")
3434

35+
if (all(c("sched", "sched_opt") %in% names(x$parameters))) {
36+
cat_schedule(x$parameters)
37+
}
38+
3539
if (!is.null(x$loss)) {
3640
it <- x$best_epoch
3741
chr_it <- cli::pluralize("{it} epoch{?s}:")
3842
if(x$parameters$validation > 0) {
3943
if (is.na(x$y_stats$mean)) {
4044
cat("validation loss after", chr_it,
41-
signif(x$loss[it], 5), "\n")
45+
signif(x$loss[it], 3), "\n")
4246
} else {
4347
cat("scaled validation loss after", chr_it,
44-
signif(x$loss[it], 5), "\n")
48+
signif(x$loss[it], 3), "\n")
4549
}
4650
} else {
4751
if (is.na(x$y_stats$mean)) {
4852
cat("training set loss after", chr_it,
49-
signif(x$loss[it], 5), "\n")
53+
signif(x$loss[it], 3), "\n")
5054
} else {
5155
cat("scaled training set loss after", chr_it,
52-
signif(x$loss[it], 5), "\n")
56+
signif(x$loss[it], 3), "\n")
5357
}
5458
}
5559
}
@@ -58,6 +62,21 @@ brulee_print <- function(x, ...) {
5862

5963
# ------------------------------------------------------------------------------
6064

65+
cat_schedule <- function(x) {
66+
if (x$sched == "none") {
67+
cat("learn rate:", x$learn_rate, "\n")
68+
} else {
69+
.fn <- paste0("schedule_", x$sched)
70+
cl <- rlang::call2(.fn, !!!x$sched_opt)
71+
chr_cl <- rlang::expr_deparse(cl, width = 200)
72+
73+
cat(gsub("^schedule_", "schedule: ", chr_cl), "\n")
74+
}
75+
invisible(NULL)
76+
}
77+
78+
# ------------------------------------------------------------------------------
79+
6180

6281
model_to_raw <- function(model) {
6382
con <- rawConnection(raw(), open = "w")

R/linear_reg-fit.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ linear_reg_fit_imp <-
558558

559559
if (verbose) {
560560
msg <- paste("epoch:", epoch_chr[epoch], loss_label,
561-
signif(loss_curr, 5), loss_note)
561+
signif(loss_curr, 3), loss_note)
562562

563563
rlang::inform(msg)
564564
}

R/logistic_reg-fit.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
#' * `blueprint`: The `hardhat` blueprint data.
7575
#'
7676
#' @examples
77+
#' \donttest{
7778
#' if (torch::torch_is_installed()) {
7879
#'
7980
#' library(recipes)
@@ -119,7 +120,7 @@
119120
#' bind_cols(cells_test) %>%
120121
#' roc_auc(class, .pred_PS)
121122
#' }
122-
#'
123+
#' }
123124
#' @export
124125
brulee_logistic_reg <- function(x, ...) {
125126
UseMethod("brulee_logistic_reg")
@@ -561,7 +562,7 @@ logistic_reg_fit_imp <-
561562

562563
if (verbose) {
563564
msg <- paste("epoch:", epoch_chr[epoch], loss_label,
564-
signif(loss_curr, 5), loss_note)
565+
signif(loss_curr, 3), loss_note)
565566

566567
rlang::inform(msg)
567568
}

0 commit comments

Comments
 (0)