|
| 1 | +import argparse |
| 2 | +import json |
| 3 | +from pathlib import Path |
| 4 | + |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +import pandas as pd |
| 7 | +from matplotlib import colormaps |
| 8 | + |
| 9 | +parser = argparse.ArgumentParser() |
| 10 | +parser.add_argument( |
| 11 | + "--summary", |
| 12 | + type=Path, |
| 13 | + required=True, |
| 14 | + help="Path to summary.json produced by the bench", |
| 15 | +) |
| 16 | +parser.add_argument( |
| 17 | + "--router-summary", |
| 18 | + type=Path, |
| 19 | + required=False, |
| 20 | + help="Optional path to router summary.json to overlay", |
| 21 | +) |
| 22 | +parser.add_argument( |
| 23 | + "--metrics", |
| 24 | + type=str, |
| 25 | + nargs="+", |
| 26 | + default=["accuracy", "avg_response_time", "avg_total_tokens"], |
| 27 | + choices=["accuracy", "avg_response_time", "avg_total_tokens"], |
| 28 | + help="One or more metrics to plot (default: all)", |
| 29 | +) |
| 30 | +parser.add_argument( |
| 31 | + "--out-dir", |
| 32 | + type=Path, |
| 33 | + default=Path("."), |
| 34 | + help="Directory to save plots (default: current directory)", |
| 35 | +) |
| 36 | +args = parser.parse_args() |
| 37 | +summary_path = args.summary |
| 38 | + |
| 39 | +with open(summary_path) as f: |
| 40 | + s = json.load(f) |
| 41 | + |
| 42 | +s_router = None |
| 43 | +if args.router_summary: |
| 44 | + with open(args.router_summary) as f: |
| 45 | + s_router = json.load(f) |
| 46 | + |
| 47 | + |
| 48 | +def derive_metrics(summary_json: dict, summary_path: Path): |
| 49 | + cat_by_mode = summary_json.get("category_by_mode") |
| 50 | + cat_ranges = summary_json.get("category_ranges") |
| 51 | + if cat_by_mode is not None and cat_ranges is not None: |
| 52 | + return cat_by_mode, cat_ranges |
| 53 | + |
| 54 | + csv_path = summary_path.parent / "detailed_results.csv" |
| 55 | + if not csv_path.exists(): |
| 56 | + raise SystemExit(f"Missing fields in summary and CSV not found: {csv_path}") |
| 57 | + df = pd.read_csv(csv_path) |
| 58 | + df = df[df.get("success", True) == True] |
| 59 | + if "mode_label" not in df.columns: |
| 60 | + raise SystemExit( |
| 61 | + "detailed_results.csv lacks 'mode_label' column; cannot compute per-mode stats" |
| 62 | + ) |
| 63 | + |
| 64 | + grouped = ( |
| 65 | + df.groupby(["category", "mode_label"]).agg( |
| 66 | + accuracy=("is_correct", "mean"), |
| 67 | + avg_response_time=("response_time", "mean"), |
| 68 | + avg_prompt_tokens=("prompt_tokens", "mean"), |
| 69 | + avg_completion_tokens=("completion_tokens", "mean"), |
| 70 | + avg_total_tokens=("total_tokens", "mean"), |
| 71 | + ) |
| 72 | + ).reset_index() |
| 73 | + |
| 74 | + cat_by_mode = {} |
| 75 | + cat_ranges = {} |
| 76 | + for cat in grouped["category"].unique(): |
| 77 | + sub = grouped[grouped["category"] == cat] |
| 78 | + modes = {} |
| 79 | + for _, row in sub.iterrows(): |
| 80 | + modes[str(row["mode_label"])] = { |
| 81 | + "accuracy": ( |
| 82 | + float(row["accuracy"]) if pd.notna(row["accuracy"]) else 0.0 |
| 83 | + ), |
| 84 | + "avg_response_time": ( |
| 85 | + float(row["avg_response_time"]) |
| 86 | + if pd.notna(row["avg_response_time"]) |
| 87 | + else 0.0 |
| 88 | + ), |
| 89 | + "avg_prompt_tokens": ( |
| 90 | + float(row["avg_prompt_tokens"]) |
| 91 | + if pd.notna(row["avg_prompt_tokens"]) |
| 92 | + else None |
| 93 | + ), |
| 94 | + "avg_completion_tokens": ( |
| 95 | + float(row["avg_completion_tokens"]) |
| 96 | + if pd.notna(row["avg_completion_tokens"]) |
| 97 | + else None |
| 98 | + ), |
| 99 | + "avg_total_tokens": ( |
| 100 | + float(row["avg_total_tokens"]) |
| 101 | + if pd.notna(row["avg_total_tokens"]) |
| 102 | + else None |
| 103 | + ), |
| 104 | + } |
| 105 | + cat_by_mode[cat] = modes |
| 106 | + |
| 107 | + # ranges |
| 108 | + def _mm(values): |
| 109 | + values = [v for v in values if v is not None] |
| 110 | + if not values: |
| 111 | + return {"min": 0.0, "max": 0.0} |
| 112 | + return {"min": float(min(values)), "max": float(max(values))} |
| 113 | + |
| 114 | + acc_vals = [v.get("accuracy") for v in modes.values()] |
| 115 | + lat_vals = [v.get("avg_response_time") for v in modes.values()] |
| 116 | + tok_vals = [v.get("avg_total_tokens") for v in modes.values()] |
| 117 | + cat_ranges[cat] = { |
| 118 | + "accuracy": _mm(acc_vals), |
| 119 | + "avg_response_time": _mm(lat_vals), |
| 120 | + "avg_total_tokens": _mm(tok_vals), |
| 121 | + } |
| 122 | + return cat_by_mode, cat_ranges |
| 123 | + |
| 124 | + |
| 125 | +cat_by_mode, cat_ranges = derive_metrics(s, summary_path) |
| 126 | + |
| 127 | +cats = sorted(cat_ranges.keys()) |
| 128 | + |
| 129 | + |
| 130 | +def plot_metric(metric: str, out_path: Path): |
| 131 | + fig, ax = plt.subplots(figsize=(14, 6)) |
| 132 | + |
| 133 | + x = range(len(cats)) |
| 134 | + |
| 135 | + # Overlay each mode as points |
| 136 | + all_modes = sorted({m for c in cats for m in cat_by_mode.get(c, {}).keys()}) |
| 137 | + if len(all_modes) > 0: |
| 138 | + palette = colormaps.get_cmap("tab10").resampled(len(all_modes)) |
| 139 | + for i, mode in enumerate(all_modes): |
| 140 | + ys = [] |
| 141 | + for c in cats: |
| 142 | + ys.append(cat_by_mode.get(c, {}).get(mode, {}).get(metric)) |
| 143 | + ax.scatter(x, ys, s=20, color=palette.colors[i], label=mode, alpha=0.8) |
| 144 | + |
| 145 | + # Overlay router per-category metric as diamonds, if provided |
| 146 | + if s_router is not None: |
| 147 | + router_cat = s_router.get("category_metrics", {}) |
| 148 | + router_vals = [] |
| 149 | + router_x = [] |
| 150 | + for idx, c in enumerate(cats): |
| 151 | + v = router_cat.get(c, {}).get(metric) |
| 152 | + if v is not None: |
| 153 | + router_x.append(idx) |
| 154 | + router_vals.append(v) |
| 155 | + if router_vals: |
| 156 | + ax.scatter( |
| 157 | + router_x, |
| 158 | + router_vals, |
| 159 | + s=50, |
| 160 | + color="tab:red", |
| 161 | + marker="D", |
| 162 | + label="router", |
| 163 | + zorder=4, |
| 164 | + ) |
| 165 | + |
| 166 | + ax.set_xticks(list(x)) |
| 167 | + ax.set_xticklabels(cats, rotation=60, ha="right") |
| 168 | + ylabel = metric.replace("_", " ") |
| 169 | + ax.set_ylabel(ylabel) |
| 170 | + ax.set_title(f"Per-category {ylabel} per-mode values") |
| 171 | + ax.legend(ncol=3, fontsize=8) |
| 172 | + plt.tight_layout() |
| 173 | + plt.savefig(out_path, dpi=200, bbox_inches="tight") |
| 174 | + plt.close(fig) |
| 175 | + |
| 176 | + |
| 177 | +args.out_dir.mkdir(parents=True, exist_ok=True) |
| 178 | +for metric in args.metrics: |
| 179 | + out_path = args.out_dir / f"bench_plot_{metric}.png" |
| 180 | + plot_metric(metric, out_path) |
0 commit comments