Skip to content

Commit 639c9f1

Browse files
committed
Merge branch 'ridgeline-size' of https://github.com/stan-dev/bayesplot into ridgeline-size
2 parents a27059a + 29e7e68 commit 639c9f1

File tree

6 files changed

+324
-62
lines changed

6 files changed

+324
-62
lines changed

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ Authors@R: c(person("Jonah", "Gabry", role = c("aut", "cre"), email = "jsg2201@c
88
person("Paul-Christian", "Bürkner", role = "ctb"),
99
person("Martin", "Modrák", role = "ctb"),
1010
person("Malcolm", "Barrett", role = "ctb"),
11-
person("Frank", "Weber", role = "ctb"))
11+
person("Frank", "Weber", role = "ctb"),
12+
person("Eduardo", "Coronado Sroka", role = "ctb"))
1213
Maintainer: Jonah Gabry <[email protected]>
1314
Description: Plotting functions for posterior analysis, MCMC diagnostics,
1415
prior and posterior predictive checks, and other visualizations

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ export(ppc_intervals_grouped)
120120
export(ppc_km_overlay)
121121
export(ppc_loo_intervals)
122122
export(ppc_loo_pit)
123+
export(ppc_loo_pit_data)
123124
export(ppc_loo_pit_overlay)
124125
export(ppc_loo_pit_qq)
125126
export(ppc_loo_ribbon)

NEWS.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@
2929

