diff --git a/.github/workflows/CRAN-R-CMD-check.yaml b/.github/workflows/CRAN-R-CMD-check.yaml index cae83ca..2477d83 100644 --- a/.github/workflows/CRAN-R-CMD-check.yaml +++ b/.github/workflows/CRAN-R-CMD-check.yaml @@ -6,9 +6,9 @@ # usethis::use_github_action("check-standard") will install it. on: push: - branches: [main, master] + branches: main pull_request: - branches: [main, master] + branches: main schedule: # runs tests every day at 1am EST - cron: '0 5 * * *' diff --git a/.github/workflows/GH-R-CMD-check.yaml b/.github/workflows/GH-R-CMD-check.yaml index a23c081..806a3da 100644 --- a/.github/workflows/GH-R-CMD-check.yaml +++ b/.github/workflows/GH-R-CMD-check.yaml @@ -6,9 +6,9 @@ # usethis::use_github_action("check-standard") will install it. on: push: - branches: [main, master] + branches: main pull_request: - branches: [main, master] + branches: main schedule: # runs tests every day at 1am EST - cron: '0 5 * * *' diff --git a/.github/workflows/spark-R-CMD-check.yaml b/.github/workflows/spark-R-CMD-check.yaml index b377f21..f4d0281 100644 --- a/.github/workflows/spark-R-CMD-check.yaml +++ b/.github/workflows/spark-R-CMD-check.yaml @@ -6,9 +6,9 @@ # usethis::use_github_action("check-standard") will install it. on: push: - branches: [main, master] + branches: main pull_request: - branches: [main, master] + branches: main schedule: # runs tests every day at 1am EST - cron: '0 5 * * *' @@ -26,13 +26,14 @@ jobs: fail-fast: false matrix: config: - - {os: ubuntu-latest, r: 'release'} + - {os: ubuntu-latest, r: 'release', spark: '4.0.1'} env: R_REMOTES_NO_ERRORS_FROM_WARNINGS: true GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} R_COMPILE_AND_INSTALL_PACKAGES: 'always' R_KEEP_PKG_SOURCE: yes + SPARK_VERSION: ${{ matrix.config.spark }} steps: - uses: actions/checkout@v3 @@ -52,12 +53,14 @@ jobs: - uses: r-lib/actions/setup-r-dependencies@v2 with: - extra-packages: any::rcmdcheck + extra-packages: | + any::rcmdcheck + any::devtools needs: check - name: Install Spark run: | - try(sparklyr::spark_install(version = "3.4.0", verbose = TRUE), silent = TRUE) + sparklyr::spark_install(version = Sys.getenv("SPARK_VERSION")) shell: Rscript {0} - name: Install devel versions @@ -65,7 +68,10 @@ jobs: try(pak::pkg_install("tidymodels/parsnip")) shell: Rscript {0} - - uses: r-lib/actions/check-r-package@v2 + - name: Run Spark tests + run: | + devtools::test(filter = "spark", stop_on_failure = TRUE) + shell: Rscript {0} - name: Notify slack fail if: failure() && github.event_name == 'schedule' diff --git a/tests/testthat/helper-objects.R b/tests/testthat/helper-objects.R index 3572f0b..bf3468d 100644 --- a/tests/testthat/helper-objects.R +++ b/tests/testthat/helper-objects.R @@ -1,6 +1,10 @@ library(modeldata) library(parsnip) +# Test tracker variables +.env_tests <- new.env() +.env_tests$spark_connection <- NULL + ## ----------------------------------------------------------------------------- data("wa_churn") @@ -13,13 +17,34 @@ ctrl <- control_parsnip(verbosity = 1, catch = FALSE) caught_ctrl <- control_parsnip(verbosity = 1, catch = TRUE) quiet_ctrl <- control_parsnip(verbosity = 0, catch = TRUE) -run_glmnet <- utils::compareVersion('3.6.0', as.character(getRversion())) > 0 +run_glmnet <- utils::compareVersion("3.6.0", as.character(getRversion())) > 0 -## ----------------------------------------------------------------------------- +## --------------------------- Spark ------------------------------------------- + +skip_if_no_spark <- function() { + skip <- NULL + test_spark <- as.logical(Sys.getenv("TEST_SPARK", unset = TRUE)) + if (!isTRUE(test_spark)) { + skip <- TRUE + } + if (is.null(skip)) { + if (spark_not_installed()) { + skip <- TRUE + } else { + if (is.null(spark_test_connection())) { + skip <- TRUE + } + } + } + if (skip %||% FALSE) { + skip("Skipping Spark related tests") + } else { + return(invisible()) + } +} spark_not_installed <- function() { need_install <- purrr:::quietly(sparklyr::spark_install_find)() - if (inherits(need_install, "try-error")) { need_install <- TRUE } else { @@ -28,6 +53,19 @@ spark_not_installed <- function() { need_install } +spark_test_connection <- function() { + suppressPackageStartupMessages(library(sparklyr)) + if (is.null(.env_tests$spark_connection)) { + version <- Sys.getenv("SPARK_VERSION", unset = "3.5.7") + sc <- try(sparklyr::spark_connect("local", version = version), silent = TRUE) + if (inherits(sc, "try-error")) { + return(NULL) + } + .env_tests$spark_connection <- sc + } + .env_tests$spark_connection +} + # ------------------------------------------------------------------------------ expect_ptype <- function(x, ptype) { diff --git a/tests/testthat/test-parsnip-case-weights.R b/tests/testthat/test-parsnip-case-weights.R index 2508dcf..1336caa 100644 --- a/tests/testthat/test-parsnip-case-weights.R +++ b/tests/testthat/test-parsnip-case-weights.R @@ -4,11 +4,9 @@ skip_if_not_installed("hardhat", minimum_version = "1.2.0") skip_if_not_installed("yardstick", minimum_version = "1.0.0") skip_if_not_installed("workflows", minimum_version = "1.0.0") skip_if_not_installed("recipes", minimum_version = "1.0.0") -skip_if_not_installed("sparklyr", minimum_version = "1.9.1.9000") # load all extension packages to register the engines library(parsnip) -suppressPackageStartupMessages(library(sparklyr)) # boosted trees ----------------------------------------------------------- @@ -184,53 +182,6 @@ test_that('linear_reg - glmnet case weights', { expect_equal(sub_fit$fit$beta, wt_fit$fit$beta) }) -test_that('linear_reg - spark case weights', { - skip_if_not_installed("sparklyr") - - sc <- try(spark_connect(master = "local"), silent = TRUE) - - skip_if(inherits(sc, "try-error")) - - dat <- make_mtcars_wts() - - mtcars_wts <- copy_to( - sc, - mtcars %>% mutate(wts = as.double(dat$wts)), - "dat_wts", - overwrite = TRUE - ) - - mtcars_subset <- copy_to( - sc, - dat$subset, - "mtcars_subset", - overwrite = TRUE - ) - - expect_error( - { - wt_fit <- - linear_reg() %>% - set_engine("spark") %>% - fit( - mpg ~ . - wts, - data = mtcars_wts, - case_weights = "wts" - ) - }, - regexp = NA - ) - - sub_fit <- - linear_reg() %>% - set_engine("spark") %>% - fit(mpg ~ ., data = mtcars_subset) - - expect_equal(coef(sub_fit$fit), coef(wt_fit$fit)) - - spark_disconnect_all() -}) - # logistic_reg ------------------------------------------------------------ @@ -296,54 +247,6 @@ test_that('logistic_reg - stan case weights', { expect_snapshot(print(wt_fit$fit$call)) }) -test_that('logistic_reg - spark case weights', { - skip_if_not_installed("sparklyr") - - sc <- try(spark_connect(master = "local"), silent = TRUE) - - skip_if(inherits(sc, "try-error")) - - dat <- make_two_class_wts() - - two_class_dat_wts <- copy_to( - sc, - two_class_dat %>% mutate(wts = as.double(dat$wts)), - "two_class_dat_wts", - overwrite = TRUE - ) - - dat_subset <- copy_to( - sc, - dat$subset, - "dat_subset", - overwrite = TRUE - ) - - expect_error( - { - wt_fit <- - logistic_reg() %>% - set_engine("spark") %>% - fit( - Class ~ . - wts, - data = two_class_dat_wts, - case_weights = "wts" - ) - }, - regexp = NA - ) - - sub_fit <- - logistic_reg() %>% - set_engine("spark") %>% - fit(Class ~ ., data = dat_subset) - - expect_equal(coef(sub_fit$fit), coef(wt_fit$fit)) - - spark_disconnect_all() -}) - - # mars -------------------------------------------------------------------- test_that('mars - earth case weights', { @@ -407,54 +310,6 @@ test_that('multinom_reg - glmnet case weights', { expect_equal(sub_fit$fit$beta, wt_fit$fit$beta) }) -test_that('multinom_reg - spark case weights', { - skip_if_not_installed("sparklyr") - - sc <- try(spark_connect(master = "local"), silent = TRUE) - - skip_if(inherits(sc, "try-error")) - - dat <- make_penguin_wts() - - penguin_wts <- copy_to( - sc, - penguins[complete.cases(penguins), ] %>% mutate(wts = as.double(dat$wts)), - "penguin_wts", - overwrite = TRUE - ) - - penguin_subset <- copy_to( - sc, - dat$subset, - "penguin_subset", - overwrite = TRUE - ) - - expect_error( - { - wt_fit <- - multinom_reg() %>% - set_engine("spark") %>% - fit( - island ~ . - wts, - data = penguin_wts, - case_weights = "wts" - ) - }, - regexp = NA - ) - - sub_fit <- - multinom_reg() %>% - set_engine("spark") %>% - fit(island ~ ., data = penguin_subset) - - expect_equal(coef(sub_fit$fit), coef(wt_fit$fit)) - - spark_disconnect_all() -}) - - # rand_forest ------------------------------------------------------------- test_that('rand_forest - ranger case weights', { diff --git a/tests/testthat/test-spark-boost-tree.R b/tests/testthat/test-spark-boost-tree.R index 8087c30..80e3ce5 100644 --- a/tests/testthat/test-spark-boost-tree.R +++ b/tests/testthat/test-spark-boost-tree.R @@ -1,6 +1,4 @@ -## Skip entire file is Spark is not installed -skip_if(spark_not_installed()) -skip_if_not_installed("sparklyr", minimum_version = "1.9.1.9000") +skip_if_no_spark() library(testthat) library(parsnip) @@ -23,12 +21,7 @@ test_that('Reminder to check for CRAN sparklyr > 1.9.1.9000', { test_that('spark execution', { - skip_if_not_installed("sparklyr") - suppressPackageStartupMessages(library(sparklyr)) - - sc <- try(spark_connect(master = "local"), silent = TRUE) - - skip_if(inherits(sc, "try-error")) + sc <- spark_test_connection() hpc_bt_tr <- copy_to(sc, hpc[-(1:4), ], "hpc_bt_tr", overwrite = TRUE) hpc_bt_te <- copy_to(sc, hpc[1:4, -1], "hpc_bt_te", overwrite = TRUE) @@ -206,5 +199,4 @@ test_that('spark execution', { as.data.frame(spark_class_dup_classprob) ) - spark_disconnect_all() }) diff --git a/tests/testthat/test-spark-case-weights.R b/tests/testthat/test-spark-case-weights.R new file mode 100644 index 0000000..ce69180 --- /dev/null +++ b/tests/testthat/test-spark-case-weights.R @@ -0,0 +1,123 @@ +skip_if_no_spark() + +test_that("linear_reg - spark case weights", { + sc <- spark_test_connection() + dat <- make_mtcars_wts() + + mtcars_wts <- copy_to( + sc, + mtcars %>% mutate(wts = as.double(dat$wts)), + "dat_wts", + overwrite = TRUE + ) + + mtcars_subset <- copy_to( + sc, + dat$subset, + "mtcars_subset", + overwrite = TRUE + ) + + expect_error( + { + wt_fit <- + linear_reg() %>% + set_engine("spark") %>% + fit( + mpg ~ . - wts, + data = mtcars_wts, + case_weights = "wts" + ) + }, + regexp = NA + ) + + sub_fit <- + linear_reg() %>% + set_engine("spark") %>% + fit(mpg ~ ., data = mtcars_subset) + + expect_equal(coef(sub_fit$fit), coef(wt_fit$fit)) +}) + +test_that("logistic_reg - spark case weights", { + sc <- spark_test_connection() + + dat <- make_two_class_wts() + + two_class_dat_wts <- copy_to( + sc, + two_class_dat %>% mutate(wts = as.double(dat$wts)), + "two_class_dat_wts", + overwrite = TRUE + ) + + dat_subset <- copy_to( + sc, + dat$subset, + "dat_subset", + overwrite = TRUE + ) + + expect_error( + { + wt_fit <- + logistic_reg() %>% + set_engine("spark") %>% + fit( + Class ~ . - wts, + data = two_class_dat_wts, + case_weights = "wts" + ) + }, + regexp = NA + ) + + sub_fit <- + logistic_reg() %>% + set_engine("spark") %>% + fit(Class ~ ., data = dat_subset) + + expect_equal(coef(sub_fit$fit), coef(wt_fit$fit)) +}) + +test_that("multinom_reg - spark case weights", { + sc <- spark_test_connection() + + dat <- make_penguin_wts() + + penguin_wts <- copy_to( + sc, + penguins[complete.cases(penguins), ] %>% mutate(wts = as.double(dat$wts)), + "penguin_wts", + overwrite = TRUE + ) + + penguin_subset <- copy_to( + sc, + dat$subset, + "penguin_subset", + overwrite = TRUE + ) + + expect_error( + { + wt_fit <- + multinom_reg() %>% + set_engine("spark") %>% + fit( + island ~ . - wts, + data = penguin_wts, + case_weights = "wts" + ) + }, + regexp = NA + ) + + sub_fit <- + multinom_reg() %>% + set_engine("spark") %>% + fit(island ~ ., data = penguin_subset) + + expect_equal(coef(sub_fit$fit), coef(wt_fit$fit)) +}) diff --git a/tests/testthat/test-spark-data-descriptors.R b/tests/testthat/test-spark-data-descriptors.R index d441797..e224f42 100644 --- a/tests/testthat/test-spark-data-descriptors.R +++ b/tests/testthat/test-spark-data-descriptors.R @@ -1,6 +1,4 @@ -## Skip entire file is Spark is not installed -skip_if(spark_not_installed()) -skip_if_not_installed("sparklyr", minimum_version = "1.9.1.9000") +skip_if_no_spark() library(testthat) library(parsnip) @@ -39,14 +37,7 @@ class_tab <- table(hpc$class, dnn = NULL) # ------------------------------------------------------------------------------ test_that("spark descriptor", { - skip_if_not_installed("sparklyr") - - suppressPackageStartupMessages(library(sparklyr)) - library(dplyr) - - sc <- try(spark_connect(master = "local"), silent = TRUE) - - skip_if(inherits(sc, "try-error")) + sc <- spark_test_connection() npk_descr <- copy_to(sc, npk[, 1:4], "npk_descr", overwrite = TRUE) hpc_descr <- copy_to(sc, hpc, "hpc_descr", overwrite = TRUE) @@ -88,5 +79,4 @@ test_that("spark descriptor", { eval_descrs2(parsnip:::get_descr_form(K ~ ., data = npk_descr)), ignore_attr = TRUE ) - spark_disconnect_all() }) diff --git a/tests/testthat/test-spark-linear-reg.R b/tests/testthat/test-spark-linear-reg.R index 2f615c9..4a04d83 100644 --- a/tests/testthat/test-spark-linear-reg.R +++ b/tests/testthat/test-spark-linear-reg.R @@ -1,6 +1,4 @@ -## Skip entire file is Spark is not installed -skip_if(spark_not_installed()) -skip_if_not_installed("sparklyr", minimum_version = "1.9.1.9000") +skip_if_no_spark() library(testthat) library(parsnip) @@ -13,13 +11,7 @@ hpc <- hpc_data[1:150, c(2:5, 8)] # ------------------------------------------------------------------------------ test_that('spark execution', { - skip_if_not_installed("sparklyr") - - suppressPackageStartupMessages(library(sparklyr)) - - sc <- try(spark_connect(master = "local"), silent = TRUE) - - skip_if(inherits(sc, "try-error")) + sc <- spark_test_connection() hpc_linreg_tr <- copy_to(sc, hpc[-(1:4), ], "hpc_linreg_tr", overwrite = TRUE) hpc_linreg_te <- copy_to(sc, hpc[1:4, -1], "hpc_linreg_te", overwrite = TRUE) @@ -53,6 +45,4 @@ test_that('spark execution', { expect_equal(as.data.frame(spark_pred)$pred, lm_pred) expect_equal(as.data.frame(spark_pred_num)$pred, lm_pred) - - spark_disconnect_all() }) diff --git a/tests/testthat/test-spark-logistic-reg.R b/tests/testthat/test-spark-logistic-reg.R index 492245e..989caea 100644 --- a/tests/testthat/test-spark-logistic-reg.R +++ b/tests/testthat/test-spark-logistic-reg.R @@ -1,6 +1,4 @@ -## Skip entire file is Spark is not installed -skip_if(spark_not_installed()) -skip_if_not_installed("sparklyr", minimum_version = "1.9.1.9000") +skip_if_no_spark() library(testthat) library(parsnip) @@ -13,13 +11,7 @@ hpc <- hpc_data[1:150, c(2:5, 8)] # ------------------------------------------------------------------------------ test_that('spark execution', { - skip_if_not_installed("sparklyr") - - suppressPackageStartupMessages(library(sparklyr)) - - sc <- try(spark_connect(master = "local"), silent = TRUE) - - skip_if(inherits(sc, "try-error")) + sc <- spark_test_connection() churn_logit_tr <- copy_to( sc, @@ -97,6 +89,4 @@ test_that('spark execution', { as.data.frame(spark_class_prob_classprob), ignore_attr = TRUE ) - - spark_disconnect_all() }) diff --git a/tests/testthat/test-spark-multinom-reg.R b/tests/testthat/test-spark-multinom-reg.R index 97541d6..68953f0 100644 --- a/tests/testthat/test-spark-multinom-reg.R +++ b/tests/testthat/test-spark-multinom-reg.R @@ -1,6 +1,4 @@ -## Skip entire file is Spark is not installed -skip_if(spark_not_installed()) -skip_if_not_installed("sparklyr", minimum_version = "1.9.1.9000") +skip_if_no_spark() library(testthat) library(parsnip) @@ -13,13 +11,7 @@ hpc <- hpc_data[1:150, c(2:5, 8)] # ------------------------------------------------------------------------------ test_that('spark execution', { - skip_if_not_installed("sparklyr") - - library(sparklyr) - - sc <- try(spark_connect(master = "local"), silent = TRUE) - - skip_if(inherits(sc, "try-error")) + sc <- spark_test_connection() hpc_rows <- c(1, 51, 101) hpc_tr <- copy_to(sc, hpc[-hpc_rows, ], "hpc_tr", overwrite = TRUE) @@ -80,5 +72,4 @@ test_that('spark execution', { ignore_attr = TRUE ) - spark_disconnect_all() }) diff --git a/tests/testthat/test-spark-rand-forest.R b/tests/testthat/test-spark-rand-forest.R index c2cb907..60384bd 100644 --- a/tests/testthat/test-spark-rand-forest.R +++ b/tests/testthat/test-spark-rand-forest.R @@ -1,6 +1,4 @@ -## Skip entire file is Spark is not installed -skip_if(spark_not_installed()) -skip_if_not_installed("sparklyr", minimum_version = "1.9.1.9000") +skip_if_no_spark() library(testthat) library(parsnip) @@ -13,13 +11,7 @@ hpc <- hpc_data[1:150, c(2:5, 8)] # ------------------------------------------------------------------------------ test_that('spark execution', { - skip_if_not_installed("sparklyr") - - suppressPackageStartupMessages(library(sparklyr)) - - sc <- try(spark_connect(master = "local"), silent = TRUE) - - skip_if(inherits(sc, "try-error")) + sc <- spark_test_connection() hpc_rf_tr <- copy_to(sc, hpc[-(1:4), ], "hpc_rf_tr", overwrite = TRUE) hpc_rf_te <- copy_to(sc, hpc[1:4, -1], "hpc_rf_te", overwrite = TRUE) @@ -185,6 +177,4 @@ test_that('spark execution', { as.data.frame(spark_class_prob_classprob), as.data.frame(spark_class_dup_classprob) ) - - spark_disconnect_all() })