Skip to content
Open

Spark #282

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/CRAN-R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
# usethis::use_github_action("check-standard") will install it.
on:
push:
branches: [main, master]
branches: main
pull_request:
branches: [main, master]
branches: main
schedule:
# runs tests every day at 1am EST
- cron: '0 5 * * *'
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/GH-R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
# usethis::use_github_action("check-standard") will install it.
on:
push:
branches: [main, master]
branches: main
pull_request:
branches: [main, master]
branches: main
schedule:
# runs tests every day at 1am EST
- cron: '0 5 * * *'
Expand Down
18 changes: 12 additions & 6 deletions .github/workflows/spark-R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
# usethis::use_github_action("check-standard") will install it.
on:
push:
branches: [main, master]
branches: main
pull_request:
branches: [main, master]
branches: main
schedule:
# runs tests every day at 1am EST
- cron: '0 5 * * *'
Expand All @@ -26,13 +26,14 @@ jobs:
fail-fast: false
matrix:
config:
- {os: ubuntu-latest, r: 'release'}
- {os: ubuntu-latest, r: 'release', spark: '4.0.1'}

env:
R_REMOTES_NO_ERRORS_FROM_WARNINGS: true
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
R_COMPILE_AND_INSTALL_PACKAGES: 'always'
R_KEEP_PKG_SOURCE: yes
SPARK_VERSION: ${{ matrix.config.spark }}

steps:
- uses: actions/checkout@v3
Expand All @@ -52,20 +53,25 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::rcmdcheck
extra-packages: |
any::rcmdcheck
any::devtools
needs: check

- name: Install Spark
run: |
try(sparklyr::spark_install(version = "3.4.0", verbose = TRUE), silent = TRUE)
sparklyr::spark_install(version = Sys.getenv("SPARK_VERSION"))
shell: Rscript {0}

- name: Install devel versions
run: |
try(pak::pkg_install("tidymodels/parsnip"))
shell: Rscript {0}

- uses: r-lib/actions/check-r-package@v2
- name: Run Spark tests
run: |
devtools::test(filter = "spark", stop_on_failure = TRUE)
shell: Rscript {0}

- name: Notify slack fail
if: failure() && github.event_name == 'schedule'
Expand Down
44 changes: 41 additions & 3 deletions tests/testthat/helper-objects.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
library(modeldata)
library(parsnip)

# Test tracker variables
.env_tests <- new.env()
.env_tests$spark_connection <- NULL

## -----------------------------------------------------------------------------

data("wa_churn")
Expand All @@ -13,13 +17,34 @@ ctrl <- control_parsnip(verbosity = 1, catch = FALSE)
caught_ctrl <- control_parsnip(verbosity = 1, catch = TRUE)
quiet_ctrl <- control_parsnip(verbosity = 0, catch = TRUE)

run_glmnet <- utils::compareVersion('3.6.0', as.character(getRversion())) > 0
run_glmnet <- utils::compareVersion("3.6.0", as.character(getRversion())) > 0

## -----------------------------------------------------------------------------
## --------------------------- Spark -------------------------------------------

skip_if_no_spark <- function() {
skip <- NULL
test_spark <- as.logical(Sys.getenv("TEST_SPARK", unset = TRUE))
if (!isTRUE(test_spark)) {
skip <- TRUE
}
if (is.null(skip)) {
if (spark_not_installed()) {
skip <- TRUE
} else {
if (is.null(spark_test_connection())) {
skip <- TRUE
}
}
}
if (skip %||% FALSE) {
skip("Skipping Spark related tests")
} else {
return(invisible())
}
}

