Skip to content

Commit f3ee07f

Browse files
committed
feat: implement dataset-agnostic benchmark with multi-category evaluation support. Add ARC, GPQA, TruthfulQA, CommonsenseQA, and HellaSwag datasets with optimized token limits and robust answer extraction.
1 parent 7128765 commit f3ee07f

14 files changed

+3205
-75
lines changed

bench/README.md

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Router vs Direct vLLM Benchmark Commands
2+
3+
## 🚀 Quick One-Liner Commands
4+
5+
### Basic Comparison (ARC dataset, 3 samples per category)
6+
```bash
7+
# Router + Direct vLLM comparison
8+
cd bench && source ../.venv/bin/activate && \
9+
python3 router_reason_bench_multi_dataset.py --dataset arc --samples-per-category 3 --run-router --router-models auto --output-dir results/router_test && \
10+
python3 router_reason_bench_multi_dataset.py --dataset arc --samples-per-category 3 --run-vllm --vllm-endpoint http://127.0.0.1:8000/v1 --vllm-models openai/gpt-oss-20b --vllm-exec-modes NR XC --output-dir results/vllm_test
11+
```
12+
13+
### Comprehensive Script (Recommended)
14+
```bash
15+
cd bench && ./benchmark_comparison.sh arc 5
16+
```
17+
18+
## 📋 Command Breakdown
19+
20+
### Router Evaluation (via Envoy)
21+
- **Endpoint**: `http://127.0.0.1:8801/v1` (Envoy proxy)
22+
- **Model**: `auto` (router decides which model to use)
23+
- **API Key**: `1234` (default)
24+
- **Purpose**: Tests the semantic router's routing decisions
25+
26+
```bash
27+
python3 router_reason_bench_multi_dataset.py \
28+
--dataset arc \
29+
--samples-per-category 5 \
30+
--run-router \
31+
--router-endpoint http://127.0.0.1:8801/v1 \
32+
--router-api-key 1234 \
33+
--router-models auto
34+
```
35+
36+
### Direct vLLM Evaluation
37+
- **Endpoint**: `http://127.0.0.1:8000/v1` (direct vLLM)
38+
- **Model**: `openai/gpt-oss-20b` (specific model)
39+
- **API Key**: `1234` (default)
40+
- **Modes**: 3 realistic scenarios (NR, XC, NR_REASONING)
41+
- **Purpose**: Tests the raw model performance with scientific controls
42+
43+
```bash
44+
python3 router_reason_bench_multi_dataset.py \
45+
--dataset arc \
46+
--samples-per-category 5 \
47+
--run-vllm \
48+
--vllm-endpoint http://127.0.0.1:8000/v1 \
49+
--vllm-api-key 1234 \
50+
--vllm-models openai/gpt-oss-20b
51+
```
52+
53+
## 🎯 Available Datasets
54+
55+
- `arc` - AI2 Reasoning Challenge (both Easy + Challenge)
56+
- `arc-easy` - ARC Easy questions only
57+
- `arc-challenge` - ARC Challenge questions only
58+
- `mmlu` / `mmlu-pro` - MMLU-Pro dataset (14 categories)
59+
- `gpqa` / `gpqa-main` - GPQA Main dataset (graduate-level)
60+
- `gpqa-extended` - GPQA Extended dataset
61+
- `gpqa-diamond` - GPQA Diamond dataset (highest quality)
62+
- `truthfulqa` - TruthfulQA dataset (6 categories, tests truthfulness)
63+
- `commonsenseqa` - CommonsenseQA dataset (9 categories, tests reasoning)
64+
- `hellaswag` - HellaSwag dataset (192 categories, tests commonsense)
65+
66+
## 📊 Example Usage
67+
68+
```bash
69+
# Quick test with ARC
70+
./benchmark_comparison.sh arc 3
71+
72+
# Comprehensive test with MMLU
73+
./benchmark_comparison.sh mmlu 10
74+
75+
# Challenge questions only
76+
./benchmark_comparison.sh arc-challenge 5
77+
```
78+
79+
## 📈 Output Analysis
80+
81+
The script will create timestamped results in `results/comparison_YYYYMMDD_HHMMSS/`:
82+
- Router results: `*router*auto*/`
83+
- vLLM results: `*vllm*gpt-oss*/`
84+
- **Comparison plots**: `plots/` directory with visual comparisons
85+
- Each contains `summary.json` and `detailed_results.csv`
86+
87+
### 📊 Generated Visualizations
88+
- `plots/bench_plot_accuracy.png` - Accuracy comparison by category
89+
- `plots/bench_plot_avg_response_time.png` - Response time comparison
90+
- `plots/bench_plot_avg_total_tokens.png` - Token usage comparison
91+
- PDF versions of all plots are also generated
92+
93+
Compare:
94+
- **Accuracy**: Overall correctness
95+
- **Latency**: Response time per question
96+
- **Tokens**: Token usage efficiency
97+
- **Mode Performance**: NR vs XC reasoning approaches

bench/bench_plot.py

Lines changed: 61 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,18 @@
66
import pandas as pd
77
from matplotlib import colormaps
88

