Skip to content

Commit f830e5a

Browse files
authored
changes for #182 (#184)
1 parent d3977c4 commit f830e5a

File tree

6 files changed

+163
-50
lines changed

6 files changed

+163
-50
lines changed

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# probably (development version)
22

3+
* Updated unit tests for new ggplot2 release (#180).
4+
5+
* Better error message when using one of the `cal_validate_*()` functions with a validation set (#182).
6+
37
# probably 1.1.0
48

59
* Significant refactoring of the code underlying the calibration functions. The user-facing APIs have not changed.

R/cal-validate.R

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ cal_validate_logistic.resample_results <-
6868
metrics = NULL,
6969
save_pred = FALSE,
7070
...) {
71+
cl <- match.call()
72+
validation_check(.data, cl)
73+
7174
if (!is.null(truth)) {
7275
cli::cli_warn("{.arg truth} is automatically set when this type of object is used.")
7376
}
@@ -151,6 +154,9 @@ cal_validate_isotonic.resample_results <-
151154
metrics = NULL,
152155
save_pred = FALSE,
153156
...) {
157+
cl <- match.call()
158+
validation_check(.data, cl)
159+
154160
if (!is.null(truth)) {
155161
cli::cli_warn("{.arg truth} is automatically set when this type of object is used.")
156162
}
@@ -236,6 +242,9 @@ cal_validate_isotonic_boot.resample_results <-
236242
metrics = NULL,
237243
save_pred = FALSE,
238244
...) {
245+
cl <- match.call()
246+
validation_check(.data, cl)
247+
239248
if (!is.null(truth)) {
240249
cli::cli_warn("{.arg truth} is automatically set when this type of object is used.")
241250
}
@@ -321,6 +330,12 @@ cal_validate_beta.resample_results <-
321330
metrics = NULL,
322331
save_pred = FALSE,
323332
...) {
333+
cl <- match.call()
334+
validation_check(.data, cl)
335+
336+
cl <- match.call()
337+
validation_check(.data, cl)
338+
324339
if (!is.null(truth)) {
325340
cli::cli_warn("{.arg truth} is automatically set when this type of object is used.")
326341
}
@@ -401,6 +416,9 @@ cal_validate_multinomial.resample_results <-
401416
metrics = NULL,
402417
save_pred = FALSE,
403418
...) {
419+
cl <- match.call()
420+
validation_check(.data, cl)
421+
404422
if (!is.null(truth)) {
405423
cli::cli_warn("{.arg truth} is automatically set when this type of object is used.")
406424
}
@@ -515,6 +533,7 @@ cal_validate <- function(rset,
515533
predictions_out <- pull_pred(rset, analysis = FALSE)
516534

517535
est_fn_name <- paste0("cal_estimate_", cal_function)
536+
518537
est_cl <-
519538
rlang::call2(
520539
est_fn_name,
@@ -560,18 +579,23 @@ cal_validate <- function(rset,
560579
}
561580

562581
pull_pred <- function(x, analysis = TRUE) {
563-
has_dot_row <- any(names(x$splits[[1]]$data) == ".row")
564582
if (analysis) {
565583
what <- "analysis"
566584
} else {
567585
what <- "assessment"
568586
}
569-
preds <- purrr::map(x$splits, as.data.frame, data = what)
570-
if (!has_dot_row) {
571-
rows <- purrr::map(x$splits, ~ dplyr::tibble(.row = as.integer(.x, data = what)))
572-
preds <- purrr::map2(preds, rows, ~ dplyr::bind_cols(.x, .y))
573-
}
574587

588+
if (inherits(x$splits[[1]], "val_split")) {
589+
preds <- as.data.frame(x$splits[[1]], what)
590+
} else {
591+
has_dot_row <- any(names(x$splits[[1]]$data) == ".row")
592+
593+
preds <- purrr::map(x$splits, as.data.frame, data = what)
594+
if (!has_dot_row) {
595+
rows <- purrr::map(x$splits, ~ dplyr::tibble(.row = as.integer(.x, data = what)))
596+
preds <- purrr::map2(preds, rows, ~ dplyr::bind_cols(.x, .y))
597+
}
598+
}
575599
preds
576600
}
577601

@@ -655,6 +679,9 @@ cal_validate_linear.resample_results <-
655679
metrics = NULL,
656680
save_pred = FALSE,
657681
...) {
682+
cl <- match.call()
683+
validation_check(.data, cl)
684+
658685
if (!is.null(truth)) {
659686
cli::cli_warn("{.arg truth} is automatically set when this type of object is used.")
660687
}
@@ -748,6 +775,9 @@ cal_validate_none.resample_results <-
748775
metrics = NULL,
749776
save_pred = FALSE,
750777
...) {
778+
cl <- match.call()
779+
validation_check(.data, cl)
780+
751781
if (!is.null(truth)) {
752782
cli::cli_warn("{.arg truth} is automatically set when this type of object is used.")
753783
}
@@ -803,8 +833,14 @@ convert_resamples <- function(x) {
803833
predictions <-
804834
tune::collect_predictions(x, summarize = TRUE) |>
805835
dplyr::arrange(.row)
836+
837+
# Not all prediction sets, when collected, will match the size of the original
838+
# data so buff out the data set
839+
data_ind <- dplyr::tibble(.row = seq_len(nrow(x$splits[[1]]$data)))
840+
all_data <- dplyr::full_join(data_ind, predictions, by = ".row")
841+
806842
for (i in seq_along(x$splits)) {
807-
x$splits[[i]]$data <- predictions
843+
x$splits[[i]]$data <- all_data
808844
}
809845
class(x) <- c("rset", "tbl_df", "tbl", "data.frame")
810846
x
@@ -891,3 +927,18 @@ collect_predictions.cal_rset <- function(x, summarize = TRUE, ...) {
891927
}
892928
res
893929
}
930+
931+
validation_check <- function(x, cl = NULL, call = rlang::caller_env()) {
932+
fn <- as.character(cl[[1]])
933+
fn <- strsplit(fn, "\\.")[[1]][1]
934+
935+
if (inherits(x$splits[[1]], "val_split")) {
936+
cli::cli_abort(
937+
"For validation sets, please make a resampling object from the predictions
938+
prior to calling {.fn {fn}}",
939+
call = call
940+
)
941+
}
942+
invisible(NULL)
943+
}
944+

man/int_conformal_full.Rd

Lines changed: 1 addition & 22 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/_snaps/cal-validate.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,59 @@
4444
This function can only be used with an <rset> object or the results of `tune::fit_resamples()` with a .predictions column.
4545
i Not an <tune_results> object.
4646

47+
# validation sets fail with better message
48+
49+
Code
50+
cal_validate_beta(mt_res)
51+
Condition
52+
Error in `cal_validate_beta()`:
53+
! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_beta()`
54+
55+
---
56+
57+
Code
58+
cal_validate_isotonic(mt_res)
59+
Condition
60+
Error in `cal_validate_isotonic()`:
61+
! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_isotonic()`
62+
63+
---
64+
65+
Code
66+
cal_validate_isotonic_boot(mt_res)
67+
Condition
68+
Error in `cal_validate_isotonic_boot()`:
69+
! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_isotonic_boot()`
70+
71+
---
72+
73+
Code
74+
cal_validate_linear(mt_res)
75+
Condition
76+
Error in `cal_validate_linear()`:
77+
! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_linear()`
78+
79+
---
80+
81+
Code
82+
cal_validate_logistic(mt_res)
83+
Condition
84+
Error in `cal_validate_logistic()`:
85+
! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_logistic()`
86+
87+
---
88+
89+
Code
90+
cal_validate_multinomial(mt_res)
91+
Condition
92+
Error in `cal_validate_multinomial()`:
93+
! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_multinomial()`
94+
95+
---
96+
97+
Code
98+
cal_validate_none(mt_res)
99+
Condition
100+
Error in `cal_validate_none()`:
101+
! For validation sets, please make a resampling object from the predictions prior to calling `cal_validate_none()`
102+

tests/testthat/test-cal-validate-multiclass.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ test_that("Isotonic validation with `fit_resamples` - Multiclass", {
5050
)
5151
skip_if_not_installed("tune", "1.2.0")
5252
expect_equal(
53-
names(val_with_pred$.predictions_cal[[1]]),
54-
c(".pred_one", ".pred_two", ".pred_three", ".row", "outcome", ".config", ".pred_class")
53+
sort(names(val_with_pred$.predictions_cal[[1]])),
54+
sort(c(".pred_one", ".pred_two", ".pred_three", ".row", "outcome", ".config", ".pred_class"))
5555
)
5656
expect_equal(
5757
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),

tests/testthat/test-cal-validate.R

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,8 @@ test_that("Logistic validation with `fit_resamples`", {
348348

349349
skip_if_not_installed("tune", "1.2.0")
350350
expect_equal(
351-
names(val_with_pred$.predictions_cal[[1]]),
352-
c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")
351+
sort(names(val_with_pred$.predictions_cal[[1]])),
352+
sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class"))
353353
)
354354
expect_equal(
355355
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
@@ -380,8 +380,8 @@ test_that("Isotonic classification validation with `fit_resamples`", {
380380

381381
skip_if_not_installed("tune", "1.2.0")
382382
expect_equal(
383-
names(val_with_pred$.predictions_cal[[1]]),
384-
c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")
383+
sort(names(val_with_pred$.predictions_cal[[1]])),
384+
sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class"))
385385
)
386386
expect_equal(
387387
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
@@ -413,8 +413,8 @@ test_that("Bootstrapped isotonic classification validation with `fit_resamples`"
413413

414414
skip_if_not_installed("tune", "1.2.0")
415415
expect_equal(
416-
names(val_with_pred$.predictions_cal[[1]]),
417-
c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")
416+
sort(names(val_with_pred$.predictions_cal[[1]])),
417+
sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class"))
418418
)
419419
expect_equal(
420420
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
@@ -446,8 +446,8 @@ test_that("Beta calibration validation with `fit_resamples`", {
446446

447447
skip_if_not_installed("tune", "1.2.0")
448448
expect_equal(
449-
names(val_with_pred$.predictions_cal[[1]]),
450-
c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")
449+
sort(names(val_with_pred$.predictions_cal[[1]])),
450+
sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class"))
451451
)
452452
expect_equal(
453453
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
@@ -481,8 +481,8 @@ test_that("Multinomial calibration validation with `fit_resamples`", {
481481

482482
skip_if_not_installed("tune", "1.2.0")
483483
expect_equal(
484-
names(val_with_pred$.predictions_cal[[1]]),
485-
c(".pred_one", ".pred_two", ".pred_three", ".row", "outcome", ".config", ".pred_class")
484+
sort(names(val_with_pred$.predictions_cal[[1]])),
485+
sort(c(".pred_one", ".pred_two", ".pred_three", ".row", "outcome", ".config", ".pred_class"))
486486
)
487487
expect_equal(
488488
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
@@ -513,8 +513,8 @@ test_that("Validation without calibration with `fit_resamples`", {
513513

514514
skip_if_not_installed("tune", "1.2.0")
515515
expect_equal(
516-
names(val_with_pred$.predictions_cal[[1]]),
517-
c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class")
516+
sort(names(val_with_pred$.predictions_cal[[1]])),
517+
sort(c(".pred_class_1", ".pred_class_2", ".row", "outcome", ".config", ".pred_class"))
518518
)
519519
expect_equal(
520520
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
@@ -548,8 +548,8 @@ test_that("Linear validation with `fit_resamples`", {
548548

549549
skip_if_not_installed("tune", "1.2.0")
550550
expect_equal(
551-
names(val_with_pred$.predictions_cal[[1]]),
552-
c(".pred", ".row", "outcome", ".config")
551+
sort(names(val_with_pred$.predictions_cal[[1]])),
552+
sort(c(".pred", ".row", "outcome", ".config"))
553553
)
554554
expect_equal(
555555
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
@@ -621,8 +621,8 @@ test_that("Isotonic regression validation with `fit_resamples`", {
621621

622622
skip_if_not_installed("tune", "1.2.0")
623623
expect_equal(
624-
names(val_with_pred$.predictions_cal[[1]]),
625-
c(".pred", ".row", "outcome", ".config")
624+
sort(names(val_with_pred$.predictions_cal[[1]])),
625+
sort(c(".pred", ".row", "outcome", ".config"))
626626
)
627627
expect_equal(
628628
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
@@ -657,8 +657,8 @@ test_that("Isotonic bootstrapped regression validation with `fit_resamples`", {
657657

658658
skip_if_not_installed("tune", "1.2.0")
659659
expect_equal(
660-
names(val_with_pred$.predictions_cal[[1]]),
661-
c(".pred", ".row", "outcome", ".config")
660+
sort(names(val_with_pred$.predictions_cal[[1]])),
661+
sort(c(".pred", ".row", "outcome", ".config"))
662662
)
663663
expect_equal(
664664
purrr::map_int(val_with_pred$splits, ~ holdout_length(.x)),
@@ -670,7 +670,6 @@ test_that("Isotonic bootstrapped regression validation with `fit_resamples`", {
670670

671671
# ------------------------------------------------------------------------------
672672

673-
674673
test_that("validation functions error with tune_results input", {
675674
skip_if_not_installed("modeldata")
676675
skip_if_not_installed("nnet")
@@ -698,3 +697,27 @@ test_that("validation functions error with tune_results input", {
698697
cal_validate_none(testthat_cal_binary())
699698
)
700699
})
700+
701+
# ------------------------------------------------------------------------------
702+
703+
test_that("validation sets fail with better message", {
704+
library(tune)
705+
set.seed(1)
706+
mt_split <- rsample::initial_validation_split(mtcars)
707+
mt_rset <- rsample::validation_set(mt_split)
708+
mt_res <-
709+
parsnip::linear_reg() |>
710+
fit_resamples(
711+
mpg ~ .,
712+
resamples = mt_rset,
713+
control = control_resamples(save_pred = TRUE)
714+
)
715+
716+
expect_snapshot(cal_validate_beta(mt_res), error = TRUE)
717+
expect_snapshot(cal_validate_isotonic(mt_res), error = TRUE)
718+
expect_snapshot(cal_validate_isotonic_boot(mt_res), error = TRUE)
719+
expect_snapshot(cal_validate_linear(mt_res), error = TRUE)
720+
expect_snapshot(cal_validate_logistic(mt_res), error = TRUE)
721+
expect_snapshot(cal_validate_multinomial(mt_res), error = TRUE)
722+
expect_snapshot(cal_validate_none(mt_res), error = TRUE)
723+
})

0 commit comments

Comments
 (0)