Skip to content

Commit 617a769

Browse files
Documentation and example for mcmc-rank-ecdf
1 parent 2cc9dff commit 617a769

File tree

2 files changed

+113
-109
lines changed

2 files changed

+113
-109
lines changed

R/mcmc-traces.R

Lines changed: 92 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#' @template args-regex_pars
1111
#' @template args-transformations
1212
#' @template args-facet_args
13+
#' @template args-pit-ecdf
1314
#' @param ... Currently ignored.
1415
#' @param size An optional value to override the default line size
1516
#' for `mcmc_trace()` or the default point size for `mcmc_trace_highlight()`.
@@ -113,6 +114,11 @@
113114
#' mcmc_rank_hist(x, pars = c("alpha", "sigma"), ref_line = TRUE)
114115
#' mcmc_rank_overlay(x, "alpha")
115116
#'
117+
#' # ECDF difference plots of the ranking of MCMC samples between chains.
118+
#' # Provide 99% simultaneous confidence intervals for the chains sampling from
119+
#' the same distribution.
120+
#' mcmc_rank_ecdf(x, prob = 0.99)
121+
#'
116122
#' \dontrun{
117123
#' # parse facet label text
118124
#' color_scheme_set("purple")
@@ -443,20 +449,96 @@ mcmc_rank_ecdf <-
443449
regex_pars = character(),
444450
transformations = list(),
445451
...,
452+
K = NULL,
446453
facet_args = list(),
447454
prob = 0.99,
448-
plot_diff = TRUE) {
449-
check_ignored_arguments(...)
450-
.mcmc_rank_ecdf(x,
451-
pars = pars,
452-
regex_pars = regex_pars,
453-
transformations = transformations,
454-
facet_args = facet_args,
455-
prob = prob,
456-
plot_diff = plot_diff
457-
)
455+
plot_diff = TRUE,
456+
adj_method = "interpolate") {
457+
check_ignored_arguments(..., ok_args = c("M"))
458+
data <- mcmc_trace_data(
459+
x,
460+
pars = pars,
461+
regex_pars = regex_pars,
462+
transformations = transformations
463+
)
464+
n_iter <- unique(data$n_iterations)
465+
n_chain <- unique(data$n_chains)
466+
n_param <- unique(data$n_parameters)
467+
468+
x <- if (is.null(K)) {
469+
0:n_iter / n_iter
470+
} else {
471+
0:K / K
472+
}
473+
gamma <- adjust_gamma(
474+
N = n_iter,
475+
L = n_chain,
476+
K = if (is.null(K)) {
477+
n_iter
478+
} else {
479+
K
480+
},
481+
prob = prob,
482+
...,
483+
adj_method = adj_method
484+
)
485+
lims <- ecdf_intervals(
486+
N = n_iter,
487+
L = n_chain,
488+
K = if (is.null(K)) {
489+
n_iter
490+
} else {
491+
K
492+
},
493+
gamma = gamma
494+
)
495+
data_lim <- data.frame(
496+
upper = lims$upper / n_iter - (plot_diff == TRUE) * x,
497+
lower = lims$lower / n_iter - (plot_diff == TRUE) * x,
498+
x = x
499+
)
500+
data <- data %>%
501+
group_by(parameter, chain) %>%
502+
dplyr::group_map(~ data.frame(
503+
parameter = .y[1],
504+
chain = .y[2],
505+
ecdf_value = ecdf(.x$value_rank / (n_iter * n_chain))(x) -
506+
(plot_diff == TRUE) * x,
507+
x = x
508+
)) %>%
509+
dplyr::bind_rows()
510+
511+
mapping <- aes_(
512+
x = ~x,
513+
y = ~ecdf_value,
514+
color = ~chain,
515+
group = ~chain
516+
)
517+
518+
scale_color <- scale_color_manual("Chain", values = chain_colors(n_chain))
519+
520+
facet_call <- NULL
521+
if (n_param == 1) {
522+
facet_call <- ylab(levels(data$parameter))
523+
} else {
524+
facet_args$facets <- ~parameter
525+
facet_args$scales <- facet_args$scales %||% "free"
526+
facet_call <- do.call("facet_wrap", facet_args)
458527
}
459528

529+
ggplot() +
530+
geom_step(data = data_lim, aes_(x = ~x, y = ~upper), show.legend = FALSE) +
531+
geom_step(data = data_lim, aes_(x = ~x, y = ~lower), show.legend = FALSE) +
532+
geom_step(mapping, data) +
533+
bayesplot_theme_get() +
534+
scale_color +
535+
facet_call +
536+
scale_x_continuous(breaks = pretty) +
537+
legend_move(ifelse(n_chain > 1, "right", "none")) +
538+
xaxis_title(FALSE) +
539+
yaxis_title(on = n_param == 1)
540+
}
541+
460542
#' @rdname MCMC-traces
461543
#' @export
462544
mcmc_trace_data <- function(x,
@@ -639,100 +721,6 @@ mcmc_trace_data <- function(x,
639721
yaxis_title(on = n_param == 1)
640722
}
641723

642-
.mcmc_rank_ecdf <- function(x,
643-
pars = character(),
644-
regex_pars = character(),
645-
transformations = list(),
646-
facet_args = list(),
647-
...,
648-
K,
649-
plot_diff = TRUE,
650-
prob = 0.99,
651-
adj_method = "interpolate") {
652-
data <- mcmc_trace_data(
653-
x,
654-
pars = pars,
655-
regex_pars = regex_pars,
656-
transformations = transformations
657-
)
658-
n_iter <- unique(data$n_iterations)
659-
n_chain <- unique(data$n_chains)
660-
n_param <- unique(data$n_parameters)
661-
662-
x <- if (missing(K)) {
663-
0:n_iter / n_iter
664-
} else {
665-
0:K / K
666-
}
667-
gamma <- adjust_gamma(
668-
N = n_iter,
669-
L = n_chain,
670-
K = if (missing(K)) {
671-
n_iter
672-
} else {
673-
K
674-
},
675-
conf_level = prob,
676-
...,
677-
adj_method = adj_method
678-
)
679-
lims <- ecdf_intervals(
680-
N = n_iter,
681-
L = n_chain,
682-
K = if (missing(K)) {
683-
n_iter
684-
} else {
685-
K
686-
},
687-
gamma = gamma
688-
)
689-
data_lim <- data.frame(
690-
upper = lims$upper / n_iter - (plot_diff == TRUE) * x,
691-
lower = lims$lower / n_iter - (plot_diff == TRUE) * x,
692-
x = x
693-
)
694-
data <- data %>%
695-
group_by(parameter, chain) %>%
696-
dplyr::group_map(~ data.frame(
697-
parameter = .y[1],
698-
chain = .y[2],
699-
ecdf_value = ecdf(.x$value_rank / (n_iter * n_chain))(x) -
700-
(plot_diff == TRUE) * x,
701-
x = x
702-
)) %>%
703-
dplyr::bind_rows()
704-
705-
mapping <- aes_(
706-
x = ~x,
707-
y = ~ecdf_value,
708-
color = ~chain,
709-
group = ~chain
710-
)
711-
712-
scale_color <- scale_color_manual("Chain", values = chain_colors(n_chain))
713-
714-
facet_call <- NULL
715-
if (n_param == 1) {
716-
facet_call <- ylab(levels(data$parameter))
717-
} else {
718-
facet_args$facets <- ~parameter
719-
facet_args$scales <- facet_args$scales %||% "free"
720-
facet_call <- do.call("facet_wrap", facet_args)
721-
}
722-
723-
ggplot() +
724-
geom_step(data = data_lim, aes_(x = ~x, y = ~upper), show.legend = FALSE) +
725-
geom_step(data = data_lim, aes_(x = ~x, y = ~lower), show.legend = FALSE) +
726-
geom_step(mapping, data) +
727-
bayesplot_theme_get() +
728-
scale_color +
729-
facet_call +
730-
scale_x_continuous(breaks = pretty) +
731-
legend_move(ifelse(n_chain > 1, "right", "none")) +
732-
xaxis_title(FALSE) +
733-
yaxis_title(on = n_param == 1)
734-
}
735-
736724
chain_colors <- function(n) {
737725
all_clrs <- unlist(color_scheme_get())
738726
clrs <- switch(

man/MCMC-traces.Rd

Lines changed: 21 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)