Skip to content

Commit e54367b

Browse files
authored
add extractor methods for CmdStanMCMC objects (from CmdStanR) (#227)
* add extractor methods for CmdStanMCMC objects * fix test * Update bayesplot-extractors.R * Update NEWS.md
1 parent f4ed652 commit e54367b

File tree

8 files changed

+123
-18
lines changed

8 files changed

+123
-18
lines changed

NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22

33
S3method("[",neff_ratio)
44
S3method("[",rhat)
5+
S3method(log_posterior,CmdStanMCMC)
56
S3method(log_posterior,stanfit)
67
S3method(log_posterior,stanreg)
8+
S3method(neff_ratio,CmdStanMCMC)
79
S3method(neff_ratio,stanfit)
810
S3method(neff_ratio,stanreg)
11+
S3method(nuts_params,CmdStanMCMC)
912
S3method(nuts_params,list)
1013
S3method(nuts_params,stanfit)
1114
S3method(nuts_params,stanreg)
@@ -15,6 +18,7 @@ S3method(pp_check,default)
1518
S3method(print,bayesplot_function_list)
1619
S3method(print,bayesplot_grid)
1720
S3method(print,bayesplot_scheme)
21+
S3method(rhat,CmdStanMCMC)
1822
S3method(rhat,stanfit)
1923
S3method(rhat,stanreg)
2024
export(abline_01)

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
* Items for next release go here
99
-->
1010

11+
* CmdStanMCMC objects (from CmdStanR) can now be used with extractor
12+
functions `nuts_params()`, `log_posterior()`, `rhat()`, and
13+
`neff_ratio()`. (#227)
14+
1115
* Added missing `facet_args` argument to `mcmc_rank_overlay()`. (#221, @hhau)
1216
* Size of points and interval lines can set in
1317
`mcmc_intervals(..., outer_size, inner_size, point_size)`. (#215, #228, #229)

R/bayesplot-extractors.R

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#' Extract quantities needed for plotting from model objects
22
#'
33
#' Generics and methods for extracting quantities needed for plotting from
4-
#' various types of model objects. Currently methods are only provided for
5-
#' stanfit (**rstan**) and stanreg (**rstanarm**) objects, but adding new
6-
#' methods should be relatively straightforward.
4+
#' various types of model objects. Currently methods are provided for stanfit
5+
#' (**rstan**), CmdStanMCMC (**cmdstanr**), and stanreg (**rstanarm**) objects,
6+
#' but adding new methods should be relatively straightforward.
77
#'
88
#' @name bayesplot-extractors
99
#' @param object The object to use.
@@ -87,7 +87,8 @@ log_posterior.stanfit <- function(object, inc_warmup = FALSE, ...) {
8787
...)
8888
lp <- lapply(lp, as.array)
8989
lp <- set_names(reshape2::melt(lp), c("Iteration", "Value", "Chain"))
90-
validate_df_classes(lp, c("integer", "numeric", "integer"))
90+
validate_df_classes(lp[, c("Chain", "Iteration", "Value")],
91+
c("integer", "integer", "numeric"))
9192
}
9293

9394
#' @rdname bayesplot-extractors
@@ -98,11 +99,22 @@ log_posterior.stanreg <- function(object, inc_warmup = FALSE, ...) {
9899
log_posterior.stanfit(object$stanfit, inc_warmup = inc_warmup, ...)
99100
}
100101

102+
#' @rdname bayesplot-extractors
103+
#' @export
104+
#' @method log_posterior CmdStanMCMC
105+
log_posterior.CmdStanMCMC <- function(object, inc_warmup = FALSE, ...) {
106+
lp <- object$draws("lp__", inc_warmup = inc_warmup)
107+
lp <- reshape2::melt(lp)
108+
lp$variable <- NULL
109+
lp <- dplyr::rename_with(lp, capitalize_first)
110+
validate_df_classes(lp[, c("Chain", "Iteration", "Value")],
111+
c("integer", "integer", "numeric"))
112+
}
113+
101114

102115
#' @rdname bayesplot-extractors
103116
#' @export
104117
#' @method nuts_params stanfit
105-
#'
106118
nuts_params.stanfit <-
107119
function(object,
108120
pars = NULL,
@@ -153,7 +165,23 @@ nuts_params.list <- function(object, pars = NULL, ...) {
153165

154166
out <- reshape2::melt(object)
155167
out <- set_names(out, c("Iteration", "Parameter", "Value", "Chain"))
156-
validate_df_classes(out, c("integer", "factor", "numeric", "integer"))
168+
validate_df_classes(out[, c("Chain", "Iteration", "Parameter", "Value")],
169+
c("integer", "integer", "factor", "numeric"))
170+
}
171+
172+
#' @rdname bayesplot-extractors
173+
#' @export
174+
#' @method nuts_params CmdStanMCMC
175+
nuts_params.CmdStanMCMC <- function(object, pars = NULL, ...) {
176+
arr <- object$sampler_diagnostics()
177+
if (!is.null(pars)) {
178+
arr <- arr[,, pars]
179+
}
180+
out <- reshape2::melt(arr)
181+
colnames(out)[colnames(out) == "variable"] <- "parameter"
182+
out <- dplyr::rename_with(out, capitalize_first)
183+
validate_df_classes(out[, c("Chain", "Iteration", "Parameter", "Value")],
184+
c("integer", "integer", "factor", "numeric"))
157185
}
158186

159187

@@ -188,6 +216,17 @@ rhat.stanreg <- function(object, pars = NULL, regex_pars = NULL, ...) {
188216
r[!names(r) %in% c("mean_PPD", "log-posterior")]
189217
}
190218

219+
#' @rdname bayesplot-extractors
220+
#' @export
221+
#' @method rhat CmdStanMCMC
222+
rhat.CmdStanMCMC <- function(object, pars = NULL, ...) {
223+
.rhat <- utils::getFromNamespace("rhat", "posterior")
224+
s <- object$summary(pars, rhat = .rhat)[, c("variable", "rhat")]
225+
r <- setNames(s$rhat, s$variable)
226+
r <- validate_rhat(r)
227+
r[!names(r) %in% "lp__"]
228+
}
229+
191230

192231
#' @rdname bayesplot-extractors
193232
#' @export
@@ -223,6 +262,18 @@ neff_ratio.stanreg <- function(object, pars = NULL, regex_pars = NULL, ...) {
223262
ratio[!names(ratio) %in% c("mean_PPD", "log-posterior")]
224263
}
225264

265+
#' @rdname bayesplot-extractors
266+
#' @export
267+
#' @method neff_ratio CmdStanMCMC
268+
neff_ratio.CmdStanMCMC <- function(object, pars = NULL, ...) {
269+
s <- object$summary(pars, "n_eff" = "ess_basic")[, c("variable", "n_eff")]
270+
ess <- setNames(s$n_eff, s$variable)
271+
tss <- prod(dim(object$draws())[1:2])
272+
ratio <- ess / tss
273+
ratio <- validate_neff_ratio(ratio)
274+
ratio[!names(ratio) %in% "lp__"]
275+
}
276+
226277

227278
# internals ---------------------------------------------------------------
228279

@@ -245,3 +296,10 @@ validate_df_classes <- function(x, classes = character()) {
245296
}
246297
x
247298
}
299+
300+
# capitalize first letter in a string only
301+
capitalize_first <- function(name) {
302+
name <- tolower(name) # in case whole string is capitalized
303+
substr(name, 1, 1) <- toupper(substr(name, 1, 1))
304+
name
305+
}

R/mcmc-diagnostics-nuts.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -513,8 +513,8 @@ validate_nuts_data_frame <- function(x, lp) {
513513
abort("NUTS parameters should be in a data frame.")
514514
}
515515

516-
valid_cols <- c("Iteration", "Parameter", "Value", "Chain")
517-
if (!identical(colnames(x), valid_cols)) {
516+
valid_cols <- sort(c("Iteration", "Parameter", "Value", "Chain"))
517+
if (!identical(sort(colnames(x)), valid_cols)) {
518518
abort(paste(
519519
"NUTS parameter data frame must have columns:",
520520
paste(valid_cols, collapse = ", ")
@@ -529,8 +529,8 @@ validate_nuts_data_frame <- function(x, lp) {
529529
abort("lp should be in a data frame.")
530530
}
531531

532-
valid_lp_cols <- c("Iteration", "Value", "Chain")
533-
if (!identical(colnames(lp), valid_lp_cols)) {
532+
valid_lp_cols <- sort(c("Iteration", "Value", "Chain"))
533+
if (!identical(sort(colnames(lp)), valid_lp_cols)) {
534534
abort(paste(
535535
"lp data frame must have columns:",
536536
paste(valid_lp_cols, collapse = ", ")

man/bayesplot-extractors.Rd

Lines changed: 15 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-extractors.R

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ test_that("all nuts_params methods identical", {
4242

4343
test_that("nuts_params.stanreg returns correct structure", {
4444
np <- nuts_params(fit)
45-
expect_identical(colnames(np), c("Iteration", "Parameter", "Value", "Chain"))
45+
expect_identical(colnames(np), c("Chain", "Iteration", "Parameter", "Value"))
4646

4747
np_names <- paste0(c("accept_stat", "stepsize", "treedepth", "n_leapfrog",
4848
"divergent", "energy"), "__")
@@ -54,7 +54,7 @@ test_that("nuts_params.stanreg returns correct structure", {
5454

5555
test_that("log_posterior.stanreg returns correct structure", {
5656
lp <- log_posterior(fit)
57-
expect_identical(colnames(lp), c("Iteration", "Value", "Chain"))
57+
expect_identical(colnames(lp), c("Chain", "Iteration", "Value"))
5858
expect_equal(length(unique(lp$Iteration)), floor(ITER / 2))
5959
expect_equal(length(unique(lp$Chain)), CHAINS)
6060
})
@@ -100,3 +100,30 @@ test_that("neff_ratio.stanreg returns correct structure", {
100100
ans2 <- summary(fit, pars = c("wt", "sigma"))[, "n_eff"] / denom
101101
expect_equal(ratio2, ans2, tol = 0.001)
102102
})
103+
104+
test_that("cmdstanr methods work", {
105+
skip_on_cran()
106+
skip_if_not_installed("cmdstanr")
107+
108+
fit <- cmdstanr::cmdstanr_example("logistic", iter_sampling = 500, chains = 2)
109+
np <- nuts_params(fit)
110+
np_names <- paste0(c("accept_stat", "stepsize", "treedepth", "n_leapfrog",
111+
"divergent", "energy"), "__")
112+
expect_identical(levels(np$Parameter), np_names)
113+
expect_equal(range(np$Iteration), c(1, 500))
114+
expect_equal(range(np$Chain), c(1, 2))
115+
expect_true(all(np$Value[np$Parameter == "divergent__"] == 0))
116+
117+
lp <- log_posterior(fit)
118+
expect_named(lp, c("Chain", "Iteration", "Value"))
119+
expect_equal(range(np$Chain), c(1, 2))
120+
expect_equal(range(np$Iteration), c(1, 500))
121+
122+
r <- rhat(fit)
123+
expect_named(r, c("alpha", "beta[1]", "beta[2]", "beta[3]"))
124+
expect_true(all(round(r) == 1))
125+
126+
ratio <- neff_ratio(fit)
127+
expect_named(ratio, c("alpha", "beta[1]", "beta[2]", "beta[3]"))
128+
expect_true(all(ratio > 0))
129+
})

tests/testthat/test-mcmc-nuts.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ test_that("validate_nuts_data_frame throws errors", {
5858
)
5959
expect_error(
6060
validate_nuts_data_frame(data.frame(Iteration = 1, apple = 2)),
61-
"NUTS parameter data frame must have columns: Iteration, Parameter, Value, Chain"
61+
"NUTS parameter data frame must have columns: Chain, Iteration, Parameter, Value"
6262
)
6363
expect_error(
6464
validate_nuts_data_frame(np, as.matrix(lp)),
@@ -69,7 +69,7 @@ test_that("validate_nuts_data_frame throws errors", {
6969
colnames(lp2)[3] <- "Chains"
7070
expect_error(
7171
validate_nuts_data_frame(np, lp2),
72-
"lp data frame must have columns: Iteration, Value, Chain"
72+
"lp data frame must have columns: Chain, Iteration, Value"
7373
)
7474

7575
lp2 <- subset(lp, Chain %in% 1:2)

tests/testthat/test-mcmc-scatter-and-parcoord.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ test_that("mcmc_parcoord throws correct warnings and errors", {
311311

312312
expect_error(
313313
mcmc_parcoord(post, np = np[, -1]),
314-
"NUTS parameter data frame must have columns: Iteration, Parameter, Value, Chain",
314+
"NUTS parameter data frame must have columns: Chain, Iteration, Parameter, Value",
315315
fixed = TRUE
316316
)
317317

0 commit comments

Comments
 (0)