Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ scratch/
CLAUDE.md

/.quarto/
Rplots.pdf
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,13 @@ Suggests:
mgcv,
survey,
testthat (>= 3.0.0),
vdiffr
vdiffr,
withr
Config/testthat/edition: 3
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.2
Remotes:
Remotes:
r-causal/tidysmd,
r-causal/propensity
40 changes: 30 additions & 10 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
# Generated by roxygen2: do not edit by hand

S3method(plot_calibration,data.frame)
S3method(plot_calibration,glm)
S3method(plot_calibration,lm)
S3method(autoplot,halfmoon_auc)
S3method(autoplot,halfmoon_balance)
S3method(autoplot,halfmoon_calibration)
S3method(autoplot,halfmoon_ess)
S3method(autoplot,halfmoon_qq)
S3method(autoplot,halfmoon_roc)
S3method(plot,halfmoon_auc)
S3method(plot,halfmoon_balance)
S3method(plot,halfmoon_calibration)
S3method(plot,halfmoon_ess)
S3method(plot,halfmoon_qq)
S3method(plot,halfmoon_roc)
S3method(plot_model_calibration,data.frame)
S3method(plot_model_calibration,glm)
S3method(plot_model_calibration,halfmoon_calibration)
S3method(plot_model_calibration,lm)
S3method(plot_qq,default)
S3method(plot_qq,halfmoon_qq)
S3method(plot_stratified_residuals,data.frame)
S3method(plot_stratified_residuals,glm)
S3method(plot_stratified_residuals,lm)
Expand All @@ -11,15 +26,21 @@ export(StatRoc)
export(add_ess_header)
export(bal_corr)
export(bal_energy)
export(bal_ess)
export(bal_ks)
export(bal_model_auc)
export(bal_model_roc_curve)
export(bal_prognostic_score)
export(bal_qq)
export(bal_smd)
export(bal_vr)
export(bind_matches)
export(check_auc)
export(check_balance)
export(check_calibration)
export(check_ess)
export(check_model_auc)
export(check_model_calibration)
export(check_model_roc_curve)
export(check_qq)
export(contains)
export(ends_with)
export(ess)
Expand All @@ -38,20 +59,19 @@ export(num_range)
export(one_of)
export(peek_vars)
export(plot_balance)
export(plot_calibration)
export(plot_ess)
export(plot_mirror_distributions)
export(plot_model_auc)
export(plot_model_calibration)
export(plot_model_roc_curve)
export(plot_qq)
export(plot_roc_auc)
export(plot_roc_curve)
export(plot_stratified_residuals)
export(qq)
export(roc_curve)
export(starts_with)
export(stat_qq2)
export(stat_roc)
export(tidy_smd)
export(weighted_quantile)
importFrom(ggplot2,autoplot)
importFrom(rlang,"%||%")
importFrom(rlang,.data)
importFrom(rlang,.env)
Expand Down
104 changes: 104 additions & 0 deletions R/autoplot-methods.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#' Autoplot Methods for halfmoon Objects
#'
#' These methods provide automatic plot generation for halfmoon data objects
#' using ggplot2's autoplot interface. Each method dispatches to the appropriate
#' plot_*() function as follows:
#'
#' - `autoplot.halfmoon_balance` calls [plot_balance()]
#' - `autoplot.halfmoon_ess` calls [plot_ess()]
#' - `autoplot.halfmoon_calibration` calls [plot_model_calibration()]
#' - `autoplot.halfmoon_roc` calls [plot_model_roc_curve()]
#' - `autoplot.halfmoon_auc` calls [plot_model_auc()]
#' - `autoplot.halfmoon_qq` calls [plot_qq()]
#'
#' @param object A halfmoon data object with appropriate class
#' @param ... Additional arguments passed to the underlying plot_*() function
#'
#' @return A ggplot2 object
#' @name autoplot-halfmoon
NULL

#' @rdname autoplot-halfmoon
#' @export
#' @importFrom ggplot2 autoplot
autoplot.halfmoon_balance <- function(object, ...) {
plot_balance(object, ...)
}

