@@ -98,13 +98,14 @@ psis <- function(log_ratios, ...) UseMethod("psis")
9898# ' @template array
9999# '
100100psis.array <-
101- function (log_ratios , ... ,
102- r_eff = 1 ,
103- cores = getOption(" mc.cores" , 1 )) {
104- importance_sampling.array(log_ratios = log_ratios , ... ,
105- r_eff = r_eff ,
106- cores = cores ,
107- method = " psis" )
101+ function (log_ratios , ... , r_eff = 1 , cores = getOption(" mc.cores" , 1 )) {
102+ importance_sampling.array(
103+ log_ratios = log_ratios ,
104+ ... ,
105+ r_eff = r_eff ,
106+ cores = cores ,
107+ method = " psis"
108+ )
108109 }
109110
110111
@@ -113,15 +114,14 @@ psis.array <-
113114# ' @template matrix
114115# '
115116psis.matrix <-
116- function (log_ratios ,
117- ... ,
118- r_eff = 1 ,
119- cores = getOption(" mc.cores" , 1 )) {
120- importance_sampling.matrix(log_ratios ,
121- ... ,
122- r_eff = r_eff ,
123- cores = cores ,
124- method = " psis" )
117+ function (log_ratios , ... , r_eff = 1 , cores = getOption(" mc.cores" , 1 )) {
118+ importance_sampling.matrix(
119+ log_ratios ,
120+ ... ,
121+ r_eff = r_eff ,
122+ cores = cores ,
123+ method = " psis"
124+ )
125125 }
126126
127127# ' @export
@@ -130,9 +130,12 @@ psis.matrix <-
130130# '
131131psis.default <-
132132 function (log_ratios , ... , r_eff = 1 ) {
133- importance_sampling.default(log_ratios = log_ratios , ... ,
134- r_eff = r_eff ,
135- method = " psis" )
133+ importance_sampling.default(
134+ log_ratios = log_ratios ,
135+ ... ,
136+ r_eff = r_eff ,
137+ method = " psis"
138+ )
136139 }
137140
138141
@@ -149,25 +152,26 @@ is.psis <- function(x) {
149152# ' @noRd
150153# ' @seealso importance_sampling_object
151154psis_object <-
152- function (unnormalized_log_weights ,
153- pareto_k ,
154- tail_len ,
155- r_eff ) {
156- importance_sampling_object(unnormalized_log_weights = unnormalized_log_weights ,
157- pareto_k = pareto_k ,
158- tail_len = tail_len ,
159- r_eff = r_eff ,
160- method = " psis" )
155+ function (unnormalized_log_weights , pareto_k , tail_len , r_eff ) {
156+ importance_sampling_object(
157+ unnormalized_log_weights = unnormalized_log_weights ,
158+ pareto_k = pareto_k ,
159+ tail_len = tail_len ,
160+ r_eff = r_eff ,
161+ method = " psis"
162+ )
161163 }
162164
163165
164166# ' @noRd
165167# ' @seealso do_importance_sampling
166- do_psis <- function (log_ratios , r_eff , cores , method ){
167- do_importance_sampling(log_ratios = log_ratios ,
168- r_eff = r_eff ,
169- cores = cores ,
170- method = " psis" )
168+ do_psis <- function (log_ratios , r_eff , cores , method ) {
169+ do_importance_sampling(
170+ log_ratios = log_ratios ,
171+ r_eff = r_eff ,
172+ cores = cores ,
173+ method = " psis"
174+ )
171175}
172176
173177# ' Extract named components from each list in the list of lists obtained by
@@ -181,7 +185,9 @@ do_psis <- function(log_ratios, r_eff, cores, method){
181185# ' @return Numeric vector or matrix.
182186# '
183187psis_apply <- function (x , item , fun = c(" [[" , " attr" ), fun_val = numeric (1 )) {
184- if (! is.list(x )) stop(" Internal error ('x' must be a list for psis_apply)" )
188+ if (! is.list(x )) {
189+ stop(" Internal error ('x' must be a list for psis_apply)" )
190+ }
185191 vapply(x , FUN = match.arg(fun ), FUN.VALUE = fun_val , item )
186192}
187193
@@ -212,7 +218,7 @@ do_psis_i <- function(log_ratios_i, tail_len_i, ...) {
212218 ord <- sort.int(lw_i , index.return = TRUE )
213219 tail_ids <- seq(S - tail_len_i + 1 , S )
214220 lw_tail <- ord $ x [tail_ids ]
215- if (abs(max(lw_tail ) - min(lw_tail )) < .Machine $ double.eps / 100 ) {
221+ if (abs(max(lw_tail ) - min(lw_tail )) < .Machine $ double.eps / 100 ) {
216222 warning(
217223 " Can't fit generalized Pareto distribution " ,
218224 " because all tail values are the same." ,
@@ -252,11 +258,11 @@ psis_smooth_tail <- function(x, cutoff) {
252258 k <- fit $ k
253259 sigma <- fit $ sigma
254260 if (is.finite(k )) {
255- p <- (seq_len(len ) - 0.5 ) / len
256- qq <- qgpd(p , k , sigma ) + exp_cutoff
257- tail <- log(qq )
261+ p <- (seq_len(len ) - 0.5 ) / len
262+ qq <- qgpd(p , k , sigma ) + exp_cutoff
263+ tail <- log(qq )
258264 } else {
259- tail <- x
265+ tail <- x
260266 }
261267 list (tail = tail , k = k )
262268}
@@ -322,7 +328,8 @@ throw_tail_length_warnings <- function(tail_lengths) {
322328 if (length(tail_lengths ) == 1 ) {
323329 warning(
324330 " Not enough tail samples to fit the generalized Pareto distribution." ,
325- call. = FALSE , immediate. = TRUE
331+ call. = FALSE ,
332+ immediate. = TRUE
326333 )
327334 } else {
328335 bad <- which(tail_len_bad )
@@ -332,7 +339,11 @@ throw_tail_length_warnings <- function(tail_lengths) {
332339 " in some or all columns of matrix of log importance ratios. " ,
333340 " Skipping the following columns: " ,
334341 paste(if (Nbad < = 10 ) bad else bad [1 : 10 ], collapse = " , " ),
335- if (Nbad > 10 ) paste0(" , ... [" , Nbad - 10 , " more not printed].\n " ) else " \n " ,
342+ if (Nbad > 10 ) {
343+ paste0(" , ... [" , Nbad - 10 , " more not printed].\n " )
344+ } else {
345+ " \n "
346+ },
336347 call. = FALSE ,
337348 immediate. = TRUE
338349 )
@@ -352,17 +363,20 @@ throw_tail_length_warnings <- function(tail_lengths) {
352363# ' * If `r_eff` is `NA` then `rep(1, len)` is returned.
353364# ' * If `r_eff` is a scalar then `rep(r_eff, len)` is returned.
354365# ' * If `r_eff` is not a scalar but the length is not `len` then an error is thrown.
355- # ' * If `r_eff` has length `len` but has `NA`s then an error is thrown .
366+ # ' * If `r_eff` has length `len` but has `NA`s then `NA`s are filled in with `1`s .
356367# '
357368prepare_psis_r_eff <- function (r_eff , len ) {
358369 if (isTRUE(is.null(r_eff ) || all(is.na(r_eff )))) {
359370 r_eff <- rep(1 , len )
360371 } else if (length(r_eff ) == 1 ) {
361372 r_eff <- rep(r_eff , len )
362373 } else if (length(r_eff ) != len ) {
363- stop(" 'r_eff' must have one value or one value per observation." , call. = FALSE )
374+ stop(
375+ " 'r_eff' must have one value or one value per observation." ,
376+ call. = FALSE
377+ )
364378 } else if (anyNA(r_eff )) {
365- message(" If `r_eff` has length `len` but has `NA`s then `NA`s are replaced with 1's. " )
379+ message(" Replacing NAs in `r_eff` with 1s " )
366380 r_eff [is.na(r_eff )] <- 1
367381 }
368382 r_eff
@@ -391,4 +405,3 @@ throw_psis_r_eff_warning <- function() {
391405 call. = FALSE
392406 )
393407}
394-
0 commit comments