Skip to content

Commit 6677c46

Browse files
authored
TLC for loop_over_all_stages() (#1142)
* first version of loop doc * align pattern for grid to pattern for workflow have `current_*` be the object that gets updated across the stages * anchor naming more in context * don't use `all_` prefix unless it's across all iterations * add note to point to `loop.qmd`
1 parent 8a4ea3a commit 6677c46

File tree

4 files changed

+228
-29
lines changed

4 files changed

+228
-29
lines changed

R/loop_over_all_stages-helpers.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ process_prediction_data <- function(wflow_fit, static) {
338338
# ------------------------------------------------------------------------------
339339
# Misc functions
340340

341-
rebind_grid <- function(...) {
341+
extend_grid <- function(...) {
342342
list(...) |> purrr::map(remove_stage) |> purrr::list_cbind()
343343
}
344344

R/loop_over_all_stages.R

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Notes for easier reading are in `inst/loop.qmd`.
12
# Notes on debugging:
23
# 1. You can set `options(future.debug = TRUE)` to help
34
# 2. If you are debugging .loop_over_all_stages, use the control option
@@ -52,7 +53,8 @@
5253

5354
for (iter_pre in seq_len(num_iterations_pre)) {
5455
current_sched_pre <- sched[iter_pre, ]
55-
0
56+
current_grid <- remove_stage(current_sched_pre)
57+
5658
location <- glue::glue("preprocessor {iter_pre}/{num_iterations_pre}")
5759

5860
# Note: finalize_fit_pre() will process the data used for modeling. We'll
@@ -98,12 +100,15 @@
98100
# values currently are tune()
99101
wflow_with_fitted_pre <- current_wflow
100102

103+
grid_with_pre <- current_grid
104+
101105
if (is_failure(pred_data)) {
102106
next
103107
}
104108

105109
for (iter_model in seq_len(num_iterations_model)) {
106110
current_sched_model <- current_sched_pre$model_stage[[1]][iter_model, ]
111+
current_grid <- extend_grid(grid_with_pre, current_sched_model)
107112

108113
# Splice in any parameters marked for tuning and fit the model
109114
location <- glue::glue(
@@ -122,8 +127,6 @@
122127
next
123128
}
124129

125-
current_grid <- rebind_grid(current_sched_pre, current_sched_model)
126-
127130
has_submodel <- has_sub_param(current_sched_model$predict_stage[[1]])
128131
num_iterations_pred <- max(
129132
nrow(current_sched_model$predict_stage[[1]]),
@@ -136,25 +139,37 @@
136139

137140
if (has_submodel) {
138141
# Collect all submodel values and predict once
139-
all_sub_sched <- current_sched_model$predict_stage[[1]]
140-
sub_nm <- get_sub_param(all_sub_sched)
141-
all_sub_grid <- all_sub_sched[, sub_nm, drop = FALSE]
142+
sched_pred_all_submodels <- current_sched_model$predict_stage[[1]]
143+
sub_nm <- get_sub_param(sched_pred_all_submodels)
144+
grid_pred_all_submodels <- sched_pred_all_submodels[,
145+
sub_nm,
146+
drop = FALSE
147+
]
148+
149+
# Submodel parameters will be added in the predict stage
150+
grid_with_pre_model <- current_grid |>
151+
dplyr::select(-dplyr::all_of(sub_nm))
142152

143153
location <- glue::glue(
144154
"preprocessor {iter_pre}/{num_iterations_pre}, model {iter_model}/{num_iterations_model} (predictions)"
145155
)
146-
all_submodel_pred <- .catch_and_log(
147-
predict_all_types(current_wflow, pred_data, static, all_sub_grid),
156+
pred_all_submodels <- .catch_and_log(
157+
predict_all_types(
158+
current_wflow,
159+
pred_data,
160+
static,
161+
grid_pred_all_submodels
162+
),
148163
control = static$control,
149164
split_labels = split_labs,
150165
location = location,
151166
notes = notes
152167
)
153168

154-
if (is_failure(all_submodel_pred)) {
169+
if (is_failure(pred_all_submodels)) {
155170
next
156171
}
157-
all_submodel_pred <- remove_log_notes(all_submodel_pred)
172+
pred_all_submodels <- remove_log_notes(pred_all_submodels)
158173
}
159174

160175
for (iter_pred in seq_len(num_iterations_pred)) {
@@ -166,14 +181,11 @@
166181
sub_nm <- get_sub_param(current_sched_pred)
167182
sub_val <- current_sched_pred[[sub_nm]]
168183

169-
# The assigned submodel parameter (from min_grid()) is in the
170-
# current grid. Remove that and add the one that we are predicting on
171-
current_grid <- current_grid |>
172-
dplyr::select(-dplyr::all_of(sub_nm)) |>
173-
rebind_grid(current_sched_pred)
184+
# Add submodel param to grid
185+
current_grid <- extend_grid(grid_with_pre_model, current_sched_pred)
174186

175187
# Filter to this submodel's predictions (already computed above)
176-
current_pred <- all_submodel_pred |>
188+
current_pred <- pred_all_submodels |>
177189
dplyr::filter(.data[[sub_nm]] == sub_val) |>
178190
dplyr::select(-dplyr::all_of(sub_nm))
179191
} else {
@@ -205,16 +217,15 @@
205217
# values currently are tune()
206218
wflow_with_fitted_pre_and_model <- current_wflow
207219

208-
current_predict_grid <- current_grid
220+
grid_with_pre_model_pred <- current_grid
209221

210222
for (iter_post in seq_len(num_iterations_post)) {
211223
if (has_post) {
212224
current_sched_post <-
213225
current_sched_pred$post_stage[[1]][iter_post, ]
214-
post_grid <- current_sched_post
215226

216-
current_post_grid <- rebind_grid(
217-
current_predict_grid,
227+
current_grid <- extend_grid(
228+
grid_with_pre_model_pred,
218229
current_sched_post
219230
)
220231

@@ -236,7 +247,7 @@
236247
finalize_fit_post(
237248
wflow_with_fitted_pre_and_model,
238249
data_calibration = tailor_train_data,
239-
grid = post_grid
250+
grid = current_sched_post
240251
),
241252
control = static$control,
242253
split_labels = split_labs,
@@ -262,13 +273,10 @@
262273
next
263274
}
264275

265-
final_pred <- dplyr::bind_cols(post_pred, current_post_grid)
266-
current_extract_grid <- current_post_grid
267-
# end submodels
276+
final_pred <- dplyr::bind_cols(post_pred, current_grid)
268277
} else {
269278
# No postprocessor so just use what we have
270-
final_pred <- dplyr::bind_cols(current_pred, current_predict_grid)
271-
current_extract_grid <- current_predict_grid
279+
final_pred <- dplyr::bind_cols(current_pred, current_grid)
272280
}
273281

274282
current_wflow <- workflows::.fit_finalize(current_wflow)
@@ -298,7 +306,7 @@
298306
extracts <- tibble::tibble(.extracts = list(1))
299307
if (nrow(static$param_info) > 0) {
300308
extracts <- tibble::add_column(
301-
current_extract_grid,
309+
current_grid,
302310
.extracts = list(1)
303311
)
304312
}
@@ -309,7 +317,7 @@
309317
extracts <- tibble::add_row(
310318
extracts,
311319
tibble::add_column(
312-
current_extract_grid,
320+
current_grid,
313321
.extracts = list(elt_extract)
314322
)
315323
)

inst/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
/.quarto/
2+
**/*.quarto_ipynb

inst/loop.qmd

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
---
2+
title: Loop over all stages aka loopy
3+
---
4+
5+
We have a big ol loop at the heart of `tune_grid()` (and thus the rest of the tuning/resampling functions). That can be a lot to hold in your head at once, so here's a memory aid.
6+
7+
## Overview
8+
9+
We need to work our way through a whole lot of resamples and tuning parameter combinations. We could take the whole grid of tuning parameters, splice them into the workflow, and fit all those workflows on all resamples. This repeats potentially costly calculations, so we have carefully crafted the loop to avoid redundant computation.
10+
11+
The loop runs over a single resample. The basic structure is:
12+
13+
- For each preproc parameter combination
14+
- Fit preprocessor
15+
- Apply it to analysis set
16+
- Apply it to the assessment set
17+
- For each (non-sub) model parameter combination
18+
- Fit the model
19+
- Predict for all submodel parameters at once (via multi_predict)
20+
- For each submodel parameter value
21+
- Filter to this submodel's predictions
22+
- For each post parameter combination
23+
- Fit the postprocessor
24+
- Apply the postprocessor to the predictions (on the assessment set)
25+
- Combine (post-processed) predictions with the grid to `final_pred`
26+
- Save `final_pred` by appending it to `pred_reserve`
27+
- Do the extracts
28+
- Compute the metrics
29+
30+
## Inputs
31+
32+
### The schedule
33+
34+
The schedule is a nested tibble created by `schedule_grid()` that organizes tuning parameters by stage. Each stage contains the next stage in a list-column:
35+
36+
```
37+
sched
38+
├── [preproc params]
39+
└── model_stage (list-col)
40+
└── tibble
41+
├── [model params]
42+
└── predict_stage (list-col)
43+
└── tibble
44+
├── [submodel params]
45+
└── post_stage (list-col)
46+
└── tibble
47+
└── [post params]
48+
```
49+
50+
The four stages are:
51+
52+
- **pre**: preprocessing via recipes
53+
- **model**: the model fit via parsnip (submodel parameters are collapsed via `min_grid()`)
54+
- **predict**: prediction (submodel parameters are expanded)
55+
- **post**: postprocessing via tailor
56+
57+
To access the next stage, extract with `[[1]]`:
58+
59+
- `current_sched_pre$model_stage[[1]]` → tibble of model param combinations
60+
- `current_sched_model$predict_stage[[1]]` → tibble of submodel param values
61+
- `current_sched_pred$post_stage[[1]]` → tibble of post param combinations
62+
63+
### The `static` object
64+
65+
The `static` list contains everything that stays constant throughout the loop:
66+
67+
- `wflow` - the original workflow (template for finalization)
68+
- `param_info` - parameter set info from `tune_args()`
69+
- `configs` - tibble mapping parameter values to `.config` labels
70+
- `metrics` - the metric set
71+
- `pred_types` - prediction types needed (e.g., "class", "prob", "numeric")
72+
- `eval_time` - evaluation times for survival models
73+
- `control` - control options
74+
- `data` - list with `fit`, `pred`, and `cal` data partitions (added after setup)
75+
- `y_name` - outcome column name(s)
76+
77+
### Data partitions
78+
79+
The `static$data` list contains three partitions (set up once per resample):
80+
81+
- `fit` - training data for preprocessor and model
82+
- `pred` - assessment data for predictions (used to compute metrics)
83+
- `cal` - calibration data for postprocessors that need fitting (e.g., probability calibration)
84+
85+
When there's no postprocessor requiring calibration, `cal` is NULL and `fit` uses the full analysis set. When calibration is needed, the analysis set is further split into `fit` and `cal`.
86+
87+
## Reading the code
88+
89+
### Naming conventions
90+
91+
**Loop variables:**
92+
93+
- `iter_{stage}` - iteration counter (e.g., `iter_pre`, `iter_model`, `iter_pred`, `iter_post`)
94+
- `num_iterations_{stage}` - total iterations for that stage
95+
96+
**Schedule objects:**
97+
98+
- `sched` - the full nested schedule tibble
99+
- `current_sched_{stage}` - current row of the schedule at each stage
100+
101+
**Grid objects:**
102+
103+
- `current_grid` - progressively accumulates tuning params as we descend into loops
104+
- `grid_with_pre` - snapshot with pre params (before model loop)
105+
- `grid_with_pre_model` - snapshot with pre + model params, without submodel col (before pred loop)
106+
- `grid_with_pre_model_pred` - snapshot with pre + model + pred params (before post loop)
107+
108+
**Workflow snapshots** (saved to allow re-finalization in inner loops):
109+
110+
- `current_wflow` - the workflow being modified
111+
- `wflow_with_fitted_pre` - snapshot after fitting preprocessor
112+
- `wflow_with_fitted_pre_and_model` - snapshot after fitting model
113+
114+
**Prediction objects:**
115+
116+
- `pred_data` - processed prediction data (features + outcomes)
117+
- `pred_all_submodels` - batched predictions for all submodel values (source for filtering)
118+
- `current_pred` - predictions for current pred iteration (filtered from `pred_all_submodels`)
119+
- `final_pred` - predictions after postprocessing (ready to save)
120+
- `pred_reserve` - accumulator for all final predictions
121+
122+
**General conventions:**
123+
124+
- `static` - things that don't change during the loop
125+
- `current_*` - value for the current iteration
126+
- `*_all_submodels` - batched values for all submodel params (e.g., `pred_all_submodels`)
127+
- Stage suffixes: `_pre`, `_model`, `_pred`, `_post`
128+
129+
### Key helper functions
130+
131+
**Finalization** (splice tuning params into workflow and fit):
132+
133+
- `finalize_fit_pre()` - finalize recipe params, fit preprocessor
134+
- `finalize_fit_model()` - finalize model params, fit model
135+
- `finalize_fit_post()` - finalize tailor params, fit postprocessor
136+
137+
**Grid helpers:**
138+
139+
- `extend_grid()` - extend a grid with params from a schedule row (strips `*_stage` columns)
140+
- `remove_stage()` - remove nested stage columns from a schedule row
141+
142+
**Prediction:**
143+
144+
- `process_prediction_data()` - apply fitted preprocessor to assessment data
145+
- `predict_all_types()` - generate all needed prediction types
146+
147+
### Error handling
148+
149+
The loop uses a consistent error handling pattern:
150+
151+
```r
152+
result <- .catch_and_log(
153+
some_operation(),
154+
control = static$control,
155+
split_labels = split_labs,
156+
location = location,
157+
notes = notes
158+
)
159+
160+
if (is_failure(result)) {
161+
next
162+
}
163+
result <- remove_log_notes(result)
164+
```
165+
166+
- `.catch_and_log()` wraps operations to capture errors/warnings without stopping
167+
- `is_failure()` checks if the operation failed
168+
- `next` skips to the next iteration (the failed config won't have results)
169+
- `remove_log_notes()` strips logging metadata from successful results
170+
171+
## Efficiency
172+
173+
The nested structure avoids redundant computation:
174+
175+
| What | Computed | Reused for |
176+
|------|----------|------------|
177+
| Preprocessor fit | Once per preproc param combo | All model params below it |
178+
| Processed prediction data | Once per preproc param combo | All predictions below it |
179+
| Model fit | Once per model param combo | All submodel predictions |
180+
| Submodel predictions | Once per model (batched) | All submodel × post combos |
181+
| Postprocessor fit | Once per post param combo | That specific config |
182+
183+
Submodel parameters (like `penalty` in glmnet) are predicted all at once using `multi_predict()`, which is much faster than predicting one at a time.
184+
185+
## Debugging
186+
187+
Set `control = control_grid(allow_par = FALSE)` to run sequentially with `lapply()` so you can see output and use `browser()`.
188+
189+
For parallel debugging, `options(future.debug = TRUE)` can help.

0 commit comments

Comments
 (0)