|
10 | 10 | #' @template args-regex_pars
|
11 | 11 | #' @template args-transformations
|
12 | 12 | #' @template args-facet_args
|
| 13 | +#' @template args-pit-ecdf |
13 | 14 | #' @param ... Currently ignored.
|
14 | 15 | #' @param size An optional value to override the default line size
|
15 | 16 | #' for `mcmc_trace()` or the default point size for `mcmc_trace_highlight()`.
|
|
113 | 114 | #' mcmc_rank_hist(x, pars = c("alpha", "sigma"), ref_line = TRUE)
|
114 | 115 | #' mcmc_rank_overlay(x, "alpha")
|
115 | 116 | #'
|
| 117 | +#' # ECDF difference plots of the ranking of MCMC samples between chains. |
| 118 | +#' # Provide 99% simultaneous confidence intervals for the chains sampling from |
| 119 | +#' the same distribution. |
| 120 | +#' mcmc_rank_ecdf(x, prob = 0.99) |
| 121 | +#' |
116 | 122 | #' \dontrun{
|
117 | 123 | #' # parse facet label text
|
118 | 124 | #' color_scheme_set("purple")
|
@@ -443,20 +449,96 @@ mcmc_rank_ecdf <-
|
443 | 449 | regex_pars = character(),
|
444 | 450 | transformations = list(),
|
445 | 451 | ...,
|
| 452 | + K = NULL, |
446 | 453 | facet_args = list(),
|
447 | 454 | prob = 0.99,
|
448 |
| - plot_diff = TRUE) { |
449 |
| - check_ignored_arguments(...) |
450 |
| - .mcmc_rank_ecdf(x, |
451 |
| - pars = pars, |
452 |
| - regex_pars = regex_pars, |
453 |
| - transformations = transformations, |
454 |
| - facet_args = facet_args, |
455 |
| - prob = prob, |
456 |
| - plot_diff = plot_diff |
457 |
| - ) |
| 455 | + plot_diff = TRUE, |
| 456 | + adj_method = "interpolate") { |
| 457 | + check_ignored_arguments(..., ok_args = c("M")) |
| 458 | + data <- mcmc_trace_data( |
| 459 | + x, |
| 460 | + pars = pars, |
| 461 | + regex_pars = regex_pars, |
| 462 | + transformations = transformations |
| 463 | + ) |
| 464 | + n_iter <- unique(data$n_iterations) |
| 465 | + n_chain <- unique(data$n_chains) |
| 466 | + n_param <- unique(data$n_parameters) |
| 467 | + |
| 468 | + x <- if (is.null(K)) { |
| 469 | + 0:n_iter / n_iter |
| 470 | + } else { |
| 471 | + 0:K / K |
| 472 | + } |
| 473 | + gamma <- adjust_gamma( |
| 474 | + N = n_iter, |
| 475 | + L = n_chain, |
| 476 | + K = if (is.null(K)) { |
| 477 | + n_iter |
| 478 | + } else { |
| 479 | + K |
| 480 | + }, |
| 481 | + prob = prob, |
| 482 | + ..., |
| 483 | + adj_method = adj_method |
| 484 | + ) |
| 485 | + lims <- ecdf_intervals( |
| 486 | + N = n_iter, |
| 487 | + L = n_chain, |
| 488 | + K = if (is.null(K)) { |
| 489 | + n_iter |
| 490 | + } else { |
| 491 | + K |
| 492 | + }, |
| 493 | + gamma = gamma |
| 494 | + ) |
| 495 | + data_lim <- data.frame( |
| 496 | + upper = lims$upper / n_iter - (plot_diff == TRUE) * x, |
| 497 | + lower = lims$lower / n_iter - (plot_diff == TRUE) * x, |
| 498 | + x = x |
| 499 | + ) |
| 500 | + data <- data %>% |
| 501 | + group_by(parameter, chain) %>% |
| 502 | + dplyr::group_map(~ data.frame( |
| 503 | + parameter = .y[1], |
| 504 | + chain = .y[2], |
| 505 | + ecdf_value = ecdf(.x$value_rank / (n_iter * n_chain))(x) - |
| 506 | + (plot_diff == TRUE) * x, |
| 507 | + x = x |
| 508 | + )) %>% |
| 509 | + dplyr::bind_rows() |
| 510 | + |
| 511 | + mapping <- aes_( |
| 512 | + x = ~x, |
| 513 | + y = ~ecdf_value, |
| 514 | + color = ~chain, |
| 515 | + group = ~chain |
| 516 | + ) |
| 517 | + |
| 518 | + scale_color <- scale_color_manual("Chain", values = chain_colors(n_chain)) |
| 519 | + |
| 520 | + facet_call <- NULL |
| 521 | + if (n_param == 1) { |
| 522 | + facet_call <- ylab(levels(data$parameter)) |
| 523 | + } else { |
| 524 | + facet_args$facets <- ~parameter |
| 525 | + facet_args$scales <- facet_args$scales %||% "free" |
| 526 | + facet_call <- do.call("facet_wrap", facet_args) |
458 | 527 | }
|
459 | 528 |
|
| 529 | + ggplot() + |
| 530 | + geom_step(data = data_lim, aes_(x = ~x, y = ~upper), show.legend = FALSE) + |
| 531 | + geom_step(data = data_lim, aes_(x = ~x, y = ~lower), show.legend = FALSE) + |
| 532 | + geom_step(mapping, data) + |
| 533 | + bayesplot_theme_get() + |
| 534 | + scale_color + |
| 535 | + facet_call + |
| 536 | + scale_x_continuous(breaks = pretty) + |
| 537 | + legend_move(ifelse(n_chain > 1, "right", "none")) + |
| 538 | + xaxis_title(FALSE) + |
| 539 | + yaxis_title(on = n_param == 1) |
| 540 | +} |
| 541 | + |
460 | 542 | #' @rdname MCMC-traces
|
461 | 543 | #' @export
|
462 | 544 | mcmc_trace_data <- function(x,
|
@@ -639,100 +721,6 @@ mcmc_trace_data <- function(x,
|
639 | 721 | yaxis_title(on = n_param == 1)
|
640 | 722 | }
|
641 | 723 |
|
642 |
| -.mcmc_rank_ecdf <- function(x, |
643 |
| - pars = character(), |
644 |
| - regex_pars = character(), |
645 |
| - transformations = list(), |
646 |
| - facet_args = list(), |
647 |
| - ..., |
648 |
| - K, |
649 |
| - plot_diff = TRUE, |
650 |
| - prob = 0.99, |
651 |
| - adj_method = "interpolate") { |
652 |
| - data <- mcmc_trace_data( |
653 |
| - x, |
654 |
| - pars = pars, |
655 |
| - regex_pars = regex_pars, |
656 |
| - transformations = transformations |
657 |
| - ) |
658 |
| - n_iter <- unique(data$n_iterations) |
659 |
| - n_chain <- unique(data$n_chains) |
660 |
| - n_param <- unique(data$n_parameters) |
661 |
| - |
662 |
| - x <- if (missing(K)) { |
663 |
| - 0:n_iter / n_iter |
664 |
| - } else { |
665 |
| - 0:K / K |
666 |
| - } |
667 |
| - gamma <- adjust_gamma( |
668 |
| - N = n_iter, |
669 |
| - L = n_chain, |
670 |
| - K = if (missing(K)) { |
671 |
| - n_iter |
672 |
| - } else { |
673 |
| - K |
674 |
| - }, |
675 |
| - conf_level = prob, |
676 |
| - ..., |
677 |
| - adj_method = adj_method |
678 |
| - ) |
679 |
| - lims <- ecdf_intervals( |
680 |
| - N = n_iter, |
681 |
| - L = n_chain, |
682 |
| - K = if (missing(K)) { |
683 |
| - n_iter |
684 |
| - } else { |
685 |
| - K |
686 |
| - }, |
687 |
| - gamma = gamma |
688 |
| - ) |
689 |
| - data_lim <- data.frame( |
690 |
| - upper = lims$upper / n_iter - (plot_diff == TRUE) * x, |
691 |
| - lower = lims$lower / n_iter - (plot_diff == TRUE) * x, |
692 |
| - x = x |
693 |
| - ) |
694 |
| - data <- data %>% |
695 |
| - group_by(parameter, chain) %>% |
696 |
| - dplyr::group_map(~ data.frame( |
697 |
| - parameter = .y[1], |
698 |
| - chain = .y[2], |
699 |
| - ecdf_value = ecdf(.x$value_rank / (n_iter * n_chain))(x) - |
700 |
| - (plot_diff == TRUE) * x, |
701 |
| - x = x |
702 |
| - )) %>% |
703 |
| - dplyr::bind_rows() |
704 |
| - |
705 |
| - mapping <- aes_( |
706 |
| - x = ~x, |
707 |
| - y = ~ecdf_value, |
708 |
| - color = ~chain, |
709 |
| - group = ~chain |
710 |
| - ) |
711 |
| - |
712 |
| - scale_color <- scale_color_manual("Chain", values = chain_colors(n_chain)) |
713 |
| - |
714 |
| - facet_call <- NULL |
715 |
| - if (n_param == 1) { |
716 |
| - facet_call <- ylab(levels(data$parameter)) |
717 |
| - } else { |
718 |
| - facet_args$facets <- ~parameter |
719 |
| - facet_args$scales <- facet_args$scales %||% "free" |
720 |
| - facet_call <- do.call("facet_wrap", facet_args) |
721 |
| - } |
722 |
| - |
723 |
| - ggplot() + |
724 |
| - geom_step(data = data_lim, aes_(x = ~x, y = ~upper), show.legend = FALSE) + |
725 |
| - geom_step(data = data_lim, aes_(x = ~x, y = ~lower), show.legend = FALSE) + |
726 |
| - geom_step(mapping, data) + |
727 |
| - bayesplot_theme_get() + |
728 |
| - scale_color + |
729 |
| - facet_call + |
730 |
| - scale_x_continuous(breaks = pretty) + |
731 |
| - legend_move(ifelse(n_chain > 1, "right", "none")) + |
732 |
| - xaxis_title(FALSE) + |
733 |
| - yaxis_title(on = n_param == 1) |
734 |
| -} |
735 |
| - |
736 | 724 | chain_colors <- function(n) {
|
737 | 725 | all_clrs <- unlist(color_scheme_get())
|
738 | 726 | clrs <- switch(
|
|
0 commit comments