Skip to content

Commit 46992a0

Browse files
committed
Format project using Air
1 parent 9886905 commit 46992a0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+5526
-2131
lines changed

.Rbuildignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,5 @@ vignettes/loo2-non-factorizable_cache/*
2121

2222
^CRAN-SUBMISSION$
2323
^release-prep\.R$
24+
^[\.]?air\.toml$
25+
^\.vscode$
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Workflow derived from https://github.com/posit-dev/setup-air/tree/main/examples
2+
on:
3+
pull_request:
4+
5+
name: format-suggest.yaml
6+
7+
permissions: read-all
8+
9+
jobs:
10+
format-suggest:
11+
name: format-suggest
12+
runs-on: ubuntu-latest
13+
permissions:
14+
pull-requests: write
15+
steps:
16+
- uses: actions/checkout@v4
17+
18+
- name: Install
19+
uses: posit-dev/setup-air@v1
20+
21+
- name: Format
22+
run: air format .
23+
24+
- name: Suggest
25+
uses: reviewdog/action-suggester@v1
26+
with:
27+
level: error
28+
fail_level: error
29+
tool_name: air

R/E_loo.R

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,14 @@ E_loo <- function(x, psis_object, ...) {
100100
#' @rdname E_loo
101101
#' @export
102102
E_loo.default <-
103-
function(x,
104-
psis_object,
105-
...,
106-
type = c("mean", "variance", "sd", "quantile"),
107-
probs = NULL,
108-
log_ratios = NULL) {
103+
function(
104+
x,
105+
psis_object,
106+
...,
107+
type = c("mean", "variance", "sd", "quantile"),
108+
probs = NULL,
109+
log_ratios = NULL
110+
) {
109111
stopifnot(
110112
is.numeric(x),
111113
is.psis(psis_object),
@@ -137,12 +139,14 @@ E_loo.default <-
137139
#' @rdname E_loo
138140
#' @export
139141
E_loo.matrix <-
140-
function(x,
141-
psis_object,
142-
...,
143-
type = c("mean", "variance", "sd", "quantile"),
144-
probs = NULL,
145-
log_ratios = NULL) {
142+
function(
143+
x,
144+
psis_object,
145+
...,
146+
type = c("mean", "variance", "sd", "quantile"),
147+
probs = NULL,
148+
log_ratios = NULL
149+
) {
146150
stopifnot(
147151
is.numeric(x),
148152
is.psis(psis_object),
@@ -162,9 +166,13 @@ E_loo.matrix <-
162166
}
163167
w <- weights(psis_object, log = FALSE)
164168

165-
out <- vapply(seq_len(ncol(x)), function(i) {
166-
E_fun(x[, i], w[, i], probs = probs)
167-
}, FUN.VALUE = fun_val)
169+
out <- vapply(
170+
seq_len(ncol(x)),
171+
function(i) {
172+
E_fun(x[, i], w[, i], probs = probs)
173+
},
174+
FUN.VALUE = fun_val
175+
)
168176

169177
if (is.null(log_ratios)) {
170178
# Use of smoothed ratios gives slightly optimistic
@@ -183,7 +191,6 @@ E_loo.matrix <-
183191
}
184192

185193

186-
187194
#' Select the function to use based on user's 'type' argument
188195
#'
189196
#' @noRd
@@ -290,22 +297,37 @@ E_loo_khat.matrix <- function(x, psis_object, log_ratios, ...) {
290297
.E_loo_khat_i <- function(x_i, log_ratios_i, tail_len_i) {
291298
h_theta <- x_i
292299
r_theta <- exp(log_ratios_i - max(log_ratios_i))
293-
khat_r <- posterior::pareto_khat(r_theta, tail = "right", ndraws_tail = tail_len_i)
294-
if (is.list(khat_r)) { # retain compatiblity with older posterior that returned a list
300+
khat_r <- posterior::pareto_khat(
301+
r_theta,
302+
tail = "right",
303+
ndraws_tail = tail_len_i
304+
)
305+
if (is.list(khat_r)) {
306+
# retain compatiblity with older posterior that returned a list
295307
khat_r <- khat_r$khat
296308
}
297-
if (is.null(x_i) || is_constant(x_i) || length(unique(x_i))==2 ||
298-
anyNA(x_i) || any(is.infinite(x_i))) {
309+
if (
310+
is.null(x_i) ||
311+
is_constant(x_i) ||
312+
length(unique(x_i)) == 2 ||
313+
anyNA(x_i) ||
314+
any(is.infinite(x_i))
315+
) {
299316
khat_r
300317
} else {
301-
khat_hr <- posterior::pareto_khat(h_theta * r_theta, tail = "both", ndraws_tail = tail_len_i)
302-
if (is.list(khat_hr)) { # retain compatiblity with older posterior that returned a list
318+
khat_hr <- posterior::pareto_khat(
319+
h_theta * r_theta,
320+
tail = "both",
321+
ndraws_tail = tail_len_i
322+
)
323+
if (is.list(khat_hr)) {
324+
# retain compatiblity with older posterior that returned a list
303325
khat_hr <- khat_hr$khat
304326
}
305327
if (is.na(khat_hr) && is.na(khat_r)) {
306328
k <- NA
307329
} else {
308-
k <- max(khat_hr, khat_r, na.rm=TRUE)
330+
k <- max(khat_hr, khat_r, na.rm = TRUE)
309331
}
310332
k
311333
}

R/compare.R

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,10 @@ compare <- function(..., x = list()) {
6363
dots <- list(...)
6464
if (length(dots)) {
6565
if (length(x)) {
66-
stop("If 'x' is specified then '...' should not be specified.",
67-
call. = FALSE)
66+
stop(
67+
"If 'x' is specified then '...' should not be specified.",
68+
call. = FALSE
69+
)
6870
}
6971
nms <- as.character(match.call(expand.dots = TRUE))[-1L]
7072
} else {
@@ -97,16 +99,18 @@ compare <- function(..., x = list()) {
9799

98100
x <- sapply(dots, function(x) {
99101
est <- x$estimates
100-
setNames(c(est), nm = c(rownames(est), paste0("se_", rownames(est))) )
102+
setNames(c(est), nm = c(rownames(est), paste0("se_", rownames(est))))
101103
})
102104
colnames(x) <- nms
103105
rnms <- rownames(x)
104106
comp <- x
105107
ord <- order(x[grep("^elpd", rnms), ], decreasing = TRUE)
106108
comp <- t(comp)[ord, ]
107109
patts <- c("elpd", "p_", "^waic$|^looic$", "^se_waic$|^se_looic$")
108-
col_ord <- unlist(sapply(patts, function(p) grep(p, colnames(comp))),
109-
use.names = FALSE)
110+
col_ord <- unlist(
111+
sapply(patts, function(p) grep(p, colnames(comp))),
112+
use.names = FALSE
113+
)
110114
comp <- comp[, col_ord]
111115

112116
# compute elpd_diff and se_elpd_diff relative to best model
@@ -122,13 +126,25 @@ compare <- function(..., x = list()) {
122126
}
123127

124128

125-
126129
# internal ----------------------------------------------------------------
127-
compare_two_models <- function(loo_a, loo_b, return = c("elpd_diff", "se"), check_dims = TRUE) {
130+
compare_two_models <- function(
131+
loo_a,
132+
loo_b,
133+
return = c("elpd_diff", "se"),
134+
check_dims = TRUE
135+
) {
128136
if (check_dims) {
129137
if (dim(loo_a$pointwise)[1] != dim(loo_b$pointwise)[1]) {
130-
stop(paste("Models don't have the same number of data points.",
131-
"\nFound N_1 =", dim(loo_a$pointwise)[1], "and N_2 =", dim(loo_b$pointwise)[1]), call. = FALSE)
138+
stop(
139+
paste(
140+
"Models don't have the same number of data points.",
141+
"\nFound N_1 =",
142+
dim(loo_a$pointwise)[1],
143+
"and N_2 =",
144+
dim(loo_b$pointwise)[1]
145+
),
146+
call. = FALSE
147+
)
132148
}
133149
}
134150

R/crps.R

Lines changed: 68 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ crps.matrix <- function(x, x2, y, ..., permutations = 1) {
9292
#' @rdname crps
9393
#' @export
9494
crps.numeric <- function(x, x2, y, ..., permutations = 1) {
95-
stopifnot(length(x) == length(x2),
96-
length(y) == 1)
95+
stopifnot(length(x) == length(x2), length(y) == 1)
9796
crps.matrix(as.matrix(x), as.matrix(x2), y, permutations)
9897
}
9998

@@ -106,23 +105,32 @@ crps.numeric <- function(x, x2, y, ..., permutations = 1) {
106105
#' @param cores The number of cores to use for parallelization of `[psis()]`.
107106
#' See [psis()] for details.
108107
loo_crps.matrix <-
109-
function(x,
110-
x2,
111-
y,
112-
log_lik,
113-
...,
114-
permutations = 1,
115-
r_eff = 1,
116-
cores = getOption("mc.cores", 1)) {
117-
validate_crps_input(x, x2, y, log_lik)
118-
repeats <- replicate(permutations,
119-
EXX_loo_compute(x, x2, log_lik, r_eff = r_eff, ...),
120-
simplify = F)
121-
EXX <- Reduce(`+`, repeats) / permutations
122-
psis_obj <- psis(-log_lik, r_eff = r_eff, cores = cores)
123-
EXy <- E_loo(abs(sweep(x, 2, y)), psis_obj, log_ratios = -log_lik, ...)$value
124-
crps_output(.crps_fun(EXX, EXy))
125-
}
108+
function(
109+
x,
110+
x2,
111+
y,
112+
log_lik,
113+
...,
114+
permutations = 1,
115+
r_eff = 1,
116+
cores = getOption("mc.cores", 1)
117+
) {
118+
validate_crps_input(x, x2, y, log_lik)
119+
repeats <- replicate(
120+
permutations,
121+
EXX_loo_compute(x, x2, log_lik, r_eff = r_eff, ...),
122+
simplify = F
123+
)
124+
EXX <- Reduce(`+`, repeats) / permutations
125+
psis_obj <- psis(-log_lik, r_eff = r_eff, cores = cores)
126+
EXy <- E_loo(
127+
abs(sweep(x, 2, y)),
128+
psis_obj,
129+
log_ratios = -log_lik,
130+
...
131+
)$value
132+
crps_output(.crps_fun(EXX, EXy))
133+
}
126134

127135

128136
#' @rdname crps
@@ -138,8 +146,7 @@ scrps.matrix <- function(x, x2, y, ..., permutations = 1) {
138146
#' @rdname crps
139147
#' @export
140148
scrps.numeric <- function(x, x2, y, ..., permutations = 1) {
141-
stopifnot(length(x) == length(x2),
142-
length(y) == 1)
149+
stopifnot(length(x) == length(x2), length(y) == 1)
143150
scrps.matrix(as.matrix(x), as.matrix(x2), y, permutations)
144151
}
145152

@@ -155,40 +162,54 @@ loo_scrps.matrix <-
155162
...,
156163
permutations = 1,
157164
r_eff = 1,
158-
cores = getOption("mc.cores", 1)) {
159-
validate_crps_input(x, x2, y, log_lik)
160-
repeats <- replicate(permutations,
161-
EXX_loo_compute(x, x2, log_lik, r_eff = r_eff, ...),
162-
simplify = F)
163-
EXX <- Reduce(`+`, repeats) / permutations
164-
psis_obj <- psis(-log_lik, r_eff = r_eff, cores = cores)
165-
EXy <- E_loo(abs(sweep(x, 2, y)), psis_obj, log_ratios = -log_lik, ...)$value
166-
crps_output(.crps_fun(EXX, EXy, scale = TRUE))
167-
}
165+
cores = getOption("mc.cores", 1)
166+
) {
167+
validate_crps_input(x, x2, y, log_lik)
168+
repeats <- replicate(
169+
permutations,
170+
EXX_loo_compute(x, x2, log_lik, r_eff = r_eff, ...),
171+
simplify = F
172+
)
173+
EXX <- Reduce(`+`, repeats) / permutations
174+
psis_obj <- psis(-log_lik, r_eff = r_eff, cores = cores)
175+
EXy <- E_loo(
176+
abs(sweep(x, 2, y)),
177+
psis_obj,
178+
log_ratios = -log_lik,
179+
...
180+
)$value
181+
crps_output(.crps_fun(EXX, EXy, scale = TRUE))
182+
}
168183

169184
# ------------ Internals ----------------
170185

171-
172186
EXX_compute <- function(x, x2) {
173187
S <- nrow(x)
174-
colMeans(abs(x - x2[sample(1:S),]))
188+
colMeans(abs(x - x2[sample(1:S), ]))
175189
}
176190

177191

178192
EXX_loo_compute <- function(x, x2, log_lik, r_eff = 1, ...) {
179193
S <- nrow(x)
180-
shuffle <- sample (1:S)
181-
x2 <- x2[shuffle,]
182-
log_lik2 <- log_lik[shuffle,]
183-
psis_obj_joint <- psis(-log_lik - log_lik2 , r_eff = r_eff)
184-
E_loo(abs(x - x2), psis_obj_joint, log_ratios = -log_lik - log_lik2, ...)$value
194+
shuffle <- sample(1:S)
195+
x2 <- x2[shuffle, ]
196+
log_lik2 <- log_lik[shuffle, ]
197+
psis_obj_joint <- psis(-log_lik - log_lik2, r_eff = r_eff)
198+
E_loo(
199+
abs(x - x2),
200+
psis_obj_joint,
201+
log_ratios = -log_lik - log_lik2,
202+
...
203+
)$value
185204
}
186205

187206

188207
#' Function to compute crps and scrps
189208
#' @noRd
190209
.crps_fun <- function(EXX, EXy, scale = FALSE) {
191-
if (scale) return(-EXy/EXX - 0.5 * log(EXX))
210+
if (scale) {
211+
return(-EXy / EXX - 0.5 * log(EXX))
212+
}
192213
0.5 * EXX - EXy
193214
}
194215

@@ -208,11 +229,12 @@ crps_output <- function(crps_pw) {
208229
#' Check that predictive draws and observed data are of compatible shape
209230
#' @noRd
210231
validate_crps_input <- function(x, x2, y, log_lik = NULL) {
211-
stopifnot(is.numeric(x),
212-
is.numeric(x2),
213-
is.numeric(y),
214-
identical(dim(x), dim(x2)),
215-
ncol(x) == length(y),
216-
ifelse(is.null(log_lik), TRUE, identical(dim(log_lik), dim(x)))
217-
)
232+
stopifnot(
233+
is.numeric(x),
234+
is.numeric(x2),
235+
is.numeric(y),
236+
identical(dim(x), dim(x2)),
237+
ncol(x) == length(y),
238+
ifelse(is.null(log_lik), TRUE, identical(dim(log_lik), dim(x)))
239+
)
218240
}

0 commit comments

Comments
 (0)