Skip to content

Commit dfd9b0e

Browse files
authored
merge pr #41: remove tailor(type)
2 parents 6704eb6 + a9bb433 commit dfd9b0e

15 files changed

+214
-91
lines changed

R/tailor.R

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@
1919
#' with the [tidymodels](https://tidymodels.org) framework; for greatest ease
2020
#' of use, situate tailors in model workflows with `?workflows::add_tailor()`.
2121
#'
22-
#' @param type Character. The model sub-mode. Possible values are
23-
#' `"unknown"`, `"regression"`, `"binary"`, or `"multiclass"`. Only required
24-
#' when used independently of `?workflows::add_tailor()`.
2522
#' @param outcome <[`tidy-select`][dplyr::dplyr_tidy_select]> Only required
2623
#' when used independently of `?workflows::add_tailor()`, and can also be passed
2724
#' at `fit()` time instead. The column name of the outcome variable.
@@ -64,18 +61,16 @@
6461
#' # adjust hard class predictions
6562
#' predict(tlr_fit, two_class_example) %>% count(predicted)
6663
#' @export
67-
tailor <- function(type = "unknown", outcome = NULL, estimate = NULL,
68-
probabilities = NULL) {
64+
tailor <- function(outcome = NULL, estimate = NULL, probabilities = NULL) {
6965
columns <-
7066
list(
7167
outcome = outcome,
72-
type = type,
7368
estimate = estimate,
7469
probabilities = probabilities
7570
)
7671

7772
new_tailor(
78-
type,
73+
"unknown",
7974
adjustments = list(),
8075
columns = columns,
8176
ptype = tibble::new_tibble(list()),
@@ -84,8 +79,6 @@ tailor <- function(type = "unknown", outcome = NULL, estimate = NULL,
8479
}
8580

8681
new_tailor <- function(type, adjustments, columns, ptype, call) {
87-
type <- arg_match0(type, c("unknown", "regression", "binary", "multiclass"))
88-
8982
if (!is.list(adjustments)) {
9083
cli_abort("The {.arg adjustments} argument should be a list.", call = call)
9184
}
@@ -97,8 +90,14 @@ new_tailor <- function(type, adjustments, columns, ptype, call) {
9790
{.val adjustment}: {bad_adjustment}.", call = call)
9891
}
9992

93+
orderings <- adjustment_orderings(adjustments)
94+
95+
if (type == "unknown") {
96+
type <- infer_type(orderings)
97+
}
98+
10099
# validate adjustment order and check duplicates
101-
validate_order(adjustments, type, call)
100+
validate_order(orderings, type, call)
102101

103102
# check columns
104103
res <- list(
@@ -233,5 +232,5 @@ set_tailor_type <- function(object, y) {
233232
# todo setup eval_time
234233
# todo missing methods:
235234
# todo tune_args
236-
# todo tidy
235+
# todo tidy (this should probably just be `adjustment_orderings()`)
237236
# todo extract_parameter_set_dials

R/utils.R

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,19 @@ tailor_adjustment_requires_fit <- function(x) {
8888
isTRUE(x$requires_fit)
8989
}
9090

91+
# an tidy-esque method for adjustment lists, used in validating
92+
# compatibility of adjustments
93+
adjustment_orderings <- function(adjustments) {
94+
tibble::new_tibble(list(
95+
name = purrr::map_chr(adjustments, ~ class(.x)[1]),
96+
input = purrr::map_chr(adjustments, ~ .x$inputs),
97+
output_numeric = purrr::map_lgl(adjustments, ~ grepl("numeric", .x$outputs)),
98+
output_prob = purrr::map_lgl(adjustments, ~ grepl("probability", .x$outputs)),
99+
output_class = purrr::map_lgl(adjustments, ~ grepl("class", .x$outputs)),
100+
output_all = purrr::map_lgl(adjustments, ~ grepl("everything", .x$outputs))
101+
))
102+
}
103+
91104
# ad-hoc checking --------------------------------------------------------------
92105
check_tailor <- function(x, calibration_type = NULL, call = caller_env(), arg = caller_arg(x)) {
93106
if (!is_tailor(x)) {

R/validation-rules.R

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,9 @@
1-
validate_order <- function(adjustments, type, call = caller_env()) {
2-
orderings <-
3-
tibble::new_tibble(list(
4-
name = purrr::map_chr(adjustments, ~ class(.x)[1]),
5-
input = purrr::map_chr(adjustments, ~ .x$inputs),
6-
output_numeric = purrr::map_lgl(adjustments, ~ grepl("numeric", .x$outputs)),
7-
output_prob = purrr::map_lgl(adjustments, ~ grepl("probability", .x$outputs)),
8-
output_class = purrr::map_lgl(adjustments, ~ grepl("class", .x$outputs)),
9-
output_all = purrr::map_lgl(adjustments, ~ grepl("everything", .x$outputs))
10-
))
11-
12-
if (length(adjustments) < 2) {
1+
validate_order <- function(orderings, type, call = caller_env()) {
2+
if (nrow(orderings) < 2) {
133
return(invisible(orderings))
144
}
155

16-
if (type == "unknown") {
17-
type <- infer_type(orderings)
18-
}
6+
check_incompatible_types(orderings, call)
197

208
switch(
219
type,
@@ -27,6 +15,24 @@ validate_order <- function(adjustments, type, call = caller_env()) {
2715
invisible(orderings)
2816
}
2917

18+
check_incompatible_types <- function(orderings, call) {
19+
if (all(c("numeric", "probability") %in% orderings$input)) {
20+
numeric_adjustments <- orderings$name[which(orderings$input == "numeric")]
21+
probability_adjustments <- orderings$name[which(orderings$input == "probability")]
22+
cli_abort(
23+
c(
24+
"Can't compose adjustments for different prediction types.",
25+
"i" = "{cli::qty(numeric_adjustments)}
26+
Adjustment{?s} {.fn {paste0('adjust_', numeric_adjustments)}}
27+
{cli::qty(numeric_adjustments[-1])} operate{?s} on numerics while
28+
{.fn {paste0('adjust_', probability_adjustments)}}
29+
{cli::qty(probability_adjustments[-1])} operate{?s} on probabilities."
30+
),
31+
call = call
32+
)
33+
}
34+
}
35+
3036
check_classification_order <- function(x, call) {
3137
cal_ind <- which(grepl("calibration$", x$name))
3238
eq_ind <- which(grepl("equivocal", x$name))

man/tailor.Rd

Lines changed: 1 addition & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/_snaps/adjust-equivocal-zone.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Message
66
77
-- tailor ----------------------------------------------------------------------
8-
A postprocessor with 1 adjustment:
8+
A binary postprocessor with 1 adjustment:
99
1010
* Add equivocal zone of size 0.1.
1111

@@ -16,7 +16,7 @@
1616
Message
1717
1818
-- tailor ----------------------------------------------------------------------
19-
A postprocessor with 1 adjustment:
19+
A binary postprocessor with 1 adjustment:
2020
2121
* Add equivocal zone of optimized size.
2222

tests/testthat/_snaps/adjust-numeric-calibration.md

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Message
66
77
-- tailor ----------------------------------------------------------------------
8-
A postprocessor with 1 adjustment:
8+
A regression postprocessor with 1 adjustment:
99
1010
* Re-calibrate numeric predictions.
1111

@@ -20,15 +20,7 @@
2020
---
2121

2222
Code
23-
tailor("binary") %>% adjust_numeric_calibration("linear")
24-
Condition
25-
Error in `adjust_numeric_calibration()`:
26-
! A binary tailor is incompatible with the adjustment `adjust_numeric_calibration()`.
27-
28-
---
29-
30-
Code
31-
tailor("regression") %>% adjust_numeric_calibration("binary")
23+
tailor() %>% adjust_numeric_calibration("binary")
3224
Condition
3325
Error in `adjust_numeric_calibration()`:
3426
! `method` must be one of "linear", "isotonic", or "isotonic_boot", not "binary".

tests/testthat/_snaps/adjust-numeric-range.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Message
66
77
-- tailor ----------------------------------------------------------------------
8-
A postprocessor with 1 adjustment:
8+
A regression postprocessor with 1 adjustment:
99
1010
* Constrain numeric predictions to be between [-Inf, Inf].
1111

@@ -16,7 +16,7 @@
1616
Message
1717
1818
-- tailor ----------------------------------------------------------------------
19-
A postprocessor with 1 adjustment:
19+
A regression postprocessor with 1 adjustment:
2020
2121
* Constrain numeric predictions to be between [?, Inf].
2222

@@ -27,7 +27,7 @@
2727
Message
2828
2929
-- tailor ----------------------------------------------------------------------
30-
A postprocessor with 1 adjustment:
30+
A regression postprocessor with 1 adjustment:
3131
3232
* Constrain numeric predictions to be between [-1, ?].
3333

@@ -38,7 +38,7 @@
3838
Message
3939
4040
-- tailor ----------------------------------------------------------------------
41-
A postprocessor with 1 adjustment:
41+
A regression postprocessor with 1 adjustment:
4242
4343
* Constrain numeric predictions to be between [?, 1].
4444

tests/testthat/_snaps/adjust-probability-calibration.md

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Message
66
77
-- tailor ----------------------------------------------------------------------
8-
A postprocessor with 1 adjustment:
8+
A binary postprocessor with 1 adjustment:
99
1010
* Re-calibrate classification probabilities.
1111

@@ -20,15 +20,7 @@
2020
---
2121

2222
Code
23-
tailor("regression") %>% adjust_probability_calibration("binary")
24-
Condition
25-
Error in `adjust_probability_calibration()`:
26-
! A regression tailor is incompatible with the adjustment `adjust_probability_calibration()`.
27-
28-
---
29-
30-
Code
31-
tailor("binary") %>% adjust_probability_calibration("linear")
23+
tailor() %>% adjust_probability_calibration("linear")
3224
Condition
3325
Error in `adjust_probability_calibration()`:
3426
! `method` must be one of "logistic", "multinomial", "beta", "isotonic", or "isotonic_boot", not "linear".

tests/testthat/_snaps/adjust-probability-threshold.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Message
66
77
-- tailor ----------------------------------------------------------------------
8-
A postprocessor with 1 adjustment:
8+
A binary postprocessor with 1 adjustment:
99
1010
* Adjust probability threshold to 0.5.
1111

@@ -16,7 +16,7 @@
1616
Message
1717
1818
-- tailor ----------------------------------------------------------------------
19-
A postprocessor with 1 adjustment:
19+
A binary postprocessor with 1 adjustment:
2020
2121
* Adjust probability threshold to optimized value.
2222

tests/testthat/_snaps/tailor.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,16 @@
1010
---
1111

1212
Code
13-
tailor(type = "binary")
13+
tailor()
1414
Message
1515
1616
-- tailor ----------------------------------------------------------------------
17-
A binary postprocessor with 0 adjustments.
17+
A postprocessor with 0 adjustments.
1818

1919
---
2020

2121
Code
22-
tailor(type = "binary") %>% adjust_probability_threshold(0.2)
22+
tailor() %>% adjust_probability_threshold(0.2)
2323
Message
2424
2525
-- tailor ----------------------------------------------------------------------
@@ -30,8 +30,7 @@
3030
---
3131

3232
Code
33-
tailor(type = "binary") %>% adjust_probability_threshold(0.2) %>%
34-
adjust_equivocal_zone()
33+
tailor() %>% adjust_probability_threshold(0.2) %>% adjust_equivocal_zone()
3534
Message
3635
3736
-- tailor ----------------------------------------------------------------------

0 commit comments

Comments
 (0)