Skip to content

Commit f4d2d16

Browse files
Tailor support (#103)
1 parent fcff0d8 commit f4d2d16

19 files changed

+1171
-30
lines changed

DESCRIPTION

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,14 @@ Suggests:
3737
modeldata,
3838
parsnip,
3939
partykit,
40+
probably,
4041
R6,
4142
recipes,
4243
rmarkdown,
4344
RSQLite,
4445
rstanarm,
4546
sparklyr,
47+
tailor,
4648
testthat (>= 3.0.0),
4749
themis,
4850
tibble,

NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33
S3method(augment,orbital_class)
44
S3method(orbital,constparty)
55
S3method(orbital,default)
6+
S3method(orbital,equivocal_zone)
67
S3method(orbital,glm)
78
S3method(orbital,last_fit)
89
S3method(orbital,model_fit)
910
S3method(orbital,model_spec)
11+
S3method(orbital,numeric_range)
12+
S3method(orbital,predictions_custom)
13+
S3method(orbital,probability_threshold)
1014
S3method(orbital,recipe)
1115
S3method(orbital,step_BoxCox)
1216
S3method(orbital,step_adasyn)
@@ -54,6 +58,7 @@ S3method(orbital,step_tomek)
5458
S3method(orbital,step_unknown)
5559
S3method(orbital,step_upsample)
5660
S3method(orbital,step_zv)
61+
S3method(orbital,tailor)
5762
S3method(orbital,workflow)
5863
S3method(orbital,xgb.Booster)
5964
S3method(predict,orbital_class)

NEWS.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# orbital (development version)
22

3+
* Added support for tailor package and its integration into workflows. The following adjustments have gained `orbital()` support. (#103)
4+
- `adjust_equivocal_zone()`
5+
- `adjust_numeric_range()`
6+
- `adjust_predictions_custom()`
7+
- `adjust_probability_threshold()`
8+
39
# orbital 0.3.1
410

511
* Fixed bug where PCA steps didn't work if they were trained with more than 99 predictors. (#82)

R/adjust_equivocal_zone.R

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#' @export
2+
orbital.equivocal_zone <- function(x, tailor, type, prefix, ...) {
3+
if (!rlang::is_missing(type) && !(all(c("prob", "class") %in% type))) {
4+
cli::cli_abort(c(
5+
x = "{.arg type} must contain {.val prob} and {.val class} to work with
6+
{.fn adjust_equivocal_zone}."
7+
))
8+
}
9+
10+
input <- x$arguments
11+
12+
out_name <- tailor$columns$estimate
13+
prob_name <- tailor$columns$probabilities[[1]]
14+
15+
levels <- gsub("^\\.pred_", "", tailor$columns$probabilities)
16+
17+
if (prefix != "prefix") {
18+
out_name <- gsub("^\\.pred", prefix, out_name)
19+
prob_name <- gsub("^\\.pred", prefix, prob_name)
20+
}
21+
22+
out <- glue::glue(
23+
"dplyr::case_when(
24+
{prob_name} > {input$threshold} + {input$value} ~ '{levels[1]}',
25+
{prob_name} < {input$threshold} - {input$value} ~ '{levels[2]}',
26+
.default = '[EQ]'
27+
)"
28+
)
29+
names(out) <- out_name
30+
out
31+
}

R/adjust_numeric_range.R

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#' @export
2+
orbital.numeric_range <- function(x, tailor, ...) {
3+
lower <- x$arguments$lower_limit
4+
upper <- x$arguments$upper_limit
5+
6+
estimate <- tailor$columns$estimate
7+
8+
if (!is.finite(lower) && !is.finite(upper)) {
9+
return(NULL)
10+
}
11+
12+
out <- "dplyr::case_when("
13+
14+
if (is.finite(lower)) {
15+
out <- paste0(out, "{estimate} < {lower} ~ {lower},")
16+
}
17+
if (is.finite(upper)) {
18+
out <- paste0(out, "{estimate} > {upper} ~ {upper},")
19+
}
20+
21+
out <- paste0(out, "TRUE ~ {estimate})")
22+
out <- glue::glue(out)
23+
names(out) <- estimate
24+
out
25+
}

R/adjust_predictions_custom.R

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#' @export
2+
orbital.predictions_custom <- function(x, ...) {
3+
input <- x$arguments$commands
4+
5+
if (length(input) == 0) {
6+
return(NULL)
7+
}
8+
9+
out <- vapply(input, rlang::as_label, character(1))
10+
out
11+
}

R/adjust_probability_threshold.R

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#' @export
2+
orbital.probability_threshold <- function(x, tailor, type, prefix, ...) {
3+
if (!rlang::is_missing(type) && !(all(c("prob", "class") %in% type))) {
4+
cli::cli_abort(c(
5+
x = "{.arg type} must contain {.val prob} and {.val class} to work with
6+
{.fn adjust_equivocal_zone}."
7+
))
8+
}
9+
10+
input <- x$arguments
11+
12+
prob_name <- tailor$columns$probabilities[[1]]
13+
14+
levels <- gsub("^\\.pred_", "", tailor$columns$probabilities)
15+
16+
out_name <- paste0(prefix, "_class")
17+
18+
if (prefix != "prefix") {
19+
prob_name <- gsub("^\\.pred", prefix, prob_name)
20+
}
21+
22+
out <- glue::glue(
23+
"dplyr::case_when(
24+
{prob_name} > {input$threshold} ~ '{levels[1]}',
25+
{prob_name} < {input$threshold} ~ '{levels[2]}',
26+
.default = '[EQ]'
27+
)"
28+
)
29+
names(out) <- out_name
30+
out
31+
}

