Skip to content

Commit 776c2f4

Browse files
authored
implement tune_args() and tunable() (#51)
* implement `tune_args()` and `tunable()` * remove redundant method * namespace fn in test * more machinery from recipes * add vctrs to Imports * test `find_tune_id()` * add snapshot * address `vec_rbind()` internal error re: bad recycling * update for new object structure * add `extract_parameter_set_dials()` method * migrate check helper from workflows * generate snaps
1 parent 317a4db commit 776c2f4

23 files changed

+517
-16
lines changed

DESCRIPTION

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,15 @@ Imports:
2727
purrr,
2828
rlang (>= 1.1.0),
2929
tibble,
30-
tidyselect
30+
tidyselect,
31+
vctrs
3132
Suggests:
33+
dials,
3234
modeldata,
3335
testthat (>= 3.0.0),
3436
workflows
3537
Remotes:
38+
tidymodels/dials#358,
3639
tidymodels/probably,
3740
tidymodels/workflows
3841
Config/testthat/edition: 3

NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Generated by roxygen2: do not edit by hand
22

3+
S3method(extract_parameter_dials,tailor)
4+
S3method(extract_parameter_set_dials,tailor)
35
S3method(fit,equivocal_zone)
46
S3method(fit,numeric_calibration)
57
S3method(fit,numeric_range)
@@ -33,6 +35,9 @@ S3method(tunable,numeric_range)
3335
S3method(tunable,predictions_custom)
3436
S3method(tunable,probability_calibration)
3537
S3method(tunable,probability_threshold)
38+
S3method(tunable,tailor)
39+
S3method(tune_args,adjustment)
40+
S3method(tune_args,tailor)
3641
export("%>%")
3742
export(adjust_equivocal_zone)
3843
export(adjust_numeric_calibration)

R/adjust-equivocal-zone.R

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,16 +128,15 @@ required_pkgs.equivocal_zone <- function(x, ...) {
128128

129129
#' @export
130130
tunable.equivocal_zone <- function(x, ...) {
131-
tibble::new_tibble(list(
131+
tibble::tibble(
132132
name = "buffer",
133133
call_info = list(list(pkg = "dials", fun = "buffer")),
134134
source = "tailor",
135135
component = "equivocal_zone",
136136
component_id = "equivocal_zone"
137-
))
137+
)
138138
}
139139

140140
# todo missing methods:
141-
# todo tune_args
142141
# todo tidy
143142
# todo extract_parameter_set_dials

R/adjust-numeric-calibration.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,5 @@ tunable.numeric_calibration <- function(x, ...) {
126126
}
127127

128128
# todo missing methods:
129-
# todo tune_args
130129
# todo tidy
131130
# todo extract_parameter_set_dials

R/adjust-numeric-range.R

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ required_pkgs.numeric_range <- function(x, ...) {
129129

130130
#' @export
131131
tunable.numeric_range <- function(x, ...) {
132-
tibble::new_tibble(list(
132+
tibble::tibble(
133133
name = c("lower_limit", "upper_limit"),
134134
call_info = list(
135135
list(pkg = "dials", fun = "lower_limit"), # todo make these dials functions
@@ -138,10 +138,9 @@ tunable.numeric_range <- function(x, ...) {
138138
source = "tailor",
139139
component = "numeric_range",
140140
component_id = "numeric_range"
141-
))
141+
)
142142
}
143143

144144
# todo missing methods:
145-
# todo tune_args
146145
# todo tidy
147146
# todo extract_parameter_set_dials

R/adjust-predictions-custom.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,5 @@ tunable.predictions_custom <- function(x, ...) {
9292
}
9393

9494
# todo missing methods:
95-
# todo tune_args
9695
# todo tidy
9796
# todo extract_parameter_set_dials

R/adjust-probability-calibration.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,5 @@ tunable.probability_calibration <- function(x, ...) {
134134
}
135135

136136
# todo missing methods:
137-
# todo tune_args
138137
# todo tidy
139138
# todo extract_parameter_set_dials

R/adjust-probability-threshold.R

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,15 @@ required_pkgs.probability_threshold <- function(x, ...) {
113113

114114
#' @export
115115
tunable.probability_threshold <- function(x, ...) {
116-
tibble::new_tibble(list(
116+
tibble::tibble(
117117
name = "threshold",
118118
call_info = list(list(pkg = "dials", fun = "threshold")),
119119
source = "tailor",
120120
component = "probability_threshold",
121121
component_id = "probability_threshold"
122-
))
122+
)
123123
}
124124

125125
# todo missing methods:
126-
# todo tune_args
127126
# todo tidy
128127
# todo extract_parameter_set_dials

R/extract.R

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#' @export
2+
extract_parameter_set_dials.tailor <- function(x, ...) {
3+
all_args <- generics::tunable(x)
4+
tuning_param <- generics::tune_args(x)
5+
res <-
6+
dplyr::inner_join(
7+
tuning_param %>% dplyr::select(-tunable),
8+
all_args,
9+
by = c("name", "source", "component", "component_id")
10+
) %>%
11+
dplyr::mutate(object = purrr::map(call_info, eval_call_info))
12+
13+
dials::parameters_constr(
14+
res$name,
15+
res$id,
16+
res$source,
17+
res$component,
18+
res$component_id,
19+
res$object
20+
)
21+
}
22+
23+
eval_call_info <- function(x) {
24+
if (!is.null(x)) {
25+
# Look for other options
26+
allowed_opts <- c("range", "trans", "values")
27+
if (any(names(x) %in% allowed_opts)) {
28+
opts <- x[names(x) %in% allowed_opts]
29+
} else {
30+
opts <- list()
31+
}
32+
res <- try(rlang::eval_tidy(rlang::call2(x$fun, .ns = x$pkg, !!!opts)), silent = TRUE)
33+
if (inherits(res, "try-error")) {
34+
cli::cli_abort(
35+
"Error when calling {.fn {x$fun}}: {as.character(res)}"
36+
)
37+
}
38+
} else {
39+
res <- NA
40+
}
41+
res
42+
}
43+
44+
#' @export
45+
extract_parameter_dials.tailor <- function(x, parameter, ...) {
46+
extract_parameter_dials(extract_parameter_set_dials(x), parameter)
47+
}

R/tailor.R

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,42 @@ set_tailor_type <- function(object, y, call = caller_env()) {
235235
# todo: where to validate #levels?
236236
# todo setup eval_time
237237
# todo missing methods:
238-
# todo tune_args
238+
239+
#' @export
240+
tune_args.tailor <- function(object, full = FALSE, ...) {
241+
adjustments <- object$adjustments
242+
243+
if (length(adjustments) == 0L) {
244+
return(tune_tbl())
245+
}
246+
247+
res <- purrr::map(object$adjustments, tune_args, full = full)
248+
res <- purrr::list_rbind(res)
249+
250+
tune_tbl(
251+
res$name,
252+
res$tunable,
253+
res$id,
254+
res$source,
255+
res$component,
256+
res$component_id,
257+
full = full
258+
)
259+
}
260+
261+
#' @export
262+
tunable.tailor <- function(x, ...) {
263+
if (length(x$adjustments) == 0) {
264+
res <- no_param
265+
} else {
266+
res <- purrr::map(x$adjustments, tunable)
267+
res <- vctrs::vec_rbind(!!!res)
268+
if (nrow(res) > 0) {
269+
res <- res[!is.na(res$name), ]
270+
}
271+
}
272+
res
273+
}
274+
239275
# todo tidy (this should probably just be `adjustment_orderings()`)
240276
# todo extract_parameter_set_dials

0 commit comments

Comments
 (0)