Skip to content

Commit 96ac346

Browse files
Merge pull request #48 from r-causal/plot_ess
Add ESS extensions
2 parents 33ba280 + 2d3ba81 commit 96ac346

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1540
-236
lines changed

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Imports:
2323
propensity (>= 0.0.0.9000),
2424
purrr,
2525
rlang,
26+
scales,
2627
smd,
2728
tibble,
2829
tidyr,

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ export(bind_matches)
1919
export(check_auc)
2020
export(check_balance)
2121
export(check_calibration)
22+
export(check_ess)
2223
export(contains)
2324
export(ends_with)
2425
export(ess)
@@ -38,6 +39,7 @@ export(one_of)
3839
export(peek_vars)
3940
export(plot_balance)
4041
export(plot_calibration)
42+
export(plot_ess)
4143
export(plot_mirror_distributions)
4244
export(plot_qq)
4345
export(plot_roc_auc)

R/check_auc.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ roc_curve <- function(
226226
}
227227

228228
weights <- extract_weight_data(weights)
229-
229+
230230
# Handle zero and negative weights
231231
if (any(weights <= 0, na.rm = TRUE)) {
232232
n_zero_neg <- sum(weights <= 0, na.rm = TRUE)

R/check_ess.R

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
#' Check Effective Sample Size
2+
#'
3+
#' Computes the effective sample size (ESS) for one or more weighting schemes,
4+
#' optionally stratified by treatment groups. ESS reflects how many observations
5+
#' you would have if all were equally weighted.
6+
#'
7+
#' @details
8+
#' The effective sample size (ESS) is calculated using the classical formula:
9+
#' \eqn{ESS = (\sum w)^2 / \sum(w^2)}.
10+
#'
11+
#' When weights vary substantially, the ESS can be much smaller than the actual
12+
#' number of observations, indicating that a few observations carry
13+
#' disproportionately large weights.
14+
#'
15+
#' When `.group` is provided, ESS is calculated separately for each group level:
16+
#' - For binary/categorical exposures: ESS is computed within each treatment level
17+
#' - For continuous exposures: The variable is divided into quantiles (using
18+
#' `dplyr::ntile()`) and ESS is computed within each quantile
19+
#'
20+
#' The function returns results in a tidy format suitable for plotting or
21+
#' further analysis.
22+
#'
23+
#' @inheritParams check_params
24+
#' @param .group Optional grouping variable. When provided, ESS is calculated
25+
#' separately for each group level. For continuous variables, groups are
26+
#' created using quantiles.
27+
#' @param n_tiles For continuous `.group` variables, the number of quantile
28+
#' groups to create. Default is 4 (quartiles).
29+
#' @param tile_labels Optional character vector of labels for the quantile groups
30+
#' when `.group` is continuous. If NULL, uses "Q1", "Q2", etc.
31+
#'
32+
#' @return A tibble with columns:
33+
#' \item{method}{Character. The weighting method ("observed" or weight variable name).}
34+
#' \item{group}{Character. The group level (if `.group` is provided).}
35+
#' \item{n}{Integer. The number of observations in the group.}
36+
#' \item{ess}{Numeric. The effective sample size.}
37+
#' \item{ess_pct}{Numeric. ESS as a percentage of the actual sample size.}
38+
#'
39+
#' @family balance functions
40+
#' @seealso [ess()] for the underlying ESS calculation, [plot_ess()] for visualization
41+
#'
42+
#' @examples
43+
#' # Overall ESS for different weighting schemes
44+
#' check_ess(nhefs_weights, .wts = c(w_ate, w_att, w_atm))
45+
#'
46+
#' # ESS by treatment group (binary exposure)
47+
#' check_ess(nhefs_weights, .wts = c(w_ate, w_att), .group = qsmk)
48+
#'
49+
#' # ESS by treatment group (categorical exposure)
50+
#' check_ess(nhefs_weights, .wts = w_cat_ate, .group = alcoholfreq_cat)
51+
#'
52+
#' # ESS by quartiles of a continuous variable
53+
#' check_ess(nhefs_weights, .wts = w_ate, .group = age, n_tiles = 4)
54+
#'
55+
#' # Custom labels for continuous groups
56+
#' check_ess(nhefs_weights, .wts = w_ate, .group = age,
57+
#' n_tiles = 3, tile_labels = c("Young", "Middle", "Older"))
58+
#'
59+
#' # Without unweighted comparison
60+
#' check_ess(nhefs_weights, .wts = w_ate, .group = qsmk,
61+
#' include_observed = FALSE)
62+
#'
63+
#' @export
64+
check_ess <- function(
65+
.data,
66+
.wts = NULL,
67+
.group = NULL,
68+
include_observed = TRUE,
69+
n_tiles = 4,
70+
tile_labels = NULL
71+
) {
72+
# Validate inputs
73+
validate_data_frame(.data)
74+
75+
# Handle group variable
76+
group_quo <- rlang::enquo(.group)
77+
has_group <- !rlang::quo_is_null(group_quo)
78+
79+
if (has_group) {
80+
group_name <- get_column_name(group_quo, ".group")
81+
validate_column_exists(.data, group_name, ".group")
82+
group_var <- .data[[group_name]]
83+
84+
# Check if continuous (numeric and more than 10 unique values)
85+
is_continuous <- is.numeric(group_var) &&
86+
length(unique(stats::na.omit(group_var))) > 10
87+
88+
if (is_continuous) {
89+
# Create quantile groups
90+
if (!is.null(tile_labels) && length(tile_labels) != n_tiles) {
91+
abort(
92+
"Length of {.arg tile_labels} must equal {.arg n_tiles}",
93+
error_class = "halfmoon_length_error"
94+
)
95+
}
96+
97+
# Create tile groups
98+
.data$.ess_group <- dplyr::ntile(group_var, n_tiles)
99+
100+
# Apply labels
101+
if (is.null(tile_labels)) {
102+
tile_labels <- paste0("Q", seq_len(n_tiles))
103+
}
104+
.data$.ess_group <- factor(
105+
.data$.ess_group,
106+
levels = seq_len(n_tiles),
107+
labels = tile_labels
108+
)
109+
group_col <- ".ess_group"
110+
} else {
111+
group_col <- group_name
112+
}
113+
}
114+
115+
# Handle weights
116+
wts_quo <- rlang::enquo(.wts)
117+
118+
if (rlang::quo_is_null(wts_quo)) {
119+
# No weights provided, just use observed
120+
wts_names <- character()
121+
} else {
122+
wts_cols <- tidyselect::eval_select(wts_quo, .data)
123+
wts_names <- names(wts_cols)
124+
125+
# Convert psw weight columns to numeric
126+
for (wts_name in wts_names) {
127+
.data[[wts_name]] <- extract_weight_data(.data[[wts_name]])
128+
}
129+
}
130+
131+
# Add observed if requested
132+
if (include_observed || length(wts_names) == 0) {
133+
.data$.observed <- 1
134+
wts_names <- c(".observed", wts_names)
135+
}
136+
137+
# Reshape to long format
138+
plot_data <- tidyr::pivot_longer(
139+
.data,
140+
cols = dplyr::all_of(wts_names),
141+
names_to = "method",
142+
values_to = "weight"
143+
)
144+
145+
# Clean up method names
146+
plot_data$method <- ifelse(
147+
plot_data$method == ".observed",
148+
"observed",
149+
plot_data$method
150+
)
151+
152+
# Calculate ESS
153+
if (has_group) {
154+
# Group-wise ESS
155+
ess_data <- plot_data |>
156+
dplyr::group_by(method, .data[[group_col]]) |>
157+
dplyr::summarise(
158+
n = dplyr::n(),
159+
ess = ess(weight, na.rm = TRUE),
160+
ess_pct = ess / n * 100,
161+
.groups = "drop"
162+
) |>
163+
dplyr::rename(group = !!group_col)
164+
} else {
165+
# Overall ESS
166+
ess_data <- plot_data |>
167+
dplyr::group_by(method) |>
168+
dplyr::summarise(
169+
n = dplyr::n(),
170+
ess = ess(weight, na.rm = TRUE),
171+
ess_pct = ess / n * 100,
172+
.groups = "drop"
173+
)
174+
}
175+
176+
# Clean up temporary columns
177+
if (has_group && is_continuous && ".ess_group" %in% names(ess_data)) {
178+
ess_data <- dplyr::select(ess_data, -.ess_group)
179+
}
180+
181+
ess_data
182+
}

R/compute_balance.R

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -370,12 +370,14 @@ bal_ks <- function(
370370
p_ref <- if (is.null(weights)) {
371371
mean(covariate[idx_ref])
372372
} else {
373-
sum(extract_weight_data(weights)[idx_ref] * covariate[idx_ref]) / sum(extract_weight_data(weights)[idx_ref])
373+
sum(extract_weight_data(weights)[idx_ref] * covariate[idx_ref]) /
374+
sum(extract_weight_data(weights)[idx_ref])
374375
}
375376
p_other <- if (is.null(weights)) {
376377
mean(covariate[idx_other])
377378
} else {
378-
sum(extract_weight_data(weights)[idx_other] * covariate[idx_other]) / sum(extract_weight_data(weights)[idx_other])
379+
sum(extract_weight_data(weights)[idx_other] * covariate[idx_other]) /
380+
sum(extract_weight_data(weights)[idx_other])
379381
}
380382
return(abs(p_other - p_ref))
381383
}
@@ -384,7 +386,8 @@ bal_ks <- function(
384386
# Extract and weight
385387
x_ref <- covariate[idx_ref]
386388
x_other <- covariate[idx_other]
387-
w_ref <- if (is.null(weights)) rep(1, length(x_ref)) else extract_weight_data(weights)[idx_ref]
389+
w_ref <- if (is.null(weights)) rep(1, length(x_ref)) else
390+
extract_weight_data(weights)[idx_ref]
388391
w_other <- if (is.null(weights)) rep(1, length(x_other)) else
389392
extract_weight_data(weights)[idx_other]
390393
w_ref <- w_ref / sum(w_ref)

R/compute_qq.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ compute_method_quantiles <- function(
259259
weighted_quantile <- function(values, quantiles, .wts) {
260260
# Extract numeric data from weights (handles both numeric and psw objects)
261261
.wts <- extract_weight_data(.wts)
262-
262+
263263
# Remove NA values if present
264264
na_idx <- is.na(values) | is.na(.wts)
265265
if (any(na_idx)) {

R/ess.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#'
77
#' @param wts A numeric vector of weights (e.g., from survey or
88
#' inverse-probability weighting).
9+
#' @param na.rm Logical. Should missing values be removed? Default is FALSE.
910
#'
1011
#' @return A single numeric value representing the effective sample size.
1112
#'
@@ -42,8 +43,8 @@
4243
#' ess(wts2)
4344
#'
4445
#' @export
45-
ess <- function(wts) {
46+
ess <- function(wts, na.rm = FALSE) {
4647
# Extract numeric data from psw weights if present
4748
wts <- extract_weight_data(wts)
48-
sum(wts)^2 / sum(wts^2)
49+
sum(wts, na.rm = na.rm)^2 / sum(wts^2, na.rm = na.rm)
4950
}

R/geom_calibration.R

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,13 @@ check_treatment_level <- function(group_var, treatment_level) {
152152
create_treatment_indicator(group_var, treatment_level)
153153
}
154154

155-
check_columns <- function(data, fitted_name, group_name, treatment_level, call = rlang::caller_env()) {
155+
check_columns <- function(
156+
data,
157+
fitted_name,
158+
group_name,
159+
treatment_level,
160+
call = rlang::caller_env()
161+
) {
156162
if (is.null(treatment_level)) {
157163
if (!fitted_name %in% names(data)) {
158164
abort(
@@ -591,7 +597,13 @@ compute_calibration_for_group <- function(
591597

592598
# Compute calibration based on method
593599
calibration_result <- if (method == "breaks") {
594-
compute_calibration_breaks_imp(df, bins, binning_method, conf_level, call = call)
600+
compute_calibration_breaks_imp(
601+
df,
602+
bins,
603+
binning_method,
604+
conf_level,
605+
call = call
606+
)
595607
} else if (method == "logistic") {
596608
compute_calibration_logistic_imp(df, smooth, conf_level, k = k, call = call)
597609
} else if (method == "windowed") {

R/geom_mirrored_density.R

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,27 +62,29 @@ StatMirrorDensity <- ggplot2::ggproto(
6262
.n_groups = length(unique(group)),
6363
.groups = "drop"
6464
)
65-
65+
6666
# Check for panels with more than 2 groups
6767
if (any(panel_groups$.n_groups > 2)) {
6868
abort(
6969
"Groups of three or greater not supported in `geom_mirror_density()`",
7070
error_class = "halfmoon_group_error"
7171
)
7272
}
73-
73+
7474
# Join back to get panel group info for each row
7575
data <- dplyr::left_join(data, panel_groups, by = "PANEL")
76-
76+
7777
# Mark which groups should be mirrored (first group in each panel)
78-
data$.should_mirror <- purrr::map2_lgl(data$group, data$.panel_groups,
78+
data$.should_mirror <- purrr::map2_lgl(
79+
data$group,
80+
data$.panel_groups,
7981
~ length(.y) == 2 && .x == .y[1]
8082
)
81-
83+
8284
# Clean up temporary columns
8385
data$.panel_groups <- NULL
8486
data$.n_groups <- NULL
85-
87+
8688
data
8789
},
8890
compute_group = function(
@@ -109,15 +111,15 @@ StatMirrorDensity <- ggplot2::ggproto(
109111
error_class = "halfmoon_aes_error"
110112
)
111113
}
112-
114+
113115
# Store mirroring flag
114116
should_mirror <- unique(data$.should_mirror)
115-
117+
116118
# Extract numeric data from psw weights if present
117119
if ("weight" %in% names(data)) {
118120
data$weight <- extract_weight_data(data$weight)
119121
}
120-
122+
121123
data <- ggplot2::StatDensity$compute_group(
122124
data = data,
123125
scales = scales,
@@ -130,15 +132,15 @@ StatMirrorDensity <- ggplot2::ggproto(
130132
bounds = bounds,
131133
flipped_aes = flipped_aes
132134
)
133-
135+
134136
# Apply mirroring if needed
135137
if (length(should_mirror) == 1 && should_mirror) {
136138
data$density <- -data$density
137139
data$count <- -data$count
138140
data$scaled <- -data$scaled
139141
data$ndensity <- -data$ndensity
140142
}
141-
143+
142144
data
143145
}
144146
)

0 commit comments

Comments
 (0)