Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions R/check_auc.R
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ roc_curve <- function(
next
}

weights <- extract_weight_data(weights)

# Handle zero and negative weights
if (any(weights <= 0, na.rm = TRUE)) {
n_zero_neg <- sum(weights <= 0, na.rm = TRUE)
Expand Down Expand Up @@ -297,6 +299,9 @@ compute_roc_curve_imp <- function(
# Set default weights
if (is.null(weights)) {
weights <- rep(1, length(truth))
} else {
# Extract numeric data from psw weights if present
weights <- extract_weight_data(weights)
}

# Convert to binary (1 = event, 0 = non-event)
Expand Down
35 changes: 19 additions & 16 deletions R/compute_balance.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ bal_smd <- function(
res <- smd::smd(
x = covariate,
g = group,
w = weights,
w = extract_weight_data(weights),
gref = gref_index,
na.rm = na.rm
)
Expand Down Expand Up @@ -195,7 +195,7 @@ bal_vr <- function(
if (
check_na_return(
covariate[c(idx_ref, idx_other)],
weights[c(idx_ref, idx_other)] %||% NULL,
extract_weight_data(weights)[c(idx_ref, idx_other)] %||% NULL,
na.rm = FALSE
)
) {
Expand All @@ -214,7 +214,7 @@ bal_vr <- function(
p <- mean(covariate[idx_ref])
p * (1 - p)
} else {
wr <- weights[idx_ref]
wr <- extract_weight_data(weights)[idx_ref]
xr <- covariate[idx_ref]
p <- sum(wr * xr) / sum(wr)
p * (1 - p)
Expand All @@ -223,7 +223,7 @@ bal_vr <- function(
p <- mean(covariate[idx_other])
p * (1 - p)
} else {
wo <- weights[idx_other]
wo <- extract_weight_data(weights)[idx_other]
xo <- covariate[idx_other]
p <- sum(wo * xo) / sum(wo)
p * (1 - p)
Expand All @@ -233,7 +233,7 @@ bal_vr <- function(
var_ref <- if (is.null(weights)) {
stats::var(covariate[idx_ref])
} else {
wr <- weights[idx_ref]
wr <- extract_weight_data(weights)[idx_ref]
xr <- covariate[idx_ref]
mr <- sum(wr * xr) / sum(wr)
# Use Bessel's correction for weighted sample variance
Expand All @@ -244,7 +244,7 @@ bal_vr <- function(
var_other <- if (is.null(weights)) {
stats::var(covariate[idx_other])
} else {
wo <- weights[idx_other]
wo <- extract_weight_data(weights)[idx_other]
xo <- covariate[idx_other]
mo <- sum(wo * xo) / sum(wo)
# Use Bessel's correction for weighted sample variance
Expand Down Expand Up @@ -352,7 +352,7 @@ bal_ks <- function(
if (
check_na_return(
covariate[c(idx_ref, idx_other)],
weights[c(idx_ref, idx_other)] %||% NULL,
extract_weight_data(weights)[c(idx_ref, idx_other)] %||% NULL,
na.rm = FALSE
)
) {
Expand All @@ -370,12 +370,12 @@ bal_ks <- function(
p_ref <- if (is.null(weights)) {
mean(covariate[idx_ref])
} else {
sum(weights[idx_ref] * covariate[idx_ref]) / sum(weights[idx_ref])
sum(extract_weight_data(weights)[idx_ref] * covariate[idx_ref]) / sum(extract_weight_data(weights)[idx_ref])
}
p_other <- if (is.null(weights)) {
mean(covariate[idx_other])
} else {
sum(weights[idx_other] * covariate[idx_other]) / sum(weights[idx_other])
sum(extract_weight_data(weights)[idx_other] * covariate[idx_other]) / sum(extract_weight_data(weights)[idx_other])
}
return(abs(p_other - p_ref))
}
Expand All @@ -384,9 +384,9 @@ bal_ks <- function(
# Extract and weight
x_ref <- covariate[idx_ref]
x_other <- covariate[idx_other]
w_ref <- if (is.null(weights)) rep(1, length(x_ref)) else weights[idx_ref]
w_ref <- if (is.null(weights)) rep(1, length(x_ref)) else extract_weight_data(weights)[idx_ref]
w_other <- if (is.null(weights)) rep(1, length(x_other)) else
weights[idx_other]
extract_weight_data(weights)[idx_other]
w_ref <- w_ref / sum(w_ref)
w_other <- w_other / sum(w_other)
# Sort and CDF
Expand Down Expand Up @@ -458,8 +458,10 @@ bal_corr <- function(x, y, weights = NULL, na.rm = FALSE) {
}
x <- x[idx]
y <- y[idx]
if (!is.null(weights)) weights <- weights[idx]
if (!is.null(weights)) weights <- extract_weight_data(weights)[idx]
} else {
# Extract weight data if needed
if (!is.null(weights)) weights <- extract_weight_data(weights)
# Check for missing values
if (is.null(weights)) {
if (any(is.na(x) | is.na(y))) return(NA_real_)
Expand Down Expand Up @@ -643,7 +645,7 @@ bal_energy <- function(
complete_cases <- stats::complete.cases(covariates, group)
if (!is.null(weights)) {
complete_cases <- complete_cases & !is.na(weights)
weights <- weights[complete_cases]
weights <- extract_weight_data(weights)[complete_cases]
}
covariates <- covariates[complete_cases, , drop = FALSE]
group <- group[complete_cases]
Expand Down Expand Up @@ -697,11 +699,12 @@ bal_energy <- function(
}

# Normalize weights by group
weights_normalized <- weights
weights_numeric <- extract_weight_data(weights)
weights_normalized <- weights_numeric
for (g in unique_groups) {
group_mask <- group == g
if (any(group_mask)) {
group_weights <- weights[group_mask]
group_weights <- weights_numeric[group_mask]
weights_normalized[group_mask] <- group_weights / mean(group_weights)
}
}
Expand Down Expand Up @@ -980,7 +983,7 @@ bal_energy_att_atc <- function(

# Identify focal group observations
focal_mask <- group == treatment_level
focal_weights <- weights[focal_mask]
focal_weights <- extract_weight_data(weights)[focal_mask]
focal_weights_norm <- focal_weights / sum(focal_weights)

# Compute P matrix
Expand Down
3 changes: 3 additions & 0 deletions R/compute_qq.R
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,9 @@ compute_method_quantiles <- function(
#'
#' @export
weighted_quantile <- function(values, quantiles, .wts) {
# Extract numeric data from weights (handles both numeric and psw objects)
.wts <- extract_weight_data(.wts)

# Remove NA values if present
na_idx <- is.na(values) | is.na(.wts)
if (any(na_idx)) {
Expand Down
2 changes: 2 additions & 0 deletions R/ess.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,7 @@
#'
#' @export
ess <- function(wts) {
# Extract numeric data from psw weights if present
wts <- extract_weight_data(wts)
sum(wts)^2 / sum(wts^2)
}
2 changes: 2 additions & 0 deletions R/geom_ecdf.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ StatWeightedECDF <- ggplot2::ggproto(
ggplot2::StatEcdf,
compute_group = function(data, scales, n = NULL, pad = NULL) {
if ("weights" %in% names(data)) {
# Extract numeric data from psw weights if present
data$weights <- extract_weight_data(data$weights)
data <- data[order(data$x), ]
# ggplot2 3.4.1 changed this stat's name from `y` to `ecdf`
if (packageVersion("ggplot2") >= "3.4.1") {
Expand Down
5 changes: 5 additions & 0 deletions R/geom_mirrored_density.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ StatMirrorDensity <- ggplot2::ggproto(
# Store mirroring flag
should_mirror <- unique(data$.should_mirror)

# Extract numeric data from psw weights if present
if ("weight" %in% names(data)) {
data$weight <- extract_weight_data(data$weight)
}

data <- ggplot2::StatDensity$compute_group(
data = data,
scales = scales,
Expand Down
5 changes: 5 additions & 0 deletions R/geom_mirrored_histogram.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ StatMirrorCount <- ggplot2::ggproto(
# Store mirroring flag
should_mirror <- unique(data$.should_mirror)

# Extract numeric data from psw weights if present
if ("weight" %in% names(data)) {
data$weight <- extract_weight_data(data$weight)
}

data <- ggplot2::StatBin$compute_group(
data = data,
scales = scales,
Expand Down
11 changes: 8 additions & 3 deletions R/geom_qq2.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,13 @@
#' geom_abline(intercept = 0, slope = 1, linetype = "dashed")
#'
#' # Compare multiple weights using long format
#' # TODO: Remove vec_data() workaround once propensity implements vctrs methods
#' # Extract numeric data from psw objects first
#' nhefs_for_pivot <- nhefs_weights
#' nhefs_for_pivot$w_ate <- vctrs::vec_data(nhefs_weights$w_ate)
#' nhefs_for_pivot$w_att <- vctrs::vec_data(nhefs_weights$w_att)
#' long_data <- tidyr::pivot_longer(
#' nhefs_weights,
#' nhefs_for_pivot,
#' cols = c(w_ate, w_att),
#' names_to = "weight_type",
#' values_to = "weight"
Expand Down Expand Up @@ -184,7 +189,7 @@ process_aesthetic_group <- function(

# Add weight if present
if (!is.null(combined_data$weight) && all(!is.na(combined_data$weight))) {
temp_data$.wts <- combined_data$weight
temp_data$.wts <- extract_weight_data(combined_data$weight)
wts_arg <- ".wts"
} else {
wts_arg <- NULL
Expand Down Expand Up @@ -308,7 +313,7 @@ StatQq2 <- ggplot2::ggproto(

# Add weight if present
if (!is.null(data$weight) && all(!is.na(data$weight))) {
temp_data$.wts <- data$weight
temp_data$.wts <- extract_weight_data(data$weight)
wts_arg <- ".wts"
} else {
wts_arg <- NULL
Expand Down
7 changes: 6 additions & 1 deletion R/geom_roc.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,13 @@
#' geom_abline(intercept = 0, slope = 1, linetype = "dashed")
#'
#' # With grouping by weight
#' # TODO: Remove vec_data() workaround once propensity implements vctrs methods
#' # Extract numeric data from psw objects first
#' nhefs_for_pivot <- nhefs_weights
#' nhefs_for_pivot$w_ate <- vctrs::vec_data(nhefs_weights$w_ate)
#' nhefs_for_pivot$w_att <- vctrs::vec_data(nhefs_weights$w_att)
#' long_data <- tidyr::pivot_longer(
#' nhefs_weights,
#' nhefs_for_pivot,
#' cols = c(w_ate, w_att),
#' names_to = "weight_type",
#' values_to = "weight"
Expand Down
5 changes: 5 additions & 0 deletions R/plot_mirror_distributions.R
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ plot_mirror_distributions <- function(
wts_cols <- tidyselect::eval_select(wts_quo, .data)
wts_names <- names(wts_cols)

# Convert psw weight columns to numeric for compatibility with pivot_longer
for (wts_name in wts_names) {
.data[[wts_name]] <- extract_weight_data(.data[[wts_name]])
}

if (include_unweighted) {
.data$.observed <- 1
wts_names <- c(".observed", wts_names)
Expand Down
5 changes: 5 additions & 0 deletions R/plot_qq.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ plot_qq <- function(
wts_names <- names(wts_cols)

# Create long format data
# Convert psw weight columns to numeric for compatibility with pivot_longer
for (wts_name in wts_names) {
.data[[wts_name]] <- extract_weight_data(.data[[wts_name]])
}

if (include_observed) {
# Add observed as a weight column with value 1
.data$.observed <- 1
Expand Down
7 changes: 5 additions & 2 deletions R/utils-validation.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ validate_numeric <- function(x, arg_name = deparse(substitute(x)), call = rlang:
validate_weights <- function(weights, n, arg_name = "weights", call = rlang::caller_env()) {
if (is.null(weights)) return(invisible(weights))

if (!is.numeric(weights)) {
# Accept both numeric vectors and psw objects from propensity package
is_valid_weights <- is.numeric(weights) || inherits(weights, "psw")

if (!is_valid_weights) {
abort(
"{.arg {arg_name}} must be numeric or {.code NULL}",
"{.arg {arg_name}} must be numeric, a psw object, or {.code NULL}",
error_class = "halfmoon_type_error",
call = call
)
Expand Down
11 changes: 11 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,14 @@ create_group_signature <- function(group_data, aes_cols) {
"no_aes"
}
}

# Extract numeric data from weights (handles both numeric and psw objects)
extract_weight_data <- function(weights) {
if (is.null(weights)) {
return(NULL)
} else {
# Use vctrs::vec_data to extract underlying numeric data
# This works for both numeric vectors and psw objects
vctrs::vec_data(weights)
}
}
46 changes: 21 additions & 25 deletions data-raw/nhefs_weights.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,56 +93,52 @@ cat_ps <- as.data.frame(predict(
# Calculate weights and join back to data
nhefs_with_cat_weights <- nhefs_for_model %>%
mutate(
# Convert psw objects to numeric using as.numeric
w_cat_ate = as.numeric(wt_ate(cat_ps, alcoholfreq_cat)),
w_cat_att_none = as.numeric(wt_att(
# Keep psw objects from propensity package
w_cat_ate = wt_ate(cat_ps, alcoholfreq_cat),
w_cat_att_none = wt_att(
cat_ps,
alcoholfreq_cat,
focal = "none"
)),
w_cat_att_lt12 = as.numeric(wt_att(
),
w_cat_att_lt12 = wt_att(
cat_ps,
alcoholfreq_cat,
focal = "lt_12_per_year"
)),
w_cat_att_1_4mo = as.numeric(wt_att(
),
w_cat_att_1_4mo = wt_att(
cat_ps,
alcoholfreq_cat,
focal = "1_4_per_month"
)),
w_cat_att_2_3wk = as.numeric(wt_att(
),
w_cat_att_2_3wk = wt_att(
cat_ps,
alcoholfreq_cat,
focal = "2_3_per_week"
)),
w_cat_att_daily = as.numeric(wt_att(
),
w_cat_att_daily = wt_att(
cat_ps,
alcoholfreq_cat,
focal = "daily"
)),
w_cat_atu_none = as.numeric(wt_atu(
),
w_cat_atu_none = wt_atu(
cat_ps,
alcoholfreq_cat,
focal = "none"
)),
w_cat_ato = as.numeric(wt_ato(cat_ps, alcoholfreq_cat)),
w_cat_atm = as.numeric(wt_atm(cat_ps, alcoholfreq_cat))
),
w_cat_ato = wt_ato(cat_ps, alcoholfreq_cat),
w_cat_atm = wt_atm(cat_ps, alcoholfreq_cat)
)

# Now combine with binary exposure weights, keeping all rows
nhefs_weights <- propensity_model %>%
augment(type.predict = "response", data = nhefs_with_cat) %>%
mutate(
wts = 1 / ifelse(qsmk == 0, 1 - .fitted, .fitted),
w_ate = (qsmk / .fitted) +
((1 - qsmk) / (1 - .fitted)),
w_att = ((.fitted * qsmk) / .fitted) +
((.fitted * (1 - qsmk)) / (1 - .fitted)),
w_atc = (((1 - .fitted) * qsmk) / .fitted) +
(((1 - .fitted) * (1 - qsmk)) / (1 - .fitted)),
w_atm = pmin(.fitted, 1 - .fitted) /
(qsmk * .fitted + (1 - qsmk) * (1 - .fitted)),
w_ato = (1 - .fitted) * qsmk + .fitted * (1 - qsmk)
w_ate = wt_ate(.fitted, qsmk),
w_att = wt_att(.fitted, qsmk),
w_atc = wt_atu(.fitted, qsmk),
w_atm = wt_atm(.fitted, qsmk),
w_ato = wt_ato(.fitted, qsmk)
) %>%
# Join categorical weights, they will be NA for unknown alcohol frequency
left_join(
Expand Down
Binary file modified data/nhefs_weights.rda
Binary file not shown.
7 changes: 6 additions & 1 deletion man/geom_qq2.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading