Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions bench/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Router vs Direct vLLM Benchmark Commands

## 🚀 Quick One-Liner Commands

### Basic Comparison (ARC dataset, 3 samples per category)
```bash
# Router + Direct vLLM comparison
cd bench && source ../.venv/bin/activate && \
python3 router_reason_bench_multi_dataset.py --dataset arc --samples-per-category 3 --run-router --router-models auto --output-dir results/router_test && \
python3 router_reason_bench_multi_dataset.py --dataset arc --samples-per-category 3 --run-vllm --vllm-endpoint http://127.0.0.1:8000/v1 --vllm-models openai/gpt-oss-20b --vllm-exec-modes NR XC --output-dir results/vllm_test
```

### Comprehensive Script (Recommended)
```bash
cd bench && ./benchmark_comparison.sh arc 5
```

## 📋 Command Breakdown

### Router Evaluation (via Envoy)
- **Endpoint**: `http://127.0.0.1:8801/v1` (Envoy proxy)
- **Model**: `auto` (router decides which model to use)
- **API Key**: `1234` (default)
- **Purpose**: Tests the semantic router's routing decisions

```bash
python3 router_reason_bench_multi_dataset.py \
--dataset arc \
--samples-per-category 5 \
--run-router \
--router-endpoint http://127.0.0.1:8801/v1 \
--router-api-key 1234 \
--router-models auto
```

### Direct vLLM Evaluation
- **Endpoint**: `http://127.0.0.1:8000/v1` (direct vLLM)
- **Model**: `openai/gpt-oss-20b` (specific model)
- **API Key**: `1234` (default)
- **Modes**: 3 realistic scenarios (NR, XC, NR_REASONING)
- **Purpose**: Tests the raw model performance with scientific controls

```bash
python3 router_reason_bench_multi_dataset.py \
--dataset arc \
--samples-per-category 5 \
--run-vllm \
--vllm-endpoint http://127.0.0.1:8000/v1 \
--vllm-api-key 1234 \
--vllm-models openai/gpt-oss-20b
```

## 🎯 Available Datasets

- `arc` - AI2 Reasoning Challenge (both Easy + Challenge)
- `arc-easy` - ARC Easy questions only
- `arc-challenge` - ARC Challenge questions only
- `mmlu` / `mmlu-pro` - MMLU-Pro dataset (14 categories)
- `gpqa` / `gpqa-main` - GPQA Main dataset (graduate-level)
- `gpqa-extended` - GPQA Extended dataset
- `gpqa-diamond` - GPQA Diamond dataset (highest quality)
- `truthfulqa` - TruthfulQA dataset (6 categories, tests truthfulness)
- `commonsenseqa` - CommonsenseQA dataset (9 categories, tests reasoning)
- `hellaswag` - HellaSwag dataset (192 categories, tests commonsense)

## 📊 Example Usage

```bash
# Quick test with ARC
./benchmark_comparison.sh arc 3

# Comprehensive test with MMLU
./benchmark_comparison.sh mmlu 10

# Challenge questions only
./benchmark_comparison.sh arc-challenge 5
```

## 📈 Output Analysis

The script will create timestamped results in `results/comparison_YYYYMMDD_HHMMSS/`:
- Router results: `*router*auto*/`
- vLLM results: `*vllm*gpt-oss*/`
- **Comparison plots**: `plots/` directory with visual comparisons
- Each contains `summary.json` and `detailed_results.csv`

### 📊 Generated Visualizations
- `plots/bench_plot_accuracy.png` - Accuracy comparison by category
- `plots/bench_plot_avg_response_time.png` - Response time comparison
- `plots/bench_plot_avg_total_tokens.png` - Token usage comparison
- PDF versions of all plots are also generated

Compare:
- **Accuracy**: Overall correctness
- **Latency**: Response time per question
- **Tokens**: Token usage efficiency
- **Mode Performance**: NR vs XC reasoning approaches
104 changes: 61 additions & 43 deletions bench/bench_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,18 @@
import pandas as pd
from matplotlib import colormaps

# This script plots benchmark results from the 3-case vLLM design:
# - VLLM_NR: Plain prompt, no reasoning toggle (baseline)
# - VLLM_XC: CoT prompt, no reasoning toggle (prompt reasoning)
# - VLLM_NR_REASONING: Plain prompt, reasoning toggle ON (model reasoning)
# - router: Router auto mode for comparison

