Skip to content

Commit 98da77c

Browse files
authored
Merge pull request #188 from stan-dev/feature-tidyselect
Feature tidyselect
2 parents bf1dec2 + a1c1eec commit 98da77c

19 files changed

+688
-152
lines changed

DESCRIPTION

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Date: 2018-08-01
66
Authors@R: c(person("Jonah", "Gabry", role = c("aut", "cre"), email = "[email protected]"),
77
person("Tristan", "Mahr", role = "aut"),
88
person("Paul-Christian", "Bürkner", role = "ctb"),
9-
person("Martin", "Modrák", role = "ctb"),
9+
person("Martin", "Modrák", role = "ctb"),
1010
person("Malcolm", "Barrett", role = "ctb"))
1111
Maintainer: Jonah Gabry <[email protected]>
1212
Description: Plotting functions for posterior analysis, posterior predictive checks,
@@ -25,10 +25,12 @@ Imports:
2525
dplyr (>= 0.8.0),
2626
ggplot2 (>= 2.2.1),
2727
ggridges,
28+
glue,
2829
reshape2,
2930
rlang (>= 0.3.0),
3031
stats,
3132
tibble,
33+
tidyselect,
3234
utils
3335
Suggests:
3436
gridExtra (>= 2.2.1),

NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ export(overlay_function)
8989
export(pairs_condition)
9090
export(pairs_style_np)
9191
export(panel_bg)
92+
export(param_glue)
93+
export(param_range)
9294
export(parcoord_style_np)
9395
export(plot_bg)
9496
export(pp_check)
@@ -132,6 +134,7 @@ export(rhat)
132134
export(scatter_style_np)
133135
export(theme_default)
134136
export(trace_style_np)
137+
export(vars)
135138
export(vline_0)
136139
export(vline_at)
137140
export(xaxis_text)
@@ -158,6 +161,7 @@ importFrom(dplyr,select)
158161
importFrom(dplyr,summarise)
159162
importFrom(dplyr,top_n)
160163
importFrom(dplyr,ungroup)
164+
importFrom(dplyr,vars)
161165
importFrom(ggplot2,"%+replace%")
162166
importFrom(ggridges,geom_density_ridges)
163167
importFrom(ggridges,geom_density_ridges2)

NEWS.md

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
<!-- Items for next release go here* -->
88

