@@ -122,24 +122,40 @@ check_balance <- function(
122122 # Extract just the variables we're working with
123123 vars_data <- dplyr :: select(.data , dplyr :: all_of(var_names ))
124124
125+ # Track variable origins for interaction filtering
126+ dummy_var_mapping <- list ()
127+
125128 # Create dummy variables if requested
126129 if (make_dummy_vars ) {
127- vars_data <- create_dummy_variables(vars_data , binary_as_single = TRUE )
130+ dummy_result <- create_dummy_variables(
131+ vars_data ,
132+ binary_as_single = TRUE ,
133+ return_mapping = TRUE
134+ )
135+ vars_data <- dummy_result $ data
136+ dummy_var_mapping <- dummy_result $ mapping
128137 }
129138
130139 # Add squared terms if requested
131140 if (squares ) {
132141 numeric_vars <- purrr :: map_lgl(vars_data , is.numeric )
133142 if (any(numeric_vars )) {
134143 numeric_data <- dplyr :: select(vars_data , dplyr :: where(is.numeric ))
135- squared_data <- dplyr :: mutate(
144+ # Only square non-binary numeric variables
145+ non_binary_numeric <- dplyr :: select(
136146 numeric_data ,
137- dplyr :: across(everything(), \(x ) x ^ 2 , .names = " {.col}_squared" )
138- )
139- vars_data <- dplyr :: bind_cols(
140- vars_data ,
141- dplyr :: select(squared_data , dplyr :: ends_with(" _squared" ))
147+ dplyr :: where(\(x ) ! is_binary(x ))
142148 )
149+ if (ncol(non_binary_numeric ) > 0 ) {
150+ squared_data <- dplyr :: mutate(
151+ non_binary_numeric ,
152+ dplyr :: across(everything(), \(x ) x ^ 2 , .names = " {.col}_squared" )
153+ )
154+ vars_data <- dplyr :: bind_cols(
155+ vars_data ,
156+ dplyr :: select(squared_data , dplyr :: ends_with(" _squared" ))
157+ )
158+ }
143159 }
144160 }
145161
@@ -148,14 +164,19 @@ check_balance <- function(
148164 numeric_vars <- purrr :: map_lgl(vars_data , is.numeric )
149165 if (any(numeric_vars )) {
150166 numeric_data <- dplyr :: select(vars_data , dplyr :: where(is.numeric ))
151- # Only cube original variables, not squared ones
167+ # Only cube original non-binary variables, not squared ones
152168 original_numeric <- dplyr :: select(
153169 numeric_data ,
154170 - dplyr :: ends_with(" _squared" )
155171 )
156- if (ncol(original_numeric ) > 0 ) {
172+ # Filter out binary variables
173+ non_binary_original <- dplyr :: select(
174+ original_numeric ,
175+ dplyr :: where(\(x ) ! is_binary(x ))
176+ )
177+ if (ncol(non_binary_original ) > 0 ) {
157178 cubed_data <- dplyr :: mutate(
158- original_numeric ,
179+ non_binary_original ,
159180 dplyr :: across(everything(), \(x ) x ^ 3 , .names = " {.col}_cubed" )
160181 )
161182 vars_data <- dplyr :: bind_cols(
@@ -206,13 +227,32 @@ check_balance <- function(
206227 }
207228
208229 # Prepare variables for interactions
209- interaction_vars <- purrr :: imap(
230+ interaction_vars_list <- purrr :: imap(
210231 original_numeric ,
211232 prepare_interaction_variable ,
212233 binary_categorical_names = binary_categorical_names ,
213234 original_vars_data = original_vars_data
214- ) | >
215- purrr :: flatten()
235+ )
236+
237+ # Extract the variables and update mapping for expanded binaries
238+ interaction_vars <- purrr :: flatten(interaction_vars_list )
239+
240+ # Update mapping for any expanded binary categoricals
241+ for (i in seq_along(interaction_vars_list )) {
242+ var_result <- interaction_vars_list [[i ]]
243+ var_name <- names(original_numeric )[i ]
244+
245+ # Check if this variable was expanded (binary categorical)
246+ if (var_name %in% binary_categorical_names ) {
247+ # The result is already flattened by prepare_interaction_variable
248+ # Get the names of the expanded dummies
249+ expanded_names <- names(var_result )
250+ for (expanded_name in expanded_names ) {
251+ # Track that this expanded dummy came from the original variable
252+ dummy_var_mapping [[expanded_name ]] <- var_name
253+ }
254+ }
255+ }
216256
217257 # Now create interactions between all pairs
218258 var_combinations <- utils :: combn(
@@ -224,7 +264,7 @@ check_balance <- function(
224264 # Filter out same-variable dummy interactions (e.g., sex0 x sex1)
225265 valid_combinations <- purrr :: keep(
226266 var_combinations ,
227- is_valid_interaction_combo
267+ \( combo ) is_valid_interaction_combo( combo , dummy_var_mapping )
228268 )
229269
230270 # Create interaction terms using functional programming
@@ -234,9 +274,12 @@ check_balance <- function(
234274 interaction_vars = interaction_vars
235275 )
236276
237- # Flatten the list and add to vars_data
277+ # Flatten the list and convert to data frame
238278 interaction_terms <- purrr :: flatten(interaction_terms )
239- vars_data <- c(vars_data , interaction_terms )
279+ if (length(interaction_terms ) > 0 ) {
280+ interaction_df <- dplyr :: as_tibble(interaction_terms )
281+ vars_data <- dplyr :: bind_cols(vars_data , interaction_df )
282+ }
240283 }
241284 }
242285 }
@@ -568,10 +611,21 @@ prepare_interaction_variable <- function(
568611}
569612
570613# Check if an interaction combination is valid (not between same variable dummies)
571- is_valid_interaction_combo <- function (combo ) {
614+ is_valid_interaction_combo <- function (combo , variable_mapping = NULL ) {
572615 var1 <- combo [1 ]
573616 var2 <- combo [2 ]
574617
618+ # If we have a mapping, use it to determine if variables come from same source
619+ if (! is.null(variable_mapping )) {
620+ # Get the original variable for each dummy (or the variable itself if not a dummy)
621+ origin1 <- variable_mapping [[var1 ]] %|| % var1
622+ origin2 <- variable_mapping [[var2 ]] %|| % var2
623+
624+ # Only keep interactions between different original variables
625+ return (origin1 != origin2 )
626+ }
627+
628+ # Fallback to the old regex approach if no mapping provided
575629 # Extract base variable names (before dummy suffixes)
576630 base1 <- sub(" ^([^0-9]+).*" , " \\ 1" , var1 )
577631 base2 <- sub(" ^([^0-9]+).*" , " \\ 1" , var2 )
0 commit comments