|
| 1 | + |
| 2 | +generate_convergence_plot <- function(model_summary, title_i) { |
| 3 | + df_smi_df <- calcScaleModelIndicator(model_summary) |
| 4 | + |
| 5 | + df_smi_df <- df_smi_df %>% filter(.chain != 4) |
| 6 | + |
| 7 | + fit <- model_summary$fit |
| 8 | + p1 <- (fit$post$mcmc[1:3] %>% bayesplot::mcmc_trace()) + theme_minimal() + theme(legend.position = "top") + ggtitle("A. Trace plots for fitted parameters") |
| 9 | + p2 <- fit$post$lpost %>% filter(chain_no !=4) %>% ggplot() + geom_line(aes(x = sample_no, y = lpost, color = chain_no)) + theme_minimal() + theme(legend.position = "top") + ggtitle("B. Trace plots for log posterior") + |
| 10 | + labs(x = "Sample number", y = "Log-posterior") |
| 11 | + |
| 12 | + p3 <- df_conver_stat <- summarise_draws(fit$post$mcmc[1:3] ) %>% select(variable, rhat, ess_bulk, ess_tail) %>% |
| 13 | + pivot_longer(!variable, names_to = "stat", values_to = "value") %>% |
| 14 | + ggplot() + |
| 15 | + geom_col(aes(y = variable, x = value)) + |
| 16 | + facet_wrap(~stat, scales = "free") + theme_minimal() + ggtitle("C. Convergence diagnosis for fitted parameters") + |
| 17 | + labs(x = "Value", y = "Parameter") |
| 18 | + pA <- p1 / p2 / p3 |
| 19 | + |
| 20 | + |
| 21 | + |
| 22 | + pdims_trace <- df_smi_df %>% |
| 23 | + ggplot() + |
| 24 | + geom_line(aes(x = .iteration, y = dims, color = as.character(.chain))) + |
| 25 | + labs(x = "Iteration", y = "Model dimension", color = "Chain") + |
| 26 | + ggtitle("D. Trace plots for transdimensional convergence: dimensions of model") |
| 27 | + |
| 28 | + |
| 29 | + pdims_hist <- df_smi_df %>% |
| 30 | + ggplot() + |
| 31 | + geom_histogram(aes(x = dims, fill = as.character(.chain))) + |
| 32 | + labs(x = "Model dimension", y = "Count", fill = "Chain") |
| 33 | + |
| 34 | + p1 <- pdims_trace + pdims_hist + plot_layout(guides = "collect") & theme_minimal() & theme(legend.position = "top") |
| 35 | + |
| 36 | + |
| 37 | + |
| 38 | + psmi_trace <- df_smi_df %>% |
| 39 | + ggplot() + |
| 40 | + geom_line(aes(x = .iteration, y = sMI, color = as.character(.chain))) + |
| 41 | + labs(x = "Iteration", y = "Log2 of SMI", color = "Chain") + |
| 42 | + ggtitle("E. Trace plots for transdimensional convergence: SMI of model") |
| 43 | + |
| 44 | + |
| 45 | + psmi_hist <- df_smi_df %>% |
| 46 | + ggplot() + |
| 47 | + geom_histogram(aes(x = sMI, fill = as.character(.chain))) + |
| 48 | + labs(x = "Log2 of SMI", y = "Count", fill = "Chain") |
| 49 | + |
| 50 | + p2 <- psmi_trace + psmi_hist + plot_layout(guides = "collect") & theme_minimal() & theme(legend.position = "top") |
| 51 | + |
| 52 | + p3 <- summarise_draws(df_smi_df) %>% select(variable, rhat, ess_bulk, ess_tail) %>% pivot_longer(!variable, names_to = "stat", values_to = "value") %>% |
| 53 | + filter(variable == "sMI") %>% |
| 54 | + ggplot()+ |
| 55 | + geom_col(aes(y = "", x = value)) + |
| 56 | + facet_grid(cols = vars(stat), scales = "free") + theme_minimal() + ggtitle("F. Convergence diagnosis for transdimensional convergence") + |
| 57 | + labs(x = "Value", y = "Model") |
| 58 | + |
| 59 | + pB <- p1 / p2 / p3 |
| 60 | + |
| 61 | + |
| 62 | + (pA | pB) + plot_annotation(title = paste0("CONVERGENCE DIAGNOSITICS FOR ", title_i)) & |
| 63 | + theme(title = element_text(size = 12)) |
| 64 | +} |
| 65 | + |
| 66 | + |
| 67 | +plot_Rhat_time_alt <- function(model_summary, title_i) { |
| 68 | + |
| 69 | + outputfull <- model_summary$post |
| 70 | + |
| 71 | + model_outline <- model_summary$fit$model |
| 72 | + bio_all <- model_outline$infoModel$biomarkers |
| 73 | + |
| 74 | + fit_states_dt <- as.data.table(outputfull$fit_states) %>% filter(chain_no != 4) |
| 75 | + S <- fit_states_dt %>% filter(id == 1) %>% nrow |
| 76 | + |
| 77 | + ids <- fit_states_dt %>% group_by(id) %>% summarise(prob = sum(inf_ind) / S) %>% filter(prob > 0.5) %>% pull(id) %>% unique |
| 78 | + |
| 79 | + if (length(ids) == 0) { |
| 80 | + cat("No individuals have posterior prob of infection > 0.5") |
| 81 | + } else { |
| 82 | + # extract values here |
| 83 | + df_mcmc_time <- fit_states_dt %>% filter(id %in% ids) %>% filter(inf_ind == 1) %>% |
| 84 | + select(id, chain_no, sample, inf_time, !!bio_all) %>% rename(chain = chain_no) |
| 85 | + |
| 86 | + df_mcmc_time_wide <- df_mcmc_time %>% |
| 87 | + select(id, sample, chain, inf_time) %>% unique %>% |
| 88 | + pivot_wider(!chain, names_from = "id", values_from = "inf_time") |
| 89 | + |
| 90 | + cols <- ncol(df_mcmc_time_wide) |
| 91 | + |
| 92 | + df_summary_disc <- |
| 93 | + map_df(2:cols, |
| 94 | + ~df_mcmc_time_wide %>% select(sample, .x) %>% drop_na %>% summarise_draws() %>% .[2, ] |
| 95 | + ) |
| 96 | + |
| 97 | + p1 <- df_mcmc_time %>% |
| 98 | + ggplot() + |
| 99 | + stat_pointinterval(aes(x = inf_time, y = as.character(id), color = as.character(chain)), |
| 100 | + position = position_dodge(0.4)) + theme_bw() + |
| 101 | + labs(x = "Time in study", y = "ID", color = "Chain number") + |
| 102 | + ggtitle("A. Trace plots for timing of infection for individuals \nwith posterior P(Z) > 0.5") |
| 103 | + |
| 104 | + p2 <- df_summary_disc %>% ggplot() + geom_col(aes(x = rhat, y = as.character(variable))) + theme_bw() + |
| 105 | + geom_vline(xintercept = 1.1, color = "red", linetype = "dashed") + |
| 106 | + labs(x = "Rhat", y = "ID") + |
| 107 | + scale_x_continuous(labels = seq(0, 2, 0.2), breaks = seq(0, 2, 0.2)) + |
| 108 | + ggtitle("B. Convergence diagnostics for timing of infection \nindividuals with posterior P(Z) > 0.5") |
| 109 | + |
| 110 | + p1 + p2 + plot_annotation(title = paste0("TIMING CONVERGENCE DIAGNOSITICS FOR ", title_i)) & |
| 111 | + theme(title = element_text(size = 12)) |
| 112 | + |
| 113 | + } |
| 114 | + |
| 115 | +} |
| 116 | + |
| 117 | + |
| 118 | + |
| 119 | +# CASE STUDY 2: TRANSVIR, NO PCR |
| 120 | +model_summary <- readRDS(here::here("outputs", "fits", "transvir_data", "wave2_no_pcr", "model_summary.RDS")) |
| 121 | +p1 <- generate_convergence_plot(model_summary, "EMPIRICAL DATA WITH NO PCR") |
| 122 | +ggsave(here::here("outputs", "figs", "supp", "conv", "wave2_no_pcr_full.png"), height = 20) |
| 123 | +p2 <- plot_Rhat_time_alt(model_summary, "EMPIRICAL DATA WITH NO PCR") |
| 124 | +ggsave(here::here("outputs", "figs", "supp", "conv", "wave2_no_pcr_time.png"), height = 20) |
0 commit comments