Skip to content

Commit 63006fe

Browse files
add ml for causal bonus section
1 parent 7d3fd80 commit 63006fe

File tree

94 files changed

+7310
-33
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

94 files changed

+7310
-33
lines changed

exercises/15-bonus-ml-for-causal-exercises.qmd

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ ate_gcomp
179179

180180
## Your Turn 1
181181

182-
1. First, create a character vector `sl_library` that specifies the following algorithms: "SL.glm", "SL.ranger", "SL.xgboost", "SL.gam". Then, Fit a SuperLearner for the exposure model using the `SuperLearner` package. The predictors for this model should be the confounders identified in the DAG: `park_ticket_season`, `park_close`, and `park_temperature_high`. The outcome is `park_extra_magic_morning`.
182+
1. First, create a character vector `sl_library` that specifies the following algorithms: "SL.glm", "SL.ranger", "SL.gam". Then, Fit a SuperLearner for the exposure model using the `SuperLearner` package. The predictors for this model should be the confounders identified in the DAG: `park_ticket_season`, `park_close`, and `park_temperature_high`. The outcome is `park_extra_magic_morning`.
183183
2. Fit a SuperLearner for the outcome model using the `SuperLearner` package. The predictors for this model should be the confounders plus the exposure: `park_extra_magic_morning`, `park_ticket_season`, `park_close`, and `park_temperature_high`. The outcome is `wait_minutes_posted_avg`.
184184
3. Inspect the fitted SuperLearner objects.
185185