#' @rdname autoplot-halfmoon
#' @export
autoplot.halfmoon_ess <- function(object, ...) {
plot_ess(object, ...)
}

#' @rdname autoplot-halfmoon
#' @export
autoplot.halfmoon_calibration <- function(object, ...) {
plot_model_calibration(object, ...)
}

#' @rdname autoplot-halfmoon
#' @export
autoplot.halfmoon_roc <- function(object, ...) {
plot_model_roc_curve(object, ...)
}

#' @rdname autoplot-halfmoon
#' @export
autoplot.halfmoon_auc <- function(object, ...) {
plot_model_auc(object, ...)
}

#' @rdname autoplot-halfmoon
#' @export
autoplot.halfmoon_qq <- function(object, ...) {
plot_qq(object, ...)
}

#' Plot Methods for halfmoon Objects
#'
#' These methods provide standard plot generation for halfmoon data objects.
#' They create the plot using autoplot() and then print it.
#'
#' @param x A halfmoon data object with appropriate class
#' @param ... Additional arguments passed to autoplot()
#'
#' @return Invisibly returns the ggplot2 object after printing
#' @name plot-halfmoon
NULL

#' @rdname plot-halfmoon
#' @export
plot.halfmoon_balance <- function(x, ...) {
autoplot(x, ...)
}

#' @rdname plot-halfmoon
#' @export
plot.halfmoon_ess <- function(x, ...) {
autoplot(x, ...)
}

#' @rdname plot-halfmoon
#' @export
plot.halfmoon_calibration <- function(x, ...) {
autoplot(x, ...)
}

#' @rdname plot-halfmoon
#' @export
plot.halfmoon_roc <- function(x, ...) {
autoplot(x, ...)
}

#' @rdname plot-halfmoon
#' @export
plot.halfmoon_auc <- function(x, ...) {
autoplot(x, ...)
}

