Skip to content

Commit 2000f2d

Browse files
committed
use an inner split when training calibrators
1 parent 9a3e42c commit 2000f2d

File tree

4 files changed

+183
-11
lines changed

4 files changed

+183
-11
lines changed

R/fit.R

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,34 @@ fit.workflow <- function(object, data, ..., control = control_workflow()) {
5959
abort("`data` must be provided to fit a workflow.")
6060
}
6161

62+
# If `potato` is not overwritten in the following `if` statement, then the
63+
# the postprocessor doesn't actually require training and the dataset
64+
# passed to `.fit_post()` will have no effect.
65+
potato <- data
6266
if (should_inner_split(object)) {
63-
# todo: make an inner_split here
64-
TRUE
67+
validate_rsample_available()
68+
69+
mocked_split <-
70+
rsample::make_splits(
71+
list(analysis = seq_len(nrow(data)), assessment = integer()),
72+
data = data,
73+
class = object$post$actions$tailor$method %||% "mc_split"
74+
)
75+
76+
inner_split <- rsample::inner_split(
77+
mocked_split,
78+
list(prop = object$post$actions$tailor$prop %||% 2/3)
79+
)
80+
81+
data <- rsample::analysis(inner_split)
82+
potato <- rsample::assessment(inner_split)
6583
}
6684

6785
workflow <- object
6886
workflow <- .fit_pre(workflow, data)
6987
workflow <- .fit_model(workflow, control)
7088
if (has_postprocessor(workflow)) {
71-
workflow <- .fit_post(workflow, data)
89+
workflow <- .fit_post(workflow, potato)
7290
}
7391
workflow <- .fit_finalize(workflow)
7492

