Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions tests/testthat/helper-cache.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# tools for caching results within testthat.

is_object_available <- function(object, fail = FALSE, save_path = "saved_objects") {
cl <- match.call()
file_name <- paste0(cl$object, ".RData")
file_path <- file.path(save_path, file_name)
has_file <- file.exists(file_path)
if (fail && !has_file) {
msg <- paste0("File '", file_name, "' is not in '", save_path, "'.")
cli::cli_abort(msg)
}
has_file
}

save_object <- function(object, save_path = "saved_objects") {
cl <- match.call()
file_name <- paste0(cl$object, ".RData")
file_path <- file.path(save_path, file_name)
res <- try(save(object, file = file_path), silent = TRUE)
# returned NULL if it worked
if (is.null(res)) {
# verify
res <- file.exists(file_path)
} else {
# save failed
print(as.character(res))
res <- FALSE
}
res
}

return_object <- function(object, save_path = "saved_objects") {
cl <- match.call()
file_name <- paste0(cl$object, ".RData")
file_path <- file.path(save_path, file_name)
load(file_path)
object
}

purge_objects <- function(save_path = "saved_objects") {
all_files <- list.files(save_path, pattern = "RData$", full.names = TRUE)
res <- vapply(all_files, unlink, integer(1))
df_res <- tibble::tibble(file = names(res))
df_res$deleted <- ifelse(res == 0, TRUE, FALSE)
invisible(df_res)
}

# Example usage
if (FALSE) {
pkg <- "tune"
is_object_available(pkg)

save_object(pkg)
is_object_available(pkg)

rm(pkg)
pkg <- return_object(pkg)
pkg

file_86 <- purge_objects()
file_86
is_object_available(pkg)

is_object_available(some_other_pkg, fail = TRUE)
}
Binary file not shown.
87 changes: 53 additions & 34 deletions tests/testthat/test-survival-tune-grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,59 +9,78 @@ skip_if_not_installed("censored", minimum_version = "0.2.0.9000")
skip_if_not_installed("tune", minimum_version = "1.1.1.9001")
skip_if_not_installed("yardstick", minimum_version = "1.2.0.9001")

test_that("grid tuning survival models with static metric", {
test_that("grid tuning with static metric", {
skip_if_not_installed("prodlim")
skip_if_not_installed("coin") # required for partykit engine

stc_mtrc <- metric_set(concordance_survival)

# standard setup start
set.seed(1)
sim_dat <- prodlim::SimSurv(500) %>%
mutate(event_time = Surv(time, event)) %>%
select(event_time, X1, X2)

set.seed(2)
split <- initial_split(sim_dat)
sim_tr <- training(split)
sim_te <- testing(split)
sim_rs <- vfold_cv(sim_tr)

time_points <- c(10, 1, 5, 15)

mod_spec <-
decision_tree(tree_depth = tune(), min_n = 4) %>%
set_engine("partykit") %>%
set_mode("censored regression")
if (is_object_available(grid_static_res)) {
grid_static_res <- return_object(grid_static_res)
} else {
stc_mtrc <- metric_set(concordance_survival)

set.seed(1)
sim_dat <- prodlim::SimSurv(500) %>%
mutate(event_time = Surv(time, event)) %>%
select(event_time, X1, X2)

set.seed(2)
split <- initial_split(sim_dat)
sim_tr <- training(split)
sim_te <- testing(split)
sim_rs <- vfold_cv(sim_tr)

time_points <- c(10, 1, 5, 15)

mod_spec <-
decision_tree(tree_depth = tune(), min_n = 4) %>%
set_engine("partykit") %>%
set_mode("censored regression")

grid <- tibble(tree_depth = c(1, 2, 10))

gctrl <- control_grid(save_pred = TRUE)

set.seed(2193)
grid_static_res <-
mod_spec %>%
tune_grid(
event_time ~ X1 + X2,
resamples = sim_rs,
grid = grid,
metrics = stc_mtrc,
control = gctrl
)
save_object(grid_static_res)
}

expect_s3_class(grid_static_res, "tune_results")
})

grid <- tibble(tree_depth = c(1, 2, 10))

gctrl <- control_grid(save_pred = TRUE)
# standard setup end
test_that("grid tuning with static metric - check structure", {

set.seed(2193)
grid_static_res <-
mod_spec %>%
tune_grid(
event_time ~ X1 + X2,
resamples = sim_rs,
grid = grid,
metrics = stc_mtrc,
control = gctrl
)
is_object_available(grid_static_res, fail = TRUE)
grid_static_res <- return_object(grid_static_res)

expect_false(".eval_time" %in% names(grid_static_res$.metrics[[1]]))
expect_equal(
names(grid_static_res$.predictions[[1]]),
c(".pred_time", ".row", "tree_depth", "event_time", ".config")
)
})

test_that("grid tuning with static metric - autoplot", {

is_object_available(grid_static_res, fail = TRUE)
grid_static_res <- return_object(grid_static_res)

expect_snapshot_plot(
print(autoplot(grid_static_res)),
"static-metric-grid-search"
)
})


test_that("grid tuning survival models with integrated metric", {
skip_if_not_installed("prodlim")
skip_if_not_installed("coin") # required for partykit engine
Expand Down