Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,12 @@ S3method(vec_cast,double.hardhat_importance_weights)
S3method(vec_cast,hardhat_frequency_weights.hardhat_frequency_weights)
S3method(vec_cast,hardhat_importance_weights.hardhat_importance_weights)
S3method(vec_cast,integer.hardhat_frequency_weights)
S3method(vec_cast,quantile_pred.quantile_pred)
S3method(vec_proxy_compare,quantile_pred)
S3method(vec_proxy_order,quantile_pred)
S3method(vec_ptype2,hardhat_frequency_weights.hardhat_frequency_weights)
S3method(vec_ptype2,hardhat_importance_weights.hardhat_importance_weights)
S3method(vec_ptype2,quantile_pred.quantile_pred)
S3method(vec_ptype_abbr,hardhat_frequency_weights)
S3method(vec_ptype_abbr,hardhat_importance_weights)
S3method(vec_ptype_abbr,quantile_pred)
Expand Down
99 changes: 88 additions & 11 deletions R/quantile-pred.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@

rownames(values) <- NULL
colnames(values) <- NULL
values <- lapply(vctrs::vec_chop(values), drop)
values <- list(quantile_values = values)
new_quantile_pred(values, quantile_levels)
}

new_quantile_pred <- function(values = list(), quantile_levels = double()) {
new_quantile_pred <- function(values = list(quantile_values = matrix(double(), 0L, 0L)), quantile_levels = double()) {
quantile_levels <- vctrs::vec_cast(quantile_levels, double())
vctrs::new_vctr(
vctrs::new_rcrd(
values, quantile_levels = quantile_levels, class = "quantile_pred"
)
}
Expand All @@ -73,7 +73,7 @@
n_samp <- length(x)
n_quant <- length(lvls)
tibble::new_tibble(list(
.pred_quantile = unlist(x),
.pred_quantile = `dim<-`(t(field(x, "quantile_values")), NULL),
.quantile_levels = rep(lvls, n_samp),
.row = rep(1:n_samp, each = n_quant)
))
Expand All @@ -82,17 +82,19 @@
#' @export
#' @rdname quantile_pred
as.matrix.quantile_pred <- function(x, ...) {
num_samp <- length(x)
matrix(unlist(x), nrow = num_samp, byrow = TRUE)
field(x, "quantile_values")
}

#' @export
format.quantile_pred <- function(x, digits = 3L, ...) {
format.quantile_pred <- function(x, digits = NULL, ...) {
if (is.null(digits)) {
digits <- 3L
}
quantile_levels <- attr(x, "quantile_levels")
if (length(quantile_levels) == 1L) {
x <- unlist(x)
x <- field(x, "quantile_values")
dim(x) <- NULL
out <- signif(x, digits = digits)
out[is.na(x)] <- NA_real_
} else {
m <- median(x, na.rm = TRUE)
out <- paste0("[", signif(m, digits = digits), "]")
Expand All @@ -103,14 +105,17 @@
#' @export
median.quantile_pred <- function(x, ...) {
lvls <- attr(x, "quantile_levels")
vals <- field(x, "quantile_values")
loc_median <- (abs(lvls - 0.5) < sqrt(.Machine$double.eps))
if (any(loc_median)) {
return(map_dbl(x, ~ .x[min(which(loc_median))]))
return(vals[, min(which(loc_median))])

Check warning on line 111 in R/quantile-pred.R

View check run for this annotation

Codecov / codecov/patch

R/quantile-pred.R#L111

Added line #L111 was not covered by tests
}
if (length(lvls) < 2 || min(lvls) > 0.5 || max(lvls) < 0.5) {
return(rep(NA, vctrs::vec_size(x)))
}
map_dbl(x, ~ stats::approx(lvls, .x, xout = 0.5)$y)
map_dbl(vec_seq_along(vals), function(row_i) {
stats::approx(lvls, vals[row_i, ], xout = 0.5)$y
})
}

#' @export
Expand All @@ -129,6 +134,78 @@
cat(footer, format(lvls, digits = digits), "\n", sep = " ")
}

#' @export
vec_proxy_compare.quantile_pred <- function(x, ...) {
# Using a proxy-based lexicographical order doesn't make sense for binary
# comparison operators. (A partial order could be implemented by directly
# overriding the binary comparison operators, but would conflict with the
# lexicographical total order used for sorting.)
cli::cli_abort("
`vec_proxy_compare`, `<`, `<=`, `>`, and `>=` are not supported for
`quantile_pred`s.
", class = "hardhat_error_comparing_quantile_preds")
}

#' @export
vec_proxy_order.quantile_pred <- function(x, ...) {
# Like {vctrs}' list treatment, allow for (lexicographical) ordering based on
# `quantile_pred`s, even though we disallow using this order for binary
# comparison operators.
vec_proxy(x)
}

# ------------------------------------------------------------------------------
# ptype-related functions

format_quantile_levels <- function(quantile_levels) {
# Make sure that we format levels with enough sig figs to detect minor
# differences. Specifically, format them with enough sig figs that we recover
# their precise values.
result <- formatC(quantile_levels, digits = 3)
for (digits in 4:17) { # 17 significant digits should be enough to disambiguate
imprecise <- as.numeric(result) != quantile_levels
if (!any(imprecise)) break
result[imprecise] <- formatC(quantile_levels[imprecise], digits = digits)
}
result <- trimws(result)
result
}

validate_preds_have_same_quantile_levels <- function(x, y, action, x_arg, y_arg, call) {
x_quantile_levels <- attr(x, "quantile_levels")
y_quantile_levels <- attr(y, "quantile_levels")
if (!identical(x_quantile_levels, y_quantile_levels)) {
x_formatted_levels <- format_quantile_levels(x_quantile_levels)
y_formatted_levels <- format_quantile_levels(y_quantile_levels)
stop_incompatible_type(
x, y,
action = action,
x_arg = x_arg, y_arg = y_arg,
details = cli::format_error(c(
"They have different sets of quantile levels:",
"*" = '1st set of quantile levels: {x_formatted_levels}',
"*" = '2nd set of quantile levels: {y_formatted_levels}'
)),
call = call
)
}
}

#' @export
vec_ptype2.quantile_pred.quantile_pred <-
function(x, y, ..., x_arg = caller_arg(x), y_arg = caller_arg(y), call = caller_env()) {
validate_preds_have_same_quantile_levels(x, y, "combine", x_arg, y_arg, call)
field(x, "quantile_values") <- vec_ptype2(field(x, "quantile_values"), field(y, "quantile_values"))
x
}

#' @export
vec_cast.quantile_pred.quantile_pred <-
function(x, to, ..., x_arg = caller_arg(x), to_arg = "", call = caller_env()) {
validate_preds_have_same_quantile_levels(x, to, "convert", x_arg, to_arg, call)
field(x, "quantile_values") <- vec_cast(field(x, "quantile_values"), field(to, "quantile_values"))
x
}

# ------------------------------------------------------------------------------
# Checking functions
Expand Down
23 changes: 23 additions & 0 deletions tests/testthat/_snaps/quantile-pred.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,26 @@
Output
[1] "[1.7154]" "[0.56784]" "[1.2393]" "[2.2062]" "[0.76714]"

---

Code
data.frame(v = v)
Output
v
1 [1.72]
2 [0.568]
3 [1.24]
4 [2.21]
5 [0.767]

# quantile_pred level (in)compatibility works

Code
vec_ptype2(v1, v2)
Condition <vctrs_error_ptype2>
Error:
! Can't combine `v1` <quantiles> and `v2` <quantiles>.
They have different sets of quantile levels:
* 1st set of quantile levels: 0, 0.05, 0.1, 0.15000000000000002, and 0.2
* 2nd set of quantile levels: 0, 0.05, 0.1, 0.15, and 0.2

82 changes: 80 additions & 2 deletions tests/testthat/test-quantile-pred.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ test_that("quantile_pred outputs", {
expect_s3_class(v, "quantile_pred")
expect_identical(attr(v, "quantile_levels"), 1:4 / 5)
expect_identical(
vctrs::vec_data(v),
lapply(vctrs::vec_chop(matrix(1:20, 5)), drop)
vctrs::vec_data(v) %>% lapply(unname),
list(quantile_values = matrix(1:20, 5))
)
})

Expand Down Expand Up @@ -78,6 +78,7 @@ test_that("quantile_pred formatting", {
v <- quantile_pred(matrix(exp(rnorm(20)), ncol = 4), 1:4 / 5)
expect_snapshot(format(v))
expect_snapshot(format(v, digits = 5))
expect_snapshot(data.frame(v = v))
})

test_that("as_tibble() for quantile_pred", {
Expand All @@ -95,3 +96,80 @@ test_that("as.matrix() for quantile_pred", {
expect_true(is.matrix(m))
expect_identical(m, x)
})

test_that("Various ways to introduce NAs in quantile_pred work", {
dbl_mat <- matrix(c(1:3, c(1, NA, NA), rep(NA, 3)), 3, 3, byrow = TRUE)
int_mat <- dbl_mat
storage.mode(int_mat) <- "integer"
levels <- 1:3 / 4
dbl_v <- quantile_pred(dbl_mat, levels)
int_v <- quantile_pred(int_mat, levels)
for (v in list(dbl_v, int_v)) {
sentinel <- v[3]
expect_identical(vec_init(v, 5), rep(sentinel, 5))
expect_identical(vec_c(v[1:2], NA), v)
expect_identical(vec_c(NA, v[1:2]), v[c(3, 1, 2)])
expect_identical(
merge(tibble(date = as.Date("2020-01-01") + 0:5),
tibble(date = as.Date("2020-01-01") + 1:3,
pred = v),
by = "date", all = TRUE),
data.frame(date = as.Date("2020-01-01") + 0:5,
pred = v[c(3, 1:3, 3, 3)])
)
expect_identical(vec_detect_missing(v), c(FALSE, FALSE, TRUE))
expect_identical(vec_detect_complete(v), c(TRUE, FALSE, FALSE))
}
})

test_that("quantile_pred == logic outputs NAs when expected", {
single_pred <- function(values, levels) quantile_pred(t(as.matrix(values)), levels)
expect_identical(single_pred(1, 0.5) == single_pred(NA, 0.5), NA)
expect_identical(single_pred(NA, 0.5) == single_pred(NA, 0.5), NA)
expect_identical(single_pred(c(1, NA), 1:2/3) == single_pred(c(4, NA), 1:2/3), FALSE)
expect_identical(single_pred(c(1, NA), 1:2/3) == single_pred(c(4, 5), 1:2/3), FALSE)
})

test_that("Inequalities don't work on quantile_preds, but equality & sorting does:", {
v <- quantile_pred(matrix(c(6, 1, 2, 3, 5, 6), 2, 3, byrow = TRUE), 1:3/4)
expect_error(v < v, class = "hardhat_error_comparing_quantile_preds")
expect_error(v <= v, class = "hardhat_error_comparing_quantile_preds")
expect_error(v > v, class = "hardhat_error_comparing_quantile_preds")
expect_error(v >= v, class = "hardhat_error_comparing_quantile_preds")
expect_identical(v == v, c(TRUE, TRUE))
expect_identical(v != v, c(FALSE, FALSE))
expect_identical(sort(v), v[c(2, 1)])
})

test_that("quantile_pred typeof compatibility works", {
dbl_mat <- matrix(c(1:3, c(4, NA, NA), rep(NA, 3)), 3, 3, byrow = TRUE)
int_mat <- dbl_mat
storage.mode(int_mat) <- "integer"
levels <- 1:3 / 4
dbl_v <- quantile_pred(dbl_mat, levels)
int_v <- quantile_pred(int_mat, levels)
# ptype
expect_identical(vec_ptype(dbl_v), dbl_v[0])
expect_identical(vec_ptype(int_v), int_v[0])
# ptype2
expect_identical(vec_ptype2(dbl_v, int_v), dbl_v[0])
expect_identical(vec_ptype2(int_v, dbl_v), dbl_v[0])
# cast
expect_identical(vec_cast(int_v, dbl_v), dbl_v)
expect_identical(vec_cast(dbl_v, int_v), int_v)
dbl_v2 <- dbl_v
field(dbl_v2, "quantile_values") <- field(dbl_v2, "quantile_values") + 0.5
expect_error(vec_cast(dbl_v2, int_v), class = "vctrs_error_cast_lossy")
})

test_that("quantile_pred level (in)compatibility works", {
levels1 <- seq(0, 0.2, by = 0.05)
levels2 <- c(0, 0.05, 0.1, 0.15, 0.2)
expect_false(all(levels1 == levels2))
v1 <- quantile_pred(t(as.matrix(1:5)), levels1)
v2 <- quantile_pred(t(as.matrix(1:5)), levels2)
expect_error(vec_ptype2(v1, v2), class = "vctrs_error_incompatible_type")
expect_snapshot(vec_ptype2(v1, v2), error = TRUE, cnd_class = TRUE)
expect_error(vec_cast(v1, v2), class = "vctrs_error_incompatible_type")
expect_error(vec_cast(v2, v1), class = "vctrs_error_incompatible_type")
})