R/tailor.R

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#' @export
2+
orbital.tailor <- function(x, ...) {
3+
out <- character()
4+
5+
for (adj in x$adjustments) {
6+
new <- orbital(adj, tailor = x, ...)
7+
out <- c(out, new)
8+
}
9+
10+
new_orbital_class(out)
11+
}

R/workflows.R

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,21 @@ orbital.workflow <- function(x, ..., prefix = ".pred", type = NULL) {
44
cli::cli_abort("{.arg x} must be a fully trained {.cls workflow}.")
55
}
66

7-
if (length(x$post$actions) != 0) {
8-
cli::cli_abort("post-processing is not yet supported in orbital.")
7+
out <- character()
8+
if ("tailor" %in% names(x$post$actions)) {
9+
tailor_fit <- workflows::extract_tailor(x)
10+
post <- orbital(tailor_fit, prefix = prefix, type = type)
11+
out <- post
912
}
1013

1114
model_fit <- workflows::extract_fit_parsnip(x)
12-
out <- orbital(model_fit, prefix = prefix, type = type)
15+
mod <- orbital(model_fit, prefix = prefix, type = type)
16+
mod_atr <- attributes(mod)
17+
mod_atr$names <- c(mod_atr$names, names(out))
18+
mod_cls <- class(mod)
19+
out <- c(mod, out)
20+
attributes(out) <- mod_atr
21+
class(out) <- mod_cls
1322

1423
preprocessor <- workflows::extract_preprocessor(x)
1524

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# adjust_equivocal_zone errors if types aren't set
2+
3+
Code
4+
orbital(wf_fit)
5+
Condition
6+
Error in `orbital()`:
7+
x `type` must contain "prob" and "class" to work with `adjust_equivocal_zone()`.
8+
9+
---
10+
11+
Code
12+
orbital(wf_fit, type = "prob")
13+
Condition
14+
Error in `orbital()`:
15+
x `type` must contain "prob" and "class" to work with `adjust_equivocal_zone()`.
16+
17+
---
18+
19+
Code
20+
orbital(wf_fit, type = "class")
21+
Condition
22+
Error in `orbital()`:
23+
x `type` must contain "prob" and "class" to work with `adjust_equivocal_zone()`.
24+

0 commit comments

Comments
 (0)