Skip to content

Commit 06a70ba

Browse files
committed
Updated error message
1 parent cb48ff9 commit 06a70ba

File tree

3 files changed

+60
-47
lines changed

3 files changed

+60
-47
lines changed

R/psis.R

Lines changed: 58 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,14 @@ psis <- function(log_ratios, ...) UseMethod("psis")
9898
#' @template array
9999
#'
100100
psis.array <-
101-
function(log_ratios, ...,
102-
r_eff = 1,
103-
cores = getOption("mc.cores", 1)) {
104-
importance_sampling.array(log_ratios = log_ratios, ...,
105-
r_eff = r_eff,
106-
cores = cores,
107-
method = "psis")
101+
function(log_ratios, ..., r_eff = 1, cores = getOption("mc.cores", 1)) {
102+
importance_sampling.array(
103+
log_ratios = log_ratios,
104+
...,
105+
r_eff = r_eff,
106+
cores = cores,
107+
method = "psis"
108+
)
108109
}
109110

110111

@@ -113,15 +114,14 @@ psis.array <-
113114
#' @template matrix
114115
#'
115116
psis.matrix <-
116-
function(log_ratios,
117-
...,
118-
r_eff = 1,
119-
cores = getOption("mc.cores", 1)) {
120-
importance_sampling.matrix(log_ratios,
121-
...,
122-
r_eff = r_eff,
123-
cores = cores,
124-
method = "psis")
117+
function(log_ratios, ..., r_eff = 1, cores = getOption("mc.cores", 1)) {
118+
importance_sampling.matrix(
119+
log_ratios,
120+
...,
121+
r_eff = r_eff,
122+
cores = cores,
123+
method = "psis"
124+
)
125125
}
126126

127127
#' @export
@@ -130,9 +130,12 @@ psis.matrix <-
130130
#'
131131
psis.default <-
132132
function(log_ratios, ..., r_eff = 1) {
133-
importance_sampling.default(log_ratios = log_ratios, ...,
134-
r_eff = r_eff,
135-
method = "psis")
133+
importance_sampling.default(
134+
log_ratios = log_ratios,
135+
...,
136+
r_eff = r_eff,
137+
method = "psis"
138+
)
136139
}
137140

138141

@@ -149,25 +152,26 @@ is.psis <- function(x) {
149152
#' @noRd
150153
#' @seealso importance_sampling_object
151154
psis_object <-
152-
function(unnormalized_log_weights,
153-
pareto_k,
154-
tail_len,
155-
r_eff) {
156-
importance_sampling_object(unnormalized_log_weights = unnormalized_log_weights,
157-
pareto_k = pareto_k,
158-
tail_len = tail_len,
159-
r_eff = r_eff,
160-
method = "psis")
155+
function(unnormalized_log_weights, pareto_k, tail_len, r_eff) {
156+
importance_sampling_object(
157+
unnormalized_log_weights = unnormalized_log_weights,
158+
pareto_k = pareto_k,
159+
tail_len = tail_len,
160+
r_eff = r_eff,
161+
method = "psis"
162+
)
161163
}
162164

163165

164166
#' @noRd
165167
#' @seealso do_importance_sampling
166-
do_psis <- function(log_ratios, r_eff, cores, method){
167-
do_importance_sampling(log_ratios = log_ratios,
168-
r_eff = r_eff,
169-
cores = cores,
170-
method = "psis")
168+
do_psis <- function(log_ratios, r_eff, cores, method) {
169+
do_importance_sampling(
170+
log_ratios = log_ratios,
171+
r_eff = r_eff,
172+
cores = cores,
173+
method = "psis"
174+
)
171175
}
172176

173177
#' Extract named components from each list in the list of lists obtained by
@@ -181,7 +185,9 @@ do_psis <- function(log_ratios, r_eff, cores, method){
181185
#' @return Numeric vector or matrix.
182186
#'
183187
psis_apply <- function(x, item, fun = c("[[", "attr"), fun_val = numeric(1)) {
184-
if (!is.list(x)) stop("Internal error ('x' must be a list for psis_apply)")
188+
if (!is.list(x)) {
189+
stop("Internal error ('x' must be a list for psis_apply)")
190+
}
185191
vapply(x, FUN = match.arg(fun), FUN.VALUE = fun_val, item)
186192
}
187193

@@ -212,7 +218,7 @@ do_psis_i <- function(log_ratios_i, tail_len_i, ...) {
212218
ord <- sort.int(lw_i, index.return = TRUE)
213219
tail_ids <- seq(S - tail_len_i + 1, S)
214220
lw_tail <- ord$x[tail_ids]
215-
if (abs(max(lw_tail) - min(lw_tail)) < .Machine$double.eps/100) {
221+
if (abs(max(lw_tail) - min(lw_tail)) < .Machine$double.eps / 100) {
216222
warning(
217223
"Can't fit generalized Pareto distribution ",
218224
"because all tail values are the same.",
@@ -252,11 +258,11 @@ psis_smooth_tail <- function(x, cutoff) {
252258
k <- fit$k
253259
sigma <- fit$sigma
254260
if (is.finite(k)) {
255-
p <- (seq_len(len) - 0.5) / len
256-
qq <- qgpd(p, k, sigma) + exp_cutoff
257-
tail <- log(qq)
261+
p <- (seq_len(len) - 0.5) / len
262+
qq <- qgpd(p, k, sigma) + exp_cutoff
263+
tail <- log(qq)
258264
} else {
259-
tail <- x
265+
tail <- x
260266
}
261267
list(tail = tail, k = k)
262268
}
@@ -322,7 +328,8 @@ throw_tail_length_warnings <- function(tail_lengths) {
322328
if (length(tail_lengths) == 1) {
323329
warning(
324330
"Not enough tail samples to fit the generalized Pareto distribution.",
325-
call. = FALSE, immediate. = TRUE
331+
call. = FALSE,
332+
immediate. = TRUE
326333
)
327334
} else {
328335
bad <- which(tail_len_bad)
@@ -332,7 +339,11 @@ throw_tail_length_warnings <- function(tail_lengths) {
332339
"in some or all columns of matrix of log importance ratios. ",
333340
"Skipping the following columns: ",
334341
paste(if (Nbad <= 10) bad else bad[1:10], collapse = ", "),
335-
if (Nbad > 10) paste0(", ... [", Nbad - 10, " more not printed].\n") else "\n",
342+
if (Nbad > 10) {
343+
paste0(", ... [", Nbad - 10, " more not printed].\n")
344+
} else {
345+
"\n"
346+
},
336347
call. = FALSE,
337348
immediate. = TRUE
338349
)
@@ -352,17 +363,20 @@ throw_tail_length_warnings <- function(tail_lengths) {
352363
#' * If `r_eff` is `NA` then `rep(1, len)` is returned.
353364
#' * If `r_eff` is a scalar then `rep(r_eff, len)` is returned.
354365
#' * If `r_eff` is not a scalar but the length is not `len` then an error is thrown.
355-
#' * If `r_eff` has length `len` but has `NA`s then an error is thrown.
366+
#' * If `r_eff` has length `len` but has `NA`s then `NA`s are filled in with `1`s.
356367
#'
357368
prepare_psis_r_eff <- function(r_eff, len) {
358369
if (isTRUE(is.null(r_eff) || all(is.na(r_eff)))) {
359370
r_eff <- rep(1, len)
360371
} else if (length(r_eff) == 1) {
361372
r_eff <- rep(r_eff, len)
362373
} else if (length(r_eff) != len) {
363-
stop("'r_eff' must have one value or one value per observation.", call. = FALSE)
374+
stop(
375+
"'r_eff' must have one value or one value per observation.",
376+
call. = FALSE
377+
)
364378
} else if (anyNA(r_eff)) {
365-
message("If `r_eff` has length `len` but has `NA`s then `NA`s are replaced with 1's.")
379+
message("Replacing NAs in `r_eff` with 1s")
366380
r_eff[is.na(r_eff)] <- 1
367381
}
368382
r_eff
@@ -391,4 +405,3 @@ throw_psis_r_eff_warning <- function() {
391405
call. = FALSE
392406
)
393407
}
394-

tests/testthat/_snaps/psis.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4785,7 +4785,7 @@
47854785
Code
47864786
psis(-LLarr, r_eff = r_eff_arr)
47874787
Message
4788-
If `r_eff` has length `len` but has `NA`s then `NA`s are replaced with 1's.
4788+
Replacing NAs in `r_eff` with 1s
47894789
Output
47904790
Computed from 1000 by 32 log-weights matrix.
47914791
MCSE and ESS estimates assume MCMC draws (r_eff in [0.6, 1.0]).

tests/testthat/_snaps/tisis.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Code
44
psis(-LLarr, r_eff = r_eff_arr)
55
Message
6-
If `r_eff` has length `len` but has `NA`s then `NA`s are replaced with 1's.
6+
Replacing NAs in `r_eff` with 1s
77
Output
88
Computed from 1000 by 32 log-weights matrix.
99
MCSE and ESS estimates assume MCMC draws (r_eff in [0.6, 1.0]).

0 commit comments

Comments
 (0)