Skip to content

Commit 7e0193f

Browse files
Merge pull request #53 from r-causal/bugs
Various fixes + ggplot2 4.0.0 compatibility
2 parents 919ac62 + a46ab9d commit 7e0193f

File tree

83 files changed

+4577
-3046
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+4577
-3046
lines changed

R/bal_model_auc.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#'
1717
#' @param .data A data frame containing the variables.
1818
#' @param .exposure The treatment/outcome variable (unquoted).
19-
#' @param .estimate The propensity score or fitted values (unquoted).
19+
#' @param .fitted The propensity score or fitted values (unquoted).
2020
#' @param .weights Optional single weight variable (unquoted). If NULL, computes
2121
#' unweighted AUC.
2222
#' @inheritParams balance_params
@@ -40,15 +40,15 @@
4040
bal_model_auc <- function(
4141
.data,
4242
.exposure,
43-
.estimate,
43+
.fitted,
4444
.weights = NULL,
4545
na.rm = TRUE,
4646
.focal_level = NULL
4747
) {
4848
validate_data_frame(.data, call = rlang::caller_env())
4949

5050
exposure_quo <- rlang::enquo(.exposure)
51-
estimate_quo <- rlang::enquo(.estimate)
51+
estimate_quo <- rlang::enquo(.fitted)
5252
wts_quo <- rlang::enquo(.weights)
5353

5454
# Extract column names
@@ -64,7 +64,7 @@ bal_model_auc <- function(
6464
}
6565
if (length(estimate_name) != 1) {
6666
abort(
67-
"{.arg .estimate} must select exactly one variable",
67+
"{.arg .fitted} must select exactly one variable",
6868
error_class = "halfmoon_arg_error",
6969
call = rlang::current_env()
7070
)

R/bal_model_roc_curve.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#'
1818
#' @param .data A data frame containing the variables.
1919
#' @param .exposure The treatment/outcome variable (unquoted).
20-
#' @param .estimate The propensity score or fitted values (unquoted).
20+
#' @param .fitted The propensity score or fitted values (unquoted).
2121
#' @param .weights Optional single weight variable (unquoted). If NULL, computes
2222
#' unweighted ROC curve.
2323
#' @inheritParams balance_params
@@ -43,15 +43,15 @@
4343
bal_model_roc_curve <- function(
4444
.data,
4545
.exposure,
46-
.estimate,
46+
.fitted,
4747
.weights = NULL,
4848
na.rm = TRUE,
4949
.focal_level = NULL
5050
) {
5151
validate_data_frame(.data, call = rlang::caller_env())
5252

5353
exposure_quo <- rlang::enquo(.exposure)
54-
estimate_quo <- rlang::enquo(.estimate)
54+
estimate_quo <- rlang::enquo(.fitted)
5555
wts_quo <- rlang::enquo(.weights)
5656

5757
# Extract column names
@@ -67,7 +67,7 @@ bal_model_roc_curve <- function(
6767
}
6868
if (length(estimate_name) != 1) {
6969
abort(
70-
"{.arg .estimate} must select exactly one variable",
70+
"{.arg .fitted} must select exactly one variable",
7171
error_class = "halfmoon_arg_error",
7272
call = rlang::current_env()
7373
)

R/check_balance.R

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -122,24 +122,40 @@ check_balance <- function(
122122
# Extract just the variables we're working with
123123
vars_data <- dplyr::select(.data, dplyr::all_of(var_names))
124124

125+
# Track variable origins for interaction filtering
126+
dummy_var_mapping <- list()
127+
125128
# Create dummy variables if requested
126129
if (make_dummy_vars) {
127-
vars_data <- create_dummy_variables(vars_data, binary_as_single = TRUE)
130+
dummy_result <- create_dummy_variables(
131+
vars_data,
132+
binary_as_single = TRUE,
133+
return_mapping = TRUE
134+
)
135+
vars_data <- dummy_result$data
136+
dummy_var_mapping <- dummy_result$mapping
128137
}
129138

130139
# Add squared terms if requested
131140
if (squares) {
132141
numeric_vars <- purrr::map_lgl(vars_data, is.numeric)
133142
if (any(numeric_vars)) {
134143
numeric_data <- dplyr::select(vars_data, dplyr::where(is.numeric))
135-
squared_data <- dplyr::mutate(
144+
# Only square non-binary numeric variables
145+
non_binary_numeric <- dplyr::select(
136146
numeric_data,
137-
dplyr::across(everything(), \(x) x^2, .names = "{.col}_squared")
138-
)
139-
vars_data <- dplyr::bind_cols(
140-
vars_data,
141-
dplyr::select(squared_data, dplyr::ends_with("_squared"))
147+
dplyr::where(\(x) !is_binary(x))
142148
)
149+
if (ncol(non_binary_numeric) > 0) {
150+
squared_data <- dplyr::mutate(
151+
non_binary_numeric,
152+
dplyr::across(everything(), \(x) x^2, .names = "{.col}_squared")
153+
)
154+
vars_data <- dplyr::bind_cols(
155+
vars_data,
156+
dplyr::select(squared_data, dplyr::ends_with("_squared"))
157+
)
158+
}
143159
}
144160
}
145161

@@ -148,14 +164,19 @@ check_balance <- function(
148164
numeric_vars <- purrr::map_lgl(vars_data, is.numeric)
149165
if (any(numeric_vars)) {
150166
numeric_data <- dplyr::select(vars_data, dplyr::where(is.numeric))
151-
# Only cube original variables, not squared ones
167+
# Only cube original non-binary variables, not squared ones
152168
original_numeric <- dplyr::select(
153169
numeric_data,
154170
-dplyr::ends_with("_squared")
155171
)
156-
if (ncol(original_numeric) > 0) {
172+
# Filter out binary variables
173+
non_binary_original <- dplyr::select(
174+
original_numeric,
175+
dplyr::where(\(x) !is_binary(x))
176+
)
177+
if (ncol(non_binary_original) > 0) {
157178
cubed_data <- dplyr::mutate(
158-
original_numeric,
179+
non_binary_original,
159180
dplyr::across(everything(), \(x) x^3, .names = "{.col}_cubed")
160181
)
161182
vars_data <- dplyr::bind_cols(
@@ -206,13 +227,32 @@ check_balance <- function(
206227
}
207228

208229
# Prepare variables for interactions
209-
interaction_vars <- purrr::imap(
230+
interaction_vars_list <- purrr::imap(
210231
original_numeric,
211232
prepare_interaction_variable,
212233
binary_categorical_names = binary_categorical_names,
213234
original_vars_data = original_vars_data
214-
) |>
215-
purrr::flatten()
235+
)
236+
237+
# Extract the variables and update mapping for expanded binaries
238+
interaction_vars <- purrr::flatten(interaction_vars_list)
239+
240+
# Update mapping for any expanded binary categoricals
241+
for (i in seq_along(interaction_vars_list)) {
242+
var_result <- interaction_vars_list[[i]]
243+
var_name <- names(original_numeric)[i]
244+
245+
# Check if this variable was expanded (binary categorical)
246+
if (var_name %in% binary_categorical_names) {
247+
# The result is already flattened by prepare_interaction_variable
248+
# Get the names of the expanded dummies
249+
expanded_names <- names(var_result)
250+
for (expanded_name in expanded_names) {
251+
# Track that this expanded dummy came from the original variable
252+
dummy_var_mapping[[expanded_name]] <- var_name
253+
}
254+
}
255+
}
216256

217257
# Now create interactions between all pairs
218258
var_combinations <- utils::combn(
@@ -224,7 +264,7 @@ check_balance <- function(
224264
# Filter out same-variable dummy interactions (e.g., sex0 x sex1)
225265
valid_combinations <- purrr::keep(
226266
var_combinations,
227-
is_valid_interaction_combo
267+
\(combo) is_valid_interaction_combo(combo, dummy_var_mapping)
228268
)
229269

230270
# Create interaction terms using functional programming
@@ -234,9 +274,12 @@ check_balance <- function(
234274
interaction_vars = interaction_vars
235275
)
236276

237-
# Flatten the list and add to vars_data
277+
# Flatten the list and convert to data frame
238278
interaction_terms <- purrr::flatten(interaction_terms)
239-
vars_data <- c(vars_data, interaction_terms)
279+
if (length(interaction_terms) > 0) {
280+
interaction_df <- dplyr::as_tibble(interaction_terms)
281+
vars_data <- dplyr::bind_cols(vars_data, interaction_df)
282+
}
240283
}
241284
}
242285
}
@@ -568,10 +611,21 @@ prepare_interaction_variable <- function(
568611
}
569612

570613
# Check if an interaction combination is valid (not between same variable dummies)
571-
is_valid_interaction_combo <- function(combo) {
614+
is_valid_interaction_combo <- function(combo, variable_mapping = NULL) {
572615
var1 <- combo[1]
573616
var2 <- combo[2]
574617

618+
# If we have a mapping, use it to determine if variables come from same source
619+
if (!is.null(variable_mapping)) {
620+
# Get the original variable for each dummy (or the variable itself if not a dummy)
621+
origin1 <- variable_mapping[[var1]] %||% var1
622+
origin2 <- variable_mapping[[var2]] %||% var2
623+
624+
# Only keep interactions between different original variables
625+
return(origin1 != origin2)
626+
}
627+
628+
# Fallback to the old regex approach if no mapping provided
575629
# Extract base variable names (before dummy suffixes)
576630
base1 <- sub("^([^0-9]+).*", "\\1", var1)
577631
base2 <- sub("^([^0-9]+).*", "\\1", var2)

R/check_ess.R

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#' number of observations, indicating that a few observations carry
1313
#' disproportionately large weights.
1414
#'
15-
#' When `.group` is provided, ESS is calculated separately for each group level:
15+
#' When `.exposure` is provided, ESS is calculated separately for each exposure level:
1616
#' - For binary/categorical exposures: ESS is computed within each treatment level
1717
#' - For continuous exposures: The variable is divided into quantiles (using
1818
#' `dplyr::ntile()`) and ESS is computed within each quantile
@@ -21,17 +21,17 @@
2121
#' further analysis.
2222
#'
2323
#' @inheritParams check_params
24-
#' @param .group Optional grouping variable. When provided, ESS is calculated
25-
#' separately for each group level. For continuous variables, groups are
24+
#' @param .exposure Optional exposure variable. When provided, ESS is calculated
25+
#' separately for each exposure level. For continuous variables, groups are
2626
#' created using quantiles.
27-
#' @param n_tiles For continuous `.group` variables, the number of quantile
27+
#' @param n_tiles For continuous `.exposure` variables, the number of quantile
2828
#' groups to create. Default is 4 (quartiles).
2929
#' @param tile_labels Optional character vector of labels for the quantile groups
30-
#' when `.group` is continuous. If NULL, uses "Q1", "Q2", etc.
30+
#' when `.exposure` is continuous. If NULL, uses "Q1", "Q2", etc.
3131
#'
3232
#' @return A tibble with columns:
3333
#' \item{method}{Character. The weighting method ("observed" or weight variable name).}
34-
#' \item{group}{Character. The group level (if `.group` is provided).}
34+
#' \item{group}{Character. The exposure level (if `.exposure` is provided).}
3535
#' \item{n}{Integer. The number of observations in the group.}
3636
#' \item{ess}{Numeric. The effective sample size.}
3737
#' \item{ess_pct}{Numeric. ESS as a percentage of the actual sample size.}
@@ -44,41 +44,41 @@
4444
#' check_ess(nhefs_weights, .weights = c(w_ate, w_att, w_atm))
4545
#'
4646
#' # ESS by treatment group (binary exposure)
47-
#' check_ess(nhefs_weights, .weights = c(w_ate, w_att), .group = qsmk)
47+
#' check_ess(nhefs_weights, .weights = c(w_ate, w_att), .exposure = qsmk)
4848
#'
4949
#' # ESS by treatment group (categorical exposure)
50-
#' check_ess(nhefs_weights, .weights = w_cat_ate, .group = alcoholfreq_cat)
50+
#' check_ess(nhefs_weights, .weights = w_cat_ate, .exposure = alcoholfreq_cat)
5151
#'
5252
#' # ESS by quartiles of a continuous variable
53-
#' check_ess(nhefs_weights, .weights = w_ate, .group = age, n_tiles = 4)
53+
#' check_ess(nhefs_weights, .weights = w_ate, .exposure = age, n_tiles = 4)
5454
#'
5555
#' # Custom labels for continuous groups
56-
#' check_ess(nhefs_weights, .weights = w_ate, .group = age,
56+
#' check_ess(nhefs_weights, .weights = w_ate, .exposure = age,
5757
#' n_tiles = 3, tile_labels = c("Young", "Middle", "Older"))
5858
#'
5959
#' # Without unweighted comparison
60-
#' check_ess(nhefs_weights, .weights = w_ate, .group = qsmk,
60+
#' check_ess(nhefs_weights, .weights = w_ate, .exposure = qsmk,
6161
#' include_observed = FALSE)
6262
#'
6363
#' @export
6464
check_ess <- function(
6565
.data,
6666
.weights = NULL,
67-
.group = NULL,
67+
.exposure = NULL,
6868
include_observed = TRUE,
6969
n_tiles = 4,
7070
tile_labels = NULL
7171
) {
7272
# Validate inputs
7373
validate_data_frame(.data)
7474

75-
# Handle group variable
76-
group_quo <- rlang::enquo(.group)
75+
# Handle exposure variable
76+
group_quo <- rlang::enquo(.exposure)
7777
has_group <- !rlang::quo_is_null(group_quo)
7878

7979
if (has_group) {
80-
group_name <- get_column_name(group_quo, ".group")
81-
validate_column_exists(.data, group_name, ".group")
80+
group_name <- get_column_name(group_quo, ".exposure")
81+
validate_column_exists(.data, group_name, ".exposure")
8282
group_var <- .data[[group_name]]
8383

8484
# Check if continuous (numeric and more than 10 unique values)

0 commit comments

Comments
 (0)