Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions R/check_auc.R
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ roc_curve <- function(
next
}

weights <- extract_weight_data(weights)

# Handle zero and negative weights
if (any(weights <= 0, na.rm = TRUE)) {
n_zero_neg <- sum(weights <= 0, na.rm = TRUE)
Expand Down Expand Up @@ -297,6 +299,9 @@ compute_roc_curve_imp <- function(
# Set default weights
if (is.null(weights)) {
weights <- rep(1, length(truth))
} else {
# Extract numeric data from psw weights if present
weights <- extract_weight_data(weights)
}

# Convert to binary (1 = event, 0 = non-event)
Expand Down
34 changes: 18 additions & 16 deletions R/compute_balance.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ bal_smd <- function(
res <- smd::smd(
x = covariate,
g = group,
w = weights,
w = extract_weight_data(weights),
gref = gref_index,
na.rm = na.rm
)
Expand Down Expand Up @@ -195,7 +195,7 @@ bal_vr <- function(
if (
check_na_return(
covariate[c(idx_ref, idx_other)],
weights[c(idx_ref, idx_other)] %||% NULL,
extract_weight_data(weights)[c(idx_ref, idx_other)] %||% NULL,
na.rm = FALSE
)
) {
Expand All @@ -214,7 +214,7 @@ bal_vr <- function(
p <- mean(covariate[idx_ref])
p * (1 - p)
} else {
wr <- weights[idx_ref]
wr <- extract_weight_data(weights)[idx_ref]
xr <- covariate[idx_ref]
p <- sum(wr * xr) / sum(wr)
p * (1 - p)
Expand All @@ -223,7 +223,7 @@ bal_vr <- function(
p <- mean(covariate[idx_other])
p * (1 - p)
} else {
wo <- weights[idx_other]
wo <- extract_weight_data(weights)[idx_other]
xo <- covariate[idx_other]
p <- sum(wo * xo) / sum(wo)
p * (1 - p)
Expand All @@ -233,7 +233,7 @@ bal_vr <- function(
var_ref <- if (is.null(weights)) {
stats::var(covariate[idx_ref])
} else {
wr <- weights[idx_ref]
wr <- extract_weight_data(weights)[idx_ref]
xr <- covariate[idx_ref]
mr <- sum(wr * xr) / sum(wr)
# Use Bessel's correction for weighted sample variance
Expand All @@ -244,7 +244,7 @@ bal_vr <- function(
var_other <- if (is.null(weights)) {
stats::var(covariate[idx_other])
} else {
wo <- weights[idx_other]
wo <- extract_weight_data(weights)[idx_other]
xo <- covariate[idx_other]
mo <- sum(wo * xo) / sum(wo)
# Use Bessel's correction for weighted sample variance
Expand Down Expand Up @@ -352,7 +352,7 @@ bal_ks <- function(
if (
check_na_return(
covariate[c(idx_ref, idx_other)],
weights[c(idx_ref, idx_other)] %||% NULL,
extract_weight_data(weights)[c(idx_ref, idx_other)] %||% NULL,
na.rm = FALSE
)
) {
Expand All @@ -370,12 +370,12 @@ bal_ks <- function(
p_ref <- if (is.null(weights)) {
mean(covariate[idx_ref])
} else {
sum(weights[idx_ref] * covariate[idx_ref]) / sum(weights[idx_ref])
sum(extract_weight_data(weights)[idx_ref] * covariate[idx_ref]) / sum(extract_weight_data(weights)[idx_ref])
}
p_other <- if (is.null(weights)) {
mean(covariate[idx_other])
} else {
sum(weights[idx_other] * covariate[idx_other]) / sum(weights[idx_other])
sum(extract_weight_data(weights)[idx_other] * covariate[idx_other]) / sum(extract_weight_data(weights)[idx_other])
}
return(abs(p_other - p_ref))
}
Expand All @@ -384,9 +384,9 @@ bal_ks <- function(
# Extract and weight
x_ref <- covariate[idx_ref]
x_other <- covariate[idx_other]
w_ref <- if (is.null(weights)) rep(1, length(x_ref)) else weights[idx_ref]
w_ref <- if (is.null(weights)) rep(1, length(x_ref)) else extract_weight_data(weights)[idx_ref]
w_other <- if (is.null(weights)) rep(1, length(x_other)) else
weights[idx_other]
extract_weight_data(weights)[idx_other]
w_ref <- w_ref / sum(w_ref)
w_other <- w_other / sum(w_other)
# Sort and CDF
Expand Down Expand Up @@ -458,8 +458,10 @@ bal_corr <- function(x, y, weights = NULL, na.rm = FALSE) {
}
x <- x[idx]
y <- y[idx]
if (!is.null(weights)) weights <- weights[idx]
if (!is.null(weights)) weights <- extract_weight_data(weights)[idx]
} else {
# Extract weight data if needed
if (!is.null(weights)) weights <- extract_weight_data(weights)
# Check for missing values
if (is.null(weights)) {
if (any(is.na(x) | is.na(y))) return(NA_real_)
Expand Down Expand Up @@ -643,7 +645,7 @@ bal_energy <- function(
complete_cases <- stats::complete.cases(covariates, group)
if (!is.null(weights)) {
complete_cases <- complete_cases & !is.na(weights)
weights <- weights[complete_cases]
weights <- extract_weight_data(weights)[complete_cases]
}
covariates <- covariates[complete_cases, , drop = FALSE]
group <- group[complete_cases]
Expand Down Expand Up @@ -697,11 +699,11 @@ bal_energy <- function(
}

# Normalize weights by group
weights_normalized <- weights
weights_normalized <- extract_weight_data(weights)
for (g in unique_groups) {
group_mask <- group == g
if (any(group_mask)) {
group_weights <- weights[group_mask]
group_weights <- weights_normalized[group_mask]
weights_normalized[group_mask] <- group_weights / mean(group_weights)
}
}
Expand Down Expand Up @@ -980,7 +982,7 @@ bal_energy_att_atc <- function(

# Identify focal group observations
focal_mask <- group == treatment_level
focal_weights <- weights[focal_mask]
focal_weights <- extract_weight_data(weights)[focal_mask]
focal_weights_norm <- focal_weights / sum(focal_weights)

# Compute P matrix
Expand Down
3 changes: 3 additions & 0 deletions R/compute_qq.R
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,9 @@ compute_method_quantiles <- function(
#'
#' @export
weighted_quantile <- function(values, quantiles, .wts) {
# Extract numeric data from weights (handles both numeric and psw objects)
.wts <- extract_weight_data(.wts)

# Remove NA values if present
na_idx <- is.na(values) | is.na(.wts)
if (any(na_idx)) {
Expand Down
2 changes: 2 additions & 0 deletions R/ess.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,7 @@
#'
#' @export
ess <- function(wts) {
# Extract numeric data from psw weights if present
wts <- extract_weight_data(wts)
sum(wts)^2 / sum(wts^2)
}
2 changes: 2 additions & 0 deletions R/geom_ecdf.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ StatWeightedECDF <- ggplot2::ggproto(
ggplot2::StatEcdf,
compute_group = function(data, scales, n = NULL, pad = NULL) {
if ("weights" %in% names(data)) {
# Extract numeric data from psw weights if present
data$weights <- extract_weight_data(data$weights)
data <- data[order(data$x), ]
# ggplot2 3.4.1 changed this stat's name from `y` to `ecdf`
if (packageVersion("ggplot2") >= "3.4.1") {
Expand Down
5 changes: 5 additions & 0 deletions R/geom_mirrored_density.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ StatMirrorDensity <- ggplot2::ggproto(
# Store mirroring flag
should_mirror <- unique(data$.should_mirror)

# Extract numeric data from psw weights if present
if ("weight" %in% names(data)) {
data$weight <- extract_weight_data(data$weight)
}

data <- ggplot2::StatDensity$compute_group(
data = data,
scales = scales,
Expand Down
5 changes: 5 additions & 0 deletions R/geom_mirrored_histogram.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ StatMirrorCount <- ggplot2::ggproto(
# Store mirroring flag
should_mirror <- unique(data$.should_mirror)

# Extract numeric data from psw weights if present
if ("weight" %in% names(data)) {
data$weight <- extract_weight_data(data$weight)
}

data <- ggplot2::StatBin$compute_group(
data = data,
scales = scales,
Expand Down
11 changes: 8 additions & 3 deletions R/geom_qq2.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,13 @@
#' geom_abline(intercept = 0, slope = 1, linetype = "dashed")
#'
#' # Compare multiple weights using long format
#' # TODO: Remove vec_data() workaround once propensity implements vctrs methods
#' # Extract numeric data from psw objects first
#' nhefs_for_pivot <- nhefs_weights
#' nhefs_for_pivot$w_ate <- vctrs::vec_data(nhefs_weights$w_ate)
#' nhefs_for_pivot$w_att <- vctrs::vec_data(nhefs_weights$w_att)
#' long_data <- tidyr::pivot_longer(
#' nhefs_weights,
#' nhefs_for_pivot,
#' cols = c(w_ate, w_att),
#' names_to = "weight_type",
#' values_to = "weight"
Expand Down Expand Up @@ -184,7 +189,7 @@ process_aesthetic_group <- function(

# Add weight if present
if (!is.null(combined_data$weight) && all(!is.na(combined_data$weight))) {
temp_data$.wts <- combined_data$weight
temp_data$.wts <- extract_weight_data(combined_data$weight)
wts_arg <- ".wts"
} else {
wts_arg <- NULL
Expand Down Expand Up @@ -308,7 +313,7 @@ StatQq2 <- ggplot2::ggproto(

# Add weight if present
if (!is.null(data$weight) && all(!is.na(data$weight))) {
temp_data$.wts <- data$weight
temp_data$.wts <- extract_weight_data(data$weight)
wts_arg <- ".wts"
} else {
wts_arg <- NULL
Expand Down
7 changes: 6 additions & 1 deletion R/geom_roc.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,13 @@
#' geom_abline(intercept = 0, slope = 1, linetype = "dashed")
#'
#' # With grouping by weight
#' # TODO: Remove vec_data() workaround once propensity implements vctrs methods
#' # Extract numeric data from psw objects first
#' nhefs_for_pivot <- nhefs_weights
#' nhefs_for_pivot$w_ate <- vctrs::vec_data(nhefs_weights$w_ate)
#' nhefs_for_pivot$w_att <- vctrs::vec_data(nhefs_weights$w_att)
#' long_data <- tidyr::pivot_longer(
#' nhefs_weights,
#' nhefs_for_pivot,
#' cols = c(w_ate, w_att),
#' names_to = "weight_type",
#' values_to = "weight"
Expand Down
5 changes: 5 additions & 0 deletions R/plot_mirror_distributions.R
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ plot_mirror_distributions <- function(
wts_cols <- tidyselect::eval_select(wts_quo, .data)
wts_names <- names(wts_cols)

# Convert psw weight columns to numeric for compatibility with pivot_longer
for (wts_name in wts_names) {
.data[[wts_name]] <- extract_weight_data(.data[[wts_name]])
}

if (include_unweighted) {
.data$.observed <- 1
wts_names <- c(".observed", wts_names)
Expand Down
5 changes: 5 additions & 0 deletions R/plot_qq.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ plot_qq <- function(
wts_names <- names(wts_cols)

# Create long format data
# Convert psw weight columns to numeric for compatibility with pivot_longer
for (wts_name in wts_names) {
.data[[wts_name]] <- extract_weight_data(.data[[wts_name]])
}

if (include_observed) {
# Add observed as a weight column with value 1
.data$.observed <- 1
Expand Down
7 changes: 5 additions & 2 deletions R/utils-validation.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ validate_numeric <- function(x, arg_name = deparse(substitute(x)), call = rlang:
validate_weights <- function(weights, n, arg_name = "weights", call = rlang::caller_env()) {
if (is.null(weights)) return(invisible(weights))

if (!is.numeric(weights)) {
# Accept both numeric vectors and psw objects from propensity package
is_valid_weights <- is.numeric(weights) || inherits(weights, "psw")

if (!is_valid_weights) {
abort(
"{.arg {arg_name}} must be numeric or {.code NULL}",
"{.arg {arg_name}} must be numeric, a psw object, or {.code NULL}",
error_class = "halfmoon_type_error",
call = call
)
Expand Down
11 changes: 11 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,14 @@ create_group_signature <- function(group_data, aes_cols) {
"no_aes"
}
}

# Extract numeric data from weights (handles both numeric and psw objects)
extract_weight_data <- function(weights) {
if (is.null(weights)) {
return(NULL)
} else {
# Use vctrs::vec_data to extract underlying numeric data
# This works for both numeric vectors and psw objects
vctrs::vec_data(weights)
}
}
46 changes: 21 additions & 25 deletions data-raw/nhefs_weights.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,56 +93,52 @@ cat_ps <- as.data.frame(predict(
# Calculate weights and join back to data
nhefs_with_cat_weights <- nhefs_for_model %>%
mutate(
# Convert psw objects to numeric using as.numeric
w_cat_ate = as.numeric(wt_ate(cat_ps, alcoholfreq_cat)),
w_cat_att_none = as.numeric(wt_att(
# Keep psw objects from propensity package
w_cat_ate = wt_ate(cat_ps, alcoholfreq_cat),
w_cat_att_none = wt_att(
cat_ps,
alcoholfreq_cat,
focal = "none"
)),
w_cat_att_lt12 = as.numeric(wt_att(
),
w_cat_att_lt12 = wt_att(
cat_ps,
alcoholfreq_cat,
focal = "lt_12_per_year"
)),
w_cat_att_1_4mo = as.numeric(wt_att(
),
w_cat_att_1_4mo = wt_att(
cat_ps,
alcoholfreq_cat,
focal = "1_4_per_month"
)),
w_cat_att_2_3wk = as.numeric(wt_att(
),
w_cat_att_2_3wk = wt_att(
cat_ps,
alcoholfreq_cat,
focal = "2_3_per_week"
)),
w_cat_att_daily = as.numeric(wt_att(
),
w_cat_att_daily = wt_att(
cat_ps,
alcoholfreq_cat,
focal = "daily"
)),
w_cat_atu_none = as.numeric(wt_atu(
),
w_cat_atu_none = wt_atu(
cat_ps,
alcoholfreq_cat,
focal = "none"
)),
w_cat_ato = as.numeric(wt_ato(cat_ps, alcoholfreq_cat)),
w_cat_atm = as.numeric(wt_atm(cat_ps, alcoholfreq_cat))
),
w_cat_ato = wt_ato(cat_ps, alcoholfreq_cat),
w_cat_atm = wt_atm(cat_ps, alcoholfreq_cat)
)

# Now combine with binary exposure weights, keeping all rows
nhefs_weights <- propensity_model %>%
augment(type.predict = "response", data = nhefs_with_cat) %>%
mutate(
wts = 1 / ifelse(qsmk == 0, 1 - .fitted, .fitted),
w_ate = (qsmk / .fitted) +
((1 - qsmk) / (1 - .fitted)),
w_att = ((.fitted * qsmk) / .fitted) +
((.fitted * (1 - qsmk)) / (1 - .fitted)),
w_atc = (((1 - .fitted) * qsmk) / .fitted) +
(((1 - .fitted) * (1 - qsmk)) / (1 - .fitted)),
w_atm = pmin(.fitted, 1 - .fitted) /
(qsmk * .fitted + (1 - qsmk) * (1 - .fitted)),
w_ato = (1 - .fitted) * qsmk + .fitted * (1 - qsmk)
w_ate = wt_ate(.fitted, qsmk),
w_att = wt_att(.fitted, qsmk),
w_atc = wt_atu(.fitted, qsmk),
w_atm = wt_atm(.fitted, qsmk),
w_ato = wt_ato(.fitted, qsmk)
) %>%
# Join categorical weights, they will be NA for unknown alcohol frequency
left_join(
Expand Down
Binary file modified data/nhefs_weights.rda
Binary file not shown.
7 changes: 6 additions & 1 deletion man/geom_qq2.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading