|
| 1 | +--- |
| 2 | +title: Loop over all stages aka loopy |
| 3 | +--- |
| 4 | + |
| 5 | +We have a big ol loop at the heart of `tune_grid()` (and thus the rest of the tuning/resampling functions). That can be a lot to hold in your head at once, so here's a memory aid. |
| 6 | + |
| 7 | +## Overview |
| 8 | + |
| 9 | +We need to work our way through a whole lot of resamples and tuning parameter combinations. We could take the whole grid of tuning parameters, splice them into the workflow, and fit all those workflows on all resamples. This repeats potentially costly calculations, so we have carefully crafted the loop to avoid redundant computation. |
| 10 | + |
| 11 | +The loop runs over a single resample. The basic structure is: |
| 12 | + |
| 13 | +- For each preproc parameter combination |
| 14 | + - Fit preprocessor |
| 15 | + - Apply it to analysis set |
| 16 | + - Apply it to the assessment set |
| 17 | + - For each (non-sub) model parameter combination |
| 18 | + - Fit the model |
| 19 | + - Predict for all submodel parameters at once (via multi_predict) |
| 20 | + - For each submodel parameter value |
| 21 | + - Filter to this submodel's predictions |
| 22 | + - For each post parameter combination |
| 23 | + - Fit the postprocessor |
| 24 | + - Apply the postprocessor to the predictions (on the assessment set) |
| 25 | + - Combine (post-processed) predictions with the grid to `final_pred` |
| 26 | + - Save `final_pred` by appending it to `pred_reserve` |
| 27 | + - Do the extracts |
| 28 | +- Compute the metrics |
| 29 | + |
| 30 | +## Inputs |
| 31 | + |
| 32 | +### The schedule |
| 33 | + |
| 34 | +The schedule is a nested tibble created by `schedule_grid()` that organizes tuning parameters by stage. Each stage contains the next stage in a list-column: |
| 35 | + |
| 36 | +``` |
| 37 | +sched |
| 38 | +├── [preproc params] |
| 39 | +└── model_stage (list-col) |
| 40 | + └── tibble |
| 41 | + ├── [model params] |
| 42 | + └── predict_stage (list-col) |
| 43 | + └── tibble |
| 44 | + ├── [submodel params] |
| 45 | + └── post_stage (list-col) |
| 46 | + └── tibble |
| 47 | + └── [post params] |
| 48 | +``` |
| 49 | + |
| 50 | +The four stages are: |
| 51 | + |
| 52 | +- **pre**: preprocessing via recipes |
| 53 | +- **model**: the model fit via parsnip (submodel parameters are collapsed via `min_grid()`) |
| 54 | +- **predict**: prediction (submodel parameters are expanded) |
| 55 | +- **post**: postprocessing via tailor |
| 56 | + |
| 57 | +To access the next stage, extract with `[[1]]`: |
| 58 | + |
| 59 | +- `current_sched_pre$model_stage[[1]]` → tibble of model param combinations |
| 60 | +- `current_sched_model$predict_stage[[1]]` → tibble of submodel param values |
| 61 | +- `current_sched_pred$post_stage[[1]]` → tibble of post param combinations |
| 62 | + |
| 63 | +### The `static` object |
| 64 | + |
| 65 | +The `static` list contains everything that stays constant throughout the loop: |
| 66 | + |
| 67 | +- `wflow` - the original workflow (template for finalization) |
| 68 | +- `param_info` - parameter set info from `tune_args()` |
| 69 | +- `configs` - tibble mapping parameter values to `.config` labels |
| 70 | +- `metrics` - the metric set |
| 71 | +- `pred_types` - prediction types needed (e.g., "class", "prob", "numeric") |
| 72 | +- `eval_time` - evaluation times for survival models |
| 73 | +- `control` - control options |
| 74 | +- `data` - list with `fit`, `pred`, and `cal` data partitions (added after setup) |
| 75 | +- `y_name` - outcome column name(s) |
| 76 | + |
| 77 | +### Data partitions |
| 78 | + |
| 79 | +The `static$data` list contains three partitions (set up once per resample): |
| 80 | + |
| 81 | +- `fit` - training data for preprocessor and model |
| 82 | +- `pred` - assessment data for predictions (used to compute metrics) |
| 83 | +- `cal` - calibration data for postprocessors that need fitting (e.g., probability calibration) |
| 84 | + |
| 85 | +When there's no postprocessor requiring calibration, `cal` is NULL and `fit` uses the full analysis set. When calibration is needed, the analysis set is further split into `fit` and `cal`. |
| 86 | + |
| 87 | +## Reading the code |
| 88 | + |
| 89 | +### Naming conventions |
| 90 | + |
| 91 | +**Loop variables:** |
| 92 | + |
| 93 | +- `iter_{stage}` - iteration counter (e.g., `iter_pre`, `iter_model`, `iter_pred`, `iter_post`) |
| 94 | +- `num_iterations_{stage}` - total iterations for that stage |
| 95 | + |
| 96 | +**Schedule objects:** |
| 97 | + |
| 98 | +- `sched` - the full nested schedule tibble |
| 99 | +- `current_sched_{stage}` - current row of the schedule at each stage |
| 100 | + |
| 101 | +**Grid objects:** |
| 102 | + |
| 103 | +- `current_grid` - progressively accumulates tuning params as we descend into loops |
| 104 | +- `grid_with_pre` - snapshot with pre params (before model loop) |
| 105 | +- `grid_with_pre_model` - snapshot with pre + model params, without submodel col (before pred loop) |
| 106 | +- `grid_with_pre_model_pred` - snapshot with pre + model + pred params (before post loop) |
| 107 | + |
| 108 | +**Workflow snapshots** (saved to allow re-finalization in inner loops): |
| 109 | + |
| 110 | +- `current_wflow` - the workflow being modified |
| 111 | +- `wflow_with_fitted_pre` - snapshot after fitting preprocessor |
| 112 | +- `wflow_with_fitted_pre_and_model` - snapshot after fitting model |
| 113 | + |
| 114 | +**Prediction objects:** |
| 115 | + |
| 116 | +- `pred_data` - processed prediction data (features + outcomes) |
| 117 | +- `pred_all_submodels` - batched predictions for all submodel values (source for filtering) |
| 118 | +- `current_pred` - predictions for current pred iteration (filtered from `pred_all_submodels`) |
| 119 | +- `final_pred` - predictions after postprocessing (ready to save) |
| 120 | +- `pred_reserve` - accumulator for all final predictions |
| 121 | + |
| 122 | +**General conventions:** |
| 123 | + |
| 124 | +- `static` - things that don't change during the loop |
| 125 | +- `current_*` - value for the current iteration |
| 126 | +- `*_all_submodels` - batched values for all submodel params (e.g., `pred_all_submodels`) |
| 127 | +- Stage suffixes: `_pre`, `_model`, `_pred`, `_post` |
| 128 | + |
| 129 | +### Key helper functions |
| 130 | + |
| 131 | +**Finalization** (splice tuning params into workflow and fit): |
| 132 | + |
| 133 | +- `finalize_fit_pre()` - finalize recipe params, fit preprocessor |
| 134 | +- `finalize_fit_model()` - finalize model params, fit model |
| 135 | +- `finalize_fit_post()` - finalize tailor params, fit postprocessor |
| 136 | + |
| 137 | +**Grid helpers:** |
| 138 | + |
| 139 | +- `extend_grid()` - extend a grid with params from a schedule row (strips `*_stage` columns) |
| 140 | +- `remove_stage()` - remove nested stage columns from a schedule row |
| 141 | + |
| 142 | +**Prediction:** |
| 143 | + |
| 144 | +- `process_prediction_data()` - apply fitted preprocessor to assessment data |
| 145 | +- `predict_all_types()` - generate all needed prediction types |
| 146 | + |
| 147 | +### Error handling |
| 148 | + |
| 149 | +The loop uses a consistent error handling pattern: |
| 150 | + |
| 151 | +```r |
| 152 | +result <- .catch_and_log( |
| 153 | + some_operation(), |
| 154 | + control = static$control, |
| 155 | + split_labels = split_labs, |
| 156 | + location = location, |
| 157 | + notes = notes |
| 158 | +) |
| 159 | + |
| 160 | +if (is_failure(result)) { |
| 161 | + next |
| 162 | +} |
| 163 | +result <- remove_log_notes(result) |
| 164 | +``` |
| 165 | + |
| 166 | +- `.catch_and_log()` wraps operations to capture errors/warnings without stopping |
| 167 | +- `is_failure()` checks if the operation failed |
| 168 | +- `next` skips to the next iteration (the failed config won't have results) |
| 169 | +- `remove_log_notes()` strips logging metadata from successful results |
| 170 | + |
| 171 | +## Efficiency |
| 172 | + |
| 173 | +The nested structure avoids redundant computation: |
| 174 | + |
| 175 | +| What | Computed | Reused for | |
| 176 | +|------|----------|------------| |
| 177 | +| Preprocessor fit | Once per preproc param combo | All model params below it | |
| 178 | +| Processed prediction data | Once per preproc param combo | All predictions below it | |
| 179 | +| Model fit | Once per model param combo | All submodel predictions | |
| 180 | +| Submodel predictions | Once per model (batched) | All submodel × post combos | |
| 181 | +| Postprocessor fit | Once per post param combo | That specific config | |
| 182 | + |
| 183 | +Submodel parameters (like `penalty` in glmnet) are predicted all at once using `multi_predict()`, which is much faster than predicting one at a time. |
| 184 | + |
| 185 | +## Debugging |
| 186 | + |
| 187 | +Set `control = control_grid(allow_par = FALSE)` to run sequentially with `lapply()` so you can see output and use `browser()`. |
| 188 | + |
| 189 | +For parallel debugging, `options(future.debug = TRUE)` can help. |
0 commit comments