Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# workflows (development version)

* Added a standalone file `standalone-input-names.R` with APIs for returning the
names of the predictors in the original data given to `fit()`.

* Each of the `pull_*()` functions soft-deprecated in workflows v0.2.3 now warn on every usage.

* `add_recipe()` will now error informatively when supplied a trained recipe (#179).
Expand Down
64 changes: 64 additions & 0 deletions R/standalone-input-names.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# ---
# repo: tidymodels/workflows
# file: standalone-input-names.R
# last-updated: 2024-01-21
# license: https://unlicense.org
# ---

# secret gist at: https://gist.github.com/topepo/17d51cafcd0ac8dff0552198d6aeadbf

# This file provides a portable set of helper functions for determining the
# names of the predictor columns used as inputs into a workflow.

# ## Changelog
# 2024-01-21
# * First version

# ------------------------------------------------------------------------------

check_workflow_fit <- function(x) {
if (!x$trained) {
stop("The workflow should be trainined.")
}
invisible(NULL)
}

check_recipe_fit <- function(x) {
is_trained <- vapply(x$steps, function(x) x$trained, logical(1))
if (!all(is_trained)) {
stop("All recipe steps should be trainined.")
}
invisible(NULL)
}

blueprint_ptype <- function(x) {
names(x$pre$mold$blueprint$ptypes$predictors)
}

.get_input_predictors_workflow <- function(x, ...) {
check_workflow_fit(x)
# We can get the columns that are inputs to the recipe but some of these may
# not be predictors. We'll interrogate the recipe and pull out the current
# predictor names from the original input
if ("recipe" %in% names(x$pre$actions)) {
mold <- x$pre$mold
rec <- mold$blueprint$recipe
res <- .get_input_predictors_recipe(rec)
} else {
res <- blueprint_ptype(x)
}
sort(unique(res))
}

is_predictor_role <- function(x) {
vapply(x$role, function(x) any(x == "predictor"), logical(1))
}

.get_input_predictors_recipe <- function(x, ...) {
check_recipe_fit(x)
var_info <- x$last_term_info

keep_rows <- var_info$source == "original" & is_predictor_role(var_info)
var_info <- var_info[keep_rows,]
var_info$variable
}
32 changes: 32 additions & 0 deletions tests/testthat/_snaps/input-names.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# get recipe input column names

Code
workflows:::.get_input_predictors_workflow(workflow)
Condition
Error in `check_workflow_fit()`:
! The workflow should be trainined.

---

Code
workflows:::.get_input_predictors_recipe(rec_with_id)
Condition
Error in `check_recipe_fit()`:
! All recipe steps should be trainined.

# get formula input column names

Code
workflows:::.get_input_predictors_workflow(workflow)
Condition
Error in `check_workflow_fit()`:
! The workflow should be trainined.

# get predictor input column names

Code
workflows:::.get_input_predictors_workflow(workflow)
Condition
Error in `check_workflow_fit()`:
! The workflow should be trainined.

88 changes: 88 additions & 0 deletions tests/testthat/test-input-names.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
test_that("get recipe input column names", {
skip_if_not_installed("modeldata")
skip_if_not_installed("recipes")

library(recipes)

data(cells, package = "modeldata")

cells <- cells[, 1:10]
pred_names <- sort(names(cells)[3:10])

rec_with_id <-
recipes::recipe(class ~ ., cells) %>%
update_role(case, new_role = "destination") %>%
step_rm(angle_ch_1) %>%
step_pca(all_predictors())

workflow <- workflow()
workflow <- add_recipe(workflow, rec_with_id)
workflow <- add_model(workflow, parsnip::logistic_reg())
workflow_fit <- fit(workflow, cells)

expect_snapshot(
workflows:::.get_input_predictors_workflow(workflow),
error = TRUE
)
expect_equal(
workflows:::.get_input_predictors_workflow(workflow_fit),
pred_names
)
expect_snapshot(
workflows:::.get_input_predictors_recipe(rec_with_id),
error = TRUE
)

})

test_that("get formula input column names", {
skip_if_not_installed("modeldata")

data(Chicago, package = "modeldata")

Chicago <- Chicago[, c("ridership", "date", "Austin")]
pred_names <- sort(c("date", "Austin"))

workflow <- workflow()
workflow <- add_formula(workflow, ridership ~ .)
workflow <- add_model(workflow, parsnip::linear_reg())
workflow_fit <- fit(workflow, Chicago)

expect_snapshot(
workflows:::.get_input_predictors_workflow(workflow),
error = TRUE
)
expect_equal(
workflows:::.get_input_predictors_workflow(workflow_fit),
pred_names
)

})


test_that("get predictor input column names", {
skip_if_not_installed("modeldata")

data(Chicago, package = "modeldata")

Chicago <- Chicago[, c("ridership", "date", "Austin")]
pred_names <- sort(c("date", "Austin"))

workflow <- workflow()
workflow <-
add_variables(workflow,
outcomes = c(ridership),
predictors = c(tidyselect::everything()))
workflow <- add_model(workflow, parsnip::linear_reg())
workflow_fit <- fit(workflow, Chicago)

expect_snapshot(
workflows:::.get_input_predictors_workflow(workflow),
error = TRUE
)
expect_equal(
workflows:::.get_input_predictors_workflow(workflow_fit),
pred_names
)

})