Skip to content

Commit 24693c2

Browse files
committed
Merge branch 'master' into ridgeline-size
2 parents 7453c2e + 52f1a9f commit 24693c2

25 files changed

+645
-365
lines changed

.github/FUNDING.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
github: stan-dev
2+
custom: https://mc-stan.org/support/

.github/workflows/R-CMD-check.yaml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@ jobs:
1919
fail-fast: false
2020
matrix:
2121
config:
22-
- {os: macOS-latest, r: 'devel'}
23-
- {os: macOS-latest, r: 'release'}
2422
- {os: windows-latest, r: 'release'}
25-
- {os: ubuntu-16.04, r: 'release', rspm: "https://packagemanager.rstudio.com/cran/__linux__/xenial/latest"}
26-
- {os: ubuntu-16.04, r: 'oldrel', rspm: "https://packagemanager.rstudio.com/cran/__linux__/xenial/latest"}
23+
- {os: macOS-latest, r: 'release'}
24+
- {os: ubuntu-20.04, r: 'release', rspm: "https://packagemanager.rstudio.com/cran/__linux__/focal/latest"}
25+
- {os: ubuntu-20.04, r: 'devel', rspm: "https://packagemanager.rstudio.com/cran/__linux__/focal/latest"}
2726
env:
2827
R_REMOTES_NO_ERRORS_FROM_WARNINGS: true
2928
RSPM: ${{ matrix.config.rspm }}

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ Suggests:
4949
shinystan (>= 2.3.0),
5050
testthat (>= 2.0.0),
5151
vdiffr
52-
RoxygenNote: 7.1.0.9000
52+
RoxygenNote: 7.1.1
5353
VignetteBuilder: knitr
5454
Encoding: UTF-8
5555
Roxygen: list(markdown = TRUE)

NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22

33
S3method("[",neff_ratio)
44
S3method("[",rhat)
5+
S3method(log_posterior,CmdStanMCMC)
56
S3method(log_posterior,stanfit)
67
S3method(log_posterior,stanreg)
8+
S3method(neff_ratio,CmdStanMCMC)
79
S3method(neff_ratio,stanfit)
810
S3method(neff_ratio,stanreg)
11+
S3method(nuts_params,CmdStanMCMC)
912
S3method(nuts_params,list)
1013
S3method(nuts_params,stanfit)
1114
S3method(nuts_params,stanreg)
@@ -15,6 +18,7 @@ S3method(pp_check,default)
1518
S3method(print,bayesplot_function_list)
1619
S3method(print,bayesplot_grid)
1720
S3method(print,bayesplot_scheme)
21+
S3method(rhat,CmdStanMCMC)
1822
S3method(rhat,stanfit)
1923
S3method(rhat,stanreg)
2024
export(abline_01)

NEWS.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,14 @@
88
* Items for next release go here
99
-->
1010

