diff --git a/R/ppc-errors.R b/R/ppc-errors.R index 31ecfb2b..7a750b69 100644 --- a/R/ppc-errors.R +++ b/R/ppc-errors.R @@ -324,6 +324,45 @@ ppc_error_scatter_avg_vs_x <- function( coord_flip() } +#' @rdname PPC-errors +#' @export +ppc_residual_scatter <- function( + y, + yrep, + x = NULL, + ..., + stat = "mean", + facet_args = list(), + size = 2.5, + alpha = 0.8) { + check_ignored_arguments(...) + + y <- validate_y(y) + yrep <- validate_predictions(yrep, length(y)) + + if (!missing(x)) { + qx <- enquo(x) + x <- validate_x(x, y) + } + + stat <- as_tagged_function({{ stat }}) + residuals <- compute_residuals(y, yrep, stat) + stat_yrep <- apply(yrep, 2, stat) + + ppc_scatter( + y = if (is_null(x)) stat_yrep else x, + yrep = residuals, + facet_args = facet_args, + size = size, + alpha = alpha, + ref_line = FALSE + ) + + labs( + x = residual_label(stat), + y = if (is_null(x)) y_label() else as_label((qx)) + ) + coord_flip() +} + #' @rdname PPC-errors #' @export @@ -415,6 +454,19 @@ compute_errors <- function(y, yrep) { rstantools::predictive_error(object = yrep, y = y) } +#' Compute predictive residuals `y` - `stat(yrep)` +#' @noRd +#' @param y,yrep User's `y` and `yrep` arguments. +#' @param stat Function or string to compute statistic +#' across the draws for each data point. +#' @return A vector of residuals, one for each observation +compute_residuals <- function(y, yrep, stat) { + yrep_stat <- apply(yrep, 2, stat) + residuals <- y - yrep_stat + residuals <- matrix(residuals, nrow = 1) + return(residuals) +} + #' Create facet layer for PPC error plots #' @@ -473,6 +525,22 @@ error_avg_label <- function(stat = NULL) { expr(paste((!!de))*(italic(y) - italic(y)[rep])) } +residual_label <- function(stat = NULL) { + stat <- as_tagged_function({{ stat }}, fallback = "stat") + e <- attr(stat, "tagged_expr") + if (attr(stat, "is_anonymous_function")) { + e <- sym("stat") + } + de <- deparse1(e) + + # create some dummy variables to pass the R package check for + # global variables in the expression below + italic <- sym("italic") + y <- sym("y") + + expr(paste(italic(y) - (!!de)*(italic(y)[rep]))) +} + # Data for binned errors plots ppc_error_binnned_data <- function(y, yrep, x = NULL, bins = NULL) {