Skip to content

Commit b9bbc4f

Browse files
committed
rewrite arm's binned error function
1 parent ea1ca27 commit b9bbc4f

File tree

1 file changed

+61
-32
lines changed

1 file changed

+61
-32
lines changed

R/ppc-errors.R

Lines changed: 61 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -319,31 +319,25 @@ ppc_error_binned <- function(y, yrep, ..., size = 1, alpha = 0.25) {
319319
yrep <- validate_yrep(yrep, y)
320320
errors <- compute_errors(y, yrep)
321321

322-
ny <- length(y)
323-
if (ny >= 100) {
324-
nbins <- floor(sqrt(ny))
325-
} else if (ny > 10 && ny < 100) {
322+
N <- length(y)
323+
if (N >= 100) {
324+
nbins <- floor(sqrt(N))
325+
} else if (N > 10 && N < 100) {
326326
nbins <- 10
327327
} else {
328-
# if (ny <= 10)
329-
nbins <- floor(ny / 2)
328+
# if (N <= 10)
329+
nbins <- floor(N / 2)
330330
}
331331

332-
n <- nrow(yrep)
333-
binned <- .binner(
334-
rep_id = 1,
335-
ey = yrep[1, ],
336-
r = errors[1, ],
337-
nbins = nbins
338-
)
339-
if (n > 1) {
340-
for (i in 2:nrow(errors))
341-
binned <- rbind(binned, .binner(
342-
rep_id = i,
343-
ey = yrep[i,],
344-
r = errors[i,],
345-
nbins
346-
))
332+
S <- nrow(yrep)
333+
binned <- bin_errors(rep_id = 1, ey = yrep[1, ], r = errors[1, ],
334+
nbins = nbins)
335+
if (S > 1) {
336+
for (i in 2:nrow(errors)) {
337+
binned_i <- bin_errors(rep_id = i, ey = yrep[i,], r = errors[i,],
338+
nbins = nbins)
339+
binned <- rbind(binned, binned_i)
340+
}
347341
}
348342

349343
mixed_scheme <- is_mixed_scheme(color_scheme_get())
@@ -384,12 +378,13 @@ ppc_error_binned <- function(y, yrep, ..., size = 1, alpha = 0.25) {
384378
) +
385379
bayesplot_theme_get()
386380

387-
if (n > 1)
381+
if (S > 1) {
388382
graph <- graph +
389383
facet_wrap(
390384
facets = ~rep_id
391385
# labeller = label_bquote(italic(y)[rep](.(rep_id)))
392386
)
387+
}
393388

394389
graph +
395390
force_axes_in_facets() +
@@ -419,14 +414,48 @@ grouped_error_data <- function(y, yrep, group) {
419414
}
420415

421416

422-
.binner <- function(rep_id, ey, r, nbins) {
423-
binned_errors <- arm::binned.resids(ey, r, nbins)$binned
424-
binned_errors <- binned_errors[, c("xbar", "ybar", "2se")]
425-
if (length(dim(binned_errors)) < 2)
426-
binned_errors <- t(binned_errors)
427-
colnames(binned_errors) <- c("xbar", "ybar", "se2")
428-
data.frame(
429-
rep_id = as.integer(rep_id), #create_yrep_ids(rep_id),
430-
binned_errors
431-
)
417+
bin_errors <- function(rep_id, ey, r, nbins) {
418+
N <- length(ey)
419+
break_ids <- floor(N * (1:(nbins - 1)) / nbins)
420+
if (any(break_ids == 0)) {
421+
nbins <- 1
422+
}
423+
if (nbins == 1) {
424+
breaks <- c(-Inf, sum(range(ey)) / 2, Inf)
425+
} else {
426+
ey_sort <- sort(ey)
427+
breaks <- -Inf
428+
for (i in 1:(nbins - 1)) {
429+
break_i <- break_ids[i]
430+
ey_range <- ey_sort[c(break_i, break_i + 1)]
431+
if (diff(ey_range) == 0) {
432+
if (ey_range[1] == min(ey)) {
433+
ey_range[1] <- -Inf
434+
} else {
435+
ey_range[1] <- max(ey[ey < ey_range[1]])
436+
}
437+
}
438+
breaks <- c(breaks, sum(ey_range) / 2)
439+
}
440+
breaks <- unique(c(breaks, Inf))
441+
}
442+
443+
nbins <- length(breaks) - 1
444+
ey_binned <- as.numeric(cut(ey, breaks))
445+
446+
out <- matrix(NA, nrow = nbins, ncol = 3)
447+
for (i in 1:nbins) {
448+
mark <- which(ey_binned == i)
449+
ey_bar <- mean(ey[mark])
450+
r_bar <- mean(r[mark])
451+
s <- if (length(r[mark]) > 1) sd(r[mark]) else 0
452+
out[i, ] <- c(ey_bar, r_bar, 2 * s / sqrt(length(mark)))
453+
}
454+
out <- as.data.frame(out)
455+
colnames(out) <- c("xbar", "ybar", "se2")
456+
out$rep_id <- as.integer(rep_id)
457+
return(out)
432458
}
459+
460+
461+

0 commit comments

Comments
 (0)