parser = argparse.ArgumentParser()
parser.add_argument(
"--summary",
type=Path,
required=True,
help="Path to summary.json produced by the bench",
help="Path to vLLM summary.json produced by the 3-case benchmark",
)
parser.add_argument(
"--router-summary",
Expand Down Expand Up @@ -56,7 +62,7 @@
"--max-modes",
type=int,
default=None,
help="If set, plot only the top N modes by mean of the current metric",
help="If set, plot only the top N modes by mean of the current metric (default: all 3 modes)",
)
parser.add_argument(
"--xtick-rotation",
Expand Down Expand Up @@ -175,7 +181,41 @@ def plot_metric(metric: str, out_path: Path):

x = range(len(cats))

# Determine modes to plot, optionally limiting to top-N by mean of metric
# Plot router per-category metric FIRST (with both line and diamonds)
# This ensures router trend is visible even if vLLM dots overlap
if s_router is not None:
router_cat = s_router.get("category_metrics", {})
router_vals = []
router_x = []
for idx, c in enumerate(cats):
v = router_cat.get(c, {}).get(metric)
if v is not None:
router_x.append(idx)
router_vals.append(v)
if router_vals:
# Connect router points with a line and draw larger diamond markers
ax.plot(
router_x,
router_vals,
color="tab:red",
linestyle="-",
linewidth=2.0 * args.font_scale,
alpha=0.85,
zorder=1, # Lower zorder so it's plotted first
)
ax.scatter(
router_x,
router_vals,
s=90 * args.font_scale,
color="tab:red",
marker="D",
label="router",
zorder=2, # Lower zorder so it's plotted first
edgecolors="white",
linewidths=0.6 * args.font_scale,
)

# Then plot vLLM modes on top
all_modes = sorted({m for c in cats for m in cat_by_mode.get(c, {}).keys()})
if len(all_modes) > 0:

Expand Down Expand Up @@ -213,7 +253,7 @@ def _mean(values):
linestyle=linestyles[i % len(linestyles)],
linewidth=1.4 * args.font_scale,
alpha=0.6,
zorder=2,
zorder=3, # Higher zorder so vLLM lines are on top
)
if args.style in ("points", "both"):
ax.scatter(
Expand All @@ -225,49 +265,27 @@ def _mean(values):
alpha=0.85,
edgecolors="white",
linewidths=0.5 * args.font_scale,
zorder=3,
zorder=4, # Higher zorder so vLLM points are on top
)

# Overlay router per-category metric as diamonds, if provided
if s_router is not None:
router_cat = s_router.get("category_metrics", {})
router_vals = []
router_x = []
for idx, c in enumerate(cats):
v = router_cat.get(c, {}).get(metric)
if v is not None:
router_x.append(idx)
router_vals.append(v)
if router_vals:
# Connect router points with a line and draw larger diamond markers
ax.plot(
router_x,
router_vals,
color="tab:red",
linestyle="-",
linewidth=2.0 * args.font_scale,
alpha=0.85,
zorder=4,
)
ax.scatter(
router_x,
router_vals,
s=90 * args.font_scale,
color="tab:red",
marker="D",
label="router",
zorder=5,
edgecolors="white",
linewidths=0.6 * args.font_scale,
)
# Set x-axis labels with threshold for readability
MAX_CATEGORY_LABELS = 20 # Hide labels if more than this many categories

ax.set_xticks(list(x))
ax.set_xticklabels(
cats,
rotation=args.xtick_rotation,
ha="right",
fontsize=int(14 * args.font_scale),
)
if len(cats) <= MAX_CATEGORY_LABELS:
ax.set_xticklabels(
cats,
rotation=args.xtick_rotation,
ha="right",
fontsize=int(14 * args.font_scale),
)
else:
# Too many categories - hide labels to avoid clutter
ax.set_xticklabels([])
ax.set_xlabel(
f"Categories ({len(cats)} total - labels hidden for readability)",
fontsize=int(16 * args.font_scale),
)
# Control horizontal fit by expanding/shrinking x-limits around the first/last category
if len(cats) > 0:
n = len(cats)
Expand Down
Loading
Loading