|
63 | 63 | #' Ranks from `mcmc_rank_hist()` are plotted using overlaid lines in a
|
64 | 64 | #' single panel.
|
65 | 65 | #' }
|
| 66 | +#' \item{`mcmc_rank_ecdf()`}{ |
| 67 | +#' The ECDFs of the ranks from `mcmc_rank_hist()` are plotted and |
| 68 | +#' simultaneous confidence bands with a coverage determined by |
| 69 | +#' `prob` are drawn. By default, the difference between the |
| 70 | +#' observed ECDF and the theoretical expectation is drawn. |
| 71 | +#' See Säilynoja et al. (2021) for details. |
| 72 | +#' } |
66 | 73 | #' }
|
67 | 74 | #'
|
68 | 75 | #' @template reference-improved-rhat
|
| 76 | +#' @template reference-uniformity-test |
69 | 77 | #' @examples
|
70 | 78 | #' # some parameter draws to use for demonstration
|
71 | 79 | #' x <- example_mcmc_draws(chains = 4, params = 6)
|
@@ -316,14 +324,14 @@ mcmc_rank_overlay <- function(x,
|
316 | 324 | } else {
|
317 | 325 | NULL
|
318 | 326 | }
|
319 |
| - |
| 327 | + |
320 | 328 | facet_call <- NULL
|
321 | 329 | if (n_param > 1) {
|
322 | 330 | facet_args$facets <- ~ parameter
|
323 | 331 | facet_args$scales <- facet_args$scales %||% "fixed"
|
324 | 332 | facet_call <- do.call("facet_wrap", facet_args)
|
325 | 333 | }
|
326 |
| - |
| 334 | + |
327 | 335 | ggplot(d_bin_counts) +
|
328 | 336 | aes_(x = ~ bin_start, y = ~ n, color = ~ chain) +
|
329 | 337 | geom_step() +
|
@@ -421,6 +429,33 @@ mcmc_rank_hist <- function(x,
|
421 | 429 | labs(x = "Rank")
|
422 | 430 | }
|
423 | 431 |
|
| 432 | +#' @rdname MCMC-traces |
| 433 | +#' @param prob For `mcmc_rank_ecdf()`, a value between 0 and 1 |
| 434 | +#' specifying the desired simultaneous confidence of the confidence bands to be |
| 435 | +#' drawn for the rank ECDF plots. |
| 436 | +#' @param plot_diff For `mcmc_rank_ecdf()`, a boolean specifying it the |
| 437 | +#' difference between the observed rank ECDFs and the theoretical expectation |
| 438 | +#' should be drawn instead of the unmodified rank ECDF plots. |
| 439 | +#' @export |
| 440 | +mcmc_rank_ecdf <- |
| 441 | + function(x, |
| 442 | + pars = character(), |
| 443 | + regex_pars = character(), |
| 444 | + transformations = list(), |
| 445 | + ..., |
| 446 | + facet_args = list(), |
| 447 | + prob = 0.99, |
| 448 | + plot_diff = TRUE) { |
| 449 | + check_ignored_arguments(...) |
| 450 | + .mcmc_rank_ecdf(x, |
| 451 | + pars = pars, |
| 452 | + regex_pars = regex_pars, |
| 453 | + transformations = transformations, |
| 454 | + facet_args = facet_args, |
| 455 | + prob = prob, |
| 456 | + plot_diff = plot_diff |
| 457 | + ) |
| 458 | + } |
424 | 459 |
|
425 | 460 | #' @rdname MCMC-traces
|
426 | 461 | #' @export
|
@@ -604,6 +639,99 @@ mcmc_trace_data <- function(x,
|
604 | 639 | yaxis_title(on = n_param == 1)
|
605 | 640 | }
|
606 | 641 |
|
| 642 | +.mcmc_rank_ecdf <- function(x, |
| 643 | + pars = character(), |
| 644 | + regex_pars = character(), |
| 645 | + transformations = list(), |
| 646 | + facet_args = list(), |
| 647 | + ..., |
| 648 | + K, |
| 649 | + plot_diff = TRUE, |
| 650 | + prob = 0.99, |
| 651 | + adj_method = "interpolate") { |
| 652 | + data <- mcmc_trace_data( |
| 653 | + x, |
| 654 | + pars = pars, |
| 655 | + regex_pars = regex_pars, |
| 656 | + transformations = transformations |
| 657 | + ) |
| 658 | + n_iter <- unique(data$n_iterations) |
| 659 | + n_chain <- unique(data$n_chains) |
| 660 | + n_param <- unique(data$n_parameters) |
| 661 | + |
| 662 | + x <- if (missing(K)) { |
| 663 | + 0:n_iter / n_iter |
| 664 | + } else { |
| 665 | + 0:K / K |
| 666 | + } |
| 667 | + gamma <- adjust_gamma( |
| 668 | + N = n_iter, |
| 669 | + L = n_chain, |
| 670 | + K = if (missing(K)) { |
| 671 | + n_iter |
| 672 | + } else { |
| 673 | + K |
| 674 | + }, |
| 675 | + conf_level = prob, |
| 676 | + ..., |
| 677 | + adj_method = adj_method |
| 678 | + ) |
| 679 | + lims <- ecdf_intervals( |
| 680 | + N = n_iter, |
| 681 | + L = n_chain, |
| 682 | + K = if (missing(K)) { |
| 683 | + n_iter |
| 684 | + } else { |
| 685 | + K |
| 686 | + }, |
| 687 | + gamma = gamma |
| 688 | + ) |
| 689 | + data_lim <- data.frame( |
| 690 | + upper = lims$upper / n_iter - (plot_diff == TRUE) * x, |
| 691 | + lower = lims$lower / n_iter - (plot_diff == TRUE) * x, |
| 692 | + x = x |
| 693 | + ) |
| 694 | + data <- data %>% |
| 695 | + group_by(parameter, chain) %>% |
| 696 | + dplyr::group_map(~ data.frame( |
| 697 | + parameter = .y[1], |
| 698 | + chain = .y[2], |
| 699 | + ecdf_value = ecdf(.x$value_rank / (n_iter * n_chain))(x) - |
| 700 | + (plot_diff == TRUE) * x, |
| 701 | + x = x |
| 702 | + )) %>% |
| 703 | + dplyr::bind_rows() |
| 704 | + |
| 705 | + mapping <- aes_( |
| 706 | + x = ~x, |
| 707 | + y = ~ecdf_value, |
| 708 | + color = ~chain, |
| 709 | + group = ~chain |
| 710 | + ) |
| 711 | + |
| 712 | + scale_color <- scale_color_manual("Chain", values = chain_colors(n_chain)) |
| 713 | + |
| 714 | + facet_call <- NULL |
| 715 | + if (n_param == 1) { |
| 716 | + facet_call <- ylab(levels(data$parameter)) |
| 717 | + } else { |
| 718 | + facet_args$facets <- ~parameter |
| 719 | + facet_args$scales <- facet_args$scales %||% "free" |
| 720 | + facet_call <- do.call("facet_wrap", facet_args) |
| 721 | + } |
| 722 | + |
| 723 | + ggplot() + |
| 724 | + geom_step(data = data_lim, aes_(x = ~x, y = ~upper), show.legend = FALSE) + |
| 725 | + geom_step(data = data_lim, aes_(x = ~x, y = ~lower), show.legend = FALSE) + |
| 726 | + geom_step(mapping, data) + |
| 727 | + bayesplot_theme_get() + |
| 728 | + scale_color + |
| 729 | + facet_call + |
| 730 | + scale_x_continuous(breaks = pretty) + |
| 731 | + legend_move(ifelse(n_chain > 1, "right", "none")) + |
| 732 | + xaxis_title(FALSE) + |
| 733 | + yaxis_title(on = n_param == 1) |
| 734 | +} |
607 | 735 |
|
608 | 736 | chain_colors <- function(n) {
|
609 | 737 | all_clrs <- unlist(color_scheme_get())
|
|
0 commit comments