11+
* CmdStanMCMC objects (from CmdStanR) can now be used with extractor
12+
functions `nuts_params()`, `log_posterior()`, `rhat()`, and
13+
`neff_ratio()`. (#227)
14+
1115
* Added missing `facet_args` argument to `mcmc_rank_overlay()`. (#221, @hhau)
16+
* Size of points and interval lines can set in
17+
`mcmc_intervals(..., outer_size, inner_size, point_size)`. (#215, #228, #229)
18+
* `mcmc_areas()` tries to use less blank vertical blank space. (#218, #230)
1219

1320

1421
# bayesplot 1.7.2

R/bayesplot-extractors.R

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#' Extract quantities needed for plotting from model objects
22
#'
33
#' Generics and methods for extracting quantities needed for plotting from
4-
#' various types of model objects. Currently methods are only provided for
5-
#' stanfit (**rstan**) and stanreg (**rstanarm**) objects, but adding new
6-
#' methods should be relatively straightforward.
4+
#' various types of model objects. Currently methods are provided for stanfit
5+
#' (**rstan**), CmdStanMCMC (**cmdstanr**), and stanreg (**rstanarm**) objects,
6+
#' but adding new methods should be relatively straightforward.
77
#'
88
#' @name bayesplot-extractors
99
#' @param object The object to use.
@@ -87,7 +87,8 @@ log_posterior.stanfit <- function(object, inc_warmup = FALSE, ...) {
8787
...)
8888
lp <- lapply(lp, as.array)
8989
lp <- set_names(reshape2::melt(lp), c("Iteration", "Value", "Chain"))
90-
validate_df_classes(lp, c("integer", "numeric", "integer"))
90+
validate_df_classes(lp[, c("Chain", "Iteration", "Value")],
91+
c("integer", "integer", "numeric"))
9192
}
9293

9394
#' @rdname bayesplot-extractors
@@ -98,11 +99,22 @@ log_posterior.stanreg <- function(object, inc_warmup = FALSE, ...) {
9899
log_posterior.stanfit(object$stanfit, inc_warmup = inc_warmup, ...)
99100
}
100101

102+
#' @rdname bayesplot-extractors
103+
#' @export
104+
#' @method log_posterior CmdStanMCMC
105+
log_posterior.CmdStanMCMC <- function(object, inc_warmup = FALSE, ...) {
106+
lp <- object$draws("lp__", inc_warmup = inc_warmup)
107+
lp <- reshape2::melt(lp)
108+
lp$variable <- NULL
109+
lp <- dplyr::rename_with(lp, capitalize_first)
110+
validate_df_classes(lp[, c("Chain", "Iteration", "Value")],
111+
c("integer", "integer", "numeric"))
112+
}
113+
101114

102115
#' @rdname bayesplot-extractors
103116
#' @export
104117
#' @method nuts_params stanfit
105-
#'
106118
nuts_params.stanfit <-
107119
function(object,
108120
pars = NULL,
@@ -153,7 +165,23 @@ nuts_params.list <- function(object, pars = NULL, ...) {
153165

154166
out <- reshape2::melt(object)
155167
out <- set_names(out, c("Iteration", "Parameter", "Value", "Chain"))
156-
validate_df_classes(out, c("integer", "factor", "numeric", "integer"))
168+
validate_df_classes(out[, c("Chain", "Iteration", "Parameter", "Value")],
169+
c("integer", "integer", "factor", "numeric"))
170+
}
171+
172+
#' @rdname bayesplot-extractors
173+
#' @export
174+
#' @method nuts_params CmdStanMCMC
175+
nuts_params.CmdStanMCMC <- function(object, pars = NULL, ...) {
176+
arr <- object$sampler_diagnostics()
177+
if (!is.null(pars)) {
178+
arr <- arr[,, pars]
179+
}
180+
out <- reshape2::melt(arr)
181+
colnames(out)[colnames(out) == "variable"] <- "parameter"
182+
out <- dplyr::rename_with(out, capitalize_first)
183+
validate_df_classes(out[, c("Chain", "Iteration", "Parameter", "Value")],
184+
c("integer", "integer", "factor", "numeric"))
157185
}
158186

159187

@@ -188,6 +216,17 @@ rhat.stanreg <- function(object, pars = NULL, regex_pars = NULL, ...) {
188216
r[!names(r) %in% c("mean_PPD", "log-posterior")]
189217
}
190218

219+
#' @rdname bayesplot-extractors
220+
#' @export
221+
#' @method rhat CmdStanMCMC
222+
rhat.CmdStanMCMC <- function(object, pars = NULL, ...) {
223+
.rhat <- utils::getFromNamespace("rhat", "posterior")
224+
s <- object$summary(pars, rhat = .rhat)[, c("variable", "rhat")]
225+
r <- setNames(s$rhat, s$variable)
226+
r <- validate_rhat(r)
227+
r[!names(r) %in% "lp__"]
228+
}
229+
191230

192231
#' @rdname bayesplot-extractors
193232
#' @export
@@ -223,6 +262,18 @@ neff_ratio.stanreg <- function(object, pars = NULL, regex_pars = NULL, ...) {
223262
ratio[!names(ratio) %in% c("mean_PPD", "log-posterior")]
224263
}
225264

265+
#' @rdname bayesplot-extractors
266+
#' @export
267+
#' @method neff_ratio CmdStanMCMC
268+
neff_ratio.CmdStanMCMC <- function(object, pars = NULL, ...) {
269+
s <- object$summary(pars, "n_eff" = "ess_basic")[, c("variable", "n_eff")]
270+
ess <- setNames(s$n_eff, s$variable)
271+
tss <- prod(dim(object$draws())[1:2])
272+
ratio <- ess / tss
273+
ratio <- validate_neff_ratio(ratio)
274+
ratio[!names(ratio) %in% "lp__"]
275+
}
276+
226277

227278
# internals ---------------------------------------------------------------
228279

@@ -245,3 +296,10 @@ validate_df_classes <- function(x, classes = character()) {
245296
}
246297
x
247298
}
299+
300+
# capitalize first letter in a string only
301+
capitalize_first <- function(name) {
302+
name <- tolower(name) # in case whole string is capitalized
303+
substr(name, 1, 1) <- toupper(substr(name, 1, 1))
304+
name
305+
}

R/mcmc-diagnostics-nuts.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -513,8 +513,8 @@ validate_nuts_data_frame <- function(x, lp) {
513513
abort("NUTS parameters should be in a data frame.")
514514
}
515515

516-
valid_cols <- c("Iteration", "Parameter", "Value", "Chain")
517-
if (!identical(colnames(x), valid_cols)) {
516+
valid_cols <- sort(c("Iteration", "Parameter", "Value", "Chain"))
517+
if (!identical(sort(colnames(x)), valid_cols)) {
518518
abort(paste(
519519
"NUTS parameter data frame must have columns:",
520520
paste(valid_cols, collapse = ", ")
@@ -529,8 +529,8 @@ validate_nuts_data_frame <- function(x, lp) {
529529
abort("lp should be in a data frame.")
530530
}
531531

532-
valid_lp_cols <- c("Iteration", "Value", "Chain")
533-
if (!identical(colnames(lp), valid_lp_cols)) {
532+
valid_lp_cols <- sort(c("Iteration", "Value", "Chain"))
533+
if (!identical(sort(colnames(lp)), valid_lp_cols)) {
534534
abort(paste(
535535
"lp data frame must have columns:",
536536
paste(valid_lp_cols, collapse = ", ")

R/mcmc-intervals.R

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
#' `"equal height"` are scaled using `height*sqrt(height)`
2626
#' @param point_est The point estimate to show. Either `"median"` (the
2727
#' default), `"mean"`, or `"none"`.
28+
#' @param inner_size,outer_size For `mcmc_intervals()`, the size of
29+
#' the inner and interval segments, respectively.
30+
#' @param point_size For `mcmc_intervals()`, the size of point estimate.
2831
#' @param rhat An optional numeric vector of R-hat estimates, with one element
2932
#' per parameter included in `x`. If `rhat` is provided, the intervals/areas
3033
#' and point estimates in the resulting plot are colored based on R-hat value.
@@ -52,6 +55,10 @@
5255
#'
5356
#' @examples
5457
#' set.seed(9262017)
58+
#'
59+
#' # load ggplot2 to use its functions to modify our plots
60+
#' library(ggplot2)
61+
#'
5562
#' # some parameter draws to use for demonstration
5663
#' x <- example_mcmc_draws(params = 6)
5764
#' dim(x)
@@ -61,19 +68,31 @@
6168
#' mcmc_intervals(x)
6269
#' mcmc_intervals(x, pars = c("beta[1]", "beta[2]"))
6370
#' mcmc_areas(x, regex_pars = "beta\\[[1-3]\\]", prob = 0.8) +
64-
#' ggplot2::labs(
71+
#' labs(
6572
#' title = "Posterior distributions",
6673
#' subtitle = "with medians and 80% intervals"
6774
#' )
6875
#'
6976
#' color_scheme_set("red")
70-
#' mcmc_areas(
77+
#' p <- mcmc_areas(
7178
#' x,
7279
#' pars = c("alpha", "beta[4]"),
7380
#' prob = 2/3,
7481
#' prob_outer = 0.9,
7582
#' point_est = "mean"
7683
#' )
84+
#' plot(p)
85+
#'
86+
#' # control spacing at top and bottom of plot
87+
#' # see ?ggplot2::expansion
88+
#' p + scale_y_discrete(
89+
#' limits = c("beta[4]", "alpha"),
90+
#' expand = expansion(add = c(1, 2))
91+
#' )
92+
#' p + scale_y_discrete(
93+
#' limits = c("beta[4]", "alpha"),
94+
#' expand = expansion(add = c(.1, .3))
95+
#' )
7796
#'
7897
#' # color by rhat value
7998
#' color_scheme_set("blue")
@@ -94,22 +113,22 @@
94113
#' b3 <- c("beta[1]", "beta[2]", "beta[3]")
95114
#'
96115
#' mcmc_areas(x, pars = b3, area_method = "equal area") +
97-
#' ggplot2::labs(
116+
#' labs(
98117
#' title = "Curves have same area",
99-
#' subtitle =
100-
#' "A wide, uncertain interval is spread thin when areas are equal")
118+
#' subtitle = "A wide, uncertain interval is spread thin when areas are equal"
119+
#' )
101120
#'
102121
#' mcmc_areas(x, pars = b3, area_method = "equal height") +
103-
#' ggplot2::labs(
122+
#' labs(
104123
#' title = "Curves have same maximum height",
105-
#' subtitle =
106-
#' "Local curvature is clearer but more uncertain curves use more area")
124+
#' subtitle = "Local curvature is clearer but more uncertain curves use more area"
125+
#' )
107126
#'
108127
#' mcmc_areas(x, pars = b3, area_method = "scaled height") +
109-
#' ggplot2::labs(
128+
#' labs(
110129
#' title = "Same maximum heights but heights scaled by square-root",
111-
#' subtitle =
112-
#' "Compromise: Local curvature is accentuated and less area is used")
130+
#' subtitle = "Compromise: Local curvature is accentuated and less area is used"
131+
#' )
113132
#'
114133
#' \donttest{
115134
#' # apply transformations
@@ -145,7 +164,7 @@
145164
#' # plotted with ridgelines
146165
#' m <- shinystan::eight_schools@posterior_sample
147166
#' mcmc_areas_ridges(m, pars = "mu", regex_pars = "theta") +
148-
#' ggplot2::ggtitle("Treatment effect on eight schools (Rubin, 1981)")
167+
#' ggtitle("Treatment effect on eight schools (Rubin, 1981)")
149168
#' }
150169
#'
151170
NULL
@@ -160,6 +179,9 @@ mcmc_intervals <- function(x,
160179
prob = 0.5,
161180
prob_outer = 0.9,
162181
point_est = c("median", "mean", "none"),
182+
outer_size = 0.5,
183+
inner_size = 2,
184+
point_size = 4,
163185
rhat = numeric()) {
164186
check_ignored_arguments(...)
165187

@@ -184,17 +206,18 @@ mcmc_intervals <- function(x,
184206

185207
args_outer <- list(
186208
mapping = aes_(x = ~ ll, xend = ~ hh, y = ~ parameter, yend = ~ parameter),
187-
color = get_color("mid")
209+
color = get_color("mid"),
210+
size = outer_size
188211
)
189212
args_inner <- list(
190213
mapping = aes_(x = ~ l, xend = ~ h, y = ~ parameter, yend = ~ parameter),
191-
size = 2,
214+
size = inner_size,
192215
show.legend = FALSE
193216
)
194217
args_point <- list(
195218
mapping = aes_(x = ~ m, y = ~ parameter),
196219
data = data,
197-
size = 4,
220+
size = point_size,
198221
shape = 21
199222
)
200223

@@ -269,7 +292,8 @@ mcmc_areas <- function(x,
269292
x, pars, regex_pars, transformations,
270293
prob = prob, prob_outer = prob_outer,
271294
point_est = point_est, rhat = rhat,
272-
bw = bw, adjust = adjust, kernel = kernel, n_dens = n_dens)
295+
bw = bw, adjust = adjust, kernel = kernel, n_dens = n_dens
296+
)
273297
datas <- split(data, data$interval)
274298

275299
# Use a dummy empty dataframe if no point estimate
@@ -312,7 +336,11 @@ mcmc_areas <- function(x,
312336

313337
datas$bottom <- datas$outer %>%
314338
group_by(!!! groups) %>%
315-
summarise(ll = min(.data$x), hh = max(.data$x)) %>%
339+
summarise(
340+
ll = min(.data$x),
341+
hh = max(.data$x),
342+
.groups = "drop_last"
343+
) %>%
316344
ungroup()
317345

318346
args_bottom <- list(
@@ -360,9 +388,16 @@ mcmc_areas <- function(x,
360388
args_outer$color <- get_color("dark")
361389
}
362390

391+
# An invisible layer that is 2.5% taller than the plotted one
392+
args_outer2 <- args_outer
393+
args_outer2$mapping <- args_outer2$mapping %>%
394+
modify_aes_(scale = .925)
395+
args_outer2$color <- NA
396+
363397
layer_bottom <- do.call(geom_segment, args_bottom)
364398
layer_inner <- do.call(ggridges::geom_ridgeline, args_inner)
365399
layer_outer <- do.call(ggridges::geom_ridgeline, args_outer)
400+
layer_outer2 <- do.call(ggridges::geom_ridgeline, args_outer2)
366401

367402
point_geom <- if (no_point_est) {
368403
geom_ignore
@@ -386,12 +421,17 @@ mcmc_areas <- function(x,
386421
layer_inner +
387422
layer_point +
388423
layer_outer +
424+
layer_outer2 +
389425
layer_bottom +
390426
scale_color +
391427
scale_fill +
392428
scale_y_discrete(
393429
limits = unique(rev(data$parameter)),
394-
expand = expansion(add = c(0, .1), mult = c(.1, .3))) +
430+
expand = expansion(
431+
add = c(0, .5 + 1/(2 * nlevels(data$parameter))),
432+
mult = c(.1, .1)
433+
)
434+
) +
395435
xlim(x_lim) +
396436
bayesplot_theme_get() +
397437
legend_move(ifelse(color_by_rhat, "top", "none")) +

0 commit comments

Comments
 (0)