@@ -119,79 +119,176 @@ NULL
119
119
# ' quantiles.
120
120
# ' @param trim Passed to [ggplot2::stat_density()].
121
121
# ' @template args-density-controls
122
+ # ' @param boundary_correction For `ppc_loo_pit_overlay()`, when set to `TRUE`
123
+ # ' (the default) the function will compute boundary corrected density values
124
+ # ' via convolution and a Gaussian filter, also known as the reflection method
125
+ # ' (Boneva et al., 1971). As a result, parameters controlling the standard
126
+ # ' kernel density estimation such as `adjust`, `kernel` and `n_dens` are
127
+ # ' ignored. NOTE: The current implementation only works well for continuous
128
+ # ' observations.
129
+ # ' @param grid_len For `ppc_loo_pit_overlay()`, when `boundary_correction` is
130
+ # ' set to `TRUE` this parameter specifies the number of points used to
131
+ # ' generate the estimations. This is set to 512 by default.
132
+ # '
133
+ # ' @references Boneva, L. I., Kendall, D., & Stefanov, I. (1971). Spline
134
+ # ' transformations: Three new diagnostic aids for the statistical
135
+ # ' data-analyst. *J. R. Stat. Soc. B* (Methodological), 33(1), 1-71.
136
+ # ' https://www.jstor.org/stable/2986005.
137
+ # '
122
138
ppc_loo_pit_overlay <- function (y ,
123
139
yrep ,
124
140
lw ,
125
- pit ,
126
- samples = 100 ,
127
141
... ,
142
+ pit = NULL ,
143
+ samples = 100 ,
128
144
size = 0.25 ,
129
145
alpha = 0.7 ,
130
- trim = FALSE ,
146
+ boundary_correction = TRUE ,
147
+ grid_len = 512 ,
131
148
bw = " nrd0" ,
149
+ trim = FALSE ,
132
150
adjust = 1 ,
133
151
kernel = " gaussian" ,
134
152
n_dens = 1024 ) {
135
153
check_ignored_arguments(... )
136
154
137
- if (! missing(pit )) {
138
- stopifnot(is.numeric(pit ), is_vector_or_1Darray(pit ))
139
- inform(" 'pit' specified so ignoring 'y','yrep','lw' if specified." )
140
- } else {
141
- suggested_package(" rstantools" )
142
- y <- validate_y(y )
143
- yrep <- validate_yrep(yrep , y )
144
- stopifnot(identical(dim(yrep ), dim(lw )))
145
- pit <- rstantools :: loo_pit(object = yrep , y = y , lw = lw )
155
+ data <-
156
+ ppc_loo_pit_data(
157
+ y = y ,
158
+ yrep = yrep ,
159
+ lw = lw ,
160
+ pit = pit ,
161
+ samples = samples ,
162
+ bw = bw ,
163
+ boundary_correction = boundary_correction ,
164
+ grid_len = grid_len
165
+ )
166
+
167
+ if (all(data $ value [data $ is_y ] %in% 0 : 1 )) {
168
+ warning(
169
+ " This plot is not recommended for binary data. " ,
170
+ " For plots that are more suitable see " ,
171
+ " \n https://avehtari.github.io/modelselection/diabetes.html#44_calibration_of_predictions" ,
172
+ call. = FALSE
173
+ )
146
174
}
147
175
148
- unifs <- matrix (runif(length(pit ) * samples ), nrow = samples )
176
+ if (boundary_correction ) {
177
+ message(" NOTE: Current boundary correction implementation works for continuous observations only." )
149
178
150
- data <- ppc_data(pit , unifs )
179
+ p <- ggplot(data ) +
180
+ aes_(x = ~ x , y = ~ value ) +
181
+ geom_line(
182
+ aes_(group = ~ rep_id , color = " yrep" ),
183
+ data = function (x ) dplyr :: filter(x , ! .data $ is_y ),
184
+ alpha = alpha ,
185
+ size = size ,
186
+ na.rm = TRUE ) +
187
+ geom_line(
188
+ aes_(color = " y" ),
189
+ data = function (x ) dplyr :: filter(x , .data $ is_y ),
190
+ size = 1 ,
191
+ lineend = " round" ,
192
+ na.rm = TRUE ) +
193
+ scale_x_continuous(
194
+ limits = c(0 , 1 ),
195
+ expand = expansion(0 , 0.01 ),
196
+ breaks = seq(0 , 1 , by = 0.25 ),
197
+ labels = c(" 0" , " 0.25" , " 0.5" , " 0.75" , " 1" )
198
+ )
151
199
152
- ggplot(data ) +
153
- aes_(x = ~ value ) +
154
- stat_density(
155
- aes_(group = ~ rep_id , color = " yrep" ),
156
- data = function (x ) dplyr :: filter(x , ! .data $ is_y ),
157
- geom = " line" ,
158
- position = " identity" ,
159
- size = size ,
160
- alpha = alpha ,
161
- trim = trim ,
162
- bw = bw ,
163
- adjust = adjust ,
164
- kernel = kernel ,
165
- n = n_dens ,
166
- na.rm = TRUE ) +
167
- stat_density(
168
- aes_(color = " y" ),
169
- data = function (x ) dplyr :: filter(x , .data $ is_y ),
170
- geom = " line" ,
171
- position = " identity" ,
172
- lineend = " round" ,
173
- size = 1 ,
174
- trim = trim ,
175
- bw = bw ,
176
- adjust = adjust ,
177
- kernel = kernel ,
178
- n = n_dens ,
179
- na.rm = TRUE ) +
180
- scale_color_ppc_dist(labels = c(" PIT" , " Unif" )) +
181
- scale_x_continuous(
182
- limits = c(.1 , .9 ),
183
- expand = expansion(0 , 0 ),
184
- breaks = seq(from = .1 , to = .9 , by = .2 )) +
185
- scale_y_continuous(
186
- limits = c(0 , NA ),
187
- expand = expansion(mult = c(0 , .25 ))) +
188
- bayesplot_theme_get() +
189
- yaxis_title(FALSE ) +
190
- xaxis_title(FALSE ) +
191
- yaxis_text(FALSE ) +
192
- yaxis_ticks(FALSE )
200
+ } else {
201
+ p <- ggplot(data ) +
202
+ aes_(x = ~ value ) +
203
+ stat_density(
204
+ aes_(group = ~ rep_id , color = " yrep" ),
205
+ data = function (x ) dplyr :: filter(x , ! .data $ is_y ),
206
+ geom = " line" ,
207
+ position = " identity" ,
208
+ size = size ,
209
+ alpha = alpha ,
210
+ trim = trim ,
211
+ bw = bw ,
212
+ adjust = adjust ,
213
+ kernel = kernel ,
214
+ n = n_dens ,
215
+ na.rm = TRUE ) +
216
+ stat_density(
217
+ aes_(color = " y" ),
218
+ data = function (x ) dplyr :: filter(x , .data $ is_y ),
219
+ geom = " line" ,
220
+ position = " identity" ,
221
+ lineend = " round" ,
222
+ size = 1 ,
223
+ trim = trim ,
224
+ bw = bw ,
225
+ adjust = adjust ,
226
+ kernel = kernel ,
227
+ n = n_dens ,
228
+ na.rm = TRUE ) +
229
+ scale_x_continuous(
230
+ limits = c(0.05 , 0.95 ),
231
+ expand = expansion(0 , 0 ),
232
+ breaks = seq(from = .1 , to = .9 , by = .2 )
233
+ )
234
+ }
235
+
236
+ p +
237
+ scale_color_ppc_dist(labels = c(" PIT" , " Unif" )) +
238
+ scale_y_continuous(
239
+ limits = c(0 , NA ),
240
+ expand = expansion(mult = c(0 , .25 ))
241
+ ) +
242
+ bayesplot_theme_get() +
243
+ yaxis_title(FALSE ) +
244
+ xaxis_title(FALSE ) +
245
+ yaxis_text(FALSE ) +
246
+ yaxis_ticks(FALSE )
193
247
}
194
248
249
+ # ' @rdname PPC-loo
250
+ # ' @export
251
+ ppc_loo_pit_data <-
252
+ function (y ,
253
+ yrep ,
254
+ lw ,
255
+ ... ,
256
+ pit = NULL ,
257
+ samples = 100 ,
258
+ bw = " nrd0" ,
259
+ boundary_correction = TRUE ,
260
+ grid_len = 512 ) {
261
+ if (! is.null(pit )) {
262
+ stopifnot(is.numeric(pit ), is_vector_or_1Darray(pit ))
263
+ inform(" 'pit' specified so ignoring 'y','yrep','lw' if specified." )
264
+ } else {
265
+ suggested_package(" rstantools" )
266
+ y <- validate_y(y )
267
+ yrep <- validate_yrep(yrep , y )
268
+ stopifnot(identical(dim(yrep ), dim(lw )))
269
+ pit <- rstantools :: loo_pit(object = yrep , y = y , lw = lw )
270
+ }
271
+
272
+ if (! boundary_correction ) {
273
+ unifs <- matrix (runif(length(pit ) * samples ), nrow = samples )
274
+ data <- ppc_data(pit , unifs )
275
+ } else {
276
+ unifs <- matrix (runif(grid_len * samples ), nrow = samples )
277
+ ref_list <- .ref_kde_correction(unifs , bw = bw , grid_len = grid_len )
278
+ pit_list <- .kde_correction(pit , bw = bw , grid_len = grid_len )
279
+
280
+ pit <- pit_list $ bc_pvals
281
+ unifs <- ref_list $ unifs
282
+ xs <- c(pit_list $ xs , ref_list $ xs )
283
+
284
+ data <-
285
+ ppc_data(pit , unifs ) %> %
286
+ dplyr :: arrange(.data $ rep_id ) %> %
287
+ mutate(x = xs )
288
+ }
289
+ data
290
+ }
291
+
195
292
196
293
# ' @rdname PPC-loo
197
294
# ' @export
@@ -458,3 +555,118 @@ ppc_loo_ribbon <-
458
555
return (psis_object )
459
556
}
460
557
558
+ # # Boundary correction based on code by ArViz development team
559
+ # The main method is a 1-D density estimation for linear data with
560
+ # convolution with a Gaussian filter.
561
+
562
+ # Based on scipy.signal.gaussian formula
563
+ .gaussian <- function (N , bw ){
564
+ n <- seq(0 , N - 1 ) - (N - 1 )/ 2
565
+ sigma = 2 * bw * bw
566
+ w = exp(- n ^ 2 / sigma )
567
+ return (w )
568
+
569
+ }
570
+
571
+ .linear_convolution <- function (x ,
572
+ bw ,
573
+ grid_counts ,
574
+ grid_breaks ,
575
+ grid_len ){
576
+ # 1-D Gaussian estimation via
577
+ # convolution of a Gaussian filter and the binned relative freqs
578
+ bin_width <- grid_breaks [2 ] - grid_breaks [1 ]
579
+ f <- grid_counts / bin_width / length(x )
580
+ bw <- bw / bin_width
581
+
582
+ # number of data points to generate for gaussian filter
583
+ gauss_n <- as.integer(bw * 2 * pi )
584
+ if (gauss_n == 0 ){
585
+ gauss_n = 1
586
+ }
587
+
588
+ # Generate Gaussian filter vector
589
+ kernel <- .gaussian(gauss_n , bw )
590
+ npad <- as.integer(grid_len / 5 )
591
+
592
+ # Reflection method (i.e. get first N and last N points to pad vector)
593
+ f <- c(rev(f [1 : (npad )]),
594
+ f ,
595
+ rev(f )[(grid_len - npad ): (grid_len - 1 )])
596
+
597
+ # Convolution: Gaussian filter + reflection method (pading) works as an
598
+ # averaging moving window based on a Gaussian density which takes care
599
+ # of the density boundary values near 0 and 1.
600
+ bc_pvals <- stats :: filter(f ,
601
+ kernel ,
602
+ method = ' convolution' ,
603
+ sides = 2 )[(npad + 1 ): (npad + grid_len )]
604
+
605
+ bc_pvals <- bc_pvals / (bw * (2 * pi )^ 0.5 )
606
+ return (bc_pvals )
607
+ }
608
+
609
+ .kde_correction <- function (x ,
610
+ bw ,
611
+ grid_len ){
612
+ # Generate boundary corrected values via a linear convolution using a
613
+ # 1-D Gaussian window filter. This method uses the "reflection method"
614
+ # to estimate these pvalues and helps speed up the code
615
+ if (any(is.infinite(x ))){
616
+ warning(paste(" Ignored" , sum(is.infinite(x )),
617
+ " Non-finite PIT values are invalid for KDE boundary correction method" ))
618
+ x <- x [is.finite(x )]
619
+ }
620
+
621
+ if (grid_len < 100 ){
622
+ grid_len = 100
623
+ }
624
+
625
+ # Get relative frequency boundaries and counts for input vector
626
+ bins <- seq(from = min(x ), to = max(x ), length.out = grid_len + 1 )
627
+ hist_obj <- hist(x , breaks = bins , plot = FALSE )
628
+ grid_breaks <- hist_obj $ breaks
629
+ grid_counts <- hist_obj $ counts
630
+
631
+ # Compute bandwidth based on use specification
632
+ bw <- density(x , bw = bw )$ bw
633
+
634
+ # 1-D Convolution
635
+ bc_pvals <- .linear_convolution(x , bw , grid_counts , grid_breaks , grid_len )
636
+
637
+ # Generate vector of x-axis values for plotting based on binned relative freqs
638
+ n_breaks <- length(grid_breaks )
639
+
640
+ xs <- (grid_breaks [2 : n_breaks ] + grid_breaks [1 : (n_breaks - 1 )]) / 2
641
+
642
+ first_nonNA <- head(which(! is.na(bc_pvals )),1 )
643
+ last_nonNA <- tail(which(! is.na(bc_pvals )),1 )
644
+ bc_pvals [1 : first_nonNA ] <- bc_pvals [first_nonNA ]
645
+ bc_pvals [last_nonNA : length(bc_pvals )] <- bc_pvals [last_nonNA ]
646
+
647
+ return (list (xs = xs , bc_pvals = bc_pvals ))
648
+ }
649
+
650
+ # Wrapper function to generate runif reference lines based on
651
+ # .kde_correction()
652
+ .ref_kde_correction <- function (unifs , bw , grid_len ){
653
+
654
+ # Allocate memory
655
+ idx <- seq(from = 1 ,
656
+ to = ncol(unifs )* nrow(unifs ) + ncol(unifs ),
657
+ by = ncol(unifs ))
658
+ idx <- c(idx , ncol(unifs )* nrow(unifs ))
659
+ xs <- rep(0 , ncol(unifs )* nrow(unifs ))
660
+ bc_mat <- matrix (0 , nrow(unifs ), ncol(unifs ))
661
+
662
+ # Generate boundary corrected reference values
663
+ for (i in 1 : nrow(unifs )){
664
+ bc_list <- .kde_correction(unifs [i ,],
665
+ bw = bw ,
666
+ grid_len = grid_len )
667
+ bc_mat [i ,] <- bc_list $ bc_pvals
668
+ xs [idx [i ]: (idx [i + 1 ]- 1 )] <- bc_list $ xs
669
+ }
670
+
671
+ return (list (xs = xs , unifs = bc_mat ))
672
+ }
0 commit comments