Skip to content

Commit 10aa75d

Browse files
committed
Call as.array on input to prepare_mcmc_array
1 parent ea1ca27 commit 10aa75d

File tree

4 files changed

+60
-61
lines changed

4 files changed

+60
-61
lines changed

R/helpers-mcmc.R

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,27 @@ prepare_mcmc_array <- function(x,
1212
x <- df_with_chain2array(x)
1313
} else if (is_chain_list(x)) {
1414
x <- chain_list2array(x)
15+
} else if (is.data.frame(x)) {
16+
# data frame without Chain column
17+
x <- as.matrix(x)
18+
} else if (!is.array(x)) {
19+
x <- as.array(x)
20+
}
21+
22+
stopifnot(is.matrix(x) || is.array(x))
23+
if (is.array(x) && !(length(dim(x)) %in% c(2,3))) {
24+
stop("Arrays should have 2 or 3 dimensions. See help('MCMC-overview').")
25+
}
26+
if (anyNA(x)) {
27+
stop("NAs not allowed in 'x'.")
1528
}
16-
x <- validate_mcmc_x(x)
1729

1830
parnames <- parameter_names(x)
1931
pars <- select_parameters(
2032
explicit = pars,
2133
patterns = regex_pars,
22-
complete = parnames)
34+
complete = parnames
35+
)
2336