R/post-action-tailor.R

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,70 @@
1717
#' should not have been trained already with [tailor::fit()]; workflows
1818
#' will handle training internally.
1919
#'
20+
#' @param prop The proportion of the data in [fit.workflow()] that should be
21+
#' held back specifically for estimating the postprocessor. Only relevant for
22+
#' postprocessors that require estimation---see section Data Usage below to
23+
#' learn more. Defaults to 2/3.
24+
#'
25+
#' @param method The method with which to split the data in [fit.workflow()],
26+
#' as a character vector. Only relevant for postprocessors that
27+
#' require estimation and not required when resampling the workflow with
28+
#' tune. If `fit.workflow(data)` arose as `training(split_object)`, this argument can
29+
#' usually be supplied as `class(split_object)`. Defaults to `"mc_split"`, which
30+
#' randomly samples `fit.workflow(data)` into two sets, similarly to
31+
#' [rsample::initial_split()]. See section Data Usage below to learn more.
32+
#'
33+
#' @section Data Usage:
34+
#'
35+
#' While preprocessors and models are trained on data in the usual sense,
36+
#' postprocessors are training on _predictions_ on data. When a workflow
37+
#' is fitted, the user supplies training data with the `data` argument.
38+
#' When workflows don't contain a postprocessor that requires training,
39+
#' they can use all of the supplied `data` to train the preprocessor and model.
40+
#' However, in the case where a postprocessor must be trained as well,
41+
#' training the preprocessor and model on all of `data` would leave no data
42+
#' left to train the postprocessor with---if that were the case, workflows
43+
#' would need to `predict()` from the preprocessor and model on the same `data`
44+
#' that they were trained on, with the postprocessor then training on those
45+
#' predictions. Predictions on data that a model was trained on likely follow
46+
#' different distributions than predictions on unseen data; thus, workflows must
47+
#' split up the supplied `data` into two training sets, where the first is used to
48+
#' train the preprocessor and model and the second is passed to that trained
49+
#' processor and model to generate predictions, which then form the training data
50+
#' for the post-processor.
51+
#'
52+
#' The arguments `prop` and `method` parameterize how that data is split up.
53+
#' `prop` determines the proportion of rows in `fit.workflow(data)` that are
54+
#' allotted to training the preprocessor and model, while the rest are used to
55+
#' train the postprocessor. `method` determines how that split occurs; since
56+
#' `fit.workflow()` just takes in a data frame, the function doesn't have
57+
#' any information on how that dataset came to be. For example, `data` could
58+
#' have been created as:
59+
#'
60+
#' ```
61+
#' split <- rsample::initial_split(some_other_data)
62+
#' data <- rsample::training(split)
63+
#' ```
64+
#'
65+
#' ...in which case it's okay to randomly allot some rows of `data` to train the
66+
#' preprocessor and model and the rest to train the postprocessor. However,
67+
#' `data` could also have arisen as:
68+
#'
69+
#' ```
70+
#' boots <- rsample::bootstraps(some_other_data)
71+
#' split <- rsample::get_rsplit(boots, 1)
72+
#' data <- rsample::assessment(split)
73+
#' ```
74+
#'
75+
#' In this case, some of the rows in `data` will be duplicated. Thus, randomly
76+
#' allotting some of them to train the preprocessor and model and others to train
77+
#' the preprocessor would likely result in the same rows appearing in both
78+
#' datasets, resulting in the preprocessor and model generating predictions on
79+
#' rows they've seen before. Similarly problematic situations could arise in the
80+
#' context of other resampling situations, like time-based splits.
81+
#' The `method` argument ensures that data is allotted properly (and is
82+
#' internally handled by the tune package when resampling workflows).
83+
#'
2084
#' @param ... Not used.
2185
#'
2286
#' @return
@@ -38,10 +102,10 @@
38102
#' remove_tailor(workflow)
39103
#'
40104
#' update_tailor(workflow, adjust_probability_threshold(tailor, .2))
41-
add_tailor <- function(x, tailor, ...) {
105+
add_tailor <- function(x, tailor, prop = NULL, method = NULL, ...) {
42106
check_dots_empty()
43107
validate_tailor_available()
44-
action <- new_action_tailor(tailor)
108+
action <- new_action_tailor(tailor, prop = prop, method = method)
45109
res <- add_action(x, action, "tailor")
46110
if (should_inner_split(res)) {
47111
validate_rsample_available()
@@ -130,7 +194,7 @@ check_conflicts.action_tailor <- function(action, x, ..., call = caller_env()) {
130194

131195
# ------------------------------------------------------------------------------
132196

133-
new_action_tailor <- function(tailor, ..., call = caller_env()) {
197+
new_action_tailor <- function(tailor, prop, method, ..., call = caller_env()) {
134198
check_dots_empty()
135199

136200
if (!is_tailor(tailor)) {
@@ -142,8 +206,17 @@ new_action_tailor <- function(tailor, ..., call = caller_env()) {
142206
abort("Can't add a trained tailor to a workflow.", call = call)
143207
}
144208

209+
if (!is.null(prop) &&
210+
(!rlang::is_double(prop, n = 1) || prop <= 0 || prop >= 1)) {
211+
abort("`prop` must be a numeric on (0, 1).", call = call)
212+
}
213+
214+
# todo: test method
215+
145216
new_action_post(
146217
tailor = tailor,
218+
prop = prop,
219+
method = method,
147220
subclass = "action_tailor"
148221
)
149222
}

man/add_tailor.Rd

Lines changed: 65 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-post-action-tailor.R

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ test_that("postprocessor fit aligns with manually fitted version (no calibration
109109

110110
test_that("postprocessor fit aligns with manually fitted version (with calibration)", {
111111
skip_if_not_installed("modeldata")
112+
skip_if_not_installed("rsample")
112113

113114
# create example data
114115
y <- seq(0, 7, .1)
@@ -122,15 +123,31 @@ test_that("postprocessor fit aligns with manually fitted version (with calibrati
122123
wflow_post <- add_tailor(wflow_simple, post)
123124

124125
# train workflow
125-
wf_simple_fit <- fit(wflow_simple, dat)
126+
127+
# first, separate out the same data that workflows ought to internally
128+
# when training with a postprocessor that needs estimation
129+
mocked_split <-
130+
rsample::make_splits(
131+
list(analysis = seq_len(nrow(dat)), assessment = integer()),
132+
data = dat,
133+
class = "mc_split"
134+
)
135+
set.seed(1)
136+
inner_split <- rsample::inner_split(mocked_split, list(prop = 2/3))
137+
138+
wf_simple_fit <- fit(wflow_simple, rsample::analysis(inner_split))
139+
140+
# the following fit will do all of this internally
141+
set.seed(1)
126142
wf_post_fit <- fit(wflow_post, dat)
127143

128-
# ...verify predictions are the same as training the post-proc separately
129-
wflow_simple_preds <- augment(wf_simple_fit, dat)
144+
# ...verify predictions are the same as training the post-proc separately.
145+
# note that this test naughtily re-predicts on the potato set.
146+
wflow_simple_preds <- augment(wf_simple_fit, rsample::assessment(inner_split))
130147
post_trained <- fit(post, wflow_simple_preds, y, .pred)
131148
wflow_manual_preds <- predict(post_trained, wflow_simple_preds)
132149

133-
wflow_post_preds <- predict(wf_post_fit, dat)
150+
wflow_post_preds <- predict(wf_post_fit, rsample::assessment(inner_split))
134151

135152
expect_equal(wflow_manual_preds[".pred"], wflow_post_preds)
136153
expect_false(all(wflow_simple_preds[".pred"] == wflow_manual_preds[".pred"]))

0 commit comments

Comments
 (0)