Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,13 @@ fit.model_spec <-
eval_env$formula <- formula
eval_env$weights <- wts

if (is_sparse_matrix(data)) {
cli::cli_abort(c(
x = "Sparse matrices cannot be used with {.fn fit}.",
i = "Please use {.fn fit_xy} interface instead."
))
}

data <- materialize_sparse_tibble(data, object, "data")

fit_interface <-
Expand Down
16 changes: 11 additions & 5 deletions R/sparsevctrs.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
to_sparse_data_frame <- function(x, object) {
if (methods::is(x, "sparseMatrix")) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Digging this refactor. :)

to_sparse_data_frame <- function(x, object, call = rlang::caller_env()) {
if (is_sparse_matrix(x)) {
if (allow_sparse(object)) {
x <- sparsevctrs::coerce_to_sparse_data_frame(x)
} else {
Expand All @@ -8,8 +8,10 @@ to_sparse_data_frame <- function(x, object) {
}

cli::cli_abort(
"{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with
engine {.code {object$engine}} doesn't accept that.")
"{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with
engine {.val {object$engine}} doesn't accept that.",
call = call
)
}
} else if (is.data.frame(x)) {
x <- materialize_sparse_tibble(x, object, "x")
Expand All @@ -21,6 +23,10 @@ is_sparse_tibble <- function(x) {
any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))
}

is_sparse_matrix <- function(x) {
methods::is(x, "sparseMatrix")
}

materialize_sparse_tibble <- function(x, object, input) {
if (is_sparse_tibble(x) && (!allow_sparse(object))) {
if (inherits(object, "model_fit")) {
Expand All @@ -29,7 +35,7 @@ materialize_sparse_tibble <- function(x, object, input) {

cli::cli_warn(
"{.arg {input}} is a sparse tibble, but {.fn {class(object)[1]}} with
engine {.code {object$engine}} doesn't accept that. Converting to
engine {.val {object$engine}} doesn't accept that. Converting to
non-sparse."
)
for (i in seq_along(ncol(x))) {
Expand Down
23 changes: 16 additions & 7 deletions tests/testthat/_snaps/sparsevctrs.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,48 @@
lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ])
Condition
Warning:
`data` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse.
`data` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.

# sparse tibble can be passed to `fit_xy()

Code
lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1])
Condition
Warning:
`x` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse.
`x` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.

# sparse matrices can be passed to `fit_xy()

Code
lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1])
Condition
Error in `to_sparse_data_frame()`:
! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that.
Error in `fit_xy()`:
! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that.

# sparse matrices can not be passed to `fit()

Code
hotel_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data)
Condition
Error in `fit()`:
x Sparse matrices cannot be used with `fit()`.
i Please use `fit_xy()` interface instead.

# sparse tibble can be passed to `predict()

Code
preds <- predict(lm_fit, sparse_mtcars)
Condition
Warning:
`x` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse.
`x` is a sparse tibble, but `linear_reg()` with engine "lm" doesn't accept that. Converting to non-sparse.

# sparse matrices can be passed to `predict()

Code
predict(lm_fit, sparse_mtcars)
Condition
Error in `to_sparse_data_frame()`:
! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that.
Error in `predict()`:
! `x` is a sparse matrix, but `linear_reg()` with engine "lm" doesn't accept that.

# to_sparse_data_frame() is used correctly

Expand Down
15 changes: 15 additions & 0 deletions tests/testthat/test-sparsevctrs.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ test_that("sparse matrices can be passed to `fit_xy()", {
)
})

test_that("sparse matrices can not be passed to `fit()", {
skip_if_not_installed("xgboost")

hotel_data <- sparse_hotel_rates()

spec <- boost_tree() %>%
set_mode("regression") %>%
set_engine("xgboost")

expect_snapshot(
error = TRUE,
hotel_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data)
)
})

test_that("sparse tibble can be passed to `predict()", {
skip_if_not_installed("ranger")

Expand Down
Loading