Skip to content

Commit 6f76dc3

Browse files
authored
feat: 52 added a module to calculate basic epidemiological indicators (#172)
* Added function to calculate Incidence rate * Added risk ratio function * Added type annotations * minor changes
1 parent b26f27c commit 6f76dc3

File tree

12 files changed

+248
-407
lines changed

12 files changed

+248
-407
lines changed

docs/tutorials/forecast_switzerland/forecast_swiss.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def train_eval_single_canton(
102102
predict_n=14,
103103
look_back=14,
104104
):
105-
106105
"""
107106
Function to train and evaluate the model for one georegion.
108107
@@ -410,7 +409,6 @@ def train_all_cantons(
410409
)
411410

412411
if any(df_c[target_name] > 1):
413-
414412
ngb_m.train(
415413
target_name,
416414
df_c,

docs/tutorials/forecast_switzerland/train_models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
path="saved_models_dash",
1212
)
1313

14-
1514
target_curve_name = "total_hosp"
1615
predictors = ["foph_test_d", "foph_cases_d", "foph_hosp_d", "foph_hospcapacity_d"]
1716
ini_date = "2020-05-01"

epigraphhub/analysis/clustering.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def get_lag(
4040
x = pd.Series(x).rolling(7).mean().dropna().values
4141
y = pd.Series(y).rolling(7).mean().dropna().values
4242
corr = correlate(x, y, mode="full") / np.sqrt(np.dot(x, x) * np.dot(y, y))
43-
slice = np.s_[(len(corr) - maxlags) // 2 : -(len(corr) - maxlags) // 2]
43+
slice = np.s_[(len(corr) - maxlags) // 2: -(len(corr) - maxlags) // 2]
4444
corr = corr[slice]
4545
lags = correlation_lags(x.size, y.size, mode="full")
4646
lags = lags[slice]
@@ -327,7 +327,6 @@ def plot_clusters(
327327
if normalize:
328328

329329
for i in inc_canton.columns:
330-
331330
inc_canton[i] = inc_canton[i] / max(inc_canton[i])
332331

333332
figs = []

epigraphhub/analysis/epistats.py

Lines changed: 45 additions & 186 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from typing import Union
22

3-
import arviz as az
43
import numpy as np
54
import pandas as pd
6-
import plotly.express as px
7-
import pymc3 as pm
85
import scipy.stats as st
6+
from scipy.stats.contingency import relative_risk
7+
from scipy.stats._result_classes import RelativeRiskResult
98

109

11-
def prevalence(pop_size: int, positives: int, a: float = 1, b: float = 1) -> object:
10+
def posterior_prevalence(pop_size: int, positives: int, a: float = 1, b: float = 1) -> st.rv_continuous:
1211
"""
1312
Returns the Bayesian posterior prevalence of a disease for a point in time.
1413
It assumes number of cases follow a binomial distribution with probability described as a beta(a,b) distribution
@@ -23,203 +22,63 @@ def prevalence(pop_size: int, positives: int, a: float = 1, b: float = 1) -> obj
2322
b : float, optional
2423
prior beta parameter beta, by default 1
2524
26-
Returns
27-
-------
28-
object
29-
Returns a scipy stats frozen beta distribution that represents the posterior probability of the prevalence
25+
Args:
26+
pop_size: population size
27+
positives: number of positives
28+
a: prior beta parameter alpha
29+
b: prior beta parameter beta
30+
31+
Returns:
32+
object: Returns a scipy stats frozen beta distribution that represents the posterior probability of the prevalence
3033
"""
3134
a, b = 1, 1 # prior beta parameters
3235
pa = a + positives
3336
pb = b + pop_size - positives
3437
return st.beta(pa, pb)
3538

3639

37-
def inf_pos_prob_cases_hosp(
38-
df: pd.DataFrame,
39-
alpha: float = 0.5,
40-
beta: float = 0.5,
41-
draws: int = 2000,
42-
tune: int = 500,
43-
) -> az.data.inference_data.InferenceData:
44-
"""
45-
This function compute the posterior probability distribution for the prevalence of cases and the probability of hospitalization over time.
46-
47-
Parameters
48-
----------
49-
df : pd.DataFrame
50-
It takes as input a dataframe with four columns:
51-
- cases: Number of cases over time.
52-
- hospiotalizations: Number of hospitalizations over time.
53-
- tests: Number of tests over time.
54-
- tests_pos: Proportion of the positive tests over time.
55-
This data frame must have a datetime index.
56-
alpha:float
57-
The alpha parameter of the Beta distribution
58-
beta:float
59-
The beta parameter of the Beta distribution
60-
draws: int
61-
The number of samples to draw.
62-
tune: int
63-
Number of iterations to tune. Samplers adjust the step sizes,
64-
scalings or similar during tuning. Tuning samples will be drawn in addition to the number specified in the
65-
draws argument, and will be discarded.
66-
67-
Returns
68-
-------
69-
az.data.inference_data.InferenceData
70-
An array with the posterior probabilities infered.
71-
"""
72-
73-
with pm.Model() as var_bin:
74-
prev = pm.Beta("prevalence", alpha, beta, shape=len(df["cases"]))
75-
76-
cases = pm.Binomial("cases", n=df["tests"].values, p=prev, observed=df["cases"])
77-
78-
probs = pm.Beta("phosp", alpha, beta, shape=len(df["cases"]))
79-
80-
hosp = pm.Binomial(
81-
"hospitalizations", n=df["cases"], p=probs, observed=df["hospitalizations"]
82-
)
83-
84-
with var_bin:
85-
tracevb = pm.sample(draws, tune=tune, return_inferencedata=True)
86-
87-
return tracevb
88-
89-
90-
def plot_pos_prob_prev(
91-
df: pd.DataFrame,
92-
tracevb: az.data.inference_data.InferenceData,
93-
ci: bool = False,
94-
save: bool = False,
95-
name: Union[str, None] = None,
96-
plot: bool = True,
97-
):
40+
@np.vectorize
41+
def incidence_rate(pop_size: int, new_cases: int, scaling: float = 1e5) -> Union[float, np.ndarray, np.ndarray]:
9842
"""
99-
This function plots the posterior probability distribution of the prevalence
100-
generated with the `inf_pos_prob_cases_hosp()` function.
101-
43+
incidence is defined as the number of new cases in a population over a period of time, typically 1 year. The incidence rate is also usually scale to 100k people to facilitate comparisons between localities with different populations.
10244
Parameters
10345
----------
104-
df : pd.DataFrame
105-
A dataframe with four columns:
106-
- cases: Number of cases over time.
107-
- hospitalizations: Number of hospitalizations over time.
108-
- tests: Number of tests over time.
109-
- tests_pos: Proportion of the positive tests over time.
110-
This data frame must have a datetime index.
111-
tracevb : az.data.inference_data.InferenceData
112-
the return of the inf_pos_prob_cases_hosp() function.
113-
ci : bool, optional
114-
If True the confidence interval is computed, by default False
115-
save : bool, optional
116-
If True the plot is saved, by default False
46+
pop_size: population pop_size
47+
new_cases: number of new cases observed in the period
48+
scaling: number to scale the rate to. If ommitted, the rate is return as cases per 100k.
11749
11850
Returns
11951
-------
120-
fig
121-
A plotly figure.
52+
A float or a np.ndarray of floats
53+
54+
Examples
55+
--------
56+
>>> incidence_rate(1000, 5)
57+
500.0
58+
>>> incidence_rate([1000,5000,10000], [5,5,5])
59+
array([500, 100, 50])
12260
"""
61+
IR = new_cases / pop_size * scaling
62+
return IR
12363

124-
Prev_post = pd.DataFrame(
125-
index=df["cases"].index,
126-
data={
127-
"median": tracevb.posterior.prevalence.median(axis=(0, 1)),
128-
"lower": np.percentile(tracevb.posterior.prevalence, 2.5, axis=(0, 1)),
129-
"upper": np.percentile(tracevb.posterior.prevalence, 97.5, axis=(0, 1)),
130-
},
131-
)
132-
133-
fig = px.line(Prev_post.rolling(7).mean().dropna())
134-
135-
if ci:
136-
fig.add_scatter(
137-
x=Prev_post.index,
138-
y=Prev_post.lower,
139-
mode="none",
140-
fill="tonexty",
141-
name="95% CI",
142-
)
143-
144-
fig.add_scatter(
145-
x=df["tests_pos"].index,
146-
y=df["tests_pos"].values,
147-
name="Test positivity",
148-
mode="markers",
149-
)
150-
fig.update_layout(
151-
title="Estimated prevalence of COVID",
152-
yaxis_title="Prevalence of infected",
153-
xaxis_title="Time (days)",
154-
)
155-
if save:
156-
if name == None:
157-
fig.write_image("prevalence_est.png", scale=3)
158-
else:
159-
fig.write_image(f"{name}.png", scale=3)
160-
161-
if plot:
162-
fig.show()
16364

164-
return fig
165-
166-
167-
def plot_pos_prob_hosp(
168-
df: pd.DataFrame,
169-
tracevb: az.data.inference_data.InferenceData,
170-
save: bool = False,
171-
name: Union[str, None] = None,
172-
plot: bool = False,
173-
):
65+
def risk_ratio(exposed_cases: int, exposed_total: int, control_cases: int, control_total: int) -> RelativeRiskResult:
17466
"""
175-
This function plots the posterior probability distribution of hospitalization
176-
generated with the `inf_pos_prob_cases_hosp()` function.
177-
178-
Parameters
179-
----------
180-
df : pd.DataFrame
181-
A dataframe with four columns:
182-
- cases: Number of cases over time.
183-
- hospitalizations: Number of hospitalizations over time.
184-
- tests: Number of tests over time.
185-
- tests_pos: Proportion of the positive tests over time.
186-
This data frame must have a datetime index.
187-
tracevb : az.data.inference_data.InferenceData
188-
the return of the inf_pos_prob_cases_hosp() function.
189-
ci : bool, optional
190-
If True the confidence interval is computed, by default False
191-
save : bool, optional
192-
If True the plot is saved, by default False
193-
194-
Returns
195-
-------
196-
fig
197-
A plotly figure.
67+
Also known as relative risk, computed the risk of contracting a disease given exposure to a risk factor.
68+
Parameters:
69+
exposed_cases: number of cases in the exposed group
70+
exposed_total: size of the exposed group
71+
control_cases: number of cases in the control group
72+
control_total: size of the control group
73+
Returns:
74+
RelativeRiskResult object
75+
76+
Examples:
77+
>>> rr = risk_ratio(27, 122, 44, 487)
78+
>>> rr.relative_risk
79+
2.4495156482861398
80+
>>> rr.confidence_interval(confidence_level=0.95)
81+
ConfidenceInterval(low=1.5836990926700116, high=3.7886786315466354)
19882
"""
199-
200-
Phosp_post = pd.DataFrame(
201-
index=df.index,
202-
data={
203-
"median": tracevb.posterior.phosp.median(axis=(0, 1)),
204-
"lower": np.percentile(tracevb.posterior.phosp, 2.5, axis=(0, 1)),
205-
"upper": np.percentile(tracevb.posterior.phosp, 97.5, axis=(0, 1)),
206-
},
207-
)
208-
209-
fig = px.line(Phosp_post.rolling(7).mean().dropna())
210-
211-
fig.update_layout(
212-
title="Estimated Probability of Hospitalization",
213-
yaxis_title="probability",
214-
)
215-
216-
if plot:
217-
fig.show()
218-
219-
if save:
220-
if name == None:
221-
fig.write_image("prob_hosp_est.png", scale=3)
222-
else:
223-
fig.write_image(f"{name}.png", scale=3)
224-
225-
return fig
83+
rr = relative_risk(exposed_cases, exposed_total, control_cases, control_total)
84+
return rr

epigraphhub/analysis/preprocessing.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ def get_next_n_days(ini_date: str, next_days: int) -> list:
145145
a = datetime.strptime(ini_date, "%Y-%m-%d")
146146

147147
for i in np.arange(1, next_days + 1):
148-
149148
d_i = datetime.strftime(a + timedelta(days=int(i)), "%Y-%m-%d")
150149

151150
next_dates.append(datetime.strptime(d_i, "%Y-%m-%d"))
@@ -195,7 +194,7 @@ def lstm_split_data(
195194
data = np.empty((n_ts, look_back + predict_n, df.shape[1]))
196195
for i in range(n_ts): # - predict_):
197196
# print(i, df[i: look_back+i+predict_n,0])
198-
data[i, :, :] = df[i : look_back + i + predict_n, :]
197+
data[i, :, :] = df[i: look_back + i + predict_n, :]
199198
# train_size = int(n_ts * ratio)
200199
train_size = int(df.shape[0] * ratio) - look_back - predict_n + 1
201200
# print(train_size)

epigraphhub/data/epigraphhub_db.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ def get_data_by_location(
126126
# separe the columns by comma to apply in the sql query
127127
s_columns = ""
128128
for i in columns:
129-
130129
s_columns = s_columns + i + ","
131130

132131
s_columns = s_columns[:-1]
@@ -135,7 +134,6 @@ def get_data_by_location(
135134
query = f"select {s_columns} from {schema}.{table_name}"
136135

137136
if len(loc) == 1:
138-
139137
query = f"select {s_columns} from {schema}.{table_name} where {loc_column} = '{loc[0]}' ;"
140138

141139
if len(loc) > 1 and loc != "All":

0 commit comments

Comments
 (0)