9+
* The `pars` argument of all MCMC plotting functions now supports tidy variable selection.
10+
See `help("tidy-params", package="bayesplot")` for details and examples. (#161, #183, #188)
11+
912
* Two new plots have been added for inspecting the distribution of ranks.
1013
Rank histograms were introduced by the Stan team's [new paper on
1114
MCMC diagnostics](https://arxiv.org/abs/1903.08008). (#178, #179)
@@ -41,21 +44,21 @@
4144
curves. The default `"equal area"` constrains the heights so that the curves
4245
have the same area. As a result, a narrow interval will appear as a spike
4346
of density, while a wide, uncertain interval is spread thin over the _x_ axis.
44-
Alternatively `"equal height"` will set the maximum height on each curve to
45-
the same value. This works well when the intervals are about the same width.
46-
Otherwise, that wide, uncertain interval will dominate the visual space
47-
compared to a narrow, less uncertain interval. A compromise between the two is
48-
`"scaled height"` which scales the curves from `"equal height"` using
47+
Alternatively `"equal height"` will set the maximum height on each curve to
48+
the same value. This works well when the intervals are about the same width.
49+
Otherwise, that wide, uncertain interval will dominate the visual space
50+
compared to a narrow, less uncertain interval. A compromise between the two is
51+
`"scaled height"` which scales the curves from `"equal height"` using
4952
`height * sqrt(height)`. (#163, #169)
50-
51-
* `mcmc_areas()` correctly plots density curves where the point estimate
52-
does not include the highest point of the density curve.
53+
54+
* `mcmc_areas()` correctly plots density curves where the point estimate
55+
does not include the highest point of the density curve.
5356
(#168, #169, @jtimonen)
54-
55-
* `mcmc_areas_ridges()` draws the vertical line at *x* = 0 over the curves so
57+
58+
* `mcmc_areas_ridges()` draws the vertical line at *x* = 0 over the curves so
5659
that it is always visible.
5760

58-
* `mcmc_intervals()` and `mcmc_areas()` raise a warning if `prob_outer` is ever
61+
* `mcmc_intervals()` and `mcmc_areas()` raise a warning if `prob_outer` is ever
5962
less than `prob`. It sorts these two values into the correct order. (#138)
6063

6164
* MCMC parameter names are now *always* converted to factors prior to
@@ -148,7 +151,7 @@
148151

149152
* Added `mcmc_intervals_data()` and `mcmc_areas_data()` that return data
150153
plotted by `mcmc_intervals()` and `mcmc_areas()`. (Advances #97)
151-
154+
152155
* New `ppc_data()` function returns the data plotted by many of the PPC plotting
153156
functions. (Advances #97)
154157

@@ -165,29 +168,29 @@
165168

166169
(GitHub issue/PR numbers in parentheses)
167170

168-
* New plotting function `mcmc_parcoord()` for parallel coordinates plots of
171+
* New plotting function `mcmc_parcoord()` for parallel coordinates plots of
169172
MCMC draws (optionally including HMC/NUTS diagnostic information). (#108)
170-
173+
171174
* `mcmc_scatter` gains an `np` argument for specifying NUTS parameters, which
172175
allows highlighting divergences in the plot. (#112)
173-
174-
* New functions with names ending with suffix `_data` don't make the plots,
175-
they just return the data prepared for plotting (more of these to come in
176+
177+
* New functions with names ending with suffix `_data` don't make the plots,
178+
they just return the data prepared for plotting (more of these to come in
176179
future releases):
177180
- `ppc_intervals_data()` (#101)
178181
- `ppc_ribbon_data()` (#101)
179182
- `mcmc_parcoord_data()` (#108)
180183
- `mcmc_rhat_data()` (#110)
181184
- `mcmc_neff_data()` (#110)
182-
183-
* `ppc_stat_grouped()`, `ppc_stat_freqpoly_grouped()` gain a `facet_args`
184-
argument for controlling **ggplot2** faceting (many of the `mcmc_` functions
185+
186+
* `ppc_stat_grouped()`, `ppc_stat_freqpoly_grouped()` gain a `facet_args`
187+
argument for controlling **ggplot2** faceting (many of the `mcmc_` functions
185188
already have this).
186-
187-
* The `divergences` argument to `mcmc_trace()` has been deprecated in favor
188-
of `np` (NUTS parameters) to match the other functions that have an `np`
189+
190+
* The `divergences` argument to `mcmc_trace()` has been deprecated in favor
191+
of `np` (NUTS parameters) to match the other functions that have an `np`
189192
argument.
190-
193+
191194
* Fixed an issue where duplicated rhat values would break `mcmc_rhat()` (#105).
192195

193196

@@ -316,5 +319,5 @@ Initial CRAN release
316319

317320

318321

319-
[ggridges]: https://CRAN.R-project.org/package=ggridges
322+
[ggridges]: https://CRAN.R-project.org/package=ggridges
320323
"ggridges package"

R/helpers-mcmc.R

Lines changed: 101 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
# Prepare 3-D array for MCMC plots
2-
#
3-
# @param x,pars,regex_pars,transformations Users's arguments to one of the
4-
# mcmc_* functions.
5-
# @return A 3-D (Iterations x Chains x Parameters) array.
6-
#
1+
#' Prepare 3-D array for MCMC plots
2+
#'
3+
#' @noRd
4+
#' @param x,pars,regex_pars,transformations Users's arguments to one of the
5+
#' mcmc_* functions.
6+
#' @return A 3-D (Iterations x Chains x Parameters) array.
7+
#'
78
prepare_mcmc_array <- function(x,
89
pars = character(),
910
regex_pars = character(),
@@ -29,12 +30,14 @@ prepare_mcmc_array <- function(x,
2930
abort("NAs not allowed in 'x'.")
3031
}
3132

32-
parnames <- parameter_names(x)
33-
pars <- select_parameters(
34-
explicit = pars,
35-
patterns = regex_pars,
36-
complete = parnames
37-
)
33+
if (rlang::is_quosures(pars)) {
34+
pars <- tidyselect_parameters(complete_pars = parameter_names(x),
35+
pars_list = pars)
36+
} else {
37+
pars <- select_parameters(complete_pars = parameter_names(x),
38+
explicit = pars,
39+
patterns = regex_pars)
40+
}
3841

3942
# possibly recycle transformations (apply same to all pars)
4043
if (is.function(transformations) ||
@@ -61,12 +64,63 @@ prepare_mcmc_array <- function(x,
6164
}
6265

6366

64-
# Melt a 3-D array or matrix of MCMC draws
65-
#
66-
# @param x An mcmc_array (from prepare_mcmc_array).
67-
# @param varnames,value.name,... Passed to reshape2::melt (array method).
68-
# @return A molten data frame.
69-
#
67+
#' Explicit and/or regex parameter selection
68+
#'
69+
#' @noRd
70+
#' @param explicit Character vector of selected parameter names.
71+
#' @param patterns Character vector of regular expressions.
72+
#' @param complete_pars Character vector of all possible parameter names.
73+
#' @return Character vector of combined explicit and matched (via regex)
74+
#' parameter names, unless an error is thrown.
75+
#'
76+
select_parameters <-
77+
function(explicit = character(),
78+
patterns = character(),
79+
complete_pars = character()) {
80+
81+
stopifnot(is.character(explicit),
82+
is.character(patterns),
83+
is.character(complete_pars))
84+
85+
if (!length(explicit) && !length(patterns)) {
86+
return(complete_pars)
87+
}
88+
89+
if (length(explicit)) {
90+
if (!all(explicit %in% complete_pars)) {
91+
not_found <- which(!explicit %in% complete_pars)
92+
abort(paste(
93+
"Some 'pars' don't match parameter names:",
94+
paste(explicit[not_found], collapse = ", "),
95+
call. = FALSE
96+
))
97+
}
98+
}
99+
100+
if (!length(patterns)) {
101+
return(unique(explicit))
102+
} else {
103+
regex_pars <-
104+
unlist(lapply(seq_along(patterns), function(j) {
105+
grep(patterns[j], complete_pars, value = TRUE)
106+
}))
107+
108+
if (!length(regex_pars)) {
109+
abort("No matches for 'regex_pars'.")
110+
}
111+
}
112+
113+
unique(c(explicit, regex_pars))
114+
}
115+
116+
117+
#' Melt a 3-D array or matrix of MCMC draws
118+
#'
119+
#' @noRd
120+
#' @param x An mcmc_array (from prepare_mcmc_array).
121+
#' @param varnames,value.name,... Passed to reshape2::melt (array method).
122+
#' @return A molten data frame.
123+
#'
70124
melt_mcmc <- function(x, ...) UseMethod("melt_mcmc")
71125
melt_mcmc.mcmc_array <- function(x,
72126
varnames =
@@ -103,9 +157,11 @@ melt_mcmc.matrix <- function(x,
103157
long
104158
}
105159

106-
# Set dimnames of 3-D array
107-
# @param x 3-D array
108-
# @param parnames Character vector of parameter names
160+
#' Set dimnames of 3-D array
161+
#' @noRd
162+
#' @param x 3-D array
163+
#' @param parnames Character vector of parameter names
164+
#' @return x with a modified dimnames.
109165
set_mcmc_dimnames <- function(x, parnames) {
110166
stopifnot(is_3d_array(x))
111167
dimnames(x) <- list(
@@ -116,11 +172,12 @@ set_mcmc_dimnames <- function(x, parnames) {
116172
structure(x, class = c(class(x), "mcmc_array"))
117173
}
118174

119-
# Convert 3-D array to matrix with chains merged
120-
#
121-
# @param x A 3-D array (iter x chain x param)
122-
# @return A matrix with one column per parameter
123-
#
175+
#' Convert 3-D array to matrix with chains merged
176+
#'
177+
#' @noRd
178+
#' @param x A 3-D array (iter x chain x param)
179+
#' @return A matrix with one column per parameter
180+
#'
124181
merge_chains <- function(x) {
125182
xdim <- dim(x)
126183
mat <- array(x, dim = c(prod(xdim[1:2]), xdim[3]))
@@ -129,10 +186,11 @@ merge_chains <- function(x) {
129186
}
130187

131188

132-
# Check if an object is a data.frame with a chain index column
133-
#
134-
# @param x object to check
135-
# @return TRUE or FALSE
189+
#' Check if an object is a data.frame with a chain index column
190+
#'
191+
#' @noRd
192+
#' @param x object to check
193+
#' @return TRUE or FALSE
136194
is_df_with_chain <- function(x) {
137195
is.data.frame(x) && any(tolower(colnames(x)) %in% "chain")
138196
}
@@ -167,11 +225,11 @@ df_with_chain2array <- function(x) {
167225
}
168226

169227

170-
# Check if an object is a list (but not a data.frame) that contains
171-
# all 2-D objects
172-
#
173-
# @param x object to check
174-
# @return TRUE or FALSE
228+
#' Check if an object is a list (but not a data.frame) that contains
229+
#' all 2-D objects
230+
#' @noRd
231+
#' @param x object to check
232+
#' @return TRUE or FALSE
175233
is_chain_list <- function(x) {
176234
check1 <- !is.data.frame(x) && is.list(x)
177235
dims <- try(sapply(x, function(chain) length(dim(chain))), silent=TRUE)
@@ -316,13 +374,14 @@ validate_transformations <-
316374
}
317375

318376

319-
# Apply transformations to matrix or 3-D array of parameter draws
320-
#
321-
# @param x A matrix or 3-D array of draws
322-
# @param transformation User's 'transformations' argument to one of the mcmc_*
323-
# functions.
324-
# @return x, with tranformations having been applied to some parameters.
325-
#
377+
#' Apply transformations to matrix or 3-D array of parameter draws
378+
#'
379+
#' @noRd
380+
#' @param x A matrix or 3-D array of draws
381+
#' @param transformation User's 'transformations' argument to one of the mcmc_*
382+
#' functions.
383+
#' @return x, with tranformations having been applied to some parameters.
384+
#'
326385
apply_transformations <- function(x, transformations = list(), ...) {
327386
UseMethod("apply_transformations")
328387
}
@@ -395,4 +454,3 @@ num_iters.data.frame <- function(x, ...) {
395454

396455
n
397456
}
398-

0 commit comments

Comments
 (0)