1+
2+ generate_convergence_plot <- function (model_summary , title_i ) {
3+ df_smi_df <- calcScaleModelIndicator(model_summary )
4+
5+
6+ fit <- model_summary $ fit
7+ p1 <- (fit $ post $ mcmc %> % bayesplot :: mcmc_trace()) + theme_minimal() + theme(legend.position = " top" ) + ggtitle(" A. Trace plots for fitted parameters" )
8+ p2 <- fit $ post $ lpost %> % ggplot() + geom_line(aes(x = sample_no , y = lpost , color = chain_no )) + theme_minimal() + theme(legend.position = " top" ) + ggtitle(" B. Trace plots for log posterior" ) +
9+ labs(x = " Sample number" , y = " Log-posterior" )
10+
11+ p3 <- df_conver_stat <- summarise_draws(fit $ post $ mcmc ) %> % select(variable , rhat , ess_bulk , ess_tail ) %> %
12+ pivot_longer(! variable , names_to = " stat" , values_to = " value" ) %> %
13+ ggplot() +
14+ geom_col(aes(y = variable , x = value )) +
15+ facet_wrap(~ stat , scales = " free" ) + theme_minimal() + ggtitle(" C. Convergence diagnosis for fitted parameters" ) +
16+ labs(x = " Value" , y = " Parameter" )
17+ pA <- p1 / p2 / p3
18+
19+
20+
21+ pdims_trace <- df_smi_df %> %
22+ ggplot() +
23+ geom_line(aes(x = .iteration , y = dims , color = as.character(.chain ))) +
24+ labs(x = " Iteration" , y = " Model dimension" , color = " Chain" ) +
25+ ggtitle(" D. Trace plots for transdimensional convergence: dimensions of model" )
26+
27+
28+ pdims_hist <- df_smi_df %> %
29+ ggplot() +
30+ geom_histogram(aes(x = dims , fill = as.character(.chain ))) +
31+ labs(x = " Model dimension" , y = " Count" , fill = " Chain" )
32+
33+ p1 <- pdims_trace + pdims_hist + plot_layout(guides = " collect" ) & theme_minimal() & theme(legend.position = " top" )
34+
35+
36+
37+ psmi_trace <- df_smi_df %> %
38+ ggplot() +
39+ geom_line(aes(x = .iteration , y = sMI , color = as.character(.chain ))) +
40+ labs(x = " Iteration" , y = " Log2 of SMI" , color = " Chain" ) +
41+ ggtitle(" E. Trace plots for transdimensional convergence: SMI of model" )
42+
43+
44+ psmi_hist <- df_smi_df %> %
45+ ggplot() +
46+ geom_histogram(aes(x = sMI , fill = as.character(.chain ))) +
47+ labs(x = " Log2 of SMI" , y = " Count" , fill = " Chain" )
48+
49+ p2 <- psmi_trace + psmi_hist + plot_layout(guides = " collect" ) & theme_minimal() & theme(legend.position = " top" )
50+
51+ p3 <- summarise_draws(df_smi_df ) %> % select(variable , rhat , ess_bulk , ess_tail ) %> % pivot_longer(! variable , names_to = " stat" , values_to = " value" ) %> %
52+ filter(variable == " sMI" ) %> %
53+ ggplot()+
54+ geom_col(aes(y = " " , x = value )) +
55+ facet_grid(cols = vars(stat ), scales = " free" ) + theme_minimal() + ggtitle(" F. Convergence diagnosis for transdimensional convergence" ) +
56+ labs(x = " Value" , y = " Model" )
57+
58+ pB <- p1 / p2 / p3
59+
60+
61+ (pA | pB ) + plot_annotation(title = paste0(" CONVERGENCE DIAGNOSITICS FOR " , title_i )) &
62+ theme(title = element_text(size = 12 ))
63+ }
64+
65+
66+ plot_Rhat_time_alt <- function (model_summary , title_i ) {
67+
68+ outputfull <- model_summary $ post
69+
70+ model_outline <- model_summary $ fit $ model
71+ bio_all <- model_outline $ infoModel $ biomarkers
72+
73+ fit_states_dt <- as.data.table(outputfull $ fit_states )
74+ S <- fit_states_dt %> % filter(id == 1 ) %> % nrow
75+
76+ ids <- fit_states_dt %> % group_by(id ) %> % summarise(prob = sum(inf_ind ) / S ) %> % filter(prob > 0.5 ) %> % pull(id ) %> % unique
77+
78+ if (length(ids ) == 0 ) {
79+ cat(" No individuals have posterior prob of infection > 0.5" )
80+ } else {
81+ # extract values here
82+ df_mcmc_time <- fit_states_dt %> % filter(id %in% ids ) %> % filter(inf_ind == 1 ) %> %
83+ select(id , chain_no , sample , inf_time , !! bio_all ) %> % rename(chain = chain_no )
84+
85+ df_mcmc_time_wide <- df_mcmc_time %> %
86+ select(id , sample , chain , inf_time ) %> % unique %> %
87+ pivot_wider(! chain , names_from = " id" , values_from = " inf_time" )
88+
89+ cols <- ncol(df_mcmc_time_wide )
90+
91+ df_summary_disc <-
92+ map_df(2 : cols ,
93+ ~ df_mcmc_time_wide %> % select(sample , .x ) %> % drop_na %> % summarise_draws() %> % . [2 , ]
94+ )
95+
96+ p1 <- df_mcmc_time %> %
97+ ggplot() +
98+ stat_pointinterval(aes(x = inf_time , y = as.character(id ), color = as.character(chain )),
99+ position = position_dodge(0.4 )) + theme_bw() +
100+ labs(x = " Time in study" , y = " ID" , color = " Chain number" ) +
101+ ggtitle(" A. Trace plots for timing of infection for individuals \n with posterior P(Z) > 0.5" )
102+
103+ p2 <- df_summary_disc %> % ggplot() + geom_col(aes(x = rhat , y = as.character(variable ))) + theme_bw() +
104+ geom_vline(xintercept = 1.1 , color = " red" , linetype = " dashed" ) +
105+ labs(x = " Rhat" , y = " ID" ) +
106+ ggtitle(" B. Convergence diagnostics for timing of infection \n individuals with posterior P(Z) > 0.5" )
107+
108+ p1 + p2 + plot_annotation(title = paste0(" TIMING CONVERGENCE DIAGNOSITICS FOR " , title_i )) &
109+ theme(title = element_text(size = 12 ))
110+
111+ }
112+
113+ }
114+
115+ # CASE STUDY 1: COP
116+ model_summary <- readRDS(here :: here(" outputs" , " fits" , " simulated_data_hpc" , paste0(" cop" , " _" , " 0.1" ), " model_summary.RDS" ))
117+ p1 <- generate_convergence_plot(model_summary , " SIMULATED DATA WITH COP AND 0.1 UNCERTAINTTY IN OBSERVATIONAL ERROR" )
118+ ggsave(here :: here(" outputs" , " figs" , " supp" , " conv" , " cop_0.1_full.png" ))
119+ p2 <- plot_Rhat_time_alt(model_summary , " SIMULATED DATA WITH COP AND 0.1 UNCERTAINTTY IN OBSERVATIONAL ERROR" )
120+ ggsave(here :: here(" outputs" , " figs" , " supp" , " conv" , " cop_0.1_time.png" ))
121+
122+ # CASE STUDY 1: NO COP
123+ model_summary <- readRDS(here :: here(" outputs" , " fits" , " simulated_data_hpc" , paste0(" no_cop" , " _" , " 0.1" ), " model_summary.RDS" ))
124+ p1 <- generate_convergence_plot(model_summary , " SIMULATED DATA NO COP AND 0.1 UNCERTAINTTY IN OBSERVATIONAL ERROR" )
125+ ggsave(here :: here(" outputs" , " figs" , " supp" , " conv" , " no_cop_0.1_full.png" ))
126+ p2 <- plot_Rhat_time_alt(model_summary , " SIMULATED DATA NO COP AND 0.1 UNCERTAINTTY IN OBSERVATIONAL ERROR" )
127+ ggsave(here :: here(" outputs" , " figs" , " supp" , " conv" , " no_cop_0.1_time.png" ))
128+
129+ # CASE STUDY 2: TRANSVIR, NO PCR
130+ model_summary <- readRDS(here :: here(" outputs" , " fits" , " transvir_data" , " wave2_no_pcr" , " model_summary.RDS" ))
131+ p1 <- generate_convergence_plot(model_summary , " EMPIRICAL DATA WITH NO PCR" )
132+ ggsave(here :: here(" outputs" , " figs" , " supp" , " conv" , " wave2_no_pcr_full.png" ), height = 20 )
133+ p2 <- plot_Rhat_time_alt(model_summary , " EMPIRICAL DATA WITH NO PCR" )
134+ ggsave(here :: here(" outputs" , " figs" , " supp" , " conv" , " wave2_no_pcr_time.png" ))
135+
136+ # CASE STUDY 2: TRANSVIR, PCR
137+ model_summary <- readRDS(here :: here(" outputs" , " fits" , " transvir_data" , " wave2_base" , " model_summary.RDS" ))
138+ p1 <- generate_convergence_plot(model_summary , " EMPIRICAL DATA WITH PCR" )
139+ ggsave(here :: here(" outputs" , " figs" , " supp" , " conv" , " wave2_pcr_full.png" ), height = 20 )
140+ p2 <- plot_Rhat_time_alt(model_summary , " EMPIRICAL DATA WITH PCR" )
141+ ggsave(here :: here(" outputs" , " figs" , " supp" , " conv" , " wave2_pcr_time.png" ))
0 commit comments