Skip to content

Commit 6aad3fa

Browse files
authored
correct interface to fit.probability_calibration() (#49)
1 parent e81dffd commit 6aad3fa

File tree

3 files changed

+10
-12
lines changed

3 files changed

+10
-12
lines changed

R/adjust-probability-calibration.R

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#'
1616
#' @inheritSection adjust_numeric_calibration Data Usage
1717
#'
18-
# TODO: see #36
1918
#' @examplesIf FALSE
2019
# @examplesIf rlang::is_installed("modeldata")
2120
#' library(modeldata)
@@ -95,10 +94,10 @@ fit.probability_calibration <- function(object, data, tailor = NULL, ...) {
9594
eval_bare(
9695
call2(
9796
paste0("cal_estimate_", method),
98-
.data = data,
97+
.data = expr(data),
9998
# todo: make getters for the entries in `columns`
10099
truth = tailor$columns$outcome,
101-
estimate = tailor$columns$estimate,
100+
estimate = tailor$columns$probabilities,
102101
.ns = "probably"
103102
)
104103
)
@@ -116,7 +115,11 @@ fit.probability_calibration <- function(object, data, tailor = NULL, ...) {
116115

117116
#' @export
118117
predict.probability_calibration <- function(object, new_data, tailor, ...) {
119-
probably::cal_apply(new_data, object$results$fit)
118+
probably::cal_apply(
119+
.data = new_data,
120+
object = object$results$fit,
121+
pred_class = !!tailor$columns$estimate
122+
)
120123
}
121124

122125
# todo probably needs required_pkgs methods for cal objects

tests/testthat/test-adjust-probability-calibration.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ test_that("basic adjust_probability_calibration() usage works", {
1515
adjust_probability_calibration(method = "logistic")
1616
)
1717

18-
skip("TODO: cannot run for now due to #36")
19-
2018
expect_no_condition(
2119
tlr_fit <- fit(
2220
tlr,
@@ -28,7 +26,7 @@ test_that("basic adjust_probability_calibration() usage works", {
2826
)
2927

3028
expect_no_condition(
31-
predict(tlr_fit, d_test)
29+
tlr_pred <- predict(tlr_fit, d_test)
3230
)
3331

3432
# classes are as expected

tests/testthat/test-utils.R

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ test_that("tailor_fully_trained works", {
1717
fit(
1818
two_class_example,
1919
outcome = "truth",
20-
estimate = tidyselect::contains("Class"),
20+
estimate = predicted,
2121
probabilities = tidyselect::contains("Class")
2222
) %>%
2323
adjust_probability_threshold(.5)
@@ -31,10 +31,7 @@ test_that("tailor_fully_trained works", {
3131
fit(
3232
two_class_example,
3333
outcome = "truth",
34-
# todo: this function requires a different format of `estimate`
35-
# and `probabilities` specification than the call below to
36-
# be able to fit properly.
37-
estimate = tidyselect::contains("Class"),
34+
estimate = predicted,
3835
probabilities = tidyselect::contains("Class")
3936
)
4037
)

0 commit comments

Comments
 (0)