#' @rdname plot-halfmoon
#' @export
plot.halfmoon_qq <- function(x, ...) {
autoplot(x, ...)
}
49 changes: 49 additions & 0 deletions R/bal_ess.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#' Calculate Effective Sample Size for Single Weight Vector
#'
#' Computes the effective sample size (ESS) for a single weighting scheme.
#' This is a wrapper around [ess()] that follows the bal_*() naming convention
#' for API consistency.
#'
#' @details
#' The effective sample size (ESS) is calculated using the classical formula:
#' \eqn{ESS = (\sum w)^2 / \sum(w^2)}.
#'
#' ESS reflects how many observations you would have if all were equally weighted.
#' When weights vary substantially, the ESS can be much smaller than the actual
#' number of observations, indicating that a few observations carry
#' disproportionately large weights.
#'
#' **Diagnostic Value**:
#' * A large discrepancy between ESS and the actual sample size indicates that
#' a few observations carry disproportionately large weights
#' * A small ESS signals that weighted estimates are more sensitive to a handful
#' of observations, inflating the variance and standard errors
#' * If ESS is much lower than the total sample size, consider investigating
#' why some weights are extremely large or small
#'
#' @param .wts A numeric vector of weights or a single weight column from a data frame.
#' @inheritParams balance_params
#'
#' @return A single numeric value representing the effective sample size.
#'
#' @family balance functions
#' @seealso [ess()] for the underlying implementation, [check_ess()] for
#' computing ESS across multiple weighting schemes
#'
#' @examples
#' # ESS for ATE weights
#' bal_ess(nhefs_weights$w_ate)
#'
#' # ESS for ATT weights
#' bal_ess(nhefs_weights$w_att)
#'
#' # With missing values
#' weights_with_na <- nhefs_weights$w_ate
#' weights_with_na[1:5] <- NA
#' bal_ess(weights_with_na, na.rm = TRUE)
#'
#' @export
bal_ess <- function(.wts, na.rm = FALSE) {
# Simply call the existing ess() function
ess(.wts, na.rm = na.rm)
}
130 changes: 130 additions & 0 deletions R/bal_model_auc.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#' Calculate Single AUC for Model Balance Assessment
#'
#' Computes the Area Under the ROC Curve (AUC) for a single weighting scheme
#' or unweighted data. In causal inference, an AUC around 0.5 indicates good
#' balance between treatment groups.
#'
#' @details
#' The AUC provides a single metric for assessing propensity score balance.
#' When propensity scores achieve perfect balance, the weighted distribution
#' of scores should be identical between treatment groups, resulting in an
#' AUC of 0.5 (chance performance).
#'
#' AUC values significantly different from 0.5 indicate systematic differences
#' in propensity score distributions between groups, suggesting inadequate
#' balance.
#'
#' @param .data A data frame containing the variables.
#' @param .truth The treatment/outcome variable (unquoted).
#' @param .estimate The propensity score or fitted values (unquoted).
#' @param .wts Optional single weight variable (unquoted). If NULL, computes
#' unweighted AUC.
#' @inheritParams balance_params
#' @inheritParams treatment_param
#'
#' @return A numeric value representing the AUC. Values around 0.5 indicate
#' good balance, while values closer to 0 or 1 indicate poor balance.
#'
#' @family balance functions
#' @seealso [check_model_auc()] for computing AUC across multiple weights,
#' [bal_model_roc_curve()] for the full ROC curve
#'
#' @examples
#' # Unweighted AUC
#' bal_model_auc(nhefs_weights, qsmk, .fitted)
#'
#' # Weighted AUC
#' bal_model_auc(nhefs_weights, qsmk, .fitted, w_ate)
#'
#' @export
bal_model_auc <- function(
.data,
.truth,
.estimate,
.wts = NULL,
na.rm = TRUE,
treatment_level = NULL
) {
validate_data_frame(.data, call = rlang::caller_env())

truth_quo <- rlang::enquo(.truth)
estimate_quo <- rlang::enquo(.estimate)
wts_quo <- rlang::enquo(.wts)

# Extract column names
truth_name <- names(tidyselect::eval_select(truth_quo, .data))
estimate_name <- names(tidyselect::eval_select(estimate_quo, .data))

if (length(truth_name) != 1) {
abort(
"{.arg .truth} must select exactly one variable",
error_class = "halfmoon_arg_error",
call = rlang::current_env()
)
}
if (length(estimate_name) != 1) {
abort(
"{.arg .estimate} must select exactly one variable",
error_class = "halfmoon_arg_error",
call = rlang::current_env()
)
}

# Extract data
truth <- .data[[truth_name]]
estimate <- .data[[estimate_name]]

# Handle weights if provided
weights <- NULL
if (!rlang::quo_is_null(wts_quo)) {
weight_vars <- names(tidyselect::eval_select(wts_quo, .data))
if (length(weight_vars) != 1) {
abort(
"{.arg .wts} must select exactly one variable or be NULL",
error_class = "halfmoon_arg_error",
call = rlang::current_env()
)
}
weights <- extract_weight_data(.data[[weight_vars[1]]])
}

# Handle missing values
if (na.rm) {
if (is.null(weights)) {
complete_cases <- stats::complete.cases(truth, estimate)
} else {
complete_cases <- stats::complete.cases(truth, estimate, weights)
}
truth <- truth[complete_cases]
estimate <- estimate[complete_cases]
if (!is.null(weights)) {
weights <- weights[complete_cases]
}
} else {
if (is.null(weights)) {
na_present <- any(is.na(truth)) || any(is.na(estimate))
} else {
na_present <- any(is.na(truth)) ||
any(is.na(estimate)) ||
any(is.na(weights))
}
if (na_present) {
return(NA_real_)
}
}

# Compute ROC curve
roc_data <- compute_roc_curve_imp(
truth,
estimate,
weights = weights,
treatment_level = treatment_level,
call = rlang::current_env()
)

# Calculate AUC using trapezoidal rule
fpr <- 1 - roc_data$specificity
tpr <- roc_data$sensitivity

compute_auc(fpr, tpr)
}
Loading
Loading