Skip to content

Commit 631fe3c

Browse files
topepo‘topepo’
andauthored
Two layer neural networks (#80)
* new function for two layer models * two layer unit tests * avoid overflow issues * fix doc error * test when loss cannot be computed * tests for tunable values * update GHA * update spelling * use current CRAN torch * add missing snapshot and remove print snapshots (due to OS differences) * don't test on R 4.1 * test overflow on M1 mac --------- Co-authored-by: ‘topepo’ <‘[email protected]’>
1 parent 9b01415 commit 631fe3c

File tree

16 files changed

+1562
-555
lines changed

16 files changed

+1562
-555
lines changed

.github/workflows/R-CMD-check.yaml

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ on:
1212

1313
name: R-CMD-check
1414

15+
permissions: read-all
16+
1517
jobs:
1618
R-CMD-check:
1719
runs-on: ${{ matrix.config.os }}
@@ -26,18 +28,20 @@ jobs:
2628

2729
- {os: windows-latest, r: 'release'}
2830

29-
- {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'}
30-
- {os: ubuntu-latest, r: 'release'}
31-
- {os: ubuntu-latest, r: 'oldrel-1'}
32-
- {os: ubuntu-latest, r: 'oldrel-2'}
31+
- {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'}
32+
- {os: ubuntu-latest, r: 'release'}
33+
- {os: ubuntu-latest, r: 'oldrel-1'}
34+
- {os: ubuntu-latest, r: 'oldrel-2'}
35+
- {os: ubuntu-latest, r: 'oldrel-3'}
36+
- {os: ubuntu-latest, r: 'oldrel-4'}
3337

3438
env:
3539
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
3640
R_KEEP_PKG_SOURCE: yes
3741
TORCH_INSTALL: 1
3842

3943
steps:
40-
- uses: actions/checkout@v3
44+
- uses: actions/checkout@v4
4145

4246
- uses: r-lib/actions/setup-pandoc@v2
4347

@@ -55,4 +59,4 @@ jobs:
5559
- uses: r-lib/actions/check-r-package@v2
5660
with:
5761
upload-snapshots: true
58-
args: 'c("--no-multiarch", "--no-manual")'
62+
build_args: 'c("--no-multiarch", "--no-manual","--compact-vignettes=gs+qpdf")'

.github/workflows/test-coverage.yaml

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ on:
88

99
name: test-coverage
1010

11+
permissions: read-all
12+
1113
jobs:
1214
test-coverage:
1315
runs-on: ubuntu-latest
@@ -16,36 +18,45 @@ jobs:
1618
TORCH_INSTALL: 1
1719

1820
steps:
19-
- uses: actions/checkout@v3
21+
- uses: actions/checkout@v4
2022

2123
- uses: r-lib/actions/setup-r@v2
2224
with:
2325
use-public-rspm: true
2426

2527
- uses: r-lib/actions/setup-r-dependencies@v2
2628
with:
27-
extra-packages: any::covr
29+
extra-packages: any::covr, any::xml2
2830
needs: coverage
2931

3032
- name: Test coverage
3133
run: |
32-
covr::codecov(
34+
cov <- covr::package_coverage(
3335
quiet = FALSE,
3436
clean = FALSE,
35-
install_path = file.path(Sys.getenv("RUNNER_TEMP"), "package")
37+
install_path = file.path(normalizePath(Sys.getenv("RUNNER_TEMP"), winslash = "/"), "package")
3638
)
39+
covr::to_cobertura(cov)
3740
shell: Rscript {0}
3841

42+
- uses: codecov/codecov-action@v4
43+
with:
44+
fail_ci_if_error: ${{ github.event_name != 'pull_request' && true || false }}
45+
file: ./cobertura.xml
46+
plugin: noop
47+
disable_search: true
48+
token: ${{ secrets.CODECOV_TOKEN }}
49+
3950
- name: Show testthat output
4051
if: always()
4152
run: |
4253
## --------------------------------------------------------------------
43-
find ${{ runner.temp }}/package -name 'testthat.Rout*' -exec cat '{}' \; || true
54+
find '${{ runner.temp }}/package' -name 'testthat.Rout*' -exec cat '{}' \; || true
4455
shell: bash
4556

4657
- name: Upload test results
4758
if: failure()
48-
uses: actions/upload-artifact@v3
59+
uses: actions/upload-artifact@v4
4960
with:
5061
name: coverage-test-failures
5162
path: ${{ runner.temp }}/package

DESCRIPTION

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ Imports:
2525
rlang (>= 1.1.1),
2626
stats,
2727
tibble,
28-
torch (>= 0.11.0),
28+
torch (>= 0.13.0),
2929
utils
3030
Suggests:
3131
covr,
@@ -40,4 +40,4 @@ Config/testthat/edition: 3
4040
Encoding: UTF-8
4141
Language: en-US
4242
Roxygen: list(markdown = TRUE)
43-
RoxygenNote: 7.2.3
43+
RoxygenNote: 7.3.1

NAMESPACE

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ S3method(brulee_mlp,default)
1919
S3method(brulee_mlp,formula)
2020
S3method(brulee_mlp,matrix)
2121
S3method(brulee_mlp,recipe)
22+
S3method(brulee_mlp_two_layer,data.frame)
23+
S3method(brulee_mlp_two_layer,default)
24+
S3method(brulee_mlp_two_layer,formula)
25+
S3method(brulee_mlp_two_layer,matrix)
26+
S3method(brulee_mlp_two_layer,recipe)
2227
S3method(brulee_multinomial_reg,data.frame)
2328
S3method(brulee_multinomial_reg,default)
2429
S3method(brulee_multinomial_reg,formula)
@@ -39,13 +44,15 @@ S3method(print,brulee_multinomial_reg)
3944
S3method(tunable,brulee_linear_reg)
4045
S3method(tunable,brulee_logistic_reg)
4146
S3method(tunable,brulee_mlp)
47+
S3method(tunable,brulee_mlp_two_layer)
4248
S3method(tunable,brulee_multinomial_reg)
4349
export("%>%")
4450
export(autoplot)
4551
export(brulee_activations)
4652
export(brulee_linear_reg)
4753
export(brulee_logistic_reg)
4854
export(brulee_mlp)
55+
export(brulee_mlp_two_layer)
4956
export(brulee_multinomial_reg)
5057
export(coef)
5158
export(matrix_to_dataset)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# brulee (development version)
22

3+
* Added a convenience function, `brulee_mlp_two_layer()`, to more easily fit two-layer networks with parsnip.
4+
35
# brulee 0.3.0
46

57
* Fixed bug where `coef()` didn't would error if used on a `brulee_logistic_reg()` that was trained with a recipe. (#66)

0 commit comments

Comments
 (0)