3030
* Added missing `facet_args` argument to `mcmc_rank_overlay()`. (#221, @hhau)
3131

32+
* `ppc_loo_pit_overlay()` now uses a boundary correction for an improved kernel
33+
density estimation. The new argument `boundary_correction` defaults to TRUE but
34+
can be set to FALSE to recover the old version of the plot. (#171, #235,
35+
@ecoronado92)
36+
3237

3338
# bayesplot 1.7.2
3439

R/ppc-loo.R

Lines changed: 267 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -119,79 +119,176 @@ NULL
119119
#' quantiles.
120120
#' @param trim Passed to [ggplot2::stat_density()].
121121
#' @template args-density-controls
122+
#' @param boundary_correction For `ppc_loo_pit_overlay()`, when set to `TRUE`
123+
#' (the default) the function will compute boundary corrected density values
124+
#' via convolution and a Gaussian filter, also known as the reflection method
125+
#' (Boneva et al., 1971). As a result, parameters controlling the standard
126+
#' kernel density estimation such as `adjust`, `kernel` and `n_dens` are
127+
#' ignored. NOTE: The current implementation only works well for continuous
128+
#' observations.
129+
#' @param grid_len For `ppc_loo_pit_overlay()`, when `boundary_correction` is
130+
#' set to `TRUE` this parameter specifies the number of points used to
131+
#' generate the estimations. This is set to 512 by default.
132+
#'
133+
#' @references Boneva, L. I., Kendall, D., & Stefanov, I. (1971). Spline
134+
#' transformations: Three new diagnostic aids for the statistical
135+
#' data-analyst. *J. R. Stat. Soc. B* (Methodological), 33(1), 1-71.
136+
#' https://www.jstor.org/stable/2986005.
137+
#'
122138
ppc_loo_pit_overlay <- function(y,
123139
yrep,
124140
lw,
125-
pit,
126-
samples = 100,
127141
...,
142+
pit = NULL,
143+
samples = 100,
128144
size = 0.25,
129145
alpha = 0.7,
130-
trim = FALSE,
146+
boundary_correction = TRUE,
147+
grid_len = 512,
131148
bw = "nrd0",
149+
trim = FALSE,
132150
adjust = 1,
133151
kernel = "gaussian",
134152
n_dens = 1024) {
135153
check_ignored_arguments(...)
136154

137-
if (!missing(pit)) {
138-
stopifnot(is.numeric(pit), is_vector_or_1Darray(pit))
139-
inform("'pit' specified so ignoring 'y','yrep','lw' if specified.")
140-
} else {
141-
suggested_package("rstantools")
142-
y <- validate_y(y)
143-
yrep <- validate_yrep(yrep, y)
144-
stopifnot(identical(dim(yrep), dim(lw)))
145-
pit <- rstantools::loo_pit(object = yrep, y = y, lw = lw)
155+
data <-
156+
ppc_loo_pit_data(
157+
y = y,
158+
yrep = yrep,
159+
lw = lw,
160+
pit = pit,
161+
samples = samples,
162+
bw = bw,
163+
boundary_correction = boundary_correction,
164+
grid_len = grid_len
165+
)
166+
167+
if (all(data$value[data$is_y] %in% 0:1)) {
168+
warning(
169+
"This plot is not recommended for binary data. ",
170+
"For plots that are more suitable see ",
171+
"\nhttps://avehtari.github.io/modelselection/diabetes.html#44_calibration_of_predictions",
172+
call. = FALSE
173+
)
146174
}
147175

148-
unifs <- matrix(runif(length(pit) * samples), nrow = samples)
176+
if (boundary_correction) {
177+
message("NOTE: Current boundary correction implementation works for continuous observations only.")
149178

150-
data <- ppc_data(pit, unifs)
179+
p <- ggplot(data) +
180+
aes_(x = ~ x, y = ~ value) +
181+
geom_line(
182+
aes_(group = ~ rep_id, color = "yrep"),
183+
data = function(x) dplyr::filter(x, !.data$is_y),
184+
alpha = alpha,
185+
size = size,
186+
na.rm = TRUE) +
187+
geom_line(
188+
aes_(color = "y"),
189+
data = function(x) dplyr::filter(x, .data$is_y),
190+
size = 1,
191+
lineend = "round",
192+
na.rm = TRUE) +
193+
scale_x_continuous(
194+
limits = c(0, 1),
195+
expand = expansion(0, 0.01),
196+
breaks = seq(0, 1, by = 0.25),
197+
labels = c("0", "0.25", "0.5", "0.75", "1")
198+
)
151199

152-
ggplot(data) +
153-
aes_(x = ~ value) +
154-
stat_density(
155-
aes_(group = ~ rep_id, color = "yrep"),
156-
data = function(x) dplyr::filter(x, !.data$is_y),
157-
geom = "line",
158-
position = "identity",
159-
size = size,
160-
alpha = alpha,
161-
trim = trim,
162-
bw = bw,
163-
adjust = adjust,
164-
kernel = kernel,
165-
n = n_dens,
166-
na.rm = TRUE) +
167-
stat_density(
168-
aes_(color = "y"),
169-
data = function(x) dplyr::filter(x, .data$is_y),
170-
geom = "line",
171-
position = "identity",
172-
lineend = "round",
173-
size = 1,
174-
trim = trim,
175-
bw = bw,
176-
adjust = adjust,
177-
kernel = kernel,
178-
n = n_dens,
179-
na.rm = TRUE) +
180-
scale_color_ppc_dist(labels = c("PIT", "Unif")) +
181-
scale_x_continuous(
182-
limits = c(.1, .9),
183-
expand = expansion(0, 0),
184-
breaks = seq(from = .1, to = .9, by = .2)) +
185-
scale_y_continuous(
186-
limits = c(0, NA),
187-
expand = expansion(mult = c(0, .25))) +
188-
bayesplot_theme_get() +
189-
yaxis_title(FALSE) +
190-
xaxis_title(FALSE) +
191-
yaxis_text(FALSE) +
192-
yaxis_ticks(FALSE)
200+
} else {
201+
p <- ggplot(data) +
202+
aes_(x = ~ value) +
203+
stat_density(
204+
aes_(group = ~ rep_id, color = "yrep"),
205+
data = function(x) dplyr::filter(x, !.data$is_y),
206+
geom = "line",
207+
position = "identity",
208+
size = size,
209+
alpha = alpha,
210+
trim = trim,
211+
bw = bw,
212+
adjust = adjust,
213+
kernel = kernel,
214+
n = n_dens,
215+
na.rm = TRUE) +
216+
stat_density(
217+
aes_(color = "y"),
218+
data = function(x) dplyr::filter(x, .data$is_y),
219+
geom = "line",
220+
position = "identity",
221+
lineend = "round",
222+
size = 1,
223+
trim = trim,
224+
bw = bw,
225+
adjust = adjust,
226+
kernel = kernel,
227+
n = n_dens,
228+
na.rm = TRUE) +
229+
scale_x_continuous(
230+
limits = c(0.05, 0.95),
231+
expand = expansion(0, 0),
232+
breaks = seq(from = .1, to = .9, by = .2)
233+
)
234+
}
235+
236+
p +
237+
scale_color_ppc_dist(labels = c("PIT", "Unif")) +
238+
scale_y_continuous(
239+
limits = c(0, NA),
240+
expand = expansion(mult = c(0, .25))
241+
) +
242+
bayesplot_theme_get() +
243+
yaxis_title(FALSE) +
244+
xaxis_title(FALSE) +
245+
yaxis_text(FALSE) +
246+
yaxis_ticks(FALSE)
193247
}
194248

249+
#' @rdname PPC-loo
250+
#' @export
251+
ppc_loo_pit_data <-
252+
function(y,
253+
yrep,
254+
lw,
255+
...,
256+
pit = NULL,
257+
samples = 100,
258+
bw = "nrd0",
259+
boundary_correction = TRUE,
260+
grid_len = 512) {
261+
if (!is.null(pit)) {
262+
stopifnot(is.numeric(pit), is_vector_or_1Darray(pit))
263+
inform("'pit' specified so ignoring 'y','yrep','lw' if specified.")
264+
} else {
265+
suggested_package("rstantools")
266+
y <- validate_y(y)
267+
yrep <- validate_yrep(yrep, y)
268+
stopifnot(identical(dim(yrep), dim(lw)))
269+
pit <- rstantools::loo_pit(object = yrep, y = y, lw = lw)
270+
}
271+
272+
if (!boundary_correction) {
273+
unifs <- matrix(runif(length(pit) * samples), nrow = samples)
274+
data <- ppc_data(pit, unifs)
275+
} else {
276+
unifs <- matrix(runif(grid_len * samples), nrow = samples)
277+
ref_list <- .ref_kde_correction(unifs, bw = bw, grid_len = grid_len)
278+
pit_list <- .kde_correction(pit, bw = bw, grid_len = grid_len)
279+
280+
pit <- pit_list$bc_pvals
281+
unifs <- ref_list$unifs
282+
xs <- c(pit_list$xs, ref_list$xs)
283+
284+
data <-
285+
ppc_data(pit, unifs) %>%
286+
dplyr::arrange(.data$rep_id) %>%
287+
mutate(x = xs)
288+
}
289+
data
290+
}
291+
195292

196293
#' @rdname PPC-loo
197294
#' @export
@@ -458,3 +555,118 @@ ppc_loo_ribbon <-
458555
return(psis_object)
459556
}
460557

558+
## Boundary correction based on code by ArViz development team
559+
# The main method is a 1-D density estimation for linear data with
560+
# convolution with a Gaussian filter.
561+
562+
# Based on scipy.signal.gaussian formula
563+
.gaussian <- function(N, bw){
564+
n <- seq(0, N -1) - (N - 1)/2
565+
sigma = 2 * bw * bw
566+
w = exp(-n^2 / sigma)
567+
return(w)
568+
569+
}
570+
571+
.linear_convolution <- function(x,
572+
bw,
573+
grid_counts,
574+
grid_breaks,
575+
grid_len){
576+
# 1-D Gaussian estimation via
577+
# convolution of a Gaussian filter and the binned relative freqs
578+
bin_width <- grid_breaks[2] - grid_breaks[1]
579+
f <- grid_counts / bin_width / length(x)
580+
bw <- bw / bin_width
581+
582+
# number of data points to generate for gaussian filter
583+
gauss_n <- as.integer(bw * 2 *pi)
584+
if (gauss_n == 0){
585+
gauss_n = 1
586+
}
587+
588+
# Generate Gaussian filter vector
589+
kernel <- .gaussian(gauss_n, bw)
590+
npad <- as.integer(grid_len / 5)
591+
592+
# Reflection method (i.e. get first N and last N points to pad vector)
593+
f <- c(rev(f[1:(npad)]),
594+
f,
595+
rev(f)[(grid_len - npad):(grid_len - 1)])
596+
597+
# Convolution: Gaussian filter + reflection method (pading) works as an
598+
# averaging moving window based on a Gaussian density which takes care
599+
# of the density boundary values near 0 and 1.
600+
bc_pvals <- stats::filter(f,
601+
kernel,
602+
method = 'convolution',
603+
sides = 2)[(npad + 1):(npad + grid_len)]
604+
605+
bc_pvals <- bc_pvals / (bw * (2 * pi)^0.5)
606+
return(bc_pvals)
607+
}
608+
609+
.kde_correction <- function(x,
610+
bw,
611+
grid_len){
612+
# Generate boundary corrected values via a linear convolution using a
613+
# 1-D Gaussian window filter. This method uses the "reflection method"
614+
# to estimate these pvalues and helps speed up the code
615+
if (any(is.infinite(x))){
616+
warning(paste("Ignored", sum(is.infinite(x)),
617+
"Non-finite PIT values are invalid for KDE boundary correction method"))
618+
x <- x[is.finite(x)]
619+
}
620+
621+
if (grid_len < 100){
622+
grid_len = 100
623+
}
624+
625+
# Get relative frequency boundaries and counts for input vector
626+
bins <- seq(from= min(x), to = max(x), length.out = grid_len + 1)
627+
hist_obj <- hist(x, breaks = bins, plot = FALSE)
628+
grid_breaks <- hist_obj$breaks
629+
grid_counts <- hist_obj$counts
630+
631+
# Compute bandwidth based on use specification
632+
bw <- density(x, bw = bw)$bw
633+
634+
# 1-D Convolution
635+
bc_pvals <- .linear_convolution(x, bw, grid_counts, grid_breaks, grid_len)
636+
637+
# Generate vector of x-axis values for plotting based on binned relative freqs
638+
n_breaks <- length(grid_breaks)
639+
640+
xs <- (grid_breaks[2:n_breaks] + grid_breaks[1:(n_breaks - 1)]) / 2
641+
642+
first_nonNA <- head(which(!is.na(bc_pvals)),1)
643+
last_nonNA <- tail(which(!is.na(bc_pvals)),1)
644+
bc_pvals[1:first_nonNA] <- bc_pvals[first_nonNA]
645+
bc_pvals[last_nonNA:length(bc_pvals)] <- bc_pvals[last_nonNA]
646+
647+
return(list(xs = xs, bc_pvals = bc_pvals))
648+
}
649+
650+
# Wrapper function to generate runif reference lines based on
651+
# .kde_correction()
652+
.ref_kde_correction <- function(unifs, bw, grid_len){
653+
654+
# Allocate memory
655+
idx <- seq(from = 1,
656+
to = ncol(unifs)*nrow(unifs) + ncol(unifs),
657+
by = ncol(unifs))
658+
idx <- c(idx, ncol(unifs)*nrow(unifs))
659+
xs <- rep(0, ncol(unifs)*nrow(unifs))
660+
bc_mat <- matrix(0, nrow(unifs), ncol(unifs))
661+
662+
# Generate boundary corrected reference values
663+
for (i in 1:nrow(unifs)){
664+
bc_list <- .kde_correction(unifs[i,],
665+
bw = bw,
666+
grid_len = grid_len)
667+
bc_mat[i,] <- bc_list$bc_pvals
668+
xs[idx[i]:(idx[i+1]-1)] <- bc_list$xs
669+
}
670+
671+
return(list(xs = xs, unifs = bc_mat))
672+
}

0 commit comments

Comments
 (0)