Skip to content

Commit 0f2ef2f

Browse files
committed
feat: add results ploting script
1 parent 234f699 commit 0f2ef2f

File tree

1 file changed

+180
-0
lines changed

1 file changed

+180
-0
lines changed

bench/bench_plot.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import argparse
2+
import json
3+
from pathlib import Path
4+
5+
import matplotlib.pyplot as plt
6+
import pandas as pd
7+
from matplotlib import colormaps
8+
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument(
11+
"--summary",
12+
type=Path,
13+
required=True,
14+
help="Path to summary.json produced by the bench",
15+
)
16+
parser.add_argument(
17+
"--router-summary",
18+
type=Path,
19+
required=False,
20+
help="Optional path to router summary.json to overlay",
21+
)
22+
parser.add_argument(
23+
"--metrics",
24+
type=str,
25+
nargs="+",
26+
default=["accuracy", "avg_response_time", "avg_total_tokens"],
27+
choices=["accuracy", "avg_response_time", "avg_total_tokens"],
28+
help="One or more metrics to plot (default: all)",
29+
)
30+
parser.add_argument(
31+
"--out-dir",
32+
type=Path,
33+
default=Path("."),
34+
help="Directory to save plots (default: current directory)",
35+
)
36+
args = parser.parse_args()
37+
summary_path = args.summary
38+
39+
with open(summary_path) as f:
40+
s = json.load(f)
41+
42+
s_router = None
43+
if args.router_summary:
44+
with open(args.router_summary) as f:
45+
s_router = json.load(f)
46+
47+
48+
def derive_metrics(summary_json: dict, summary_path: Path):
49+
cat_by_mode = summary_json.get("category_by_mode")
50+
cat_ranges = summary_json.get("category_ranges")
51+
if cat_by_mode is not None and cat_ranges is not None:
52+
return cat_by_mode, cat_ranges
53+
54+
csv_path = summary_path.parent / "detailed_results.csv"
55+
if not csv_path.exists():
56+
raise SystemExit(f"Missing fields in summary and CSV not found: {csv_path}")
57+
df = pd.read_csv(csv_path)
58+
df = df[df.get("success", True) == True]
59+
if "mode_label" not in df.columns:
60+
raise SystemExit(
61+
"detailed_results.csv lacks 'mode_label' column; cannot compute per-mode stats"
62+
)
63+
64+
grouped = (
65+
df.groupby(["category", "mode_label"]).agg(
66+
accuracy=("is_correct", "mean"),
67+
avg_response_time=("response_time", "mean"),
68+
avg_prompt_tokens=("prompt_tokens", "mean"),
69+
avg_completion_tokens=("completion_tokens", "mean"),
70+
avg_total_tokens=("total_tokens", "mean"),
71+
)
72+
).reset_index()
73+
74+
cat_by_mode = {}
75+
cat_ranges = {}
76+
for cat in grouped["category"].unique():
77+
sub = grouped[grouped["category"] == cat]
78+
modes = {}
79+
for _, row in sub.iterrows():
80+
modes[str(row["mode_label"])] = {
81+
"accuracy": (
82+
float(row["accuracy"]) if pd.notna(row["accuracy"]) else 0.0
83+
),
84+
"avg_response_time": (
85+
float(row["avg_response_time"])
86+
if pd.notna(row["avg_response_time"])
87+
else 0.0
88+
),
89+
"avg_prompt_tokens": (
90+
float(row["avg_prompt_tokens"])
91+
if pd.notna(row["avg_prompt_tokens"])
92+
else None
93+
),
94+
"avg_completion_tokens": (
95+
float(row["avg_completion_tokens"])
96+
if pd.notna(row["avg_completion_tokens"])
97+
else None
98+
),
99+
"avg_total_tokens": (
100+
float(row["avg_total_tokens"])
101+
if pd.notna(row["avg_total_tokens"])
102+
else None
103+
),
104+
}
105+
cat_by_mode[cat] = modes
106+
107+
# ranges
108+
def _mm(values):
109+
values = [v for v in values if v is not None]
110+
if not values:
111+
return {"min": 0.0, "max": 0.0}
112+
return {"min": float(min(values)), "max": float(max(values))}
113+
114+
acc_vals = [v.get("accuracy") for v in modes.values()]
115+
lat_vals = [v.get("avg_response_time") for v in modes.values()]
116+
tok_vals = [v.get("avg_total_tokens") for v in modes.values()]
117+
cat_ranges[cat] = {
118+
"accuracy": _mm(acc_vals),
119+
"avg_response_time": _mm(lat_vals),
120+
"avg_total_tokens": _mm(tok_vals),
121+
}
122+
return cat_by_mode, cat_ranges
123+
124+
125+
cat_by_mode, cat_ranges = derive_metrics(s, summary_path)
126+
127+
cats = sorted(cat_ranges.keys())
128+
129+
130+
def plot_metric(metric: str, out_path: Path):
131+
fig, ax = plt.subplots(figsize=(14, 6))
132+
133+
x = range(len(cats))
134+
135+
# Overlay each mode as points
136+
all_modes = sorted({m for c in cats for m in cat_by_mode.get(c, {}).keys()})
137+
if len(all_modes) > 0:
138+
palette = colormaps.get_cmap("tab10").resampled(len(all_modes))
139+
for i, mode in enumerate(all_modes):
140+
ys = []
141+
for c in cats:
142+
ys.append(cat_by_mode.get(c, {}).get(mode, {}).get(metric))
143+
ax.scatter(x, ys, s=20, color=palette.colors[i], label=mode, alpha=0.8)
144+
145+
# Overlay router per-category metric as diamonds, if provided
146+
if s_router is not None:
147+
router_cat = s_router.get("category_metrics", {})
148+
router_vals = []
149+
router_x = []
150+
for idx, c in enumerate(cats):
151+
v = router_cat.get(c, {}).get(metric)
152+
if v is not None:
153+
router_x.append(idx)
154+
router_vals.append(v)
155+
if router_vals:
156+
ax.scatter(
157+
router_x,
158+
router_vals,
159+
s=50,
160+
color="tab:red",
161+
marker="D",
162+
label="router",
163+
zorder=4,
164+
)
165+
166+
ax.set_xticks(list(x))
167+
ax.set_xticklabels(cats, rotation=60, ha="right")
168+
ylabel = metric.replace("_", " ")
169+
ax.set_ylabel(ylabel)
170+
ax.set_title(f"Per-category {ylabel} per-mode values")
171+
ax.legend(ncol=3, fontsize=8)
172+
plt.tight_layout()
173+
plt.savefig(out_path, dpi=200, bbox_inches="tight")
174+
plt.close(fig)
175+
176+
177+
args.out_dir.mkdir(parents=True, exist_ok=True)
178+
for metric in args.metrics:
179+
out_path = args.out_dir / f"bench_plot_{metric}.png"
180+
plot_metric(metric, out_path)

0 commit comments

Comments
 (0)