Skip to content

Commit 794ffb0

Browse files
committed
add model column instead of row names
2 parents abcf209 + 3ed1c0c commit 794ffb0

File tree

4 files changed

+113
-64
lines changed

4 files changed

+113
-64
lines changed

R/loo_compare.R

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#' list.
1111
#'
1212
#' @return A data frame with class `"compare.loo"` that has its own
13-
#' print method. See the **Details** section.
13+
#' print method. See the **Details** and **Examples** sections.
1414
#'
1515
#' @details
1616
#' When comparing two fitted models, we can estimate the difference in their
@@ -149,16 +149,18 @@ loo_compare.default <- function(x, ...) {
149149
))
150150
diag_diff[khat_diff > 0.5] <- "khat_diff > 0.5"
151151
}
152+
152153
comp <- cbind(
153154
data.frame(
155+
model = rnms,
154156
elpd_diff = elpd_diff,
155157
se_diff = se_diff,
156158
p_worse = p_worse,
157159
diag_diff = diag_diff
158160
),
159161
as.data.frame(comp)
160162
)
161-
rownames(comp) <- rnms
163+
rownames(comp) <- NULL
162164

163165
# run order statistics-based checks for many model comparisons
164166
loo_order_stat_check(loos, ord)
@@ -175,25 +177,28 @@ loo_compare.default <- function(x, ...) {
175177
#' approximation based probability of each model having worse performance than
176178
#' the best model? The default is `TRUE`.
177179
print.compare.loo <- function(x, ..., digits = 1, p_worse = TRUE) {
178-
if (!inherits(x, "data.frame") && !inherits(x, "old_compare.loo")) {
180+
if (inherits(x, "old_compare.loo")) {
181+
return(unclass(x))
182+
}
183+
if (!inherits(x, "data.frame")) {
179184
class(x) <- c(class(x), "data.frame")
180185
}
181-
xcopy <- x
182-
if (NCOL(xcopy) >= 2) {
183-
xcopy <- xcopy[, c("elpd_diff", "se_diff")]
186+
if (!all(c("model", "elpd_diff", "se_diff") %in% colnames(x))) {
187+
print(as.data.frame(x))
188+
return(x)
184189
}
185-
if (p_worse &&
186-
"p_worse" %in% colnames(x) &&
187-
!inherits(x, "old_compare.loo")) {
188-
print(
189-
cbind(.fr(xcopy, digits),
190-
p_worse = .fr(x[, "p_worse"], 2),
191-
diag_diff = x[, "diag_diff"]),
192-
quote = FALSE
190+
x2 <- cbind(
191+
model = x$model,
192+
.fr(x[, c("elpd_diff", "se_diff")], digits)
193+
)
194+
if (p_worse && "p_worse" %in% colnames(x)) {
195+
x2 <- cbind(
196+
x2,
197+
p_worse = .fr(x[, "p_worse"], digits = 2),
198+
diag_diff = x[, "diag_diff"]
193199
)
194-
} else {
195-
print(.fr(xcopy, digits), quote = FALSE)
196200
}
201+
print(x2, quote = FALSE)
197202
invisible(x)
198203
}
199204

man/loo_compare.Rd

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/_snaps/compare.md

Lines changed: 71 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,86 @@
11
# loo_compare returns expected results (2 models)
22

3-
WAoAAAACAAQEAgACAwAAAAMTAAAACgAAAA4AAAACAAAAAAAAAADAEDpTX5xF7gAAAA4AAAAC
4-
AAAAAAAAAAA/tmpHtC8TAQAAAA4AAAACf/AAAAAAB6I/8AAAAAAAAAAAABAAAAACAAQACQAA
5-
AAAABAAJAAAAB04gPCAxMDAAAAAOAAAAAsBU4fDdyUJYwFXllhPDBrkAAAAOAAAAAkARCD2z
6-
EXBfQBEalRIN2T8AAAAOAAAAAkAKKMBh1blTQCZnlesA0IoAAAAOAAAAAj/x/WXscvNeP/Gb
7-
YJxtZ8cAAAAOAAAAAkBk4fDdyUJYQGXllhPDBrkAAAAOAAAAAkAhCD2zEXBfQCEalRIN2T8A
8-
AAQCAAAAAQAEAAkAAAAFbmFtZXMAAAAQAAAACgAEAAkAAAAJZWxwZF9kaWZmAAQACQAAAAdz
9-
ZV9kaWZmAAQACQAAAAdwX3dvcnNlAAQACQAAAAlkaWFnX2RpZmYABAAJAAAACWVscGRfd2Fp
10-
YwAEAAkAAAAMc2VfZWxwZF93YWljAAQACQAAAAZwX3dhaWMABAAJAAAACXNlX3Bfd2FpYwAE
11-
AAkAAAAEd2FpYwAEAAkAAAAHc2Vfd2FpYwAABAIAAAABAAQACQAAAAVjbGFzcwAAABAAAAAC
12-
AAQACQAAAAtjb21wYXJlLmxvbwAEAAkAAAAKZGF0YS5mcmFtZQAABAIAAAABAAQACQAAAAly
13-
b3cubmFtZXMAAAAQAAAAAgAEAAkAAAAGbW9kZWwxAAQACQAAAAZtb2RlbDIAAAD+
3+
WAoAAAACAAQEAgACAwAAAAMTAAAACwAAABAAAAACAAQACQAAAAZtb2RlbDEABAAJAAAABm1v
4+
ZGVsMgAAAA4AAAACAAAAAAAAAAAAAAAAAAAAAAAAAA4AAAACAAAAAAAAAAAAAAAAAAAAAAAA
5+
AA4AAAACf/AAAAAAB6J/8AAAAAAHogAAABAAAAACAAQACQAAAAAABAAJAAAAAAAAAA4AAAAC
6+
wFTh8N3JQljAVOHw3clCWAAAAA4AAAACQBEIPbMRcF9AEQg9sxFwXwAAAA4AAAACQAoowGHV
7+
uVNACijAYdW5UwAAAA4AAAACP/H9Zexy814/8f1l7HLzXgAAAA4AAAACQGTh8N3JQlhAZOHw
8+
3clCWAAAAA4AAAACQCEIPbMRcF9AIQg9sxFwXwAABAIAAAABAAQACQAAAAVuYW1lcwAAABAA
9+
AAALAAQACQAAAAVtb2RlbAAEAAkAAAAJZWxwZF9kaWZmAAQACQAAAAdzZV9kaWZmAAQACQAA
10+
AAdwX3dvcnNlAAQACQAAAAlkaWFnX2RpZmYABAAJAAAACWVscGRfd2FpYwAEAAkAAAAMc2Vf
11+
ZWxwZF93YWljAAQACQAAAAZwX3dhaWMABAAJAAAACXNlX3Bfd2FpYwAEAAkAAAAEd2FpYwAE
12+
AAkAAAAHc2Vfd2FpYwAABAIAAAABAAQACQAAAAVjbGFzcwAAABAAAAACAAQACQAAAAtjb21w
13+
YXJlLmxvbwAEAAkAAAAKZGF0YS5mcmFtZQAABAIAAAABAAQACQAAAAlyb3cubmFtZXMAAAAN
14+
AAAAAoAAAAD////+AAAA/g==
1415

15-
# loo_compare returns expected result (3 models)
16+
---
17+
18+
Code
19+
print(comp1)
20+
Output
21+
model elpd_diff se_diff p_worse diag_diff
22+
1 model1 0.0 0.0 NA
23+
2 model2 0.0 0.0 NA
1624

17-
WAoAAAACAAQEAgACAwAAAAMTAAAACgAAAA4AAAADAAAAAAAAAADAEDpTX5xF7sAwDcqRtgQY
18-
AAAADgAAAAMAAAAAAAAAAD+2ake0LxMBP8uv7eE07V4AAAAOAAAAA3/wAAAAAAeiP/AAAAAA
19-
AAA/8AAAAAAAAAAAABAAAAADAAQACQAAAAAABAAJAAAAB04gPCAxMDAABAAJAAAAB04gPCAx
20-
MDAAAAAOAAAAA8BU4fDdyUJYwFXllhPDBrnAWOVjgjbDYgAAAA4AAAADQBEIPbMRcF9AERqV
21-
Eg3ZP0AQ8gXcaKATAAAADgAAAANACijAYdW5U0AmZ5XrANCKQEHI2FIa3QoAAAAOAAAAAz/x
22-
/WXscvNeP/GbYJxtZ8c/8YDQkmfJXwAAAA4AAAADQGTh8N3JQlhAZeWWE8MGuUBo5WOCNsNi
23-
AAAADgAAAANAIQg9sxFwX0AhGpUSDdk/QCDyBdxooBMAAAQCAAAAAQAEAAkAAAAFbmFtZXMA
24-
AAAQAAAACgAEAAkAAAAJZWxwZF9kaWZmAAQACQAAAAdzZV9kaWZmAAQACQAAAAdwX3dvcnNl
25-
AAQACQAAAAlkaWFnX2RpZmYABAAJAAAACWVscGRfd2FpYwAEAAkAAAAMc2VfZWxwZF93YWlj
26-
AAQACQAAAAZwX3dhaWMABAAJAAAACXNlX3Bfd2FpYwAEAAkAAAAEd2FpYwAEAAkAAAAHc2Vf
27-
d2FpYwAABAIAAAABAAQACQAAAAVjbGFzcwAAABAAAAACAAQACQAAAAtjb21wYXJlLmxvbwAE
28-
AAkAAAAKZGF0YS5mcmFtZQAABAIAAAABAAQACQAAAAlyb3cubmFtZXMAAAAQAAAAAwAEAAkA
29-
AAAGbW9kZWwxAAQACQAAAAZtb2RlbDIABAAJAAAABm1vZGVsMwAAAP4=
30-
31-
# compare returns expected result (2 models)
25+
---
26+
27+
WAoAAAACAAQEAgACAwAAAAMTAAAACwAAABAAAAACAAQACQAAAAZtb2RlbDEABAAJAAAABm1v
28+
ZGVsMgAAAA4AAAACAAAAAAAAAADAEDpTX5xF7gAAAA4AAAACAAAAAAAAAAA/tmpHtC8TAQAA
29+
AA4AAAACf/AAAAAAB6I/8AAAAAAAAAAAABAAAAACAAQACQAAAAAABAAJAAAAB04gPCAxMDAA
30+
AAAOAAAAAsBU4fDdyUJYwFXllhPDBrkAAAAOAAAAAkARCD2zEXBfQBEalRIN2T8AAAAOAAAA
31+
AkAKKMBh1blTQCZnlesA0IoAAAAOAAAAAj/x/WXscvNeP/GbYJxtZ8cAAAAOAAAAAkBk4fDd
32+
yUJYQGXllhPDBrkAAAAOAAAAAkAhCD2zEXBfQCEalRIN2T8AAAQCAAAAAQAEAAkAAAAFbmFt
33+
ZXMAAAAQAAAACwAEAAkAAAAFbW9kZWwABAAJAAAACWVscGRfZGlmZgAEAAkAAAAHc2VfZGlm
34+
ZgAEAAkAAAAHcF93b3JzZQAEAAkAAAAJZGlhZ19kaWZmAAQACQAAAAllbHBkX3dhaWMABAAJ
35+
AAAADHNlX2VscGRfd2FpYwAEAAkAAAAGcF93YWljAAQACQAAAAlzZV9wX3dhaWMABAAJAAAA
36+
BHdhaWMABAAJAAAAB3NlX3dhaWMAAAQCAAAAAQAEAAkAAAAFY2xhc3MAAAAQAAAAAgAEAAkA
37+
AAALY29tcGFyZS5sb28ABAAJAAAACmRhdGEuZnJhbWUAAAQCAAAAAQAEAAkAAAAJcm93Lm5h
38+
bWVzAAAADQAAAAKAAAAA/////gAAAP4=
39+
40+
---
3241

3342
Code
34-
comp1
43+
print(comp2)
3544
Output
36-
elpd_diff se
37-
0.0 0.0
45+
model elpd_diff se_diff p_worse diag_diff
46+
1 model1 0.0 0.0 NA
47+
2 model2 -4.1 0.1 1.00 N < 100
48+
49+
---
50+
51+
Code
52+
print(comp2, p_worse = FALSE)
53+
Output
54+
model elpd_diff se_diff
55+
1 model1 0.0 0.0
56+
2 model2 -4.1 0.1
57+
58+
# loo_compare returns expected result (3 models)
59+
60+
WAoAAAACAAQEAgACAwAAAAMTAAAACwAAABAAAAADAAQACQAAAAZtb2RlbDEABAAJAAAABm1v
61+
ZGVsMgAEAAkAAAAGbW9kZWwzAAAADgAAAAMAAAAAAAAAAMAQOlNfnEXuwDANypG2BBgAAAAO
62+
AAAAAwAAAAAAAAAAP7ZqR7QvEwE/y6/t4TTtXgAAAA4AAAADf/AAAAAAB6I/8AAAAAAAAD/w
63+
AAAAAAAAAAAAEAAAAAMABAAJAAAAAAAEAAkAAAAHTiA8IDEwMAAEAAkAAAAHTiA8IDEwMAAA
64+
AA4AAAADwFTh8N3JQljAVeWWE8MGucBY5WOCNsNiAAAADgAAAANAEQg9sxFwX0ARGpUSDdk/
65+
QBDyBdxooBMAAAAOAAAAA0AKKMBh1blTQCZnlesA0IpAQcjYUhrdCgAAAA4AAAADP/H9Zexy
66+
814/8ZtgnG1nxz/xgNCSZ8lfAAAADgAAAANAZOHw3clCWEBl5ZYTwwa5QGjlY4I2w2IAAAAO
67+
AAAAA0AhCD2zEXBfQCEalRIN2T9AIPIF3GigEwAABAIAAAABAAQACQAAAAVuYW1lcwAAABAA
68+
AAALAAQACQAAAAVtb2RlbAAEAAkAAAAJZWxwZF9kaWZmAAQACQAAAAdzZV9kaWZmAAQACQAA
69+
AAdwX3dvcnNlAAQACQAAAAlkaWFnX2RpZmYABAAJAAAACWVscGRfd2FpYwAEAAkAAAAMc2Vf
70+
ZWxwZF93YWljAAQACQAAAAZwX3dhaWMABAAJAAAACXNlX3Bfd2FpYwAEAAkAAAAEd2FpYwAE
71+
AAkAAAAHc2Vfd2FpYwAABAIAAAABAAQACQAAAAVjbGFzcwAAABAAAAACAAQACQAAAAtjb21w
72+
YXJlLmxvbwAEAAkAAAAKZGF0YS5mcmFtZQAABAIAAAABAAQACQAAAAlyb3cubmFtZXMAAAAN
73+
AAAAAoAAAAD////9AAAA/g==
3874

3975
---
4076

4177
Code
42-
comp2
78+
print(comp1)
4379
Output
44-
elpd_diff se
45-
-4.1 0.1
80+
model elpd_diff se_diff p_worse diag_diff
81+
1 model1 0.0 0.0 NA
82+
2 model2 -4.1 0.1 1.00 N < 100
83+
3 model3 -16.1 0.2 1.00 N < 100
4684

4785
# compare returns expected result (3 models)
4886

tests/testthat/test_compare.R

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ test_that("loo_compare throws appropriate warnings", {
6060

6161

6262
comp_colnames <- c(
63+
"model",
6364
"elpd_diff",
6465
"se_diff",
6566
"p_worse",
@@ -77,19 +78,26 @@ test_that("loo_compare returns expected results (2 models)", {
7778
expect_s3_class(comp1, "compare.loo")
7879
expect_s3_class(comp1, "data.frame")
7980
expect_equal(colnames(comp1), comp_colnames)
80-
expect_equal(rownames(comp1), c("model1", "model2"))
81-
expect_output(print(comp1), "elpd_diff")
82-
expect_equal(comp1[1:2, 1], c(0, 0), ignore_attr = TRUE)
83-
expect_equal(comp1[1:2, 2], c(0, 0), ignore_attr = TRUE)
81+
expect_equal(comp1$model, c("model1", "model2"))
82+
expect_equal(comp1$elpd_diff, c(0, 0), ignore_attr = TRUE)
83+
expect_equal(comp1$se_diff, c(0, 0), ignore_attr = TRUE)
84+
expect_equal(comp1$p_worse, c(NA_real_, NA_real_), ignore_attr = TRUE)
85+
expect_snapshot_value(comp1, style = "serialize")
86+
expect_snapshot(print(comp1))
8487

8588
comp2 <- loo_compare(w1, w2)
8689
expect_s3_class(comp2, "compare.loo")
8790
expect_equal(colnames(comp2), comp_colnames)
88-
8991
expect_snapshot_value(comp2, style = "serialize")
92+
expect_snapshot(print(comp2))
93+
expect_snapshot(print(comp2, p_worse = FALSE))
9094

9195
# specifying objects via ... and via arg x gives equal results
9296
expect_equal(comp2, loo_compare(x = list(w1, w2)))
97+
98+
# custom naming works
99+
comp3 <- loo_compare(x = list("A" = w2, "B" = w1))
100+
expect_equal(comp3$model, c("B", "A"))
93101
})
94102

95103

@@ -98,12 +106,12 @@ test_that("loo_compare returns expected result (3 models)", {
98106
comp1 <- loo_compare(w1, w2, w3)
99107

100108
expect_equal(colnames(comp1), comp_colnames)
101-
expect_equal(rownames(comp1), c("model1", "model2", "model3"))
102-
expect_equal(comp1[1, 1], 0)
109+
expect_equal(comp1$model, c("model1", "model2", "model3"))
110+
expect_equal(comp1$p_worse, c(NA, 1, 1))
103111
expect_s3_class(comp1, "compare.loo")
104112
expect_s3_class(comp1, "data.frame")
105-
106113
expect_snapshot_value(comp1, style = "serialize")
114+
expect_snapshot(print(comp1))
107115

108116
# specifying objects via '...' gives equivalent results (equal
109117
# except rownames) to using 'x' argument
@@ -119,13 +127,11 @@ test_that("compare throws deprecation warnings", {
119127

120128
test_that("compare returns expected result (2 models)", {
121129
expect_warning(comp1 <- loo::compare(w1, w1), "Deprecated")
122-
expect_snapshot(comp1)
123130
expect_equal(comp1[1:2], c(elpd_diff = 0, se = 0))
124131

125132
expect_warning(comp2 <- loo::compare(w1, w2), "Deprecated")
126-
expect_snapshot(comp2)
127-
expect_named(comp2, c("elpd_diff", "se"))
128-
expect_s3_class(comp2, "compare.loo")
133+
expect_equal(round(comp2[1:2], 3), c(elpd_diff = -4.057, se = 0.088))
134+
expect_s3_class(comp2, "old_compare.loo")
129135

130136
# specifying objects via ... and via arg x gives equal results
131137
expect_warning(comp_via_list <- loo::compare(x = list(w1, w2)), "Deprecated")

0 commit comments

Comments
 (0)