Skip to content

Commit abe9fa6

Browse files
authored
refactor compute_grid_info() (#957)
1 parent 12faa6d commit abe9fa6

File tree

1 file changed

+44
-251
lines changed

1 file changed

+44
-251
lines changed

R/grid_helpers.R

Lines changed: 44 additions & 251 deletions
Original file line numberDiff line numberDiff line change
@@ -316,19 +316,52 @@ compute_grid_info <- function(workflow, grid) {
316316
any_parameters_model <- nrow(parameters_model) > 0
317317
any_parameters_preprocessor <- nrow(parameters_preprocessor) > 0
318318

319-
if (any_parameters_model) {
320-
if (any_parameters_preprocessor) {
321-
compute_grid_info_model_and_preprocessor(workflow, grid, parameters_model)
322-
} else {
323-
compute_grid_info_model(workflow, grid, parameters_model)
324-
}
319+
res <- min_grid(extract_spec_parsnip(workflow), grid)
320+
321+
if (any_parameters_preprocessor) {
322+
res$.iter_preprocessor <- seq_len(nrow(res))
325323
} else {
326-
if (any_parameters_preprocessor) {
327-
compute_grid_info_preprocessor(workflow, grid, parameters_model)
328-
} else {
329-
rlang::abort("Internal error: `workflow` should have some tunable parameters if `grid` is not `NULL`.")
330-
}
324+
res$.iter_preprocessor <- 1L
325+
}
326+
327+
res$.msg_preprocessor <-
328+
new_msgs_preprocessor(
329+
seq_len(max(res$.iter_preprocessor)),
330+
max(res$.iter_preprocessor)
331+
)
332+
333+
if (nrow(res) != nrow(grid) ||
334+
(any_parameters_model && !any_parameters_preprocessor)) {
335+
res$.iter_model <- seq_len(dplyr::n_distinct(res[parameters_model$id]))
336+
} else {
337+
res$.iter_model <- 1L
338+
}
339+
340+
res$.iter_config <- list(list())
341+
for (row in seq_len(nrow(res))) {
342+
res$.iter_config[row] <- list(iter_config(res[row, ]))
331343
}
344+
345+
res$.msg_model <-
346+
new_msgs_model(i = res$.iter_model, n = max(res$.iter_model), res$.msg_preprocessor)
347+
348+
res
349+
}
350+
351+
iter_config <- function(res_row) {
352+
submodels <- res_row$.submodels[[1]]
353+
if (identical(submodels, list())) {
354+
models <- res_row$.iter_model
355+
} else {
356+
models <- seq_len(length(submodels[[1]]) + 1)
357+
}
358+
359+
paste0(
360+
"Preprocessor",
361+
res_row$.iter_preprocessor,
362+
"_Model",
363+
format_with_padding(models)
364+
)
332365
}
333366

334367
# This generates a "dummy" grid_info object that has the same
@@ -360,217 +393,6 @@ new_grid_info_resamples <- function() {
360393
out
361394
}
362395

363-
compute_grid_info_preprocessor <- function(workflow,
364-
grid,
365-
parameters_model) {
366-
out <- grid
367-
368-
n_preprocessors <- nrow(out)
369-
seq_preprocessors <- seq_len(n_preprocessors)
370-
371-
# Preprocessor<i>_Model1
372-
ids <- format_with_padding(seq_preprocessors)
373-
iter_configs <- paste0("Preprocessor", ids, "_Model1")
374-
iter_configs <- as.list(iter_configs)
375-
376-
# preprocessor <i>/<n>
377-
msgs_preprocessor <- new_msgs_preprocessor(
378-
i = seq_preprocessors,
379-
n = n_preprocessors
380-
)
381-
382-
# preprocessor <i>/<n>, model 1/1
383-
msgs_model <- new_msgs_model(
384-
i = 1L,
385-
n = 1L,
386-
msgs_preprocessor = msgs_preprocessor
387-
)
388-
389-
# Manually add .submodels column, which will always have empty lists
390-
submodels <- rep_len(list(list()), n_preprocessors)
391-
392-
out <- tibble::add_column(
393-
.data = out,
394-
.iter_preprocessor = seq_preprocessors,
395-
.before = 1L
396-
)
397-
398-
out <- tibble::add_column(
399-
.data = out,
400-
.msg_preprocessor = msgs_preprocessor,
401-
.after = ".iter_preprocessor"
402-
)
403-
404-
# Add at the end
405-
out <- tibble::add_column(
406-
.data = out,
407-
.iter_model = 1L,
408-
.after = NULL
409-
)
410-
411-
out <- tibble::add_column(
412-
.data = out,
413-
.iter_config = iter_configs,
414-
.after = ".iter_model"
415-
)
416-
417-
out <- tibble::add_column(
418-
.data = out,
419-
.msg_model = msgs_model,
420-
.after = ".iter_config"
421-
)
422-
423-
out <- tibble::add_column(
424-
.data = out,
425-
.submodels = submodels,
426-
.after = ".msg_model"
427-
)
428-
429-
out
430-
}
431-
432-
compute_grid_info_model <- function(workflow,
433-
grid,
434-
parameters_model) {
435-
spec <- extract_spec_parsnip(workflow)
436-
out <- min_grid(spec, grid)
437-
438-
n_fit_models <- nrow(out)
439-
seq_fit_models <- seq_len(n_fit_models)
440-
441-
# preprocessor 1/1
442-
msgs_preprocessor <- new_msgs_preprocessor(i = 1L, n = 1L)
443-
msgs_preprocessor <- rep(msgs_preprocessor, times = n_fit_models)
444-
445-
# preprocessor 1/1, model <i_fit>/<n_fit>
446-
msgs_model <- new_msgs_model(
447-
i = seq_fit_models,
448-
n = n_fit_models,
449-
msgs_preprocessor = msgs_preprocessor
450-
)
451-
452-
# Preprocessor1_Model<i>
453-
iter_configs <- compute_config_ids(out, "Preprocessor1")
454-
455-
out <- tibble::add_column(
456-
.data = out,
457-
.iter_preprocessor = 1L,
458-
.before = 1L
459-
)
460-
461-
out <- tibble::add_column(
462-
.data = out,
463-
.msg_preprocessor = msgs_preprocessor,
464-
.after = ".iter_preprocessor"
465-
)
466-
467-
out <- tibble::add_column(
468-
.data = out,
469-
.iter_model = seq_fit_models,
470-
.after = ".msg_preprocessor"
471-
)
472-
473-
out <- tibble::add_column(
474-
.data = out,
475-
.iter_config = iter_configs,
476-
.after = ".iter_model"
477-
)
478-
479-
out <- tibble::add_column(
480-
.data = out,
481-
.msg_model = msgs_model,
482-
.after = ".iter_config"
483-
)
484-
485-
out
486-
}
487-
488-
compute_grid_info_model_and_preprocessor <- function(workflow,
489-
grid,
490-
parameters_model) {
491-
parameter_names_model <- parameters_model[["id"]]
492-
493-
# Nest model parameters, keep preprocessor parameters outside
494-
out <- tidyr::nest(grid, data = dplyr::all_of(parameter_names_model))
495-
496-
n_preprocessors <- nrow(out)
497-
seq_preprocessors <- seq_len(n_preprocessors)
498-
499-
# preprocessor <i_pre>/<n_pre>
500-
msgs_preprocessor <- new_msgs_preprocessor(
501-
i = seq_preprocessors,
502-
n = n_preprocessors
503-
)
504-
505-
out <- tibble::add_column(
506-
.data = out,
507-
.iter_preprocessor = seq_preprocessors,
508-
.before = 1L
509-
)
510-
511-
out <- tibble::add_column(
512-
.data = out,
513-
.msg_preprocessor = msgs_preprocessor,
514-
.after = ".iter_preprocessor"
515-
)
516-
517-
spec <- extract_spec_parsnip(workflow)
518-
519-
ids_preprocessor <- format_with_padding(seq_preprocessors)
520-
ids_preprocessor <- paste0("Preprocessor", ids_preprocessor)
521-
522-
model_grids <- out[["data"]]
523-
524-
for (i in seq_preprocessors) {
525-
model_grid <- model_grids[[i]]
526-
527-
model_grid <- min_grid(spec, model_grid)
528-
529-
n_fit_models <- nrow(model_grid)
530-
seq_fit_models <- seq_len(n_fit_models)
531-
532-
msg_preprocessor <- msgs_preprocessor[[i]]
533-
id_preprocessor <- ids_preprocessor[[i]]
534-
535-
# preprocessor <i_pre>/<n_pre>, model <i_mod>/<n_mod>
536-
msgs_model <- new_msgs_model(
537-
i = seq_fit_models,
538-
n = n_fit_models,
539-
msgs_preprocessor = msg_preprocessor
540-
)
541-
542-
# Preprocessor<i_pre>_Model<i>
543-
iter_configs <- compute_config_ids(model_grid, id_preprocessor)
544-
545-
model_grid <- tibble::add_column(
546-
.data = model_grid,
547-
.iter_model = seq_fit_models,
548-
.before = 1L
549-
)
550-
551-
model_grid <- tibble::add_column(
552-
.data = model_grid,
553-
.iter_config = iter_configs,
554-
.after = ".iter_model"
555-
)
556-
557-
model_grid <- tibble::add_column(
558-
.data = model_grid,
559-
.msg_model = msgs_model,
560-
.after = ".iter_config"
561-
)
562-
563-
model_grids[[i]] <- model_grid
564-
}
565-
566-
out[["data"]] <- model_grids
567-
568-
# Unnest to match other grid-info generators
569-
out <- tidyr::unnest(out, data)
570-
571-
out
572-
}
573-
574396
new_msgs_preprocessor <- function(i, n) {
575397
paste0("preprocessor ", i, "/", n)
576398
}
@@ -583,35 +405,6 @@ format_with_padding <- function(x) {
583405
gsub(" ", "0", format(x))
584406
}
585407

586-
compute_config_ids <- function(data, id_preprocessor) {
587-
submodels <- unnest(data, .submodels, keep_empty = TRUE)
588-
submodels <- pull(submodels, .submodels)
589-
590-
# Current model that actually is fit is not included in the submodel count
591-
# so we add 1
592-
model_sizes <- lengths(submodels) + 1L
593-
594-
n_total_models <- sum(model_sizes)
595-
596-
ids <- format_with_padding(seq_len(n_total_models))
597-
ids <- paste0(id_preprocessor, "_Model", ids)
598-
599-
n_fit_models <- nrow(data)
600-
601-
out <- vector("list", length = n_fit_models)
602-
603-
start <- 1L
604-
605-
for (i in seq_len(n_fit_models)) {
606-
size <- model_sizes[[i]]
607-
stop <- start + size - 1L
608-
out[[i]] <- ids[rlang::seq2(start, stop)]
609-
start <- stop + 1L
610-
}
611-
612-
out
613-
}
614-
615408
# ------------------------------------------------------------------------------
616409

617410
has_preprocessor <- function(workflow) {

0 commit comments

Comments
 (0)