spark_not_installed <- function() {
need_install <- purrr:::quietly(sparklyr::spark_install_find)()

if (inherits(need_install, "try-error")) {
need_install <- TRUE
} else {
Expand All @@ -28,6 +53,19 @@ spark_not_installed <- function() {
need_install
}

spark_test_connection <- function() {
suppressPackageStartupMessages(library(sparklyr))
if (is.null(.env_tests$spark_connection)) {
version <- Sys.getenv("SPARK_VERSION", unset = "3.5.7")
sc <- try(sparklyr::spark_connect("local", version = version), silent = TRUE)
if (inherits(sc, "try-error")) {
return(NULL)
}
.env_tests$spark_connection <- sc
}
.env_tests$spark_connection
}

# ------------------------------------------------------------------------------

expect_ptype <- function(x, ptype) {
Expand Down
145 changes: 0 additions & 145 deletions tests/testthat/test-parsnip-case-weights.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@ skip_if_not_installed("hardhat", minimum_version = "1.2.0")
skip_if_not_installed("yardstick", minimum_version = "1.0.0")
skip_if_not_installed("workflows", minimum_version = "1.0.0")
skip_if_not_installed("recipes", minimum_version = "1.0.0")
skip_if_not_installed("sparklyr", minimum_version = "1.9.1.9000")

# load all extension packages to register the engines
library(parsnip)
suppressPackageStartupMessages(library(sparklyr))

# boosted trees -----------------------------------------------------------

Expand Down Expand Up @@ -184,53 +182,6 @@ test_that('linear_reg - glmnet case weights', {
expect_equal(sub_fit$fit$beta, wt_fit$fit$beta)
})

test_that('linear_reg - spark case weights', {
skip_if_not_installed("sparklyr")

sc <- try(spark_connect(master = "local"), silent = TRUE)

skip_if(inherits(sc, "try-error"))

dat <- make_mtcars_wts()

mtcars_wts <- copy_to(
sc,
mtcars %>% mutate(wts = as.double(dat$wts)),
"dat_wts",
overwrite = TRUE
)

mtcars_subset <- copy_to(
sc,
dat$subset,
"mtcars_subset",
overwrite = TRUE
)

expect_error(
{
wt_fit <-
linear_reg() %>%
set_engine("spark") %>%
fit(
mpg ~ . - wts,
data = mtcars_wts,
case_weights = "wts"
)
},
regexp = NA
)

sub_fit <-
linear_reg() %>%
set_engine("spark") %>%
fit(mpg ~ ., data = mtcars_subset)

expect_equal(coef(sub_fit$fit), coef(wt_fit$fit))

spark_disconnect_all()
})


# logistic_reg ------------------------------------------------------------

Expand Down Expand Up @@ -296,54 +247,6 @@ test_that('logistic_reg - stan case weights', {
expect_snapshot(print(wt_fit$fit$call))
})

test_that('logistic_reg - spark case weights', {
skip_if_not_installed("sparklyr")

sc <- try(spark_connect(master = "local"), silent = TRUE)

skip_if(inherits(sc, "try-error"))

dat <- make_two_class_wts()

two_class_dat_wts <- copy_to(
sc,
two_class_dat %>% mutate(wts = as.double(dat$wts)),
"two_class_dat_wts",
overwrite = TRUE
)

dat_subset <- copy_to(
sc,
dat$subset,
"dat_subset",
overwrite = TRUE
)

expect_error(
{
wt_fit <-
logistic_reg() %>%
set_engine("spark") %>%
fit(
Class ~ . - wts,
data = two_class_dat_wts,
case_weights = "wts"
)
},
regexp = NA
)

sub_fit <-
logistic_reg() %>%
set_engine("spark") %>%
fit(Class ~ ., data = dat_subset)

expect_equal(coef(sub_fit$fit), coef(wt_fit$fit))

spark_disconnect_all()
})


# mars --------------------------------------------------------------------

test_that('mars - earth case weights', {
Expand Down Expand Up @@ -407,54 +310,6 @@ test_that('multinom_reg - glmnet case weights', {
expect_equal(sub_fit$fit$beta, wt_fit$fit$beta)
})

test_that('multinom_reg - spark case weights', {
skip_if_not_installed("sparklyr")

sc <- try(spark_connect(master = "local"), silent = TRUE)

skip_if(inherits(sc, "try-error"))

dat <- make_penguin_wts()

penguin_wts <- copy_to(
sc,
penguins[complete.cases(penguins), ] %>% mutate(wts = as.double(dat$wts)),
"penguin_wts",
overwrite = TRUE
)

penguin_subset <- copy_to(
sc,
dat$subset,
"penguin_subset",
overwrite = TRUE
)

expect_error(
{
wt_fit <-
multinom_reg() %>%
set_engine("spark") %>%
fit(
island ~ . - wts,
data = penguin_wts,
case_weights = "wts"
)
},
regexp = NA
)

sub_fit <-
multinom_reg() %>%
set_engine("spark") %>%
fit(island ~ ., data = penguin_subset)

expect_equal(coef(sub_fit$fit), coef(wt_fit$fit))

spark_disconnect_all()
})


# rand_forest -------------------------------------------------------------

test_that('rand_forest - ranger case weights', {
Expand Down
12 changes: 2 additions & 10 deletions tests/testthat/test-spark-boost-tree.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
## Skip entire file is Spark is not installed
skip_if(spark_not_installed())
skip_if_not_installed("sparklyr", minimum_version = "1.9.1.9000")
skip_if_no_spark()

library(testthat)
library(parsnip)
Expand All @@ -23,12 +21,7 @@ test_that('Reminder to check for CRAN sparklyr > 1.9.1.9000', {


test_that('spark execution', {
skip_if_not_installed("sparklyr")
suppressPackageStartupMessages(library(sparklyr))

sc <- try(spark_connect(master = "local"), silent = TRUE)

skip_if(inherits(sc, "try-error"))
sc <- spark_test_connection()

hpc_bt_tr <- copy_to(sc, hpc[-(1:4), ], "hpc_bt_tr", overwrite = TRUE)
hpc_bt_te <- copy_to(sc, hpc[1:4, -1], "hpc_bt_te", overwrite = TRUE)
Expand Down Expand Up @@ -206,5 +199,4 @@ test_that('spark execution', {
as.data.frame(spark_class_dup_classprob)
)

spark_disconnect_all()
})
Loading
Loading