Skip to content

Commit f3fcb28

Browse files
committed
add diag_elpd
1 parent a375c93 commit f3fcb28

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

R/loo_compare.R

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,34 @@ loo_compare.default <- function(x, ...) {
150150
diag_diff[khat_diff > 0.5] <- "khat_diff > 0.5"
151151
}
152152

153+
# get khats for PSIS
154+
khat_psis <- sapply(loos[ord],
155+
\(loo) {
156+
k <- loo$diagnostics[["pareto_k"]]
157+
if (is.null(k)) {
158+
out = ""
159+
} else {
160+
S <- dim(loo)[1]
161+
khat_threshold <- ps_khat_threshold(S)
162+
K <- sum(k > khat_threshold)
163+
if (K==0) {
164+
out <- ""
165+
} else {
166+
out <- paste0(K, " khat_psis > ", round(khat_threshold, 2))
167+
}
168+
}
169+
out
170+
}
171+
)
172+
153173
comp <- cbind(
154174
data.frame(
155175
model = rnms,
156176
elpd_diff = elpd_diff,
157177
se_diff = se_diff,
158178
p_worse = p_worse,
159-
diag_diff = diag_diff
179+
diag_diff = diag_diff,
180+
diag_elpd = khat_psis
160181
),
161182
as.data.frame(comp)
162183
)
@@ -195,7 +216,8 @@ print.compare.loo <- function(x, ..., digits = 1, p_worse = TRUE) {
195216
x2 <- cbind(
196217
x2,
197218
p_worse = .fr(x[, "p_worse"], digits = 2),
198-
diag_diff = x[, "diag_diff"]
219+
diag_diff = x[, "diag_diff"],
220+
diag_elpd = x[, "diag_elpd"]
199221
)
200222
}
201223
print(x2, quote = FALSE, row.names = FALSE)

0 commit comments

Comments
 (0)