Skip to content

Commit 9383acd

Browse files
committed
make set_tailor_type() play nice with infer_type() (closes #38)
1 parent d80ef2a commit 9383acd

File tree

4 files changed

+118
-2
lines changed

4 files changed

+118
-2
lines changed

R/tailor.R

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,9 @@ predict.tailor <- function(object, new_data, ...) {
209209
new_data
210210
}
211211

212-
set_tailor_type <- function(object, y) {
212+
set_tailor_type <- function(object, y, call = caller_env()) {
213213
if (object$type != "unknown") {
214+
check_outcome_type(y, object$type, call = call)
214215
return(object)
215216
}
216217
if (is.factor(y)) {
@@ -223,7 +224,10 @@ set_tailor_type <- function(object, y) {
223224
} else if (is.numeric(y)) {
224225
object$type <- "regression"
225226
} else {
226-
cli_abort("Only factor and numeric outcomes are currently supported.")
227+
cli_abort(
228+
"Only factor and numeric outcomes are currently supported.",
229+
call = call
230+
)
227231
}
228232
object
229233
}

R/utils.R

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,27 @@ check_method <- function(method,
193193
method
194194
}
195195

196+
# at `fit()` time, we check the type of the outcome vs the type
197+
# supported by the applied adjustments. where this is called currently,
198+
# we know already that `type` is not "unknown"
199+
check_outcome_type <- function(outcome, type, call) {
200+
outcome_is_compatible <-
201+
switch(
202+
type,
203+
regression = is.numeric(outcome),
204+
binary = , multiclass = is.factor(outcome),
205+
FALSE
206+
)
207+
208+
if (!outcome_is_compatible) {
209+
cli_abort(
210+
"Tailors with {type} adjustments are not compatible
211+
with {.cls {class(outcome)}} outcomes.",
212+
call = call
213+
)
214+
}
215+
}
216+
196217
check_selection <- function(selector, result, arg, call = caller_env()) {
197218
if (length(result) == 0) {
198219
cli_abort(

tests/testthat/_snaps/utils.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,39 @@
66
Error in `adjust_probability_threshold()`:
77
! `x` should be a <tailor> (`?tailor::tailor()`), not a string.
88

9+
# fit.tailor() errors informatively with incompatible outcome
10+
11+
Code
12+
fit(tailor() %>% adjust_probability_threshold(0.1), two_class_example, outcome = c(
13+
test_numeric), estimate = c(predicted), probabilities = c(Class1, Class2))
14+
Condition
15+
Error in `fit()`:
16+
! Tailors with binary adjustments are not compatible with <numeric> outcomes.
17+
18+
---
19+
20+
Code
21+
fit(tailor() %>% adjust_numeric_range(lower_limit = 0.1), two_class_example,
22+
outcome = c(truth), estimate = c(Class1))
23+
Condition
24+
Error in `fit()`:
25+
! Tailors with regression adjustments are not compatible with <factor> outcomes.
26+
27+
---
28+
29+
Code
30+
fit(tailor() %>% adjust_probability_threshold(0.1), two_class_example, outcome = c(
31+
test_date), estimate = c(predicted), probabilities = c(Class1, Class2))
32+
Condition
33+
Error in `fit()`:
34+
! Tailors with binary adjustments are not compatible with <POSIXct/POSIXt> outcomes.
35+
36+
---
37+
38+
Code
39+
fit(tailor() %>% adjust_predictions_custom(hey = "there"), two_class_example,
40+
outcome = c(test_date), estimate = c(predicted), probabilities = c(Class1))
41+
Condition
42+
Error in `fit()`:
43+
! Only factor and numeric outcomes are currently supported.
44+

tests/testthat/test-utils.R

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,58 @@ test_that("tailor_requires_fit works", {
7373
)
7474
)
7575
})
76+
77+
test_that("fit.tailor() errors informatively with incompatible outcome", {
78+
skip_if_not_installed("modeldata")
79+
library(modeldata)
80+
81+
two_class_example$test_numeric <- two_class_example$Class1 + 1
82+
two_class_example$test_date <- as.POSIXct(two_class_example$Class1)
83+
84+
# supply a numeric outcome to a binary tailor
85+
expect_snapshot(
86+
error = TRUE,
87+
fit(
88+
tailor() %>% adjust_probability_threshold(.1),
89+
two_class_example,
90+
outcome = c(test_numeric),
91+
estimate = c(predicted),
92+
probabilities = c(Class1, Class2)
93+
)
94+
)
95+
96+
# supply a factor outcome to a regression tailor
97+
expect_snapshot(
98+
error = TRUE,
99+
fit(
100+
tailor() %>% adjust_numeric_range(lower_limit = .1),
101+
two_class_example,
102+
outcome = c(truth),
103+
estimate = c(Class1)
104+
)
105+
)
106+
107+
# supply a totally wild outcome to a regression tailor
108+
expect_snapshot(
109+
error = TRUE,
110+
fit(
111+
tailor() %>% adjust_probability_threshold(.1),
112+
two_class_example,
113+
outcome = c(test_date),
114+
estimate = c(predicted),
115+
probabilities = c(Class1, Class2)
116+
)
117+
)
118+
119+
# supply a totally wild outcome to an unknown tailor
120+
expect_snapshot(
121+
error = TRUE,
122+
fit(
123+
tailor() %>% adjust_predictions_custom(hey = "there"),
124+
two_class_example,
125+
outcome = c(test_date),
126+
estimate = c(predicted),
127+
probabilities = c(Class1)
128+
)
129+
)
130+
})

0 commit comments

Comments
 (0)