Skip to content

Commit 3fc3a89

Browse files
authored
Merge pull request #261 from stan-dev/issue-258
add density controls to mcmc_dens() and mcmc_dens_overlay()
2 parents 60edc5a + bb50861 commit 3fc3a89

File tree

3 files changed

+176
-96
lines changed

3 files changed

+176
-96
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
* Fix R cmd check error on linux for CRAN
66

7+
* `mcmc_dens()` and `mcmc_dens_overlay()` gain arguments for controlling the
8+
the density calculation. (#258)
79

810
# bayesplot 1.8.0
911

R/mcmc-distributions.R

Lines changed: 162 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#' @template args-regex_pars
1212
#' @template args-transformations
1313
#' @template args-facet_args
14+
#' @template args-density-controls
1415
#' @param ... Currently ignored.
1516
#'
1617
#' @template return-ggplot
@@ -105,15 +106,17 @@ NULL
105106
#' @template args-hist
106107
#' @template args-hist-freq
107108
#'
108-
mcmc_hist <- function(x,
109-
pars = character(),
110-
regex_pars = character(),
111-
transformations = list(),
112-
...,
113-
facet_args = list(),
114-
binwidth = NULL,
115-
breaks = NULL,
116-
freq = TRUE) {
109+
mcmc_hist <- function(
110+
x,
111+
pars = character(),
112+
regex_pars = character(),
113+
transformations = list(),
114+
...,
115+
facet_args = list(),
116+
binwidth = NULL,
117+
breaks = NULL,
118+
freq = TRUE
119+
) {
117120
check_ignored_arguments(...)
118121
.mcmc_hist(
119122
x,
@@ -131,13 +134,19 @@ mcmc_hist <- function(x,
131134

132135
#' @rdname MCMC-distributions
133136
#' @export
134-
mcmc_dens <- function(x,
135-
pars = character(),
136-
regex_pars = character(),
137-
transformations = list(),
138-
...,
139-
facet_args = list(),
140-
trim = FALSE) {
137+
mcmc_dens <- function(
138+
x,
139+
pars = character(),
140+
regex_pars = character(),
141+
transformations = list(),
142+
...,
143+
facet_args = list(),
144+
trim = FALSE,
145+
bw = NULL,
146+
adjust = NULL,
147+
kernel = NULL,
148+
n_dens = NULL
149+
) {
141150
check_ignored_arguments(...)
142151
.mcmc_dens(
143152
x,
@@ -147,21 +156,27 @@ mcmc_dens <- function(x,
147156
facet_args = facet_args,
148157
by_chain = FALSE,
149158
trim = trim,
159+
bw = bw,
160+
adjust = adjust,
161+
kernel = kernel,
162+
n_dens = n_dens,
150163
...
151164
)
152165
}
153166

154167
#' @rdname MCMC-distributions
155168
#' @export
156169
#'
157-
mcmc_hist_by_chain <- function(x,
158-
pars = character(),
159-
regex_pars = character(),
160-
transformations = list(),
161-
...,
162-
facet_args = list(),
163-
binwidth = NULL,
164-
freq = TRUE) {
170+
mcmc_hist_by_chain <- function(
171+
x,
172+
pars = character(),
173+
regex_pars = character(),
174+
transformations = list(),
175+
...,
176+
facet_args = list(),
177+
binwidth = NULL,
178+
freq = TRUE
179+
) {
165180
check_ignored_arguments(...)
166181
.mcmc_hist(
167182
x,
@@ -178,14 +193,20 @@ mcmc_hist_by_chain <- function(x,
178193

179194
#' @rdname MCMC-distributions
180195
#' @export
181-
mcmc_dens_overlay <- function(x,
182-
pars = character(),
183-
regex_pars = character(),
184-
transformations = list(),
185-
...,
186-
facet_args = list(),
187-
color_chains = TRUE,
188-
trim = FALSE) {
196+
mcmc_dens_overlay <- function(
197+
x,
198+
pars = character(),
199+
regex_pars = character(),
200+
transformations = list(),
201+
...,
202+
facet_args = list(),
203+
color_chains = TRUE,
204+
trim = FALSE,
205+
bw = NULL,
206+
adjust = NULL,
207+
kernel = NULL,
208+
n_dens = NULL
209+
) {
189210
check_ignored_arguments(...)
190211
.mcmc_dens(
191212
x,
@@ -196,6 +217,10 @@ mcmc_dens_overlay <- function(x,
196217
by_chain = TRUE,
197218
color_chains = color_chains,
198219
trim = trim,
220+
bw = bw,
221+
adjust = adjust,
222+
kernel = kernel,
223+
n_dens = n_dens,
199224
...
200225
)
201226
}
@@ -204,19 +229,29 @@ mcmc_dens_overlay <- function(x,
204229
#' @template args-density-controls
205230
#' @param color_chains Option for whether to separately color chains.
206231
#' @export
207-
mcmc_dens_chains <- function(x,
208-
pars = character(),
209-
regex_pars = character(),
210-
transformations = list(),
211-
...,
212-
color_chains = TRUE,
213-
bw = NULL, adjust = NULL, kernel = NULL,
214-
n_dens = NULL) {
232+
mcmc_dens_chains <- function(
233+
x,
234+
pars = character(),
235+
regex_pars = character(),
236+
transformations = list(),
237+
...,
238+
color_chains = TRUE,
239+
bw = NULL,
240+
adjust = NULL,
241+
kernel = NULL,
242+
n_dens = NULL
243+
) {
215244
check_ignored_arguments(...)
216-
data <- mcmc_dens_chains_data(x, pars = pars, regex_pars = regex_pars,
217-
transformations = transformations, bw = bw,
218-
adjust = adjust, kernel = kernel,
219-
n_dens = n_dens)
245+
data <- mcmc_dens_chains_data(
246+
x,
247+
pars = pars,
248+
regex_pars = regex_pars,
249+
transformations = transformations,
250+
bw = bw,
251+
adjust = adjust,
252+
kernel = kernel,
253+
n_dens = n_dens
254+
)
220255

221256
n_chains <- length(unique(data$chain))
222257
if (n_chains == 1) STOP_need_multiple_chains()
@@ -233,17 +268,22 @@ mcmc_dens_chains <- function(x,
233268
}
234269

235270
ggplot(data) +
236-
aes_(x = ~ x, y = ~ parameter, color = ~ chain,
237-
group = ~ interaction(chain, parameter)) +
271+
aes_(
272+
x = ~ x, y = ~ parameter, color = ~ chain,
273+
group = ~ interaction(chain, parameter)
274+
) +
238275
geom_line(data = line_training) +
239276
ggridges::geom_density_ridges(
240277
aes_(height = ~ density),
241278
stat = "identity",
242279
fill = NA,
243-
show.legend = FALSE) +
280+
show.legend = FALSE
281+
) +
244282
labs(color = "Chain") +
245-
scale_y_discrete(limits = unique(rev(data$parameter)),
246-
expand = c(0.05, .6)) +
283+
scale_y_discrete(
284+
limits = unique(rev(data$parameter)),
285+
expand = c(0.05, .6)
286+
) +
247287
scale_color +
248288
bayesplot_theme_get() +
249289
yaxis_title(FALSE) +
@@ -254,38 +294,48 @@ mcmc_dens_chains <- function(x,
254294

255295
#' @rdname MCMC-distributions
256296
#' @export
257-
mcmc_dens_chains_data <- function(x,
258-
pars = character(),
259-
regex_pars = character(),
260-
transformations = list(),
261-
...,
262-
bw = NULL, adjust = NULL, kernel = NULL,
263-
n_dens = NULL) {
297+
mcmc_dens_chains_data <- function(
298+
x,
299+
pars = character(),
300+
regex_pars = character(),
301+
transformations = list(),
302+
...,
303+
bw = NULL, adjust = NULL, kernel = NULL,
304+
n_dens = NULL
305+
) {
264306
check_ignored_arguments(...)
265307

266308
x %>%
267-
prepare_mcmc_array(pars = pars, regex_pars = regex_pars,
268-
transformations = transformations) %>%
309+
prepare_mcmc_array(
310+
pars = pars,
311+
regex_pars = regex_pars,
312+
transformations = transformations
313+
) %>%
269314
melt_mcmc() %>%
270-
compute_column_density(c(.data$Parameter, .data$Chain), .data$Value,
271-
interval_width = 1,
272-
bw = bw, adjust = adjust, kernel = kernel,
273-
n_dens = n_dens) %>%
315+
compute_column_density(
316+
group_vars = c(.data$Parameter, .data$Chain),
317+
value_var = .data$Value,
318+
interval_width = 1,
319+
bw = bw, adjust = adjust, kernel = kernel, n_dens = n_dens
320+
) %>%
274321
mutate(Chain = factor(.data$Chain)) %>%
275322
rlang::set_names(tolower) %>%
276323
dplyr::as_tibble()
277324
}
278325

326+
279327
#' @rdname MCMC-distributions
280328
#' @inheritParams ppc_violin_grouped
281329
#' @export
282-
mcmc_violin <- function(x,
283-
pars = character(),
284-
regex_pars = character(),
285-
transformations = list(),
286-
...,
287-
facet_args = list(),
288-
probs = c(0.1, 0.5, 0.9)) {
330+
mcmc_violin <- function(
331+
x,
332+
pars = character(),
333+
regex_pars = character(),
334+
transformations = list(),
335+
...,
336+
facet_args = list(),
337+
probs = c(0.1, 0.5, 0.9)
338+
) {
289339
check_ignored_arguments(...)
290340
.mcmc_dens(
291341
x,
@@ -303,16 +353,18 @@ mcmc_violin <- function(x,
303353

304354

305355
# internal -----------------------------------------------------------------
306-
.mcmc_hist <- function(x,
307-
pars = character(),
308-
regex_pars = character(),
309-
transformations = list(),
310-
facet_args = list(),
311-
binwidth = NULL,
312-
breaks = NULL,
313-
by_chain = FALSE,
314-
freq = TRUE,
315-
...) {
356+
.mcmc_hist <- function(
357+
x,
358+
pars = character(),
359+
regex_pars = character(),
360+
transformations = list(),
361+
facet_args = list(),
362+
binwidth = NULL,
363+
breaks = NULL,
364+
by_chain = FALSE,
365+
freq = TRUE,
366+
...
367+
) {
316368
x <- prepare_mcmc_array(x, pars, regex_pars, transformations)
317369

318370
if (by_chain && !has_multiple_chains(x)) {
@@ -363,25 +415,37 @@ mcmc_violin <- function(x,
363415
xaxis_title(on = n_param == 1)
364416
}
365417

366-
.mcmc_dens <- function(x,
367-
pars = character(),
368-
regex_pars = character(),
369-
transformations = list(),
370-
facet_args = list(),
371-
by_chain = FALSE,
372-
color_chains = FALSE,
373-
geom = c("density", "violin"),
374-
probs = c(0.1, 0.5, 0.9),
375-
trim = FALSE,
376-
...) {
418+
.mcmc_dens <- function(
419+
x,
420+
pars = character(),
421+
regex_pars = character(),
422+
transformations = list(),
423+
facet_args = list(),
424+
by_chain = FALSE,
425+
color_chains = FALSE,
426+
geom = c("density", "violin"),
427+
probs = c(0.1, 0.5, 0.9),
428+
trim = FALSE,
429+
bw = NULL,
430+
adjust = NULL,
431+
kernel = NULL,
432+
n_dens = NULL,
433+
...
434+
) {
435+
436+
bw <- bw %||% "nrd0"
437+
adjust <- adjust %||% 1
438+
kernel <- kernel %||% "gaussian"
439+
n_dens <- n_dens %||% 1024
440+
377441
x <- prepare_mcmc_array(x, pars, regex_pars, transformations)
378-
data <- melt_mcmc(x)
442+
data <- melt_mcmc.mcmc_array(x)
379443
data$Chain <- factor(data$Chain)
380444
n_param <- num_params(data)
381445

382446
geom <- match.arg(geom)
383447
violin <- geom == "violin"
384-
geom_fun <- if (by_chain) "stat_density" else paste0("geom_", geom)
448+
geom_fun <- if (!violin) "stat_density" else "geom_violin"
385449

386450
if (by_chain || violin) {
387451
if (!has_multiple_chains(x)) {
@@ -396,11 +460,16 @@ mcmc_violin <- function(x,
396460
} else {
397461
list(x = ~ Value)
398462
}
463+
399464
geom_args <- list(size = 0.5, na.rm = TRUE)
400465
if (violin) {
401466
geom_args[["draw_quantiles"]] <- probs
402467
} else {
403468
geom_args[["trim"]] <- trim
469+
geom_args[["bw"]] <- bw
470+
geom_args[["adjust"]] <- adjust
471+
geom_args[["kernel"]] <- kernel
472+
geom_args[["n"]] <- n_dens
404473
}
405474

406475
if (by_chain) {
@@ -450,3 +519,4 @@ mcmc_violin <- function(x,
450519
yaxis_title(on = n_param == 1 && violin) +
451520
xaxis_title(on = n_param == 1)
452521
}
522+

0 commit comments

Comments
 (0)