Skip to content

Commit e94aeb5

Browse files
committed
Merge branch 'master' into development
2 parents 4f9a97f + 696a50d commit e94aeb5

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

R/extract-tunable-params.R

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#' Extract Tunable Parameters from Model Specifications
2+
#'
3+
#' @family Extractor
4+
#'
5+
#' @description Extract a list of tunable parameters from the `.model_spec` column
6+
#' of a `tidyaml_mod_spec_tbl`.
7+
#'
8+
#' @details This function iterates over the `.model_spec` column of a model table
9+
#' and extracts tunable parameters for each model using `tunable()`. The result
10+
#' is a list that can be further processed into a tibble if needed.
11+
#'
12+
#' @param .model_tbl A model table with a class of `tidyaml_mod_spec_tbl`.
13+
#'
14+
#' @return A list of tibbles, each containing the tunable parameters for a model.
15+
#'
16+
#' @examples
17+
#' library(dplyr)
18+
#' mods <- create_model_spec(.parsnip_eng = list("lm", "glmnet"))
19+
#' extract_tunable_params(mods)
20+
#'
21+
#' @export
22+
extract_tunable_params <- function(.model_tbl) {
23+
24+
# Tidyeval ----
25+
model_tbl <- .model_tbl
26+
model_tbl_class <- class(model_tbl)
27+
28+
# Checks ----
29+
if (!inherits(model_tbl, "tidyaml_mod_spec_tbl")){
30+
rlang::abort(
31+
message = paste0(
32+
"'.model_tbl' must inherit a class of 'tidyaml_mod_spec_tbl \n'",
33+
"The current class is: ",
34+
class(model_tbl)
35+
),
36+
use_cli_format = TRUE
37+
)
38+
}
39+
40+
# Manipulation
41+
model_factor_tbl <- model_tbl |>
42+
dplyr::mutate(.model_id = forcats::as_factor(.model_id))
43+
44+
# Make a group split object list
45+
models_list <- model_factor_tbl |>
46+
dplyr::group_split(.model_id)
47+
48+
# Extract tunable parameters using purrr imap
49+
tunable_params_list <- models_list |>
50+
purrr::imap(
51+
.f = function(obj, id) {
52+
53+
# Pull the model_spec column and then pluck the model_spec
54+
mod <- obj |> dplyr::pull(5) |> purrr::pluck(1)
55+
56+
# Extract tunable parameters
57+
ret <- tunable(mod)
58+
59+
# Return the result
60+
return(ret)
61+
}
62+
)
63+
64+
# Return
65+
return(tunable_params_list)
66+
}

0 commit comments

Comments
 (0)