Skip to content

Commit 240e82e

Browse files
authored
Merge pull request #307 from stan-dev/loo-compare-include-n
Print unequal sample sizes in loo_compare
2 parents 6e8bb7f + bab8827 commit 240e82e

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

R/loo_compare.R

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,15 @@ loo_compare_checks <- function(loos) {
186186
stop("All inputs should have class 'loo'.", call.=FALSE)
187187
}
188188

189-
Ns <- sapply(loos, function(x) nrow(x$pointwise))
190-
if (!all(Ns == Ns[1L])) {
191-
stop("Not all models have the same number of data points.", call.=FALSE)
189+
Ns <- vapply(loos, function(x) nrow(x$pointwise), integer(1))
190+
if (any(Ns != Ns[1L])) {
191+
stop(
192+
paste0(
193+
"All models must have the same number of observations, but models have inconsistent observation counts: ",
194+
paste(paste0("'", find_model_names(loos), "' (", Ns, ")"), collapse = ", ")
195+
),
196+
call. = FALSE
197+
)
192198
}
193199

194200
## warnings

tests/testthat/test_compare.R

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,21 @@ test_that("loo_compare throws appropriate errors", {
1818
expect_error(loo_compare(w1, list(1, 2, 3)), "class 'loo'")
1919
expect_error(loo_compare(w1), "requires at least two models")
2020
expect_error(loo_compare(x = list(w1)), "requires at least two models")
21-
expect_error(loo_compare(w1, w3), "same number of data points")
22-
expect_error(loo_compare(w1, w2, w3), "same number of data points")
21+
expect_error(
22+
loo_compare(w1, w3),
23+
"All models must have the same number of observations, but models have inconsistent observation counts: 'model1' (32), 'model2' (31)",
24+
fixed = TRUE
25+
)
26+
expect_error(
27+
loo_compare(w1, w2, w3),
28+
"All models must have the same number of observations, but models have inconsistent observation counts: 'model1' (32), 'model2' (32), 'model3' (31)",
29+
fixed = TRUE
30+
)
31+
expect_error(
32+
loo_compare(x = list("Model A" = w1, "Model B" = w2, "Model C" = w3)),
33+
"All models must have the same number of observations, but models have inconsistent observation counts: 'Model A' (32), 'Model B' (32), 'Model C' (31)",
34+
fixed = TRUE
35+
)
2336
})
2437

2538
test_that("loo_compare throws appropriate warnings", {

0 commit comments

Comments
 (0)