9+
# This script plots benchmark results from the 3-case vLLM design:
10+
# - VLLM_NR: Plain prompt, no reasoning toggle (baseline)
11+
# - VLLM_XC: CoT prompt, no reasoning toggle (prompt reasoning)
12+
# - VLLM_NR_REASONING: Plain prompt, reasoning toggle ON (model reasoning)
13+
# - router: Router auto mode for comparison
14+
915
parser = argparse.ArgumentParser()
1016
parser.add_argument(
1117
"--summary",
1218
type=Path,
1319
required=True,
14-
help="Path to summary.json produced by the bench",
20+
help="Path to vLLM summary.json produced by the 3-case benchmark",
1521
)
1622
parser.add_argument(
1723
"--router-summary",
@@ -56,7 +62,7 @@
5662
"--max-modes",
5763
type=int,
5864
default=None,
59-
help="If set, plot only the top N modes by mean of the current metric",
65+
help="If set, plot only the top N modes by mean of the current metric (default: all 3 modes)",
6066
)
6167
parser.add_argument(
6268
"--xtick-rotation",
@@ -175,7 +181,41 @@ def plot_metric(metric: str, out_path: Path):
175181

176182
x = range(len(cats))
177183

178-
# Determine modes to plot, optionally limiting to top-N by mean of metric
184+
# Plot router per-category metric FIRST (with both line and diamonds)
185+
# This ensures router trend is visible even if vLLM dots overlap
186+
if s_router is not None:
187+
router_cat = s_router.get("category_metrics", {})
188+
router_vals = []
189+
router_x = []
190+
for idx, c in enumerate(cats):
191+
v = router_cat.get(c, {}).get(metric)
192+
if v is not None:
193+
router_x.append(idx)
194+
router_vals.append(v)
195+
if router_vals:
196+
# Connect router points with a line and draw larger diamond markers
197+
ax.plot(
198+
router_x,
199+
router_vals,
200+
color="tab:red",
201+
linestyle="-",
202+
linewidth=2.0 * args.font_scale,
203+
alpha=0.85,
204+
zorder=1, # Lower zorder so it's plotted first
205+
)
206+
ax.scatter(
207+
router_x,
208+
router_vals,
209+
s=90 * args.font_scale,
210+
color="tab:red",
211+
marker="D",
212+
label="router",
213+
zorder=2, # Lower zorder so it's plotted first
214+
edgecolors="white",
215+
linewidths=0.6 * args.font_scale,
216+
)
217+
218+
# Then plot vLLM modes on top
179219
all_modes = sorted({m for c in cats for m in cat_by_mode.get(c, {}).keys()})
180220
if len(all_modes) > 0:
181221

@@ -213,7 +253,7 @@ def _mean(values):
213253
linestyle=linestyles[i % len(linestyles)],
214254
linewidth=1.4 * args.font_scale,
215255
alpha=0.6,
216-
zorder=2,
256+
zorder=3, # Higher zorder so vLLM lines are on top
217257
)
218258
if args.style in ("points", "both"):
219259
ax.scatter(
@@ -225,49 +265,27 @@ def _mean(values):
225265
alpha=0.85,
226266
edgecolors="white",
227267
linewidths=0.5 * args.font_scale,
228-
zorder=3,
268+
zorder=4, # Higher zorder so vLLM points are on top
229269
)
230270

231-
# Overlay router per-category metric as diamonds, if provided
232-
if s_router is not None:
233-
router_cat = s_router.get("category_metrics", {})
234-
router_vals = []
235-
router_x = []
236-
for idx, c in enumerate(cats):
237-
v = router_cat.get(c, {}).get(metric)
238-
if v is not None:
239-
router_x.append(idx)
240-
router_vals.append(v)
241-
if router_vals:
242-
# Connect router points with a line and draw larger diamond markers
243-
ax.plot(
244-
router_x,
245-
router_vals,
246-
color="tab:red",
247-
linestyle="-",
248-
linewidth=2.0 * args.font_scale,
249-
alpha=0.85,
250-
zorder=4,
251-
)
252-
ax.scatter(
253-
router_x,
254-
router_vals,
255-
s=90 * args.font_scale,
256-
color="tab:red",
257-
marker="D",
258-
label="router",
259-
zorder=5,
260-
edgecolors="white",
261-
linewidths=0.6 * args.font_scale,
262-
)
271+
# Set x-axis labels with threshold for readability
272+
MAX_CATEGORY_LABELS = 20 # Hide labels if more than this many categories
263273

264274
ax.set_xticks(list(x))
265-
ax.set_xticklabels(
266-
cats,
267-
rotation=args.xtick_rotation,
268-
ha="right",
269-
fontsize=int(14 * args.font_scale),
270-
)
275+
if len(cats) <= MAX_CATEGORY_LABELS:
276+
ax.set_xticklabels(
277+
cats,
278+
rotation=args.xtick_rotation,
279+
ha="right",
280+
fontsize=int(14 * args.font_scale),
281+
)
282+
else:
283+
# Too many categories - hide labels to avoid clutter
284+
ax.set_xticklabels([])
285+
ax.set_xlabel(
286+
f"Categories ({len(cats)} total - labels hidden for readability)",
287+
fontsize=int(16 * args.font_scale),
288+
)
271289
# Control horizontal fit by expanding/shrinking x-limits around the first/last category
272290
if len(cats) > 0:
273291
n = len(cats)

0 commit comments

Comments
 (0)