Skip to content

Commit ca70313

Browse files
committed
update: medrxiv version
1 parent f671254 commit ca70313

24 files changed

+202
-39
lines changed

R/manu/app_conv.R

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ plot_Rhat_time_alt <- function(model_summary, title_i) {
103103
p2 <- df_summary_disc %>% ggplot() + geom_col(aes(x = rhat, y = as.character(variable))) + theme_bw() +
104104
geom_vline(xintercept = 1.1, color = "red", linetype = "dashed") +
105105
labs(x = "Rhat", y = "ID") +
106+
scale_x_continuous(labels = seq(0, 2, 0.2), breaks = seq(0, 2, 0.2)) +
106107
ggtitle("B. Convergence diagnostics for timing of infection \nindividuals with posterior P(Z) > 0.5")
107108

108109
p1 + p2 + plot_annotation(title = paste0("TIMING CONVERGENCE DIAGNOSITICS FOR ", title_i)) &
@@ -115,16 +116,16 @@ plot_Rhat_time_alt <- function(model_summary, title_i) {
115116
# CASE STUDY 1: COP
116117
model_summary <- readRDS(here::here("outputs", "fits", "simulated_data_hpc", paste0("cop", "_", "0.1"), "model_summary.RDS"))
117118
p1 <- generate_convergence_plot(model_summary, "SIMULATED DATA WITH COP AND 0.1 UNCERTAINTTY IN OBSERVATIONAL ERROR")
118-
ggsave(here::here("outputs", "figs", "supp", "conv", "cop_0.1_full.png"))
119+
ggsave(here::here("outputs", "figs", "supp", "conv", "cop_0.1_full.png"), height = 12, width = 15)
119120
p2 <- plot_Rhat_time_alt(model_summary, "SIMULATED DATA WITH COP AND 0.1 UNCERTAINTTY IN OBSERVATIONAL ERROR")
120-
ggsave(here::here("outputs", "figs", "supp", "conv", "cop_0.1_time.png"))
121+
ggsave(here::here("outputs", "figs", "supp", "conv", "cop_0.1_time.png"), height = 12, width = 15)
121122

122123
# CASE STUDY 1: NO COP
123124
model_summary <- readRDS(here::here("outputs", "fits", "simulated_data_hpc", paste0("no_cop", "_", "0.1"), "model_summary.RDS"))
124125
p1 <- generate_convergence_plot(model_summary, "SIMULATED DATA NO COP AND 0.1 UNCERTAINTTY IN OBSERVATIONAL ERROR")
125-
ggsave(here::here("outputs", "figs", "supp", "conv", "no_cop_0.1_full.png"))
126+
ggsave(here::here("outputs", "figs", "supp", "conv", "no_cop_0.1_full.png"), height = 12, width = 15)
126127
p2 <- plot_Rhat_time_alt(model_summary, "SIMULATED DATA NO COP AND 0.1 UNCERTAINTTY IN OBSERVATIONAL ERROR")
127-
ggsave(here::here("outputs", "figs", "supp", "conv", "no_cop_0.1_time.png"))
128+
ggsave(here::here("outputs", "figs", "supp", "conv", "no_cop_0.1_time.png"), height = 12, width = 15)
128129

129130
# CASE STUDY 2: TRANSVIR, NO PCR
130131
model_summary <- readRDS(here::here("outputs", "fits", "transvir_data", "wave2_no_pcr", "model_summary.RDS"))

R/manu/app_conv_alt.R

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
2+
generate_convergence_plot <- function(model_summary, title_i) {
3+
df_smi_df <- calcScaleModelIndicator(model_summary)
4+
5+
df_smi_df <- df_smi_df %>% filter(.chain != 4)
6+
7+
fit <- model_summary$fit
8+
p1 <- (fit$post$mcmc[1:3] %>% bayesplot::mcmc_trace()) + theme_minimal() + theme(legend.position = "top") + ggtitle("A. Trace plots for fitted parameters")
9+
p2 <- fit$post$lpost %>% filter(chain_no !=4) %>% ggplot() + geom_line(aes(x = sample_no, y = lpost, color = chain_no)) + theme_minimal() + theme(legend.position = "top") + ggtitle("B. Trace plots for log posterior") +
10+
labs(x = "Sample number", y = "Log-posterior")
11+
12+
p3 <- df_conver_stat <- summarise_draws(fit$post$mcmc[1:3] ) %>% select(variable, rhat, ess_bulk, ess_tail) %>%
13+
pivot_longer(!variable, names_to = "stat", values_to = "value") %>%
14+
ggplot() +
15+
geom_col(aes(y = variable, x = value)) +
16+
facet_wrap(~stat, scales = "free") + theme_minimal() + ggtitle("C. Convergence diagnosis for fitted parameters") +
17+
labs(x = "Value", y = "Parameter")
18+
pA <- p1 / p2 / p3
19+
20+
21+
22+
pdims_trace <- df_smi_df %>%
23+
ggplot() +
24+
geom_line(aes(x = .iteration, y = dims, color = as.character(.chain))) +
25+
labs(x = "Iteration", y = "Model dimension", color = "Chain") +
26+
ggtitle("D. Trace plots for transdimensional convergence: dimensions of model")
27+
28+
29+
pdims_hist <- df_smi_df %>%
30+
ggplot() +
31+
geom_histogram(aes(x = dims, fill = as.character(.chain))) +
32+
labs(x = "Model dimension", y = "Count", fill = "Chain")
33+
34+
p1 <- pdims_trace + pdims_hist + plot_layout(guides = "collect") & theme_minimal() & theme(legend.position = "top")
35+
36+
37+
38+
psmi_trace <- df_smi_df %>%
39+
ggplot() +
40+
geom_line(aes(x = .iteration, y = sMI, color = as.character(.chain))) +
41+
labs(x = "Iteration", y = "Log2 of SMI", color = "Chain") +
42+
ggtitle("E. Trace plots for transdimensional convergence: SMI of model")
43+
44+
45+
psmi_hist <- df_smi_df %>%
46+
ggplot() +
47+
geom_histogram(aes(x = sMI, fill = as.character(.chain))) +
48+
labs(x = "Log2 of SMI", y = "Count", fill = "Chain")
49+
50+
p2 <- psmi_trace + psmi_hist + plot_layout(guides = "collect") & theme_minimal() & theme(legend.position = "top")
51+
52+
p3 <- summarise_draws(df_smi_df) %>% select(variable, rhat, ess_bulk, ess_tail) %>% pivot_longer(!variable, names_to = "stat", values_to = "value") %>%
53+
filter(variable == "sMI") %>%
54+
ggplot()+
55+
geom_col(aes(y = "", x = value)) +
56+
facet_grid(cols = vars(stat), scales = "free") + theme_minimal() + ggtitle("F. Convergence diagnosis for transdimensional convergence") +
57+
labs(x = "Value", y = "Model")
58+
59+
pB <- p1 / p2 / p3
60+
61+
62+
(pA | pB) + plot_annotation(title = paste0("CONVERGENCE DIAGNOSITICS FOR ", title_i)) &
63+
theme(title = element_text(size = 12))
64+
}
65+
66+
67+
plot_Rhat_time_alt <- function(model_summary, title_i) {
68+
69+
outputfull <- model_summary$post
70+
71+
model_outline <- model_summary$fit$model
72+
bio_all <- model_outline$infoModel$biomarkers
73+
74+
fit_states_dt <- as.data.table(outputfull$fit_states) %>% filter(chain_no != 4)
75+
S <- fit_states_dt %>% filter(id == 1) %>% nrow
76+
77+
ids <- fit_states_dt %>% group_by(id) %>% summarise(prob = sum(inf_ind) / S) %>% filter(prob > 0.5) %>% pull(id) %>% unique
78+
79+
if (length(ids) == 0) {
80+
cat("No individuals have posterior prob of infection > 0.5")
81+
} else {
82+
# extract values here
83+
df_mcmc_time <- fit_states_dt %>% filter(id %in% ids) %>% filter(inf_ind == 1) %>%
84+
select(id, chain_no, sample, inf_time, !!bio_all) %>% rename(chain = chain_no)
85+
86+
df_mcmc_time_wide <- df_mcmc_time %>%
87+
select(id, sample, chain, inf_time) %>% unique %>%
88+
pivot_wider(!chain, names_from = "id", values_from = "inf_time")
89+
90+
cols <- ncol(df_mcmc_time_wide)
91+
92+
df_summary_disc <-
93+
map_df(2:cols,
94+
~df_mcmc_time_wide %>% select(sample, .x) %>% drop_na %>% summarise_draws() %>% .[2, ]
95+
)
96+
97+
p1 <- df_mcmc_time %>%
98+
ggplot() +
99+
stat_pointinterval(aes(x = inf_time, y = as.character(id), color = as.character(chain)),
100+
position = position_dodge(0.4)) + theme_bw() +
101+
labs(x = "Time in study", y = "ID", color = "Chain number") +
102+
ggtitle("A. Trace plots for timing of infection for individuals \nwith posterior P(Z) > 0.5")
103+
104+
p2 <- df_summary_disc %>% ggplot() + geom_col(aes(x = rhat, y = as.character(variable))) + theme_bw() +
105+
geom_vline(xintercept = 1.1, color = "red", linetype = "dashed") +
106+
labs(x = "Rhat", y = "ID") +
107+
scale_x_continuous(labels = seq(0, 2, 0.2), breaks = seq(0, 2, 0.2)) +
108+
ggtitle("B. Convergence diagnostics for timing of infection \nindividuals with posterior P(Z) > 0.5")
109+
110+
p1 + p2 + plot_annotation(title = paste0("TIMING CONVERGENCE DIAGNOSITICS FOR ", title_i)) &
111+
theme(title = element_text(size = 12))
112+
113+
}
114+
115+
}
116+
117+
118+
119+
# CASE STUDY 2: TRANSVIR, NO PCR
120+
model_summary <- readRDS(here::here("outputs", "fits", "transvir_data", "wave2_no_pcr", "model_summary.RDS"))
121+
p1 <- generate_convergence_plot(model_summary, "EMPIRICAL DATA WITH NO PCR")
122+
ggsave(here::here("outputs", "figs", "supp", "conv", "wave2_no_pcr_full.png"), height = 20)
123+
p2 <- plot_Rhat_time_alt(model_summary, "EMPIRICAL DATA WITH NO PCR")
124+
ggsave(here::here("outputs", "figs", "supp", "conv", "wave2_no_pcr_time.png"), height = 20)

R/manu/fig1_2.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ p4A <- df_post %>% group_by(name, type, uncert) %>% summarise(mean = mean(value)
582582
ggplot() +
583583
geom_line(aes(x = name, y = mean, color = uncert, group = uncert), size = 1.8, alpha = 0.6) + theme_bw() +
584584
geom_line(data = data_plot, aes(x = name, y = obs_cop), color = "red", linetype = "dashed", size = 2, alpha = 0.7) +
585-
labs(y = "Protection probability (COP)", x = "Titre value at exposure (log)",
585+
labs(y = "Protection probability", x = "Titre value at exposure (log)",
586586
color = "Uncertainty") +
587587
scale_x_continuous(breaks = seq(0, 5, 0.5)) + ylim(0, 1) + ggtitle("Simulated data with COP")
588588

@@ -592,7 +592,7 @@ p4B <- df_post %>% group_by(name, type, uncert) %>% summarise(mean = mean(value)
592592
ggplot() +
593593
geom_line(aes(x = name, y = mean, color = uncert, group = uncert), size = 1.8, alpha = 0.6) + theme_bw() +
594594
geom_line(data = data_plot, aes(x = name, y = obs_no_cop), color = "red", linetype = "dashed", size = 2, alpha = 0.7) +
595-
labs(y = "Protection probability (COP)", x = "Titre value at exposure (log)",
595+
labs(y = "Protection probability", x = "Titre value at exposure (log)",
596596
color = "Uncertainty") +
597597
scale_x_continuous(breaks = seq(0, 5, 0.5)) + ylim(0, 1) + ggtitle("Simulated data without COP")
598598

0 commit comments

Comments
 (0)