2437
# possibly recycle transformations (apply same to all pars)
2538
if (is.function(transformations) ||
@@ -157,16 +170,12 @@ df_with_chain2array <- function(x) {
157170
# @param x object to check
158171
# @return TRUE or FALSE
159172
is_chain_list <- function(x) {
160-
!is.data.frame(x) && is.list(x)
173+
check1 <- !is.data.frame(x) && is.list(x)
174+
dims <- sapply(x, function(chain) length(dim(chain)))
175+
isTRUE(all(dims == 2))
161176
}
162177

163178
validate_chain_list <- function(x) {
164-
stopifnot(is_chain_list(x))
165-
dims <- sapply(x, function(chain) length(dim(chain)))
166-
if (!isTRUE(all(dims == 2))) {
167-
stop("If 'x' is a list then all elements must be matrices.")
168-
}
169-
170179
n_chain <- length(x)
171180
for (i in seq_len(n_chain)) {
172181
nms <- colnames(as.matrix(x[[i]]))
@@ -276,25 +285,6 @@ STOP_need_multiple_chains <- function(call. = FALSE) {
276285
}
277286

278287

279-
# Perform some checks on user's 'x' input for MCMC plots
280-
#
281-
# @param x User's 'x' input to one of the mcmc_* functions.
282-
# @return x, unless an error is thrown.
283-
validate_mcmc_x <- function(x) {
284-
stopifnot(!is_df_with_chain(x), !is_chain_list(x))
285-
if (is.data.frame(x)) {
286-
x <- as.matrix(x)
287-
}
288-
289-
stopifnot(is.matrix(x) || is.array(x))
290-
if (anyNA(x)) {
291-
stop("NAs not allowed in 'x'.")
292-
}
293-
294-
x
295-
}
296-
297-
298288
# Validate that transformations match parameter names
299289
validate_transformations <-
300290
function(transformations = list(),

R/mcmc-overview.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
#' corresponds to a Markov chain. All of the matrices should have the same
1919
#' number of iterations (rows) and parameters (columns), and parameters should
2020
#' have the same names and be in the same order.
21-
#' \item \strong{matrix}: A \code{\link{matrix}} with one column per parameter.
22-
#' If using matrix there should only be a single Markov chain or all chains
23-
#' should already be merged (stacked).
21+
#' \item \strong{matrix (2-D array)}: A \code{\link{matrix}} with one column
22+
#' per parameter. If using matrix there should only be a single Markov chain or
23+
#' all chains should already be merged (stacked).
2424
#' \item \strong{data frame}: There are two types of \link[=data.frame]{data
2525
#' frames} allowed. Either a data frame with one column per parameter (if only
2626
#' a single chain or all chains have already been merged), or a data frame with

man/MCMC-overview.Rd

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-helpers-mcmc.R

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,30 +24,6 @@ test_that("melt_mcmc does not convert integer parameter names to integers #162",
2424
})
2525

2626

27-
28-
# validate_mcmc_x ----------------------------------------------------------
29-
test_that("validate_mcmc_x works", {
30-
expect_identical(validate_mcmc_x(mat), mat)
31-
expect_identical(validate_mcmc_x(mat1), mat1)
32-
expect_identical(validate_mcmc_x(arr), arr)
33-
expect_identical(validate_mcmc_x(arr1), arr1)
34-
expect_identical(validate_mcmc_x(arr1chain), arr1chain)
35-
36-
# error if df_with_chain
37-
expect_error(validate_mcmc_x(dframe_multiple_chains), "is_df_with_chain")
38-
39-
# converts regular df to matrix
40-
expect_identical(validate_mcmc_x(dframe), as.matrix(dframe))
41-
42-
# NAs
43-
mat[1, 2] <- NA
44-
arr[1, 2, 3] <- NA
45-
expect_error(validate_mcmc_x(mat), "NAs not allowed")
46-
expect_error(validate_mcmc_x(arr), "NAs not allowed")
47-
})
48-
49-
50-
5127
# 3-D array helpers --------------------------------------------------------
5228
test_that("is_mcmc_array works", {
5329
expect_false(is_mcmc_array(mat))
@@ -153,7 +129,6 @@ test_that("is_chain_list works", {
153129
})
154130

155131
test_that("validate_chain_list works", {
156-
expect_error(validate_chain_list(mat), "is_chain_list")
157132
expect_identical(validate_chain_list(chainlist), chainlist)
158133
expect_identical(validate_chain_list(chainlist1), chainlist1)
159134
expect_identical(validate_chain_list(chainlist1chain), chainlist1chain)
@@ -172,8 +147,6 @@ test_that("chain_list2array works", {
172147
expect_mcmc_array(chain_list2array(chainlist))
173148
expect_mcmc_array(chain_list2array(chainlist1))
174149
expect_mcmc_array(chain_list2array(chainlist1chain))
175-
176-
expect_error(chain_list2array(dframe), "is_chain_list")
177150
})
178151

179152

@@ -231,6 +204,42 @@ test_that("transformations recycled properly if not a named list", {
231204
})
232205

233206

207+
# prepare_mcmc_array ------------------------------------------------------
208+
test_that("prepare_mcmc_array processes non-array input types correctly", {
209+
# errors are mostly covered by tests of the many internal functions above
210+
211+
# data frame with no Chain column (treat as 1 chain or merged chains)
212+
a1 <- prepare_mcmc_array(dframe)
213+
expect_s3_class(a1, "mcmc_array")
214+
expect_equal(dim(a1), c(nrow(dframe), 1, ncol(dframe)))
215+
expect_equal(parameter_names(a1), colnames(dframe))
216+
217+
# data frame with Chain column
218+
a2 <- prepare_mcmc_array(dframe_multiple_chains)
219+
expect_s3_class(a2, "mcmc_array")
220+
n_chain <- max(dframe_multiple_chains$chain)
221+
expect_equal(dim(a2), c(nrow(dframe) / n_chain, n_chain, ncol(dframe)))
222+
expect_equal(parameter_names(a2), colnames(dframe))
223+
224+
# list of matrices with multiple chains
225+
a3 <- prepare_mcmc_array(chainlist)
226+
expect_s3_class(a3, "mcmc_array")
227+
expect_equal(dim(a3), c(nrow(chainlist[[1]]), length(chainlist), ncol(chainlist[[1]])))
228+
expect_equal(parameter_names(a3), colnames(chainlist[[1]]))
229+
230+
# object with acceptable as.array method
231+
suppressPackageStartupMessages(library(rstanarm))
232+
fit <- stan_glm(mpg ~ wt, data = mtcars, chains = 2, iter = 500, refresh = 0)
233+
a4 <- prepare_mcmc_array(fit)
234+
expect_s3_class(a4, "mcmc_array")
235+
expect_equal(dim(a4), c(250, 2, 3))
236+
expect_equal(parameter_names(a4), c("(Intercept)", "wt", "sigma"))
237+
238+
# object with unacceptable as.array method
239+
fit2 <- lm(mpg ~ wt, data = mtcars)
240+
expect_error(prepare_mcmc_array(fit2), "Arrays should have 2 or 3 dimensions.")
241+
})
242+
234243

235244
# rhat and neff helpers ---------------------------------------------------
236245
test_that("diagnostic_factor.rhat works", {

0 commit comments

Comments
 (0)