Skip to content

Commit 3a0a61c

Browse files
Rank ECDF with simultaneous confidence intervals
1 parent 7e492aa commit 3a0a61c

File tree

2 files changed

+161
-2
lines changed

2 files changed

+161
-2
lines changed

R/mcmc-traces.R

Lines changed: 130 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,17 @@
6363
#' Ranks from `mcmc_rank_hist()` are plotted using overlaid lines in a
6464
#' single panel.
6565
#' }
66+
#' \item{`mcmc_rank_ecdf()`}{
67+
#' The ECDFs of the ranks from `mcmc_rank_hist()` are plotted and
68+
#' simultaneous confidence bands with a coverage determined by
69+
#' `prob` are drawn. By default, the difference between the
70+
#' observed ECDF and the theoretical expectation is drawn.
71+
#' See Säilynoja et al. (2021) for details.
72+
#' }
6673
#' }
6774
#'
6875
#' @template reference-improved-rhat
76+
#' @template reference-uniformity-test
6977
#' @examples
7078
#' # some parameter draws to use for demonstration
7179
#' x <- example_mcmc_draws(chains = 4, params = 6)
@@ -316,14 +324,14 @@ mcmc_rank_overlay <- function(x,
316324
} else {
317325
NULL
318326
}
319-
327+
320328
facet_call <- NULL
321329
if (n_param > 1) {
322330
facet_args$facets <- ~ parameter
323331
facet_args$scales <- facet_args$scales %||% "fixed"
324332
facet_call <- do.call("facet_wrap", facet_args)
325333
}
326-
334+
327335
ggplot(d_bin_counts) +
328336
aes_(x = ~ bin_start, y = ~ n, color = ~ chain) +
329337
geom_step() +
@@ -421,6 +429,33 @@ mcmc_rank_hist <- function(x,
421429
labs(x = "Rank")
422430
}
423431

432+
#' @rdname MCMC-traces
433+
#' @param prob For `mcmc_rank_ecdf()`, a value between 0 and 1
434+
#' specifying the desired simultaneous confidence of the confidence bands to be
435+
#' drawn for the rank ECDF plots.
436+
#' @param plot_diff For `mcmc_rank_ecdf()`, a boolean specifying it the
437+
#' difference between the observed rank ECDFs and the theoretical expectation
438+
#' should be drawn instead of the unmodified rank ECDF plots.
439+
#' @export
440+
mcmc_rank_ecdf <-
441+
function(x,
442+
pars = character(),
443+
regex_pars = character(),
444+
transformations = list(),
445+
...,
446+
facet_args = list(),
447+
prob = 0.99,
448+
plot_diff = TRUE) {
449+
check_ignored_arguments(...)
450+
.mcmc_rank_ecdf(x,
451+
pars = pars,
452+
regex_pars = regex_pars,
453+
transformations = transformations,
454+
facet_args = facet_args,
455+
prob = prob,
456+
plot_diff = plot_diff
457+
)
458+
}
424459

425460
#' @rdname MCMC-traces
426461
#' @export
@@ -604,6 +639,99 @@ mcmc_trace_data <- function(x,
604639
yaxis_title(on = n_param == 1)
605640
}
606641

642+
.mcmc_rank_ecdf <- function(x,
643+
pars = character(),
644+
regex_pars = character(),
645+
transformations = list(),
646+
facet_args = list(),
647+
...,
648+
K,
649+
plot_diff = TRUE,
650+
prob = 0.99,
651+
adj_method = "interpolate") {
652+
data <- mcmc_trace_data(
653+
x,
654+
pars = pars,
655+
regex_pars = regex_pars,
656+
transformations = transformations
657+
)
658+
n_iter <- unique(data$n_iterations)
659+
n_chain <- unique(data$n_chains)
660+
n_param <- unique(data$n_parameters)
661+
662+
x <- if (missing(K)) {
663+
0:n_iter / n_iter
664+
} else {
665+
0:K / K
666+
}
667+
gamma <- adjust_gamma(
668+
N = n_iter,
669+
L = n_chain,
670+
K = if (missing(K)) {
671+
n_iter
672+
} else {
673+
K
674+
},
675+
conf_level = prob,
676+
...,
677+
adj_method = adj_method
678+
)
679+
lims <- ecdf_intervals(
680+
N = n_iter,
681+
L = n_chain,
682+
K = if (missing(K)) {
683+
n_iter
684+
} else {
685+
K
686+
},
687+
gamma = gamma
688+
)
689+
data_lim <- data.frame(
690+
upper = lims$upper / n_iter - (plot_diff == TRUE) * x,
691+
lower = lims$lower / n_iter - (plot_diff == TRUE) * x,
692+
x = x
693+
)
694+
data <- data %>%
695+
group_by(parameter, chain) %>%
696+
dplyr::group_map(~ data.frame(
697+
parameter = .y[1],
698+
chain = .y[2],
699+
ecdf_value = ecdf(.x$value_rank / (n_iter * n_chain))(x) -
700+
(plot_diff == TRUE) * x,
701+
x = x
702+
)) %>%
703+
dplyr::bind_rows()
704+
705+
mapping <- aes_(
706+
x = ~x,
707+
y = ~ecdf_value,
708+
color = ~chain,
709+
group = ~chain
710+
)
711+
712+
scale_color <- scale_color_manual("Chain", values = chain_colors(n_chain))
713+
714+
facet_call <- NULL
715+
if (n_param == 1) {
716+
facet_call <- ylab(levels(data$parameter))
717+
} else {
718+
facet_args$facets <- ~parameter
719+
facet_args$scales <- facet_args$scales %||% "free"
720+
facet_call <- do.call("facet_wrap", facet_args)
721+
}
722+
723+
ggplot() +
724+
geom_step(data = data_lim, aes_(x = ~x, y = ~upper), show.legend = FALSE) +
725+
geom_step(data = data_lim, aes_(x = ~x, y = ~lower), show.legend = FALSE) +
726+
geom_step(mapping, data) +
727+
bayesplot_theme_get() +
728+
scale_color +
729+
facet_call +
730+
scale_x_continuous(breaks = pretty) +
731+
legend_move(ifelse(n_chain > 1, "right", "none")) +
732+
xaxis_title(FALSE) +
733+
yaxis_title(on = n_param == 1)
734+
}
607735

608736
chain_colors <- function(n) {
609737
all_clrs <- unlist(color_scheme_get())

man/MCMC-traces.Rd

Lines changed: 31 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)