Skip to content

Commit 7b95f20

Browse files
committed
Changes for #956
1 parent a212f78 commit 7b95f20

File tree

3 files changed

+87
-0
lines changed

3 files changed

+87
-0
lines changed

R/linear_reg.R

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,27 @@ translate.linear_reg <- function(x, engine = x$engine, ...) {
7373
# evaluated value for the parameter.
7474
x$args$penalty <- rlang::eval_tidy(x$args$penalty)
7575
}
76+
77+
# ------------------------------------------------------------------------------
78+
# We want to avoid folks passing in a poisson family instead of using
79+
# poisson_reg(). It's hard to detect this.
80+
81+
is_fam <- names(x$eng_args) == "family"
82+
if (any(is_fam)) {
83+
eng_args <- rlang::eval_tidy(x$eng_args[[which(is_fam)]])
84+
if (is.function(eng_args)) {
85+
eng_args <- try(eng_args(), silent = TRUE)
86+
}
87+
if (inherits(eng_args, "family")) {
88+
eng_args <- eng_args$family
89+
}
90+
if (eng_args == "poisson") {
91+
cli::cli_abort(
92+
"A Poisson family was requested for {.fn linear_reg}. Please use
93+
{.fn poisson_reg} and the engines in the {.pkg poissonreg} package.",
94+
call = rlang::call2("linear_reg"))
95+
}
96+
}
7697
x
7798
}
7899

tests/testthat/_snaps/linear_reg.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,39 @@
139139
Error in `fit()`:
140140
! `penalty` must be a number larger than or equal to 0 or `NULL`, not the number -1.
141141

142+
# Poisson family (#956)
143+
144+
Code
145+
linear_reg(penalty = 1) %>% set_engine("glmnet", family = poisson) %>%
146+
translate()
147+
Condition
148+
Error in `linear_reg()`:
149+
! A Poisson family was requested for `linear_reg()`. Please use `poisson_reg()` and the engines in the poissonreg package.
150+
151+
---
152+
153+
Code
154+
linear_reg(penalty = 1) %>% set_engine("glmnet", family = stats::poisson) %>%
155+
translate()
156+
Condition
157+
Error in `linear_reg()`:
158+
! A Poisson family was requested for `linear_reg()`. Please use `poisson_reg()` and the engines in the poissonreg package.
159+
160+
---
161+
162+
Code
163+
linear_reg(penalty = 1) %>% set_engine("glmnet", family = stats::poisson()) %>%
164+
translate()
165+
Condition
166+
Error in `linear_reg()`:
167+
! A Poisson family was requested for `linear_reg()`. Please use `poisson_reg()` and the engines in the poissonreg package.
168+
169+
---
170+
171+
Code
172+
linear_reg(penalty = 1) %>% set_engine("glmnet", family = "poisson") %>%
173+
translate()
174+
Condition
175+
Error in `linear_reg()`:
176+
! A Poisson family was requested for `linear_reg()`. Please use `poisson_reg()` and the engines in the poissonreg package.
177+

tests/testthat/test-linear_reg.R

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,3 +358,33 @@ test_that("check_args() works", {
358358
}
359359
)
360360
})
361+
362+
363+
test_that('Poisson family (#956)', {
364+
expect_snapshot(
365+
linear_reg(penalty = 1) %>%
366+
set_engine("glmnet", family = poisson) %>%
367+
translate(),
368+
error = TRUE
369+
)
370+
expect_snapshot(
371+
linear_reg(penalty = 1) %>%
372+
set_engine("glmnet", family = stats::poisson) %>%
373+
translate(),
374+
error = TRUE
375+
)
376+
expect_snapshot(
377+
linear_reg(penalty = 1) %>%
378+
set_engine("glmnet", family = stats::poisson()) %>%
379+
translate(),
380+
error = TRUE
381+
)
382+
expect_snapshot(
383+
linear_reg(penalty = 1) %>%
384+
set_engine("glmnet", family = "poisson") %>%
385+
translate(),
386+
error = TRUE
387+
)
388+
389+
390+
})

0 commit comments

Comments
 (0)