|
2 | 2 | from sklearn.preprocessing import MinMaxScaler |
3 | 3 | from Murray.main import select_controls, SyntheticControl |
4 | 4 | from Murray.auxiliary import market_correlations, handle_duplicates |
| 5 | +from Murray.plots import calculate_confidence_bands, calculate_optimal_noise_scale |
5 | 6 | import pandas as pd |
6 | 7 | from logger_config import get_logger |
7 | 8 |
|
@@ -215,3 +216,144 @@ def stat_func(x): |
215 | 216 |
|
216 | 217 | logger.info("run_geo_evaluation completed successfully") |
217 | 218 | return results_evaluation |
| 219 | + |
| 220 | + |
| 221 | +def get_evaluation_chart_data( |
| 222 | + data_input, |
| 223 | + start_treatment, |
| 224 | + end_treatment, |
| 225 | + treatment_group, |
| 226 | + significance_level=0.05, |
| 227 | +): |
| 228 | + """ |
| 229 | + Extract only the data needed for plotting charts from evaluation results. |
| 230 | +
|
| 231 | + Args: |
| 232 | + data_input: Input dataframe |
| 233 | + start_treatment: Treatment start date |
| 234 | + end_treatment: Treatment end date |
| 235 | + treatment_group: List of treatment locations |
| 236 | + significance_level: Significance level for confidence bands |
| 237 | +
|
| 238 | + Returns: |
| 239 | + dict: Dictionary containing all data needed for chart plotting |
| 240 | + """ |
| 241 | + logger.info("Starting get_evaluation_chart_data") |
| 242 | + |
| 243 | + # First run the evaluation to get base results |
| 244 | + results = run_geo_evaluation( |
| 245 | + data_input, start_treatment, end_treatment, treatment_group, spend=0 |
| 246 | + ) |
| 247 | + |
| 248 | + # Extract base values |
| 249 | + treatment = results["treatment"] |
| 250 | + counterfactual = results["counterfactual"] |
| 251 | + period = results["period"] |
| 252 | + length_treatment = results["length_treatment"] |
| 253 | + |
| 254 | + # Get date information |
| 255 | + random_state = data_input["location"].unique()[0] |
| 256 | + filtered_data = data_input[data_input["location"] == random_state].copy() |
| 257 | + filtered_data["time"] = pd.to_datetime(filtered_data["time"]) |
| 258 | + dates = filtered_data["time"].dt.date.astype(str).tolist() |
| 259 | + |
| 260 | + # Calculate treatment start position |
| 261 | + start_treatment = pd.to_datetime(start_treatment, dayfirst=True) |
| 262 | + start_idx = (filtered_data["time"].dt.date == start_treatment.date()).idxmax() |
| 263 | + start_position_treatment = filtered_data.index.get_loc(start_idx) |
| 264 | + |
| 265 | + # Calculate derived series |
| 266 | + point_difference = treatment - counterfactual |
| 267 | + cumulative_effect = ([0] * (len(treatment) - period)) + ( |
| 268 | + np.cumsum(point_difference[len(treatment) - period:]) |
| 269 | + ).tolist() |
| 270 | + |
| 271 | + # Extract treatment period data |
| 272 | + y_treatment = treatment[start_position_treatment:] |
| 273 | + point_difference_treatment = point_difference[start_position_treatment:] |
| 274 | + cumulative_effect_treatment = cumulative_effect[start_position_treatment:] |
| 275 | + |
| 276 | + # Calculate confidence bands |
| 277 | + ci = 1 - significance_level |
| 278 | + noise_scale = calculate_optimal_noise_scale(y_treatment, counterfactual) |
| 279 | + |
| 280 | + lower_bound, upper_bound = calculate_confidence_bands( |
| 281 | + y_treatment, noise_scale=noise_scale, ci=ci |
| 282 | + ) |
| 283 | + lower_bound_pd, upper_bound_pd = calculate_confidence_bands( |
| 284 | + point_difference_treatment, ci=ci |
| 285 | + ) |
| 286 | + lower_bound_ce, upper_bound_ce = calculate_confidence_bands( |
| 287 | + cumulative_effect_treatment, ci=ci |
| 288 | + ) |
| 289 | + |
| 290 | + # Calculate aggregate values |
| 291 | + lower_bound_value = np.sum(lower_bound) |
| 292 | + upper_bound_value = np.sum(upper_bound) |
| 293 | + prediction_value = np.sum(treatment[start_position_treatment:]) |
| 294 | + |
| 295 | + # Calculate ATT and incremental |
| 296 | + att = np.mean(treatment[start_position_treatment:] - counterfactual[start_position_treatment:]) |
| 297 | + att = att / length_treatment |
| 298 | + incremental = np.sum(treatment[start_position_treatment:] - counterfactual[start_position_treatment:]) |
| 299 | + |
| 300 | + # Calculate pre/post treatment data |
| 301 | + pre_treatment = treatment[start_position_treatment - period : start_position_treatment] |
| 302 | + pre_counterfactual = counterfactual[start_position_treatment - period : start_position_treatment] |
| 303 | + post_treatment = treatment[start_position_treatment:] |
| 304 | + post_counterfactual = counterfactual[start_position_treatment:] |
| 305 | + |
| 306 | + chart_data = { |
| 307 | + # Base series |
| 308 | + "dates": dates, |
| 309 | + "treatment": treatment.tolist(), |
| 310 | + "counterfactual": counterfactual.tolist(), |
| 311 | + "point_difference": point_difference.tolist(), |
| 312 | + "cumulative_effect": cumulative_effect, |
| 313 | + |
| 314 | + # Treatment period data |
| 315 | + "treatment_dates": dates[start_position_treatment:], |
| 316 | + "y_treatment": y_treatment.tolist(), |
| 317 | + "point_difference_treatment": point_difference_treatment.tolist(), |
| 318 | + "cumulative_effect_treatment": cumulative_effect_treatment, |
| 319 | + |
| 320 | + # Confidence bands |
| 321 | + "lower_bound": lower_bound.tolist(), |
| 322 | + "upper_bound": upper_bound.tolist(), |
| 323 | + "lower_bound_pd": lower_bound_pd.tolist(), |
| 324 | + "upper_bound_pd": upper_bound_pd.tolist(), |
| 325 | + "lower_bound_ce": lower_bound_ce.tolist(), |
| 326 | + "upper_bound_ce": upper_bound_ce.tolist(), |
| 327 | + |
| 328 | + # Aggregate values |
| 329 | + "lower_bound_value": float(lower_bound_value), |
| 330 | + "upper_bound_value": float(upper_bound_value), |
| 331 | + "prediction_value": float(prediction_value), |
| 332 | + "att": float(att), |
| 333 | + "incremental": float(incremental), |
| 334 | + |
| 335 | + # Pre/post treatment periods |
| 336 | + "pre_treatment": pre_treatment.tolist(), |
| 337 | + "pre_counterfactual": pre_counterfactual.tolist(), |
| 338 | + "post_treatment": post_treatment.tolist(), |
| 339 | + "post_counterfactual": post_counterfactual.tolist(), |
| 340 | + |
| 341 | + # Metadata |
| 342 | + "start_position_treatment": start_position_treatment, |
| 343 | + "period": period, |
| 344 | + "length_treatment": length_treatment, |
| 345 | + |
| 346 | + # Include key metrics from original evaluation |
| 347 | + "p_value": results["p_value"], |
| 348 | + "power": results["power"], |
| 349 | + "percenge_lift": results["percenge_lift"], |
| 350 | + "MAPE": results["MAPE"], |
| 351 | + "SMAPE": results["SMAPE"], |
| 352 | + "observed_stat": results["observed_stat"], |
| 353 | + "null_stats": results["null_stats"].tolist(), |
| 354 | + "control_group": results["control_group"], |
| 355 | + "weights": results["weights"], |
| 356 | + } |
| 357 | + |
| 358 | + logger.info("get_evaluation_chart_data completed successfully") |
| 359 | + return chart_data |
0 commit comments