Skip to content

Commit 3104a3e

Browse files
committed
make parameter,value,value_rank first columns of mcmc_trace_data()
1 parent 0e46464 commit 3104a3e

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

R/mcmc-traces.R

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#' Trace plots (time series plot) of MCMC draws
1+
#' Trace plots of MCMC draws
22
#'
33
#' Trace plot (or traceplot) of MCMC draws. See the **Plot Descriptions**
44
#' section, below, for details.
@@ -12,8 +12,7 @@
1212
#' @template args-facet_args
1313
#' @param ... Currently ignored.
1414
#' @param size An optional value to override the default line size
15-
#' (`mcmc_trace()`) or the default point size
16-
#' (`mcmc_trace_highlight()`).
15+
#' for `mcmc_trace()` or the default point size for `mcmc_trace_highlight()`.
1716
#' @param alpha For `mcmc_trace_highlight()`, passed to
1817
#' [ggplot2::geom_point()] to control the transparency of the points
1918
#' for the chains not highlighted.
@@ -27,10 +26,10 @@
2726
#' if `n_warmup` is also set to a positive value.
2827
#' @param window An integer vector of length two specifying the limits of a
2928
#' range of iterations to display.
30-
#' @param np For models fit using [NUTS] (more generally, any [symplectic
31-
#' integrator](https://en.wikipedia.org/wiki/Symplectic_integrator)), an
32-
#' optional data frame providing NUTS diagnostic information. The data frame
33-
#' should be the object returned by [nuts_params()] or one with the same
29+
#' @param np For models fit using [NUTS] (more generally, any
30+
#' [symplectic integrator](https://en.wikipedia.org/wiki/Symplectic_integrator)),
31+
#' an optional data frame providing NUTS diagnostic information. The data
32+
#' frame should be the object returned by [nuts_params()] or one with the same
3433
#' structure. If `np` is specified then tick marks are added to the bottom of
3534
#' the trace plot indicating within which iterations there was a divergence
3635
#' (if there were any). See the end of the **Examples** section, below.
@@ -40,6 +39,8 @@
4039
#' @param divergences Deprecated. Use the `np` argument instead.
4140
#'
4241
#' @template return-ggplot-or-data
42+
#' @return `mcmc_trace_data()` returns the data for the trace *and* rank plots
43+
#' in the same data frame.
4344
#'
4445
#' @section Plot Descriptions:
4546
#' \describe{
@@ -208,10 +209,9 @@ mcmc_trace <-
208209
}
209210

210211
#' @rdname MCMC-traces
212+
#' @export
211213
#' @param highlight For `mcmc_trace_highlight()`, an integer specifying one
212214
#' of the chains that will be more visible than the others in the plot.
213-
#' @export
214-
#' @md
215215
mcmc_trace_highlight <- function(x,
216216
pars = character(),
217217
regex_pars = character(),
@@ -242,13 +242,13 @@ mcmc_trace_highlight <- function(x,
242242

243243

244244
#' @rdname MCMC-traces
245+
#' @export
245246
#' @param div_color,div_size,div_alpha Optional arguments to the
246247
#' `trace_style_np()` helper function that are eventually passed to
247248
#' [ggplot2::geom_rug()] if the `np` argument is also specified. They control
248249
#' the color, size, and transparency specifications for showing divergences in
249250
#' the plot. The default values are displayed in the **Usage** section above.
250-
#' @export
251-
#' @md
251+
#'
252252
trace_style_np <- function(div_color = "red", div_size = 0.25, div_alpha = 1) {
253253
stopifnot(
254254
is.character(div_color),
@@ -466,10 +466,12 @@ mcmc_trace_data <- function(x,
466466
data$n_parameters <- num_params(data)
467467
data <- rlang::set_names(data, tolower)
468468

469+
first_cols <- syms(c("parameter", "value", "value_rank"))
469470
data <- data %>%
470471
group_by(.data$parameter) %>%
471472
mutate(value_rank = dplyr::row_number(.data$value)) %>%
472-
ungroup()
473+
ungroup() %>%
474+
select(!!! first_cols, dplyr::everything())
473475

474476
data$highlight <- if (!is.null(highlight)) {
475477
data$chain == highlight
@@ -478,8 +480,7 @@ mcmc_trace_data <- function(x,
478480
}
479481

480482
data$warmup <- data$iteration <= n_warmup
481-
data$iteration <- data$iteration + iter1
482-
483+
data$iteration <- data$iteration + as.integer(iter1)
483484
tibble::as_tibble(data)
484485
}
485486

0 commit comments

Comments
 (0)