Skip to content

Commit 7544f3f

Browse files
authored
Merge pull request #48 from entropyx/feature/all-services
Feature/all services
2 parents 660c6cc + 072bdd8 commit 7544f3f

12 files changed

Lines changed: 2430 additions & 253 deletions

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@ logs/
2222
specs/
2323
ai_docs/
2424
.claude/
25-
25+
*.json

Dockerfile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ RUN apt-get update && apt-get install -y \
1111

1212
# Copy and install Python dependencies
1313
COPY requirements.txt .
14-
RUN pip install -r requirements.txt
14+
RUN pip install --upgrade pip
15+
RUN pip install --no-cache-dir -r requirements.txt
1516

1617
# Copy the rest of the application
1718
COPY . .

Murray/main.py

Lines changed: 362 additions & 200 deletions
Large diffs are not rendered by default.

Murray/plots.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,13 @@ def plot_mde_results(results_by_size, sensitivity_results, periods):
251251
"""
252252
Generates an interactive heatmap showing penalized MDE values that account for
253253
counterfactual quality and time period.
254+
Args:
255+
results_by_size: Dictionary containing simulation results
256+
sensitivity_results: Dictionary containing sensitivity results
257+
periods: List of periods to evaluate
258+
259+
Returns:
260+
fig: Interactive heatmap figure
254261
"""
255262
holdout_by_location = {
256263
size: data["Holdout Percentage"] for size, data in results_by_size.items()
@@ -262,39 +269,54 @@ def plot_mde_results(results_by_size, sensitivity_results, periods):
262269

263270
def calculate_penalty_score(mde, period_idx, total_periods, size, results_by_size):
264271
"""
265-
Calculates a penalty score based on MDE, counterfactual quality, and time period.
272+
Calculates a score based on MDE, counterfactual quality (MAPE, SMAPE), p-value, statistical power, and time period.
273+
Longer periods are considered better as they provide more statistical confidence.
266274
Returns both the score and its components for hover information.
267275
"""
268276
if pd.isna(mde):
269-
return None, None, None, None
277+
return None, None, None, None, None, None, None
270278

271279
# Quality metrics
272280
mape = results_by_size[size].get("MAPE", 0)
273281
smape = results_by_size[size].get("SMAPE", 0)
274282

283+
# Statistical metrics
284+
p_value = results_by_size[size].get("p_value", 1.0)
285+
power = results_by_size[size].get("power", 0.0)
286+
275287
# Normalize metrics
276288
mape_factor = min(mape / 100, 1)
277289
smape_factor = min(smape / 100, 1)
278290
quality_score = (mape_factor + smape_factor) / 2
279291

280-
# Time factor
281-
time_score = (period_idx + 1) / total_periods
292+
# Normalize p-value (lower is better)
293+
p_value_score = 1 - min(p_value, 1)
294+
295+
# Normalize power (higher is better)
296+
power_score = min(power, 1)
282297

283298
# MDE factor
284299
mde_factor = min(mde, 1)
285300

301+
# Time factor - longer periods are better
302+
time_score = (period_idx + 1) / total_periods
303+
286304
# Calculate final score
287-
quality_weight = 0.85
288-
time_weight = 0.05
289-
mde_weight = 0.15
305+
quality_weight = 0.20
306+
p_value_weight = 0.15
307+
power_weight = 0.55
308+
mde_weight = 0.09
309+
time_weight = 0.01
290310

291311
final_score = (
292312
quality_weight * quality_score
293-
+ time_weight * (1 - time_score)
294-
+ mde_weight * mde_factor
313+
+ p_value_weight * p_value_score
314+
+ power_weight * power_score
315+
+ mde_weight * (1 - mde_factor)
316+
+ time_weight * (1 - time_score)
295317
)
296318

297-
return final_score, mde, mape, smape
319+
return final_score, mde, mape, smape, p_value, power, time_score
298320

299321
heatmap_data = pd.DataFrame()
300322
hover_data = []
@@ -306,17 +328,18 @@ def calculate_penalty_score(mde, period_idx, total_periods, size, results_by_siz
306328

307329
for period_idx, period in enumerate(periods):
308330
mde = period_results.get(period, {}).get("MDE", None)
309-
score, original_mde, mape, smape = calculate_penalty_score(
331+
score, original_mde, mape, smape, p_value, power, time_score = calculate_penalty_score(
310332
mde, period_idx, len(periods), size, results_by_size
311333
)
312334
row.append(score)
313335
hover_row.append(
314336
{
315-
"Original MDE": (
316-
f"{original_mde:.2%}" if original_mde is not None else "N/A"
317-
),
337+
"MDE": f"{original_mde:.2%}" if original_mde is not None else "N/A",
318338
"MAPE": f"{mape:.2f}%" if mape is not None else "N/A",
319339
"SMAPE": f"{smape:.2f}%" if smape is not None else "N/A",
340+
"P-Value": f"{p_value:.4f}" if p_value is not None else "N/A",
341+
"Statistical Power": f"{power:.2%}" if power is not None else "N/A",
342+
"Period Score": f"{time_score*100:.0f}%" if time_score is not None else "N/A"
320343
}
321344
)
322345
heatmap_data[size] = row
@@ -374,10 +397,13 @@ def calculate_penalty_score(mde, period_idx, total_periods, size, results_by_siz
374397
textfont={"size": 12, "color": "black"},
375398
hovertemplate=(
376399
"Treatment size: %{customdata}<br>"
377-
+ "Penalty Score: %{text}<br>"
378-
+ "Original MDE: %{customdata:Original MDE}<br>"
400+
+ "Combined Score: %{text}<br>"
401+
+ "MDE: %{customdata:MDE}<br>"
379402
+ "MAPE: %{customdata:MAPE}<br>"
380403
+ "SMAPE: %{customdata:SMAPE}<br>"
404+
+ "P-Value: %{customdata:P-Value}<br>"
405+
+ "Statistical Power: %{customdata:Statistical Power}<br>"
406+
+ "Period Score: %{customdata:Period Score}<br>"
381407
+ "<extra></extra>"
382408
),
383409
showscale=True,

Murray/post_analysis.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from sklearn.preprocessing import MinMaxScaler
33
from Murray.main import select_controls, SyntheticControl
44
from Murray.auxiliary import market_correlations, handle_duplicates
5+
from Murray.plots import calculate_confidence_bands, calculate_optimal_noise_scale
56
import pandas as pd
67
from logger_config import get_logger
78

@@ -215,3 +216,144 @@ def stat_func(x):
215216

216217
logger.info("run_geo_evaluation completed successfully")
217218
return results_evaluation
219+
220+
221+
def get_evaluation_chart_data(
222+
data_input,
223+
start_treatment,
224+
end_treatment,
225+
treatment_group,
226+
significance_level=0.05,
227+
):
228+
"""
229+
Extract only the data needed for plotting charts from evaluation results.
230+
231+
Args:
232+
data_input: Input dataframe
233+
start_treatment: Treatment start date
234+
end_treatment: Treatment end date
235+
treatment_group: List of treatment locations
236+
significance_level: Significance level for confidence bands
237+
238+
Returns:
239+
dict: Dictionary containing all data needed for chart plotting
240+
"""
241+
logger.info("Starting get_evaluation_chart_data")
242+
243+
# First run the evaluation to get base results
244+
results = run_geo_evaluation(
245+
data_input, start_treatment, end_treatment, treatment_group, spend=0
246+
)
247+
248+
# Extract base values
249+
treatment = results["treatment"]
250+
counterfactual = results["counterfactual"]
251+
period = results["period"]
252+
length_treatment = results["length_treatment"]
253+
254+
# Get date information
255+
random_state = data_input["location"].unique()[0]
256+
filtered_data = data_input[data_input["location"] == random_state].copy()
257+
filtered_data["time"] = pd.to_datetime(filtered_data["time"])
258+
dates = filtered_data["time"].dt.date.astype(str).tolist()
259+
260+
# Calculate treatment start position
261+
start_treatment = pd.to_datetime(start_treatment, dayfirst=True)
262+
start_idx = (filtered_data["time"].dt.date == start_treatment.date()).idxmax()
263+
start_position_treatment = filtered_data.index.get_loc(start_idx)
264+
265+
# Calculate derived series
266+
point_difference = treatment - counterfactual
267+
cumulative_effect = ([0] * (len(treatment) - period)) + (
268+
np.cumsum(point_difference[len(treatment) - period:])
269+
).tolist()
270+
271+
# Extract treatment period data
272+
y_treatment = treatment[start_position_treatment:]
273+
point_difference_treatment = point_difference[start_position_treatment:]
274+
cumulative_effect_treatment = cumulative_effect[start_position_treatment:]
275+
276+
# Calculate confidence bands
277+
ci = 1 - significance_level
278+
noise_scale = calculate_optimal_noise_scale(y_treatment, counterfactual)
279+
280+
lower_bound, upper_bound = calculate_confidence_bands(
281+
y_treatment, noise_scale=noise_scale, ci=ci
282+
)
283+
lower_bound_pd, upper_bound_pd = calculate_confidence_bands(
284+
point_difference_treatment, ci=ci
285+
)
286+
lower_bound_ce, upper_bound_ce = calculate_confidence_bands(
287+
cumulative_effect_treatment, ci=ci
288+
)
289+
290+
# Calculate aggregate values
291+
lower_bound_value = np.sum(lower_bound)
292+
upper_bound_value = np.sum(upper_bound)
293+
prediction_value = np.sum(treatment[start_position_treatment:])
294+
295+
# Calculate ATT and incremental
296+
att = np.mean(treatment[start_position_treatment:] - counterfactual[start_position_treatment:])
297+
att = att / length_treatment
298+
incremental = np.sum(treatment[start_position_treatment:] - counterfactual[start_position_treatment:])
299+
300+
# Calculate pre/post treatment data
301+
pre_treatment = treatment[start_position_treatment - period : start_position_treatment]
302+
pre_counterfactual = counterfactual[start_position_treatment - period : start_position_treatment]
303+
post_treatment = treatment[start_position_treatment:]
304+
post_counterfactual = counterfactual[start_position_treatment:]
305+
306+
chart_data = {
307+
# Base series
308+
"dates": dates,
309+
"treatment": treatment.tolist(),
310+
"counterfactual": counterfactual.tolist(),
311+
"point_difference": point_difference.tolist(),
312+
"cumulative_effect": cumulative_effect,
313+
314+
# Treatment period data
315+
"treatment_dates": dates[start_position_treatment:],
316+
"y_treatment": y_treatment.tolist(),
317+
"point_difference_treatment": point_difference_treatment.tolist(),
318+
"cumulative_effect_treatment": cumulative_effect_treatment,
319+
320+
# Confidence bands
321+
"lower_bound": lower_bound.tolist(),
322+
"upper_bound": upper_bound.tolist(),
323+
"lower_bound_pd": lower_bound_pd.tolist(),
324+
"upper_bound_pd": upper_bound_pd.tolist(),
325+
"lower_bound_ce": lower_bound_ce.tolist(),
326+
"upper_bound_ce": upper_bound_ce.tolist(),
327+
328+
# Aggregate values
329+
"lower_bound_value": float(lower_bound_value),
330+
"upper_bound_value": float(upper_bound_value),
331+
"prediction_value": float(prediction_value),
332+
"att": float(att),
333+
"incremental": float(incremental),
334+
335+
# Pre/post treatment periods
336+
"pre_treatment": pre_treatment.tolist(),
337+
"pre_counterfactual": pre_counterfactual.tolist(),
338+
"post_treatment": post_treatment.tolist(),
339+
"post_counterfactual": post_counterfactual.tolist(),
340+
341+
# Metadata
342+
"start_position_treatment": start_position_treatment,
343+
"period": period,
344+
"length_treatment": length_treatment,
345+
346+
# Include key metrics from original evaluation
347+
"p_value": results["p_value"],
348+
"power": results["power"],
349+
"percenge_lift": results["percenge_lift"],
350+
"MAPE": results["MAPE"],
351+
"SMAPE": results["SMAPE"],
352+
"observed_stat": results["observed_stat"],
353+
"null_stats": results["null_stats"].tolist(),
354+
"control_group": results["control_group"],
355+
"weights": results["weights"],
356+
}
357+
358+
logger.info("get_evaluation_chart_data completed successfully")
359+
return chart_data

0 commit comments

Comments
 (0)