diff --git a/R/fit.R b/R/fit.R index 7be77c3ca..ff6fb71ff 100644 --- a/R/fit.R +++ b/R/fit.R @@ -137,6 +137,10 @@ fit.model_spec <- cli::cli_abort(msg) } + if (is_sparse_matrix(data)) { + data <- sparsevctrs::coerce_to_sparse_tibble(data) + } + dots <- quos(...) if (length(possible_engines(object)) == 0) { diff --git a/R/sparsevctrs.R b/R/sparsevctrs.R index 5fe3633ae..73b6b6443 100644 --- a/R/sparsevctrs.R +++ b/R/sparsevctrs.R @@ -1,5 +1,5 @@ -to_sparse_data_frame <- function(x, object) { - if (methods::is(x, "sparseMatrix")) { +to_sparse_data_frame <- function(x, object, call = rlang::caller_env()) { + if (is_sparse_matrix(x)) { if (allow_sparse(object)) { x <- sparsevctrs::coerce_to_sparse_data_frame(x) } else { @@ -8,8 +8,10 @@ to_sparse_data_frame <- function(x, object) { } cli::cli_abort( - "{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with - engine {.code {object$engine}} doesn't accept that.") + "{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with + engine {.val {object$engine}} doesn't accept that.", + call = call + ) } } else if (is.data.frame(x)) { x <- materialize_sparse_tibble(x, object, "x") @@ -21,6 +23,10 @@ is_sparse_tibble <- function(x) { any(vapply(x, sparsevctrs::is_sparse_vector, logical(1))) } +is_sparse_matrix <- function(x) { + methods::is(x, "sparseMatrix") +} + materialize_sparse_tibble <- function(x, object, input) { if (is_sparse_tibble(x) && (!allow_sparse(object))) { if (inherits(object, "model_fit")) { @@ -29,7 +35,7 @@ materialize_sparse_tibble <- function(x, object, input) { cli::cli_warn( "{.arg {input}} is a sparse tibble, but {.fn {class(object)[1]}} with - engine {.code {object$engine}} doesn't accept that. Converting to + engine {.val {object$engine}} doesn't accept that. Converting to non-sparse." ) for (i in seq_along(ncol(x))) { diff --git a/tests/testthat/_snaps/sparsevctrs.md b/tests/testthat/_snaps/sparsevctrs.md index 7eb9d3a55..797dd2285 100644 --- a/tests/testthat/_snaps/sparsevctrs.md +++ b/tests/testthat/_snaps/sparsevctrs.md @@ -4,7 +4,15 @@ lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ]) Condition Warning: - `data` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse. + `data` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse. + +# sparse matrix can be passed to `fit() + + Code + lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ]) + Condition + Warning: + `data` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse. # sparse tibble can be passed to `fit_xy() @@ -12,15 +20,15 @@ lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1]) Condition Warning: - `x` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse. + `x` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse. # sparse matrices can be passed to `fit_xy() Code lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1]) Condition - Error in `to_sparse_data_frame()`: - ! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that. + Error in `fit_xy()`: + ! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that. # sparse tibble can be passed to `predict() @@ -28,15 +36,15 @@ preds <- predict(lm_fit, sparse_mtcars) Condition Warning: - `x` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse. + `x` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse. # sparse matrices can be passed to `predict() Code predict(lm_fit, sparse_mtcars) Condition - Error in `to_sparse_data_frame()`: - ! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that. + Error in `predict()`: + ! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that. # to_sparse_data_frame() is used correctly diff --git a/tests/testthat/test-sparsevctrs.R b/tests/testthat/test-sparsevctrs.R index aa452f2e3..067498bf1 100644 --- a/tests/testthat/test-sparsevctrs.R +++ b/tests/testthat/test-sparsevctrs.R @@ -21,6 +21,28 @@ test_that("sparse tibble can be passed to `fit()", { ) }) +test_that("sparse matrix can be passed to `fit()", { + skip_if_not_installed("xgboost") + + hotel_data <- sparse_hotel_rates() + + spec <- boost_tree() %>% + set_mode("regression") %>% + set_engine("xgboost") + + expect_no_error( + lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data) + ) + + spec <- linear_reg() %>% + set_mode("regression") %>% + set_engine("lm") + + expect_snapshot( + lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ]) + ) +}) + test_that("sparse tibble can be passed to `fit_xy()", { skip_if_not_installed("xgboost")