@@ -251,7 +251,6 @@ outcome_rmse
251251
sl_library_extended <- c(
252252
"SL.glm",
253253
"SL.ranger",
254-
"SL.xgboost",
255254
"SL.earth",
256255
"SL.gam",
257256
"SL.glm.interaction",
@@ -310,16 +309,23 @@ tidy(ipw_model) |>
310309
```{r}
311310
# G-computation with SuperLearner outcome model
312311
# Step 1: Create counterfactual datasets
313-
seven_dwarfs_clone <- seven_dwarfs |>
314-
mutate(park_close = as.numeric(park_close))
312+
# For SuperLearner prediction, we need only the columns used in the model
315313
316314
# Dataset where everyone is treated, `park_extra_magic_morning` = 1
317-
data_all_treated <- seven_dwarfs_clone |>
318-
mutate(park_extra_magic_morning = ___)
315+
data_all_treated <- seven_dwarfs |>
316+
select(park_extra_magic_morning, park_ticket_season, park_close, park_temperature_high) |>
317+
mutate(
318+
park_close = as.numeric(park_close),
319+
park_extra_magic_morning = ___
320+
)
319321
320322
# Dataset where everyone is control, `park_extra_magic_morning` = 0
321-
data_all_control <- seven_dwarfs_clone |>
322-
mutate(park_extra_magic_morning = ___)
323+
data_all_control <- seven_dwarfs |>
324+
select(park_extra_magic_morning, park_ticket_season, park_close, park_temperature_high) |>
325+
mutate(
326+
park_close = as.numeric(park_close),
327+
park_extra_magic_morning = ___
328+
)
323329
324330
# Step 2: Predict outcomes under each scenario using SuperLearner
325331
pred_treated <- predict(______, newdata = ______)$pred[, 1]
@@ -350,7 +356,7 @@ outcome_sl_bounded <- SuperLearner(
350356
X = seven_dwarfs |>
351357
select(__________, park_ticket_season, park_close, park_temperature_high) |>
352358
mutate(park_close = as.numeric(park_close)),
353-
family = binomial(),
359+
family = quasibinomial(),
354360
SL.library = __________,
355361
cvControl = list(V = 5)
356362
)

slides/raw/15-bonus-ml-for-causal.html

Lines changed: 1168 additions & 0 deletions
Large diffs are not rendered by default.

slides/raw/15-bonus-ml-for-causal.qmd

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ dag_base <- dag |>
8888
aes(x, y, xend = xend, yend = yend, color = status)
8989
) +
9090
geom_dag_point() +
91-
geom_dag_label_repel2(aes(label = label), seed = 1234) +
9291
scale_color_okabe_ito(na.value = "grey90") +
9392
theme_dag() +
9493
theme(legend.position = "none")
@@ -171,11 +170,15 @@ halfmoon::plot_mirror_distributions(
171170
ggokabeito::scale_fill_okabe_ito()
172171
```
173172

174-
## G-computation
173+
## G-computation {background-color="#23373B"}
175174

176175
1. Fit a model for `y ~ x + z` where z is all confounders
177176
2. Create a duplicate of your data set for each level of `x`
178177
3. Set the value of x to a single value for each cloned data set (e.g `x = 1` for one, `x = 0` for the other)
178+
179+
180+
## G-computation {background-color="#23373B"}
181+
179182
4. Make predictions using the model on the cloned data sets
180183
5. Calculate the estimate you want, e.g. `mean(x_1) - mean(x_0)`
181184

@@ -258,6 +261,7 @@ dag_base +
258261

259262
```{r}
260263
#| echo: false
264+
#| fig-width: 10
261265
smk_wt_dag <- dagify(
262266
qsmk ~ sex + race + age + education +
263267
smokeintensity + smokeyrs + exercise + active + wt71,
@@ -288,7 +292,7 @@ smk_wt_dag |>
288292
geom_dag_edges() +
289293
geom_dag_point() +
290294
geom_dag_label_repel(aes(label = label), seed = 1234) +
291-
scale_color_okabe_ito(na.value = "grey90") +
295+
scale_color_okabe_ito(na.value = "grey60") +
292296
theme_dag() +
293297
theme(legend.position = "none")
294298
```
@@ -397,14 +401,18 @@ dagify(
397401

398402
```{r}
399403
#| echo: false
400-
knitr::include_graphics("img/ml_algorithms.png")
404+
knitr::include_graphics("img/superlearner.png")
401405
```
402406

407+
::: tiny
408+
Image source: Sherri Rose
409+
:::
410+
403411
## Ensemble Algorithms with SuperLearner
404412

405413
```{r}
406414
#| echo: false
407-
knitr::include_graphics("img/superlearner.png")
415+
knitr::include_graphics("img/ml_algorithms.png")
408416
```
409417

410418
:::{.fragment}
@@ -414,12 +422,12 @@ Given a set of candidate algorithms (and hyperparameters), stacked ensembles com
414422
## SuperLearner: Exposure Model
415423

416424
```{r}
417-
#| code-line-numbers: "|1,3,8|4-7|"
425+
#| code-line-numbers: "|1,3,10|4-7|"
418426
#| cache: true
419-
sl_library <- c("SL.glm", "SL.ranger", "SL.xgboost", "SL.gam")
427+
sl_library <- c("SL.glm", "SL.ranger", "SL.gam")
420428
421429
propensity_sl <- SuperLearner(
422-
Y = nhefs_complete_uc$qsmk |> as.integer(),
430+
Y = as.integer(nhefs_complete_uc$qsmk == "Yes"),
423431
X = nhefs_complete_uc |>
424432
select(sex, race, age, education, smokeintensity,
425433
smokeyrs, exercise, active, wt71) |>
@@ -460,7 +468,7 @@ outcome_sl <- SuperLearner(
460468
outcome_sl
461469
```
462470

463-
## *Your Turn 1*
471+
## *Your Turn 1* {.small}
464472

465473
```{r}
466474
#| echo: false
@@ -497,7 +505,7 @@ tidy(ipw_model)
497505
## G-computation with SuperLearner
498506

499507
```{r}
500-
#| code-line-numbers: "|1-2,7|4-5,8|10"
508+
#| code-line-numbers: "|1,5,7,11|13-14|16"
501509
data_all_quit <- nhefs_complete_uc |>
502510
select(qsmk, sex, race, age, education, smokeintensity,
503511
smokeyrs, exercise, active, wt71) |>
@@ -542,17 +550,19 @@ countdown::countdown(minutes = 8)
542550

543551
- In **IPW** and **G-computation**, we estimate the average treatment effect (ATE) using predictions from the exposure and outcome models. But these algorithms optimize for the predictions, not the ATE.
544552
- In **TMLE**, we adjust the predictions to specifically target the ATE. We change the bias-variance tradeoff to focus on the ATE rather than just minimizing prediction error. This is a debiasing step that also improves the efficiency of the estimate!
545-
- Targeting is a general technique that can be applied to many problems, not just causal ones
546553

547554
## Targeted Learning: valid statistical inference
548-
- In **IPW** and **G-computation**, we can using ML algorithms to make predictions, but we cannot easily get valid confidence intervals. Bootstrapping is often used, but it can be computationally intensive and not always valid.
555+
- In **IPW** and **G-computation**, we cannot easily get valid confidence intervals with ML. Bootstrapping is often used, but it can be computationally intensive and not always valid.
549556
- In **TMLE**, we can use the influence curve to get valid confidence intervals. The influence curve is a way to estimate the variance of the TMLE estimate, even when using complex ML algorithms.
550557

551558
## The TMLE Algorithm {background-color="#23373B"}
552559

553560
1. Start with SuperLearner predictions for the outcome
554561
2. Calculate the propensity scores using SuperLearner
555562
3. Create the clever covariate using the propensity scores
563+
564+
## The TMLE Algorithm {background-color="#23373B"}
565+
556566
4. Fit the fluctuation model to learn how much to adjust the outcome predictions
557567
5. Update the predictions with the targeted adjustment
558568
6. Calculate the TMLE estimate and standard error using the influence curve
@@ -562,6 +572,7 @@ countdown::countdown(minutes = 8)
562572
```{r}
563573
#| echo: true
564574
#| cache: true
575+
#| cache.lazy: false
565576
# For TMLE with continuous outcomes, fit SuperLearner on bounded Y
566577
min_y <- min(nhefs_complete_uc$wt82_71)
567578
max_y <- max(nhefs_complete_uc$wt82_71)
@@ -574,7 +585,7 @@ outcome_sl_bounded <- SuperLearner(
574585
select(qsmk, sex, race, age, education, smokeintensity,
575586
smokeyrs, exercise, active, wt71) |>
576587
mutate(across(everything(), as.numeric)),
577-
family = binomial(),
588+
family = quasibinomial(),
578589
SL.library = sl_library,
579590
cvControl = list(V = 5)
580591
)
@@ -588,7 +599,7 @@ initial_pred_no_quit <- predict(outcome_sl_bounded, newdata = data_all_no_quit)$
588599
589600
# Predictions for observed treatment
590601
initial_pred_observed <- ifelse(
591-
nhefs_complete_uc$qsmk == 1,
602+
nhefs_complete_uc$qsmk == "Yes",
592603
initial_pred_quit,
593604
initial_pred_no_quit
594605
)
@@ -597,9 +608,9 @@ initial_pred_observed <- ifelse(
597608
## TMLE Step 2: Clever Covariate
598609

599610
```{r}
600-
#| echo: true
611+
#| code-line-numbers: "|3-4"
601612
clever_covariate <- ifelse(
602-
nhefs_complete_uc$qsmk == 1,
613+
nhefs_complete_uc$qsmk == "Yes",
603614
1 / propensity_scores,
604615
-1 / (1 - propensity_scores)
605616
)
@@ -614,7 +625,8 @@ clever_covariate <- ifelse(
614625
## TMLE Step 3: Targeting
615626

616627
```{r}
617-
#| echo: true
628+
#| code-line-numbers: "|4|5|6|7|8|11-12"
629+
#| output-location: fragment
618630
# Fluctuation model - learns how much to adjust
619631
# Use binomial family and work on logit scale
620632
fluctuation_model <- glm(
@@ -637,7 +649,7 @@ epsilon
637649
## TMLE Step 4: Update Predictions
638650

639651
```{r}
640-
#| code-line-numbers: "|1-2|5-6"
652+
#| code-line-numbers: "|2-3|6-7"
641653
# Update predictions on logit scale, then transform back
642654
logit_pred_quit <- qlogis(initial_pred_quit) + epsilon * (1 / propensity_scores)
643655
logit_pred_no_quit <- qlogis(initial_pred_no_quit) + epsilon * (-1 / (1 - propensity_scores))
@@ -710,15 +722,15 @@ targeted_ate <- mean(
710722
targeted_pred_quit - targeted_pred_no_quit
711723
) * (max_y - min_y)
712724
713-
c(initial = initial_ate, targeted = targeted_ate)
725+
tibble(initial = initial_ate, targeted = targeted_ate)
714726
```
715727

716728
## TMLE Inference
717729

718730
```{r}
719-
#| echo: true
731+
#| output-location: slide
720732
targeted_pred_observed <- ifelse(
721-
nhefs_complete_uc$qsmk == 1,
733+
nhefs_complete_uc$qsmk == "Yes",
722734
targeted_pred_quit,
723735
targeted_pred_no_quit
724736
)
@@ -743,12 +755,12 @@ tibble(
743755

744756
```{r}
745757
#| cache: true
746-
#| output-location: slide
758+
#| cache.lazy: false
747759
library(tmle)
748760
749761
tmle_result <- tmle(
750762
Y = nhefs_complete_uc$wt82_71,
751-
A = nhefs_complete_uc$qsmk |> as.integer(),
763+
A = as.integer(nhefs_complete_uc$qsmk == "Yes"),
752764
W = nhefs_complete_uc |>
753765
select(sex, race, age, education, smokeintensity,
754766
smokeyrs, exercise, active, wt71) |>
@@ -757,9 +769,14 @@ tmle_result <- tmle(
757769
g.SL.library = sl_library
758770
)
759771
760-
summary(tmle_result)
772+
tibble(
773+
ate = tmle_result$estimates$ATE$psi,
774+
lower_ci = tmle_result$estimates$ATE$CI[[1]],
775+
upper_ci = tmle_result$estimates$ATE$CI[[2]]
776+
)
761777
```
762778

779+
763780
## *Your Turn 4*
764781

765782
```{r}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
tidyverse
2+
ggplot2
3+
tibble
4+
tidyr
5+
readr
6+
purrr
7+
dplyr
8+
stringr
9+
forcats
10+
lubridate
11+
broom
12+
causaldata
13+
touringplans
14+
propensity
15+
nnls
16+
foreach
17+
gam
18+
SuperLearner
19+
Matrix
20+
glmnet
21+
tmle
22+
yardstick
23+
ggdag
24+
ggokabeito
25+
patchwork
3.05 KB
Binary file not shown.
8.18 MB
Binary file not shown.
285 Bytes
Binary file not shown.
3 KB
Binary file not shown.
14.3 MB
Binary file not shown.
272 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)