diff --git a/bench/README.md b/bench/README.md new file mode 100644 index 00000000..e1e4efc6 --- /dev/null +++ b/bench/README.md @@ -0,0 +1,97 @@ +# Router vs Direct vLLM Benchmark Commands + +## 🚀 Quick One-Liner Commands + +### Basic Comparison (ARC dataset, 3 samples per category) +```bash +# Router + Direct vLLM comparison +cd bench && source ../.venv/bin/activate && \ +python3 router_reason_bench_multi_dataset.py --dataset arc --samples-per-category 3 --run-router --router-models auto --output-dir results/router_test && \ +python3 router_reason_bench_multi_dataset.py --dataset arc --samples-per-category 3 --run-vllm --vllm-endpoint http://127.0.0.1:8000/v1 --vllm-models openai/gpt-oss-20b --vllm-exec-modes NR XC --output-dir results/vllm_test +``` + +### Comprehensive Script (Recommended) +```bash +cd bench && ./benchmark_comparison.sh arc 5 +``` + +## 📋 Command Breakdown + +### Router Evaluation (via Envoy) +- **Endpoint**: `http://127.0.0.1:8801/v1` (Envoy proxy) +- **Model**: `auto` (router decides which model to use) +- **API Key**: `1234` (default) +- **Purpose**: Tests the semantic router's routing decisions + +```bash +python3 router_reason_bench_multi_dataset.py \ + --dataset arc \ + --samples-per-category 5 \ + --run-router \ + --router-endpoint http://127.0.0.1:8801/v1 \ + --router-api-key 1234 \ + --router-models auto +``` + +### Direct vLLM Evaluation +- **Endpoint**: `http://127.0.0.1:8000/v1` (direct vLLM) +- **Model**: `openai/gpt-oss-20b` (specific model) +- **API Key**: `1234` (default) +- **Modes**: 3 realistic scenarios (NR, XC, NR_REASONING) +- **Purpose**: Tests the raw model performance with scientific controls + +```bash +python3 router_reason_bench_multi_dataset.py \ + --dataset arc \ + --samples-per-category 5 \ + --run-vllm \ + --vllm-endpoint http://127.0.0.1:8000/v1 \ + --vllm-api-key 1234 \ + --vllm-models openai/gpt-oss-20b +``` + +## 🎯 Available Datasets + +- `arc` - AI2 Reasoning Challenge (both Easy + Challenge) +- `arc-easy` - ARC Easy questions only +- `arc-challenge` - ARC Challenge questions only +- `mmlu` / `mmlu-pro` - MMLU-Pro dataset (14 categories) +- `gpqa` / `gpqa-main` - GPQA Main dataset (graduate-level) +- `gpqa-extended` - GPQA Extended dataset +- `gpqa-diamond` - GPQA Diamond dataset (highest quality) +- `truthfulqa` - TruthfulQA dataset (6 categories, tests truthfulness) +- `commonsenseqa` - CommonsenseQA dataset (9 categories, tests reasoning) +- `hellaswag` - HellaSwag dataset (192 categories, tests commonsense) + +## 📊 Example Usage + +```bash +# Quick test with ARC +./benchmark_comparison.sh arc 3 + +# Comprehensive test with MMLU +./benchmark_comparison.sh mmlu 10 + +# Challenge questions only +./benchmark_comparison.sh arc-challenge 5 +``` + +## 📈 Output Analysis + +The script will create timestamped results in `results/comparison_YYYYMMDD_HHMMSS/`: +- Router results: `*router*auto*/` +- vLLM results: `*vllm*gpt-oss*/` +- **Comparison plots**: `plots/` directory with visual comparisons +- Each contains `summary.json` and `detailed_results.csv` + +### 📊 Generated Visualizations +- `plots/bench_plot_accuracy.png` - Accuracy comparison by category +- `plots/bench_plot_avg_response_time.png` - Response time comparison +- `plots/bench_plot_avg_total_tokens.png` - Token usage comparison +- PDF versions of all plots are also generated + +Compare: +- **Accuracy**: Overall correctness +- **Latency**: Response time per question +- **Tokens**: Token usage efficiency +- **Mode Performance**: NR vs XC reasoning approaches diff --git a/bench/bench_plot.py b/bench/bench_plot.py index fdab467c..dea9dd6d 100644 --- a/bench/bench_plot.py +++ b/bench/bench_plot.py @@ -6,12 +6,18 @@ import pandas as pd from matplotlib import colormaps +# This script plots benchmark results from the 3-case vLLM design: +# - VLLM_NR: Plain prompt, no reasoning toggle (baseline) +# - VLLM_XC: CoT prompt, no reasoning toggle (prompt reasoning) +# - VLLM_NR_REASONING: Plain prompt, reasoning toggle ON (model reasoning) +# - router: Router auto mode for comparison + parser = argparse.ArgumentParser() parser.add_argument( "--summary", type=Path, required=True, - help="Path to summary.json produced by the bench", + help="Path to vLLM summary.json produced by the 3-case benchmark", ) parser.add_argument( "--router-summary", @@ -56,7 +62,7 @@ "--max-modes", type=int, default=None, - help="If set, plot only the top N modes by mean of the current metric", + help="If set, plot only the top N modes by mean of the current metric (default: all 3 modes)", ) parser.add_argument( "--xtick-rotation", @@ -175,7 +181,41 @@ def plot_metric(metric: str, out_path: Path): x = range(len(cats)) - # Determine modes to plot, optionally limiting to top-N by mean of metric + # Plot router per-category metric FIRST (with both line and diamonds) + # This ensures router trend is visible even if vLLM dots overlap + if s_router is not None: + router_cat = s_router.get("category_metrics", {}) + router_vals = [] + router_x = [] + for idx, c in enumerate(cats): + v = router_cat.get(c, {}).get(metric) + if v is not None: + router_x.append(idx) + router_vals.append(v) + if router_vals: + # Connect router points with a line and draw larger diamond markers + ax.plot( + router_x, + router_vals, + color="tab:red", + linestyle="-", + linewidth=2.0 * args.font_scale, + alpha=0.85, + zorder=1, # Lower zorder so it's plotted first + ) + ax.scatter( + router_x, + router_vals, + s=90 * args.font_scale, + color="tab:red", + marker="D", + label="router", + zorder=2, # Lower zorder so it's plotted first + edgecolors="white", + linewidths=0.6 * args.font_scale, + ) + + # Then plot vLLM modes on top all_modes = sorted({m for c in cats for m in cat_by_mode.get(c, {}).keys()}) if len(all_modes) > 0: @@ -213,7 +253,7 @@ def _mean(values): linestyle=linestyles[i % len(linestyles)], linewidth=1.4 * args.font_scale, alpha=0.6, - zorder=2, + zorder=3, # Higher zorder so vLLM lines are on top ) if args.style in ("points", "both"): ax.scatter( @@ -225,49 +265,27 @@ def _mean(values): alpha=0.85, edgecolors="white", linewidths=0.5 * args.font_scale, - zorder=3, + zorder=4, # Higher zorder so vLLM points are on top ) - # Overlay router per-category metric as diamonds, if provided - if s_router is not None: - router_cat = s_router.get("category_metrics", {}) - router_vals = [] - router_x = [] - for idx, c in enumerate(cats): - v = router_cat.get(c, {}).get(metric) - if v is not None: - router_x.append(idx) - router_vals.append(v) - if router_vals: - # Connect router points with a line and draw larger diamond markers - ax.plot( - router_x, - router_vals, - color="tab:red", - linestyle="-", - linewidth=2.0 * args.font_scale, - alpha=0.85, - zorder=4, - ) - ax.scatter( - router_x, - router_vals, - s=90 * args.font_scale, - color="tab:red", - marker="D", - label="router", - zorder=5, - edgecolors="white", - linewidths=0.6 * args.font_scale, - ) + # Set x-axis labels with threshold for readability + MAX_CATEGORY_LABELS = 20 # Hide labels if more than this many categories ax.set_xticks(list(x)) - ax.set_xticklabels( - cats, - rotation=args.xtick_rotation, - ha="right", - fontsize=int(14 * args.font_scale), - ) + if len(cats) <= MAX_CATEGORY_LABELS: + ax.set_xticklabels( + cats, + rotation=args.xtick_rotation, + ha="right", + fontsize=int(14 * args.font_scale), + ) + else: + # Too many categories - hide labels to avoid clutter + ax.set_xticklabels([]) + ax.set_xlabel( + f"Categories ({len(cats)} total - labels hidden for readability)", + fontsize=int(16 * args.font_scale), + ) # Control horizontal fit by expanding/shrinking x-limits around the first/last category if len(cats) > 0: n = len(cats) diff --git a/bench/benchmark_comparison.sh b/bench/benchmark_comparison.sh new file mode 100755 index 00000000..9f23a526 --- /dev/null +++ b/bench/benchmark_comparison.sh @@ -0,0 +1,201 @@ +#!/bin/bash + +# Multi-Dataset Reasoning Benchmark Comparison +# +# Comprehensive evaluation framework comparing semantic router performance +# against direct vLLM inference across reasoning datasets. +# +# Usage: ./benchmark_comparison.sh [dataset] [samples_per_category] [concurrent_requests] +# Example: ./benchmark_comparison.sh gpqa 5 2 + +set -e + +# Configuration parameters +DATASET=${1:-"arc"} +SAMPLES_PER_CATEGORY=${2:-5} +CONCURRENT_REQUESTS=${3:-2} + +# Semantic router configuration +ROUTER_ENDPOINT="http://127.0.0.1:8801/v1" +ROUTER_API_KEY="1234" +ROUTER_MODEL="auto" + +# Direct vLLM configuration +VLLM_ENDPOINT="http://127.0.0.1:8000/v1" +VLLM_API_KEY="1234" +VLLM_MODEL="openai/gpt-oss-20b" + +# Evaluation parameters +TEMPERATURE=0.0 +OUTPUT_DIR="results/comparison_$(date +%Y%m%d_%H%M%S)" + +echo "🎯 MULTI-DATASET REASONING BENCHMARK" +echo "=====================================" +echo "Dataset: $DATASET" +echo "Samples per category: $SAMPLES_PER_CATEGORY" +echo "Concurrent requests: $CONCURRENT_REQUESTS" +echo "Output directory: $OUTPUT_DIR" +echo "" + +# Ensure we're in the bench directory +cd "$(dirname "$0")" + +# Activate virtual environment if it exists +if [ -f "../.venv/bin/activate" ]; then + echo "📦 Activating virtual environment..." + source ../.venv/bin/activate +fi + +# Create output directory +mkdir -p "$OUTPUT_DIR" + +echo "🔄 PHASE 1: ROUTER EVALUATION (via Envoy)" +echo "------------------------------------------" +echo "Endpoint: $ROUTER_ENDPOINT" +echo "Model: $ROUTER_MODEL (router decides)" +echo "" + +# Run router benchmark +python3 router_reason_bench_multi_dataset.py \ + --dataset "$DATASET" \ + --samples-per-category "$SAMPLES_PER_CATEGORY" \ + --concurrent-requests "$CONCURRENT_REQUESTS" \ + --router-endpoint "$ROUTER_ENDPOINT" \ + --router-api-key "$ROUTER_API_KEY" \ + --router-models "$ROUTER_MODEL" \ + --temperature "$TEMPERATURE" \ + --output-dir "$OUTPUT_DIR" \ + --run-router + +echo "" +echo "🔄 PHASE 2: DIRECT vLLM EVALUATION" +echo "-----------------------------------" +echo "Endpoint: $VLLM_ENDPOINT" +echo "Model: $VLLM_MODEL (direct access)" +echo "" + +# Run direct vLLM benchmark +python3 router_reason_bench_multi_dataset.py \ + --dataset "$DATASET" \ + --samples-per-category "$SAMPLES_PER_CATEGORY" \ + --concurrent-requests "$CONCURRENT_REQUESTS" \ + --vllm-endpoint "$VLLM_ENDPOINT" \ + --vllm-api-key "$VLLM_API_KEY" \ + --vllm-models "$VLLM_MODEL" \ + --vllm-exec-modes "NR" "XC" \ + --temperature "$TEMPERATURE" \ + --output-dir "$OUTPUT_DIR" \ + --run-vllm + +echo "" +echo "🎨 PHASE 3: GENERATING COMPARISON PLOTS" +echo "----------------------------------------" + +# Generate plots comparing router vs vLLM +ROUTER_RESULT=$(find "$OUTPUT_DIR" -name "*router*auto*" -type d | head -1) +VLLM_RESULT=$(find "$OUTPUT_DIR" -name "*vllm*gpt-oss*" -type d | head -1) + +if [ -n "$ROUTER_RESULT" ] && [ -f "$ROUTER_RESULT/summary.json" ] && [ -n "$VLLM_RESULT" ] && [ -f "$VLLM_RESULT/summary.json" ]; then + echo "Creating comparison plots (router plotted first for visibility)..." + + # Create plots directory + PLOTS_DIR="$OUTPUT_DIR/plots" + mkdir -p "$PLOTS_DIR" + + # Generate vLLM plots with router overlay (router plotted first) + python3 bench_plot.py \ + --summary "$VLLM_RESULT/summary.json" \ + --router-summary "$ROUTER_RESULT/summary.json" \ + --out-dir "$PLOTS_DIR" \ + --metrics accuracy avg_response_time avg_total_tokens \ + --font-scale 1.4 \ + --dpi 300 + + echo "✅ Plots generated in: $PLOTS_DIR" + echo " - bench_plot_accuracy.png (+ PDF)" + echo " - bench_plot_avg_response_time.png (+ PDF)" + echo " - bench_plot_avg_total_tokens.png (+ PDF)" + echo " 📊 Router trend lines plotted first to remain visible even with overlapping dots" +else + echo "⚠️ Skipping plots - missing result files" +fi + +echo "" +echo "📊 BENCHMARK COMPLETED!" +echo "=======================" +echo "Results saved to: $OUTPUT_DIR" +echo "" + +# Display quick summary if results exist +echo "📈 QUICK SUMMARY:" +echo "-----------------" + +# Find and display router results +ROUTER_RESULT=$(find "$OUTPUT_DIR" -name "*router*auto*" -type d | head -1) +if [ -n "$ROUTER_RESULT" ] && [ -f "$ROUTER_RESULT/summary.json" ]; then + echo "🔀 Router (via Envoy):" + python3 -c " +import json, sys +try: + with open('$ROUTER_RESULT/summary.json') as f: + data = json.load(f) + print(f\" Accuracy: {data.get('overall_accuracy', 0):.3f}\") + print(f\" Avg Latency: {data.get('avg_response_time', 0):.2f}s\") + print(f\" Avg Tokens: {data.get('avg_total_tokens', 0):.0f}\") + print(f\" Questions: {data.get('successful_queries', 0)}/{data.get('total_questions', 0)}\") +except Exception as e: + print(f\" Error reading router results: {e}\") +" +fi + +# Find and display vLLM results +VLLM_RESULT=$(find "$OUTPUT_DIR" -name "*vllm*gpt-oss*" -type d | head -1) +if [ -n "$VLLM_RESULT" ] && [ -f "$VLLM_RESULT/summary.json" ]; then + echo "🎯 Direct vLLM:" + python3 -c " +import json, sys +try: + with open('$VLLM_RESULT/summary.json') as f: + data = json.load(f) + print(f\" Accuracy: {data.get('overall_accuracy', 0):.3f}\") + print(f\" Avg Latency: {data.get('avg_response_time', 0):.2f}s\") + print(f\" Avg Tokens: {data.get('avg_total_tokens', 0):.0f}\") + print(f\" Questions: {data.get('successful_queries', 0)}/{data.get('total_questions', 0)}\") + + # Show breakdown by mode if available + by_mode = data.get('by_mode', {}) + if by_mode: + print(\" Mode Breakdown:\") + for mode, metrics in by_mode.items(): + if 'accuracy' in metrics: + print(f\" {mode}: {metrics['accuracy']:.3f} acc, {metrics.get('avg_response_time', 0):.2f}s\") +except Exception as e: + print(f\" Error reading vLLM results: {e}\") +" +fi + +echo "" +echo "🔍 DETAILED ANALYSIS:" +echo "--------------------" +echo "- Router results: $ROUTER_RESULT" +echo "- vLLM results: $VLLM_RESULT" +echo "- Comparison plots: $OUTPUT_DIR/plots/" +echo "- Compare CSV files for detailed question-by-question analysis" +echo "- Check summary.json files for comprehensive metrics" +echo "" + +echo "📊 VISUALIZATION FILES:" +echo "----------------------" +if [ -d "$OUTPUT_DIR/plots" ]; then + echo "- Accuracy comparison: $OUTPUT_DIR/plots/bench_plot_accuracy.png" + echo "- Response time comparison: $OUTPUT_DIR/plots/bench_plot_avg_response_time.png" + echo "- Token usage comparison: $OUTPUT_DIR/plots/bench_plot_avg_total_tokens.png" + echo "- PDF versions also available in same directory" +else + echo "- No plots generated (check for errors above)" +fi +echo "" + +echo "✅ Benchmark comparison complete!" +echo "Run with different datasets: $0 mmlu 10" +echo "Run with different datasets: $0 arc-challenge 3" diff --git a/bench/dataset_factory.py b/bench/dataset_factory.py new file mode 100644 index 00000000..b85d6ac7 --- /dev/null +++ b/bench/dataset_factory.py @@ -0,0 +1,137 @@ +""" +Dataset factory for loading different evaluation datasets. + +This module provides a factory pattern for instantiating different dataset +implementations in a unified way. +""" + +from typing import Dict, List, Optional, Type + +from dataset_implementations.arc_dataset import ( + ARCChallengeDataset, + ARCDataset, + ARCEasyDataset, +) +from dataset_implementations.commonsenseqa_dataset import CommonsenseQADataset +from dataset_implementations.gpqa_dataset import ( + GPQADataset, + GPQADiamondDataset, + GPQAExtendedDataset, + GPQAMainDataset, +) +from dataset_implementations.hellaswag_dataset import HellaSwagDataset +from dataset_implementations.mmlu_dataset import MMLUDataset +from dataset_implementations.truthfulqa_dataset import TruthfulQADataset +from dataset_interface import DatasetInterface + + +class DatasetFactory: + """Factory for creating dataset instances.""" + + _registered_datasets: Dict[str, Type[DatasetInterface]] = {} + + @classmethod + def register_dataset(cls, name: str, dataset_class: Type[DatasetInterface]) -> None: + """Register a new dataset class. + + Args: + name: Name to register the dataset under + dataset_class: Class implementing DatasetInterface + """ + cls._registered_datasets[name.lower()] = dataset_class + + @classmethod + def get_available_datasets(cls) -> List[str]: + """Get list of all registered dataset names.""" + return list(cls._registered_datasets.keys()) + + @classmethod + def create_dataset(cls, name: str) -> DatasetInterface: + """Create a dataset instance by name. + + Args: + name: Name of the dataset to create + + Returns: + Dataset instance implementing DatasetInterface + + Raises: + ValueError: If dataset name is not registered + """ + name_lower = name.lower() + if name_lower not in cls._registered_datasets: + available = ", ".join(cls.get_available_datasets()) + raise ValueError( + f"Unknown dataset: {name}. Available datasets: {available}" + ) + + dataset_class = cls._registered_datasets[name_lower] + return dataset_class() + + @classmethod + def get_dataset_info(cls, name: str) -> Dict[str, str]: + """Get basic info about a dataset without loading it. + + Args: + name: Name of the dataset + + Returns: + Dictionary with dataset information + """ + dataset = cls.create_dataset(name) + return { + "name": dataset.dataset_name, + "supports_cot": str(dataset.supports_cot), + "categories_count": str(len(dataset.get_available_categories())), + } + + +# Register built-in datasets +DatasetFactory.register_dataset("mmlu", MMLUDataset) +DatasetFactory.register_dataset("mmlu-pro", MMLUDataset) + +# Register ARC datasets +DatasetFactory.register_dataset("arc", ARCDataset) +DatasetFactory.register_dataset("arc-easy", ARCEasyDataset) +DatasetFactory.register_dataset("arc-challenge", ARCChallengeDataset) + +# Register GPQA datasets +DatasetFactory.register_dataset("gpqa", GPQAMainDataset) +DatasetFactory.register_dataset("gpqa-main", GPQAMainDataset) +DatasetFactory.register_dataset("gpqa-extended", GPQAExtendedDataset) +DatasetFactory.register_dataset("gpqa-diamond", GPQADiamondDataset) + +# Register hard reasoning datasets +DatasetFactory.register_dataset("truthfulqa", TruthfulQADataset) +DatasetFactory.register_dataset("commonsenseqa", CommonsenseQADataset) +DatasetFactory.register_dataset("hellaswag", HellaSwagDataset) + + +def list_available_datasets() -> None: + """Print information about all available datasets.""" + print("Available datasets:") + print("-" * 50) + + for name in DatasetFactory.get_available_datasets(): + try: + info = DatasetFactory.get_dataset_info(name) + print(f"• {name}") + print(f" Name: {info['name']}") + print(f" Supports CoT: {info['supports_cot']}") + print(f" Categories: {info['categories_count']}") + print() + except Exception as e: + print(f"• {name} (error loading info: {e})") + print() + + +def create_dataset(name: str) -> DatasetInterface: + """Convenience function to create a dataset instance. + + Args: + name: Name of the dataset to create + + Returns: + Dataset instance + """ + return DatasetFactory.create_dataset(name) diff --git a/bench/dataset_implementations/__init__.py b/bench/dataset_implementations/__init__.py new file mode 100644 index 00000000..00804dc7 --- /dev/null +++ b/bench/dataset_implementations/__init__.py @@ -0,0 +1,28 @@ +"""Dataset implementations for the benchmark.""" + +from .arc_dataset import ARCChallengeDataset, ARCDataset, ARCEasyDataset +from .commonsenseqa_dataset import CommonsenseQADataset +from .gpqa_dataset import ( + GPQADataset, + GPQADiamondDataset, + GPQAExtendedDataset, + GPQAMainDataset, +) +from .hellaswag_dataset import HellaSwagDataset +from .mmlu_dataset import MMLUDataset, load_mmlu_pro_dataset +from .truthfulqa_dataset import TruthfulQADataset + +__all__ = [ + "MMLUDataset", + "load_mmlu_pro_dataset", + "ARCDataset", + "ARCEasyDataset", + "ARCChallengeDataset", + "CommonsenseQADataset", + "GPQADataset", + "GPQAMainDataset", + "GPQAExtendedDataset", + "GPQADiamondDataset", + "HellaSwagDataset", + "TruthfulQADataset", +] diff --git a/bench/dataset_implementations/arc_dataset.py b/bench/dataset_implementations/arc_dataset.py new file mode 100644 index 00000000..92a51165 --- /dev/null +++ b/bench/dataset_implementations/arc_dataset.py @@ -0,0 +1,227 @@ +""" +ARC Dataset Implementation + +AI2 Reasoning Challenge for elementary and middle school science questions +with automatic subject categorization across Biology, Chemistry, Physics, +Earth Science, and General Science. +""" + +import os +import random +import sys +from typing import List, Optional, Tuple + +import numpy as np +import pandas as pd +from datasets import load_dataset + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from dataset_interface import DatasetInfo, DatasetInterface, PromptFormatter, Question + + +class ARCDataset(DatasetInterface): + """ARC (AI2 Reasoning Challenge) dataset implementation.""" + + def __init__(self, variant: str = "both"): + """Initialize ARC dataset. + + Args: + variant: Which ARC variant to use ("easy", "challenge", or "both") + """ + self.variant = variant.lower() + if self.variant not in ["easy", "challenge", "both"]: + raise ValueError("variant must be 'easy', 'challenge', or 'both'") + + self._dataset_cache = None + self._categories_cache = None + + @property + def dataset_name(self) -> str: + if self.variant == "both": + return "ARC" + return f"ARC-{self.variant.title()}" + + @property + def supports_cot(self) -> bool: + return False # ARC doesn't have built-in CoT content + + def _load_raw_dataset(self): + """Load raw ARC dataset from Hugging Face.""" + if self._dataset_cache is not None: + return self._dataset_cache + + datasets_to_load = [] + + if self.variant in ["easy", "both"]: + easy_dataset = load_dataset("allenai/ai2_arc", "ARC-Easy", split="test") + easy_df = pd.DataFrame(easy_dataset) + easy_df["difficulty"] = "Easy" + easy_df["arc_variant"] = "ARC-Easy" + datasets_to_load.append(easy_df) + + if self.variant in ["challenge", "both"]: + challenge_dataset = load_dataset( + "allenai/ai2_arc", "ARC-Challenge", split="test" + ) + challenge_df = pd.DataFrame(challenge_dataset) + challenge_df["difficulty"] = "Challenge" + challenge_df["arc_variant"] = "ARC-Challenge" + datasets_to_load.append(challenge_df) + + if len(datasets_to_load) == 1: + self._dataset_cache = datasets_to_load[0] + else: + self._dataset_cache = pd.concat(datasets_to_load, ignore_index=True) + + return self._dataset_cache + + def _get_category(self) -> str: + """ + ARC dataset doesn't have explicit subject categories. + Use a single 'Science' category since all questions are science-related. + """ + return "Science" + + def load_dataset( + self, + categories: Optional[List[str]] = None, + samples_per_category: Optional[int] = None, + seed: int = 42, + ) -> Tuple[List[Question], DatasetInfo]: + """Load ARC dataset.""" + df = self._load_raw_dataset() + + # Convert to Question objects and infer categories + questions = [] + for _, row in df.iterrows(): + # Extract choices - ARC format has choices as dict with labels + choices_dict = row["choices"] + if isinstance(choices_dict, dict): + # Extract text choices in order + labels = choices_dict.get("label", []) + texts = choices_dict.get("text", []) + options = [text for text in texts if text] # Filter out empty choices + else: + options = [] + + # Convert answer key from letter to index + answer_key = str(row["answerKey"]) + if len(options) > 0 and answer_key in "ABCDEFGHIJ": + correct_answer_index = ord(answer_key) - ord("A") + # Ensure the index is within bounds + if correct_answer_index >= len(options): + correct_answer_index = None + else: + correct_answer_index = None + + # Skip questions with invalid answer keys + if correct_answer_index is None: + continue + + # Use single category since ARC doesn't have explicit subjects + category = self._get_category() + + question = Question( + question_id=str(row.get("id", f"arc_{len(questions)}")), + category=category, + question=str(row["question"]), + options=options, + correct_answer=correct_answer_index, # Now an integer index + cot_content=None, # ARC doesn't have CoT + metadata={ + "source": "ARC", + "difficulty": row["difficulty"], + "arc_variant": row["arc_variant"], + }, + ) + questions.append(question) + + # Get all unique categories + all_categories = sorted(list(set(q.category for q in questions))) + self._categories_cache = all_categories + + # Filter by categories if specified + if categories: + questions = [q for q in questions if q.category in categories] + if not questions: + valid_categories = ", ".join(all_categories) + raise ValueError( + f"No data found for specified categories. " + f"Valid categories are: {valid_categories}" + ) + + # Sample if requested + if samples_per_category: + random.seed(seed) + np.random.seed(seed) + + # Group by category + category_questions = {} + for q in questions: + if q.category not in category_questions: + category_questions[q.category] = [] + category_questions[q.category].append(q) + + # Sample from each category + sampled_questions = [] + for category, cat_questions in category_questions.items(): + if len(cat_questions) > samples_per_category: + sampled = random.sample(cat_questions, samples_per_category) + sampled_questions.extend(sampled) + else: + sampled_questions.extend(cat_questions) + + questions = sampled_questions + + # Create dataset info + dataset_info = DatasetInfo( + name=self.dataset_name, + description=f"AI2 Reasoning Challenge ({self.variant})", + categories=list(set(q.category for q in questions)), + total_questions=len(questions), + format_type="multiple_choice", + difficulty_level="elementary" if self.variant == "easy" else "mixed", + ) + + return questions, dataset_info + + def get_available_categories(self) -> List[str]: + """Get all available ARC categories.""" + if self._categories_cache is None: + # Load dataset to get categories + self.load_dataset() + return self._categories_cache or [] + + def format_prompt(self, question: Question, prompt_style: str = "plain") -> str: + """Format ARC question into prompt.""" + if prompt_style == "plain": + return PromptFormatter.format_enhanced_prompt( + question.question, question.options, "ARC", "mixed", "plain" + ) + elif prompt_style == "cot": + return PromptFormatter.format_enhanced_prompt( + question.question, question.options, "ARC", "mixed", "cot" + ) + elif prompt_style == "explicit_cot": + # ARC doesn't have CoT content, so fall back to regular CoT + return PromptFormatter.format_cot_prompt( + question.question, question.options + ) + else: + raise ValueError(f"Unknown prompt style: {prompt_style}") + + +# Convenience classes for specific variants +class ARCEasyDataset(ARCDataset): + """ARC-Easy dataset.""" + + def __init__(self): + super().__init__(variant="easy") + + +class ARCChallengeDataset(ARCDataset): + """ARC-Challenge dataset.""" + + def __init__(self): + super().__init__(variant="challenge") diff --git a/bench/dataset_implementations/commonsenseqa_dataset.py b/bench/dataset_implementations/commonsenseqa_dataset.py new file mode 100644 index 00000000..1824594e --- /dev/null +++ b/bench/dataset_implementations/commonsenseqa_dataset.py @@ -0,0 +1,303 @@ +""" +CommonsenseQA dataset implementation. + +This module implements the DatasetInterface for CommonsenseQA dataset which +tests commonsense reasoning across various conceptual domains. +""" + +import os +import random +import sys +from typing import List, Optional, Tuple + +import numpy as np +import pandas as pd +from datasets import load_dataset + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from dataset_interface import DatasetInfo, DatasetInterface, PromptFormatter, Question + + +class CommonsenseQADataset(DatasetInterface): + """CommonsenseQA dataset implementation.""" + + def __init__(self): + """Initialize CommonsenseQA dataset.""" + self._dataset_cache = None + self._categories_cache = None + + @property + def dataset_name(self) -> str: + return "CommonsenseQA" + + @property + def supports_cot(self) -> bool: + return True # CommonsenseQA benefits from reasoning + + def _load_raw_dataset(self): + """Load raw CommonsenseQA dataset from Hugging Face.""" + if self._dataset_cache is not None: + return self._dataset_cache + + try: + # Load train and validation splits + train_dataset = load_dataset("commonsense_qa", split="train") + val_dataset = load_dataset("commonsense_qa", split="validation") + + # Combine both splits for more data + train_df = pd.DataFrame(train_dataset) + val_df = pd.DataFrame(val_dataset) + self._dataset_cache = pd.concat([train_df, val_df], ignore_index=True) + + except Exception as e: + print(f"Warning: Could not load CommonsenseQA dataset: {e}") + print("You may need to check your internet connection or dataset access.") + # Create empty dataframe as fallback + self._dataset_cache = pd.DataFrame() + + return self._dataset_cache + + def _extract_categories(self, df: pd.DataFrame) -> List[str]: + """Extract categories from CommonsenseQA dataset based on question concepts.""" + if df.empty: + return [] + + # Use question_concept as the basis for categorization + concepts = df["question_concept"].unique() + + # Group related concepts into broader categories + def categorize_concept(concept: str) -> str: + concept_lower = concept.lower() + + # Physical objects and materials + if any( + word in concept_lower + for word in [ + "tool", + "object", + "material", + "container", + "furniture", + "clothing", + "food", + ] + ): + return "Physical Objects" + + # Human activities and behaviors + elif any( + word in concept_lower + for word in [ + "activity", + "action", + "behavior", + "work", + "play", + "exercise", + ] + ): + return "Human Activities" + + # Locations and places + elif any( + word in concept_lower + for word in ["place", "location", "building", "room", "area", "space"] + ): + return "Places & Locations" + + # Emotions and mental states + elif any( + word in concept_lower + for word in ["emotion", "feeling", "mental", "mind", "thought", "mood"] + ): + return "Emotions & Mental States" + + # Social and relationships + elif any( + word in concept_lower + for word in [ + "people", + "person", + "social", + "relationship", + "family", + "friend", + ] + ): + return "Social & Relationships" + + # Time and events + elif any( + word in concept_lower + for word in ["time", "event", "occasion", "period", "moment"] + ): + return "Time & Events" + + # Animals and nature + elif any( + word in concept_lower + for word in ["animal", "nature", "plant", "wildlife", "creature"] + ): + return "Animals & Nature" + + # Abstract concepts + elif any( + word in concept_lower + for word in ["concept", "idea", "principle", "theory", "abstract"] + ): + return "Abstract Concepts" + + else: + return "General Knowledge" + + # Add category column to dataframe + if "category" not in df.columns: + df["category"] = df["question_concept"].apply(categorize_concept) + + return sorted(df["category"].unique().tolist()) + + def get_available_categories(self) -> List[str]: + """Get all available categories in the dataset.""" + if self._categories_cache is None: + df = self._load_raw_dataset() + self._categories_cache = self._extract_categories(df) + return self._categories_cache + + def load_dataset( + self, + categories: Optional[List[str]] = None, + samples_per_category: Optional[int] = None, + seed: int = 42, + ) -> Tuple[List[Question], DatasetInfo]: + """Load CommonsenseQA dataset with filtering and sampling.""" + df = self._load_raw_dataset() + + if df.empty: + return [], DatasetInfo( + name=self.dataset_name, + categories=[], + total_questions=0, + ) + + # Extract categories + all_categories = self._extract_categories(df) + + # Filter by categories if specified + if categories: + df = df[df["category"].isin(categories)] + if df.empty: + valid_categories = ", ".join(all_categories) + raise ValueError( + f"No data found for specified categories. Valid categories are: {valid_categories}" + ) + + # Sample questions per category if specified + if samples_per_category: + random.seed(seed) + np.random.seed(seed) + sampled_dfs = [] + for category in df["category"].unique(): + category_df = df[df["category"] == category] + if len(category_df) > samples_per_category: + sampled_df = category_df.sample( + samples_per_category, random_state=seed + ) + sampled_dfs.append(sampled_df) + else: + sampled_dfs.append(category_df) + df = pd.concat(sampled_dfs) if sampled_dfs else pd.DataFrame() + + # Convert to Question objects + questions = [] + for _, row in df.iterrows(): + # Extract multiple choice options + choices = row["choices"] + choice_texts = choices["text"] + choice_labels = choices["label"] # ['A', 'B', 'C', 'D', 'E'] + + # Find correct answer index + answer_key = row["answerKey"] + correct_idx = choice_labels.index(answer_key) + + question = Question( + question_id=row["id"], + question=row["question"], + options=choice_texts, + correct_answer=correct_idx, # 0-indexed + category=row["category"], + cot_content=None, # CommonsenseQA doesn't provide CoT + ) + questions.append(question) + + dataset_info = DatasetInfo( + name=self.dataset_name, + description="CommonsenseQA tests commonsense reasoning across various conceptual domains", + categories=sorted(df["category"].unique().tolist()) if not df.empty else [], + total_questions=len(questions), + format_type="multiple_choice", + difficulty_level="hard", + ) + + return questions, dataset_info + + def format_prompt(self, question: Question, style: str = "plain") -> str: + """Format a question into a prompt.""" + formatter = PromptFormatter() + + if style == "plain": + return formatter.format_enhanced_prompt( + question.question, question.options, "CommonsenseQA", "hard", "plain" + ) + elif style == "cot": + return formatter.format_enhanced_prompt( + question.question, question.options, "CommonsenseQA", "hard", "cot" + ) + elif style == "explicit_cot": + return formatter.format_explicit_cot_prompt( + question.question, question.options, question.cot_content + ) + else: + raise ValueError(f"Unknown prompt style: {style}") + + +class CommonsenseQAPromptFormatter(PromptFormatter): + """Prompt formatter for CommonsenseQA questions.""" + + def format_plain_prompt(self, question: str, options: List[str]) -> str: + """Format a plain prompt for CommonsenseQA.""" + formatted_options = "" + for i, option in enumerate(options): + letter = chr(ord("A") + i) + formatted_options += f"{letter}) {option}\n" + + prompt = ( + f"Question: {question}\n\n" + f"Options:\n{formatted_options}\n" + f"Please choose the answer that demonstrates the best commonsense reasoning. " + f"Provide your answer in the format 'Answer: [letter]'." + ) + return prompt + + def format_cot_prompt(self, question: str, options: List[str]) -> str: + """Format a chain-of-thought prompt for CommonsenseQA.""" + formatted_options = "" + for i, option in enumerate(options): + letter = chr(ord("A") + i) + formatted_options += f"{letter}) {option}\n" + + prompt = ( + f"Question: {question}\n\n" + f"Options:\n{formatted_options}\n" + f"Please think step-by-step about this question using commonsense reasoning. " + f"Consider what you know about the world and how things typically work. " + f"Then provide your final answer in the format 'Answer: [letter]'." + ) + return prompt + + def format_explicit_cot_prompt( + self, question: str, options: List[str], cot_content: Optional[str] + ) -> str: + """Format an explicit chain-of-thought prompt for CommonsenseQA.""" + # CommonsenseQA doesn't provide CoT content, so fall back to regular CoT + return self.format_cot_prompt(question, options) diff --git a/bench/dataset_implementations/gpqa_dataset.py b/bench/dataset_implementations/gpqa_dataset.py new file mode 100644 index 00000000..b0fdd403 --- /dev/null +++ b/bench/dataset_implementations/gpqa_dataset.py @@ -0,0 +1,280 @@ +""" +GPQA Dataset Implementation + +Graduate-level Google-proof Q&A dataset for advanced scientific reasoning +evaluation. Supports Main, Extended, and Diamond variants with Chain-of-Thought +reasoning content. +""" + +import os +import random +import sys +from typing import List, Optional, Tuple + +import numpy as np +import pandas as pd +from datasets import load_dataset + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from dataset_interface import DatasetInfo, DatasetInterface, PromptFormatter, Question + + +class GPQADataset(DatasetInterface): + """GPQA (Graduate-level Google-proof Q&A) dataset implementation.""" + + def __init__(self, subset: str = "gpqa_main"): + """Initialize GPQA dataset. + + Args: + subset: Which GPQA subset to use ("gpqa_main", "gpqa_extended", or "gpqa_diamond") + """ + self.subset = subset + valid_subsets = ["gpqa_main", "gpqa_extended", "gpqa_diamond"] + if self.subset not in valid_subsets: + raise ValueError(f"subset must be one of {valid_subsets}") + + self._dataset_cache = None + self._categories_cache = None + + @property + def dataset_name(self) -> str: + return f"GPQA-{self.subset.replace('gpqa_', '').title()}" + + @property + def supports_cot(self) -> bool: + return True # GPQA has reasoning explanations + + def _load_raw_dataset(self): + """Load raw GPQA dataset from Hugging Face.""" + if self._dataset_cache is not None: + return self._dataset_cache + + try: + # Try loading from the official GPQA dataset + dataset = load_dataset("Idavidrein/gpqa", self.subset, split="train") + self._dataset_cache = pd.DataFrame(dataset) + except Exception as e: + # Fallback: try alternative dataset names or warn user + print(f"Warning: Could not load GPQA dataset {self.subset}: {e}") + print( + "You may need to install the dataset manually or check the dataset name." + ) + # Create empty dataframe as fallback + self._dataset_cache = pd.DataFrame() + + return self._dataset_cache + + def _standardize_subject_category(self, subject: str) -> str: + """Standardize subject names to consistent categories.""" + subject_lower = subject.lower() if subject else "" + + # Map various subject names to standard categories + if any(word in subject_lower for word in ["physics", "phys"]): + return "Physics" + elif any(word in subject_lower for word in ["chemistry", "chem"]): + return "Chemistry" + elif any(word in subject_lower for word in ["biology", "bio"]): + return "Biology" + elif any(word in subject_lower for word in ["math", "mathematics"]): + return "Mathematics" + else: + return "Other" + + def load_dataset( + self, + categories: Optional[List[str]] = None, + samples_per_category: Optional[int] = None, + seed: int = 42, + ) -> Tuple[List[Question], DatasetInfo]: + """Load GPQA dataset.""" + df = self._load_raw_dataset() + + if df.empty: + # Return empty dataset if loading failed + return [], DatasetInfo( + name=self.dataset_name, + description="GPQA dataset (failed to load)", + categories=[], + total_questions=0, + format_type="multiple_choice", + difficulty_level="graduate", + ) + + # Convert to Question objects + questions = [] + for _, row in df.iterrows(): + # Handle different possible column names for GPQA + question_text = str(row.get("Question", row.get("question", ""))) + + # Extract multiple choice options + options = [] + correct_answer = None + + # GPQA has correct answer and incorrect answers as separate columns + correct_answer_text = None + if "Correct Answer" in row and pd.notna(row["Correct Answer"]): + correct_answer_text = str(row["Correct Answer"]) + elif "Answer" in row and pd.notna(row["Answer"]): + correct_answer_text = str(row["Answer"]) + elif "answer" in row and pd.notna(row["answer"]): + correct_answer_text = str(row["answer"]) + + # Collect all answer options + incorrect_answers = [] + for i in [1, 2, 3]: + col_name = f"Incorrect Answer {i}" + if col_name in row and pd.notna(row[col_name]): + incorrect_answers.append(str(row[col_name])) + + # Create options list with correct answer in random position + if correct_answer_text and incorrect_answers: + options = incorrect_answers + [correct_answer_text] + random.shuffle(options) # Randomize order + correct_answer = options.index( + correct_answer_text + ) # Find index after shuffle + else: + # Fallback: try other formats + options = [] + correct_answer = None + + # Try to extract from individual option columns (A, B, C, D) + for letter in ["A", "B", "C", "D"]: + if letter in row and pd.notna(row[letter]): + options.append(str(row[letter])) + + if options and correct_answer_text: + # Try to find correct answer in options + try: + correct_answer = options.index(correct_answer_text) + except ValueError: + correct_answer = 0 # Default to first option if not found + + # Get subject/category + subject = row.get( + "Subject", row.get("subject", row.get("Category", "Other")) + ) + category = self._standardize_subject_category(str(subject)) + + # Get explanation/reasoning if available + explanation = None + for col in ["Explanation", "explanation", "reasoning", "Reasoning"]: + if col in row and pd.notna(row[col]): + explanation = str(row[col]) + break + + # Skip questions without proper multiple choice format + if not options or correct_answer is None: + continue + + question = Question( + question_id=str(row.get("Record ID", f"gpqa_{len(questions)}")), + category=category, + question=question_text, + options=options, + correct_answer=correct_answer, + cot_content=explanation, + metadata={ + "source": "GPQA", + "subset": self.subset, + "difficulty": "graduate", + "subject": str(subject), + }, + ) + questions.append(question) + + # Get all unique categories + all_categories = sorted(list(set(q.category for q in questions))) + self._categories_cache = all_categories + + # Filter by categories if specified + if categories: + questions = [q for q in questions if q.category in categories] + if not questions: + valid_categories = ", ".join(all_categories) + raise ValueError( + f"No data found for specified categories. " + f"Valid categories are: {valid_categories}" + ) + + # Sample if requested + if samples_per_category: + random.seed(seed) + np.random.seed(seed) + + # Group by category + category_questions = {} + for q in questions: + if q.category not in category_questions: + category_questions[q.category] = [] + category_questions[q.category].append(q) + + # Sample from each category + sampled_questions = [] + for category, cat_questions in category_questions.items(): + if len(cat_questions) > samples_per_category: + sampled = random.sample(cat_questions, samples_per_category) + sampled_questions.extend(sampled) + else: + sampled_questions.extend(cat_questions) + + questions = sampled_questions + + # Create dataset info + dataset_info = DatasetInfo( + name=self.dataset_name, + description="Graduate-level Google-proof Q&A benchmark", + categories=list(set(q.category for q in questions)), + total_questions=len(questions), + format_type="multiple_choice", + difficulty_level="graduate", + ) + + return questions, dataset_info + + def get_available_categories(self) -> List[str]: + """Get all available GPQA categories.""" + if self._categories_cache is None: + # Load dataset to get categories + self.load_dataset() + return self._categories_cache or [] + + def format_prompt(self, question: Question, prompt_style: str = "plain") -> str: + """Format GPQA question into prompt.""" + if prompt_style == "plain": + return PromptFormatter.format_enhanced_prompt( + question.question, question.options, "GPQA", "graduate", "plain" + ) + elif prompt_style == "cot": + return PromptFormatter.format_enhanced_prompt( + question.question, question.options, "GPQA", "graduate", "cot" + ) + elif prompt_style == "explicit_cot": + return PromptFormatter.format_explicit_cot_prompt( + question.question, question.options, question.cot_content + ) + else: + raise ValueError(f"Unknown prompt style: {prompt_style}") + + +# Convenience classes for specific subsets +class GPQAMainDataset(GPQADataset): + """GPQA Main dataset.""" + + def __init__(self): + super().__init__(subset="gpqa_main") + + +class GPQAExtendedDataset(GPQADataset): + """GPQA Extended dataset.""" + + def __init__(self): + super().__init__(subset="gpqa_extended") + + +class GPQADiamondDataset(GPQADataset): + """GPQA Diamond dataset (highest quality subset).""" + + def __init__(self): + super().__init__(subset="gpqa_diamond") diff --git a/bench/dataset_implementations/hellaswag_dataset.py b/bench/dataset_implementations/hellaswag_dataset.py new file mode 100644 index 00000000..8d875d19 --- /dev/null +++ b/bench/dataset_implementations/hellaswag_dataset.py @@ -0,0 +1,232 @@ +""" +HellaSwag dataset implementation. + +This module implements the DatasetInterface for HellaSwag dataset which +tests commonsense reasoning about everyday activities and situations. +""" + +import os +import random +import sys +from typing import List, Optional, Tuple + +import numpy as np +import pandas as pd +from datasets import load_dataset + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from dataset_interface import DatasetInfo, DatasetInterface, PromptFormatter, Question + + +class HellaSwagDataset(DatasetInterface): + """HellaSwag dataset implementation.""" + + def __init__(self): + """Initialize HellaSwag dataset.""" + self._dataset_cache = None + self._categories_cache = None + + @property + def dataset_name(self) -> str: + return "HellaSwag" + + @property + def supports_cot(self) -> bool: + return True # HellaSwag benefits from reasoning about context + + def _load_raw_dataset(self): + """Load raw HellaSwag dataset from Hugging Face.""" + if self._dataset_cache is not None: + return self._dataset_cache + + try: + # Load train and validation splits + train_dataset = load_dataset("hellaswag", split="train") + val_dataset = load_dataset("hellaswag", split="validation") + + # Combine both splits for more data + train_df = pd.DataFrame(train_dataset) + val_df = pd.DataFrame(val_dataset) + self._dataset_cache = pd.concat([train_df, val_df], ignore_index=True) + + except Exception as e: + print(f"Warning: Could not load HellaSwag dataset: {e}") + print("You may need to check your internet connection or dataset access.") + # Create empty dataframe as fallback + self._dataset_cache = pd.DataFrame() + + return self._dataset_cache + + def _extract_categories(self, df: pd.DataFrame) -> List[str]: + """Extract categories from HellaSwag dataset using activity labels.""" + if df.empty: + return [] + + # Use activity_label as categories, but clean them up + def clean_activity_label(label: str) -> str: + """Clean up activity labels to make them more readable.""" + # Remove underscores and capitalize properly + cleaned = label.replace("_", " ").title() + + # Handle some common cases + replacements = { + "Tv": "TV", + "Diy": "DIY", + "Atv": "ATV", + "Bmx": "BMX", + "Sumo": "Sumo Wrestling", + "Mma": "MMA", + } + + for old, new in replacements.items(): + cleaned = cleaned.replace(old, new) + + return cleaned + + # Add cleaned category column + if "category" not in df.columns: + df["category"] = df["activity_label"].apply(clean_activity_label) + + return sorted(df["category"].unique().tolist()) + + def get_available_categories(self) -> List[str]: + """Get all available categories in the dataset.""" + if self._categories_cache is None: + df = self._load_raw_dataset() + self._categories_cache = self._extract_categories(df) + return self._categories_cache + + def load_dataset( + self, + categories: Optional[List[str]] = None, + samples_per_category: Optional[int] = None, + seed: int = 42, + ) -> Tuple[List[Question], DatasetInfo]: + """Load HellaSwag dataset with filtering and sampling.""" + df = self._load_raw_dataset() + + if df.empty: + return [], DatasetInfo( + name=self.dataset_name, + categories=[], + total_questions=0, + ) + + # Extract categories + all_categories = self._extract_categories(df) + + # Filter by categories if specified + if categories: + df = df[df["category"].isin(categories)] + if df.empty: + valid_categories = ", ".join(all_categories) + raise ValueError( + f"No data found for specified categories. Valid categories are: {valid_categories}" + ) + + # Sample questions per category if specified + if samples_per_category: + random.seed(seed) + np.random.seed(seed) + sampled_dfs = [] + for category in df["category"].unique(): + category_df = df[df["category"] == category] + if len(category_df) > samples_per_category: + sampled_df = category_df.sample( + samples_per_category, random_state=seed + ) + sampled_dfs.append(sampled_df) + else: + sampled_dfs.append(category_df) + df = pd.concat(sampled_dfs) if sampled_dfs else pd.DataFrame() + + # Convert to Question objects + questions = [] + for _, row in df.iterrows(): + # Construct the full context + context = row["ctx"] # This is the full context (ctx_a + ctx_b combined) + endings = row["endings"] # List of 4 possible endings + correct_idx = int(str(row["label"])) # Convert string label to int (0-3) + + question = Question( + question_id=f"hellaswag_{row['ind']}", + question=f"Context: {context}\n\nWhat happens next?", + options=endings, + correct_answer=correct_idx, # 0-indexed + category=row["category"], + cot_content=None, # HellaSwag doesn't provide CoT + ) + questions.append(question) + + dataset_info = DatasetInfo( + name=self.dataset_name, + description="HellaSwag tests commonsense reasoning about everyday activities and situations", + categories=sorted(df["category"].unique().tolist()) if not df.empty else [], + total_questions=len(questions), + format_type="multiple_choice", + difficulty_level="moderate", + ) + + return questions, dataset_info + + def format_prompt(self, question: Question, style: str = "plain") -> str: + """Format a question into a prompt.""" + formatter = PromptFormatter() + + if style == "plain": + return formatter.format_enhanced_prompt( + question.question, question.options, "HellaSwag", "moderate", "plain" + ) + elif style == "cot": + return formatter.format_enhanced_prompt( + question.question, question.options, "HellaSwag", "moderate", "cot" + ) + elif style == "explicit_cot": + return formatter.format_explicit_cot_prompt( + question.question, question.options, question.cot_content + ) + else: + raise ValueError(f"Unknown prompt style: {style}") + + +class HellaSwagPromptFormatter(PromptFormatter): + """Prompt formatter for HellaSwag questions.""" + + def format_plain_prompt(self, question: str, options: List[str]) -> str: + """Format a plain prompt for HellaSwag.""" + formatted_options = "" + for i, option in enumerate(options): + letter = chr(ord("A") + i) + formatted_options += f"{letter}) {option}\n" + + prompt = ( + f"{question}\n\n" + f"Options:\n{formatted_options}\n" + f"Please choose the most logical and natural continuation. " + f"Provide your answer in the format 'Answer: [letter]'." + ) + return prompt + + def format_cot_prompt(self, question: str, options: List[str]) -> str: + """Format a chain-of-thought prompt for HellaSwag.""" + formatted_options = "" + for i, option in enumerate(options): + letter = chr(ord("A") + i) + formatted_options += f"{letter}) {option}\n" + + prompt = ( + f"{question}\n\n" + f"Options:\n{formatted_options}\n" + f"Please think step-by-step about what would most likely happen next in this situation. " + f"Consider the context, the activity being performed, and what would be the most natural continuation. " + f"Then provide your final answer in the format 'Answer: [letter]'." + ) + return prompt + + def format_explicit_cot_prompt( + self, question: str, options: List[str], cot_content: Optional[str] + ) -> str: + """Format an explicit chain-of-thought prompt for HellaSwag.""" + # HellaSwag doesn't provide CoT content, so fall back to regular CoT + return self.format_cot_prompt(question, options) diff --git a/bench/dataset_implementations/mmlu_dataset.py b/bench/dataset_implementations/mmlu_dataset.py new file mode 100644 index 00000000..5b6411e5 --- /dev/null +++ b/bench/dataset_implementations/mmlu_dataset.py @@ -0,0 +1,159 @@ +""" +MMLU-Pro Dataset Implementation + +Academic knowledge evaluation across 14 subject categories with +Chain-of-Thought reasoning support. +""" + +import os +import random +import sys +from typing import List, Optional, Tuple + +import numpy as np +import pandas as pd +from datasets import load_dataset + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from dataset_interface import DatasetInfo, DatasetInterface, PromptFormatter, Question + + +class MMLUDataset(DatasetInterface): + """MMLU-Pro dataset implementation.""" + + def __init__(self): + self._dataset_cache = None + self._categories_cache = None + + @property + def dataset_name(self) -> str: + return "MMLU-Pro" + + @property + def supports_cot(self) -> bool: + return True + + def load_dataset( + self, + categories: Optional[List[str]] = None, + samples_per_category: Optional[int] = None, + seed: int = 42, + ) -> Tuple[List[Question], DatasetInfo]: + """Load MMLU-Pro dataset.""" + # Load raw dataset + if self._dataset_cache is None: + dataset = load_dataset("TIGER-Lab/MMLU-Pro", split="test") + self._dataset_cache = pd.DataFrame(dataset) + + df = self._dataset_cache.copy() + all_categories = sorted(df["category"].unique().tolist()) + self._categories_cache = all_categories + + # Filter by categories if specified + if categories: + df = df[df["category"].isin(categories)] + if df.empty: + valid_categories = ", ".join(all_categories) + raise ValueError( + f"No data found for specified categories. " + f"Valid categories are: {valid_categories}" + ) + + # Sample if requested + if samples_per_category: + random.seed(seed) + np.random.seed(seed) + sampled_dfs = [] + for category in df["category"].unique(): + category_df = df[df["category"] == category] + if len(category_df) > samples_per_category: + sampled_df = category_df.sample( + samples_per_category, random_state=seed + ) + sampled_dfs.append(sampled_df) + else: + sampled_dfs.append(category_df) + df = pd.concat(sampled_dfs) + + # Convert to Question objects + questions = [] + for _, row in df.iterrows(): + question = Question( + question_id=str(row.get("question_id", f"mmlu_{len(questions)}")), + category=str(row["category"]), + question=str(row["question"]), + options=row["options"] if isinstance(row["options"], list) else [], + correct_answer=str(row["answer"]), + cot_content=( + row.get("cot_content") if pd.notna(row.get("cot_content")) else None + ), + metadata={ + "source": "MMLU-Pro", + "difficulty": row.get("difficulty", "unknown"), + }, + ) + questions.append(question) + + # Create dataset info + dataset_info = DatasetInfo( + name="MMLU-Pro", + description="Massive Multitask Language Understanding - Professional", + categories=list(df["category"].unique()), + total_questions=len(questions), + format_type="multiple_choice", + difficulty_level="undergraduate", + ) + + return questions, dataset_info + + def get_available_categories(self) -> List[str]: + """Get all available MMLU categories.""" + if self._categories_cache is None: + # Load dataset to get categories + self.load_dataset() + return self._categories_cache or [] + + def format_prompt(self, question: Question, prompt_style: str = "plain") -> str: + """Format MMLU question into prompt.""" + if prompt_style == "plain": + return PromptFormatter.format_plain_prompt( + question.question, question.options + ) + elif prompt_style == "cot": + return PromptFormatter.format_cot_prompt( + question.question, question.options + ) + elif prompt_style == "explicit_cot": + return PromptFormatter.format_explicit_cot_prompt( + question.question, question.options, question.cot_content + ) + else: + raise ValueError(f"Unknown prompt style: {prompt_style}") + + +# Legacy compatibility function +def load_mmlu_pro_dataset( + categories: Optional[List[str]] = None, + samples_per_category: Optional[int] = None, + seed: int = 42, +) -> Tuple[pd.DataFrame, List[str]]: + """Legacy function for backward compatibility.""" + mmlu = MMLUDataset() + questions, dataset_info = mmlu.load_dataset(categories, samples_per_category, seed) + + # Convert back to DataFrame format for compatibility + records = [] + for q in questions: + record = { + "question_id": q.question_id, + "category": q.category, + "question": q.question, + "options": q.options, + "answer": q.correct_answer, + "cot_content": q.cot_content, + } + records.append(record) + + df = pd.DataFrame(records) + return df, dataset_info.categories diff --git a/bench/dataset_implementations/truthfulqa_dataset.py b/bench/dataset_implementations/truthfulqa_dataset.py new file mode 100644 index 00000000..afdb8853 --- /dev/null +++ b/bench/dataset_implementations/truthfulqa_dataset.py @@ -0,0 +1,226 @@ +""" +TruthfulQA dataset implementation. + +This module implements the DatasetInterface for TruthfulQA dataset which +tests whether language models are truthful in generating answers to questions. +""" + +import os +import random +import sys +from typing import List, Optional, Tuple + +import numpy as np +import pandas as pd +from datasets import load_dataset + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from dataset_interface import DatasetInfo, DatasetInterface, PromptFormatter, Question + + +class TruthfulQADataset(DatasetInterface): + """TruthfulQA dataset implementation.""" + + def __init__(self): + """Initialize TruthfulQA dataset.""" + self._dataset_cache = None + self._categories_cache = None + + @property + def dataset_name(self) -> str: + return "TruthfulQA" + + @property + def supports_cot(self) -> bool: + return True # TruthfulQA benefits from reasoning + + def _load_raw_dataset(self): + """Load raw TruthfulQA dataset from Hugging Face.""" + if self._dataset_cache is not None: + return self._dataset_cache + + try: + # Load the multiple choice version + dataset = load_dataset("truthful_qa", "multiple_choice", split="validation") + self._dataset_cache = pd.DataFrame(dataset) + except Exception as e: + print(f"Warning: Could not load TruthfulQA dataset: {e}") + print("You may need to check your internet connection or dataset access.") + # Create empty dataframe as fallback + self._dataset_cache = pd.DataFrame() + + return self._dataset_cache + + def _extract_categories(self, df: pd.DataFrame) -> List[str]: + """Extract categories from TruthfulQA dataset. + + TruthfulQA doesn't have explicit categories, so we'll create them + based on question topics/themes. + """ + if df.empty: + return [] + + # For now, we'll use a single "Truthfulness" category + # In the future, we could implement topic classification + def get_category() -> str: + """ + TruthfulQA doesn't have explicit categories. + All questions test truthfulness and misconception detection. + """ + return "Truthfulness" + + # Add single category since TruthfulQA doesn't have explicit subjects + if "category" not in df.columns: + df["category"] = get_category() + + return sorted(df["category"].unique().tolist()) + + def get_available_categories(self) -> List[str]: + """Get all available categories in the dataset.""" + if self._categories_cache is None: + df = self._load_raw_dataset() + self._categories_cache = self._extract_categories(df) + return self._categories_cache + + def load_dataset( + self, + categories: Optional[List[str]] = None, + samples_per_category: Optional[int] = None, + seed: int = 42, + ) -> Tuple[List[Question], DatasetInfo]: + """Load TruthfulQA dataset with filtering and sampling.""" + df = self._load_raw_dataset() + + if df.empty: + return [], DatasetInfo( + name=self.dataset_name, + categories=[], + total_questions=0, + ) + + # Extract categories + all_categories = self._extract_categories(df) + + # Filter by categories if specified + if categories: + df = df[df["category"].isin(categories)] + if df.empty: + valid_categories = ", ".join(all_categories) + raise ValueError( + f"No data found for specified categories. Valid categories are: {valid_categories}" + ) + + # Sample questions per category if specified + if samples_per_category: + random.seed(seed) + np.random.seed(seed) + sampled_dfs = [] + for category in df["category"].unique(): + category_df = df[df["category"] == category] + if len(category_df) > samples_per_category: + sampled_df = category_df.sample( + samples_per_category, random_state=seed + ) + sampled_dfs.append(sampled_df) + else: + sampled_dfs.append(category_df) + df = pd.concat(sampled_dfs) if sampled_dfs else pd.DataFrame() + + # Convert to Question objects + questions = [] + for _, row in df.iterrows(): + # Extract multiple choice options + mc1_targets = row["mc1_targets"] + choices = mc1_targets["choices"] + labels = mc1_targets["labels"] + + # Find the correct answer (label = 1) + correct_idx = None + for i, label in enumerate(labels): + if label == 1: + correct_idx = i + break + + if correct_idx is not None: + question = Question( + question_id=f"truthfulqa_{len(questions)}", + question=row["question"], + options=choices, + correct_answer=correct_idx, # 0-indexed + category=row["category"], + cot_content=None, # TruthfulQA doesn't provide CoT + ) + questions.append(question) + + dataset_info = DatasetInfo( + name=self.dataset_name, + description="TruthfulQA tests whether language models are truthful in generating answers to questions", + categories=sorted(df["category"].unique().tolist()) if not df.empty else [], + total_questions=len(questions), + format_type="multiple_choice", + difficulty_level="hard", + ) + + return questions, dataset_info + + def format_prompt(self, question: Question, style: str = "plain") -> str: + """Format a question into a prompt.""" + formatter = PromptFormatter() + + if style == "plain": + return formatter.format_enhanced_prompt( + question.question, question.options, "TruthfulQA", "hard", "plain" + ) + elif style == "cot": + return formatter.format_enhanced_prompt( + question.question, question.options, "TruthfulQA", "hard", "cot" + ) + elif style == "explicit_cot": + return formatter.format_explicit_cot_prompt( + question.question, question.options, question.cot_content + ) + else: + raise ValueError(f"Unknown prompt style: {style}") + + +class TruthfulQAPromptFormatter(PromptFormatter): + """Prompt formatter for TruthfulQA questions.""" + + def format_plain_prompt(self, question: str, options: List[str]) -> str: + """Format a plain prompt for TruthfulQA.""" + formatted_options = "" + for i, option in enumerate(options): + letter = chr(ord("A") + i) + formatted_options += f"{letter}) {option}\n" + + prompt = ( + f"Question: {question}\n\n" + f"Options:\n{formatted_options}\n" + f"Please choose the most truthful and accurate answer. " + f"Provide your answer in the format 'Answer: [letter]'." + ) + return prompt + + def format_cot_prompt(self, question: str, options: List[str]) -> str: + """Format a chain-of-thought prompt for TruthfulQA.""" + formatted_options = "" + for i, option in enumerate(options): + letter = chr(ord("A") + i) + formatted_options += f"{letter}) {option}\n" + + prompt = ( + f"Question: {question}\n\n" + f"Options:\n{formatted_options}\n" + f"Please think step-by-step about which answer is most truthful and accurate. " + f"Consider whether each option represents a fact or a common misconception. " + f"Then provide your final answer in the format 'Answer: [letter]'." + ) + return prompt + + def format_explicit_cot_prompt( + self, question: str, options: List[str], cot_content: Optional[str] + ) -> str: + """Format an explicit chain-of-thought prompt for TruthfulQA.""" + # TruthfulQA doesn't provide CoT content, so fall back to regular CoT + return self.format_cot_prompt(question, options) diff --git a/bench/dataset_interface.py b/bench/dataset_interface.py new file mode 100644 index 00000000..d4c3c1fe --- /dev/null +++ b/bench/dataset_interface.py @@ -0,0 +1,356 @@ +""" +Multi-Dataset Evaluation Interface + +Provides abstract base classes and standardized interfaces for reasoning +dataset evaluation across MMLU, ARC, GPQA, TruthfulQA, CommonsenseQA, and HellaSwag. + +Key Features: +- Unified Question and DatasetInfo data structures +- Abstract DatasetInterface for consistent implementations +- Enhanced PromptFormatter with dataset-specific optimizations +- Support for Chain-of-Thought (CoT) reasoning modes +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import pandas as pd + + +@dataclass +class Question: + """ + Standardized question representation for multi-choice reasoning tasks. + + Attributes: + question_id: Unique identifier for the question + category: Subject or topic category + question: The question text + options: List of answer choices + correct_answer: Index (int) of the correct option + cot_content: Optional chain-of-thought reasoning + metadata: Additional dataset-specific information + """ + + question_id: str + category: str + question: str + options: List[str] + correct_answer: str + cot_content: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + + +@dataclass +class DatasetInfo: + """ + Dataset metadata and configuration information. + + Attributes: + name: Dataset name (e.g., "GPQA-Main", "ARC-Challenge") + description: Brief description of the dataset + categories: List of available subject categories + total_questions: Total number of questions loaded + format_type: Question format (typically "multiple_choice") + difficulty_level: Complexity level (e.g., "graduate", "undergraduate") + """ + + name: str + description: str + categories: List[str] + total_questions: int + format_type: str + difficulty_level: str + + +class DatasetInterface(ABC): + """Abstract base class for all dataset implementations.""" + + @abstractmethod + def load_dataset( + self, + categories: Optional[List[str]] = None, + samples_per_category: Optional[int] = None, + seed: int = 42, + ) -> Tuple[List[Question], DatasetInfo]: + """Load and return questions from the dataset. + + Args: + categories: List of categories to filter by. If None, load all. + samples_per_category: Max samples per category. If None, load all. + seed: Random seed for reproducible sampling. + + Returns: + Tuple of (questions_list, dataset_info) + """ + pass + + @abstractmethod + def get_available_categories(self) -> List[str]: + """Get list of all available categories in the dataset.""" + pass + + @abstractmethod + def format_prompt(self, question: Question, prompt_style: str = "plain") -> str: + """Format a question into a prompt string. + + Args: + question: Question object to format + prompt_style: Style of prompt ("plain", "cot", "explicit_cot") + + Returns: + Formatted prompt string + """ + pass + + @property + @abstractmethod + def dataset_name(self) -> str: + """Return the name of this dataset.""" + pass + + @property + @abstractmethod + def supports_cot(self) -> bool: + """Return True if dataset has chain-of-thought content.""" + pass + + +class PromptFormatter: + """Utility class for formatting prompts consistently across datasets.""" + + @staticmethod + def get_dataset_specific_instructions(dataset_name: str, difficulty: str) -> str: + """Get dataset-specific instructions to improve accuracy.""" + dataset_name = dataset_name.lower() + difficulty = difficulty.lower() + + if "gpqa" in dataset_name: + return ( + "- This is a graduate-level scientific question\n" + "- Consider the underlying scientific principles\n" + "- Eliminate obviously incorrect options first\n" + ) + elif "truthfulqa" in dataset_name: + return ( + "- This question may contain common misconceptions\n" + "- Be wary of answers that sound plausible but are incorrect\n" + "- Choose the most factually accurate option\n" + ) + elif "hellaswag" in dataset_name: + return ( + "- Choose the most natural and logical continuation\n" + "- Consider common sense and typical sequences of events\n" + "- Think about what would realistically happen next\n" + ) + elif "commonsenseqa" in dataset_name: + return ( + "- Apply common sense reasoning\n" + "- Consider everyday knowledge and experiences\n" + "- Think about typical cause-and-effect relationships\n" + ) + elif "arc" in dataset_name: + return ( + "- This is a science question requiring logical reasoning\n" + "- Apply scientific knowledge and principles\n" + "- Consider the most scientifically accurate answer\n" + ) + elif "mmlu" in dataset_name: + return ( + "- This requires specific domain knowledge\n" + "- Choose the most accurate and complete answer\n" + "- Consider technical precision and accuracy\n" + ) + else: + return "" + + @staticmethod + def get_letter_mapping() -> Dict[int, str]: + """Get A-Z letter mapping for options (supports up to 26 options).""" + return { + 0: "A", + 1: "B", + 2: "C", + 3: "D", + 4: "E", + 5: "F", + 6: "G", + 7: "H", + 8: "I", + 9: "J", + 10: "K", + 11: "L", + 12: "M", + 13: "N", + 14: "O", + 15: "P", + 16: "Q", + 17: "R", + 18: "S", + 19: "T", + 20: "U", + 21: "V", + 22: "W", + 23: "X", + 24: "Y", + 25: "Z", + } + + @staticmethod + def format_options(options: List[str]) -> str: + """Format options list into lettered format.""" + letter_mapping = PromptFormatter.get_letter_mapping() + formatted = "" + for i, option in enumerate(options): + if option.lower() != "n/a": + if i in letter_mapping: + formatted += f"{letter_mapping[i]}) {option}\n" + else: + # Fallback for options beyond Z (unlikely but safe) + formatted += f"{i+1}.) {option}\n" + return formatted.rstrip() + + @staticmethod + def format_plain_prompt(question: str, options: List[str]) -> str: + """Format a basic multiple choice prompt.""" + formatted_options = PromptFormatter.format_options(options) + return ( + f"Question: {question}\n\nOptions:\n{formatted_options}\n\n" + "Instructions:\n" + "- Read the question carefully\n" + "- Consider each option thoroughly\n" + "- Choose the single best answer\n" + "- Respond with ONLY the format: Answer: [letter]\n" + "- Do not include any other text after your answer\n\n" + "Your response:" + ) + + @staticmethod + def format_cot_prompt(question: str, options: List[str]) -> str: + """Format a chain-of-thought prompt.""" + formatted_options = PromptFormatter.format_options(options) + return ( + f"Question: {question}\n\nOptions:\n{formatted_options}\n\n" + "Instructions:\n" + "- Think through this step-by-step\n" + "- Analyze each option carefully\n" + "- Explain your reasoning briefly\n" + "- End with your final answer in the exact format: Answer: [letter]\n\n" + "Your response:" + ) + + @staticmethod + def format_explicit_cot_prompt( + question: str, options: List[str], cot_content: Optional[str] + ) -> str: + """Format a prompt with explicit CoT content.""" + formatted_options = PromptFormatter.format_options(options) + cot_section = f"\nExplanation: {cot_content}\n" if cot_content else "\n" + return ( + f"Question: {question}\n\nOptions:\n{formatted_options}" + f"{cot_section}\n" + "Instructions:\n" + "- Use the provided explanation as guidance\n" + "- Consider how it applies to each option\n" + "- Choose the best answer based on the reasoning\n" + "- Provide your final answer in the exact format: Answer: [letter]\n\n" + "Your response:" + ) + + @staticmethod + def format_enhanced_prompt( + question: str, + options: List[str], + dataset_name: str, + difficulty: str, + prompt_style: str = "plain", + ) -> str: + """Format an enhanced prompt with dataset-specific guidance.""" + formatted_options = PromptFormatter.format_options(options) + dataset_instructions = PromptFormatter.get_dataset_specific_instructions( + dataset_name, difficulty + ) + + if prompt_style == "cot": + base_instructions = ( + "Instructions:\n" + "- Think through this step-by-step\n" + "- Analyze each option carefully\n" + ) + if dataset_instructions: + base_instructions += dataset_instructions + base_instructions += ( + "- Explain your reasoning briefly\n" + "- End with your final answer in the exact format: Answer: [letter]\n\n" + ) + else: # plain + base_instructions = ( + "Instructions:\n" + "- Read the question carefully\n" + "- Consider each option thoroughly\n" + ) + if dataset_instructions: + base_instructions += dataset_instructions + base_instructions += ( + "- Choose the single best answer\n" + "- Respond with ONLY the format: Answer: [letter]\n" + "- Do not include any other text after your answer\n\n" + ) + + return ( + f"Question: {question}\n\nOptions:\n{formatted_options}\n\n" + f"{base_instructions}" + "Your response:" + ) + + +def questions_to_dataframe(questions: List[Question]) -> pd.DataFrame: + """Convert list of Question objects to pandas DataFrame for compatibility.""" + records = [] + for q in questions: + record = { + "question_id": q.question_id, + "category": q.category, + "question": q.question, + "options": q.options, + "answer": q.correct_answer, + "cot_content": q.cot_content, + } + # Add metadata fields if present + if q.metadata: + record.update(q.metadata) + records.append(record) + return pd.DataFrame(records) + + +def dataframe_to_questions(df: pd.DataFrame) -> List[Question]: + """Convert pandas DataFrame back to list of Question objects.""" + questions = [] + for _, row in df.iterrows(): + # Extract metadata (any columns not in the standard Question fields) + standard_fields = { + "question_id", + "category", + "question", + "options", + "answer", + "cot_content", + } + metadata = { + k: v for k, v in row.items() if k not in standard_fields and pd.notna(v) + } + + question = Question( + question_id=str(row["question_id"]), + category=str(row["category"]), + question=str(row["question"]), + options=row["options"] if isinstance(row["options"], list) else [], + correct_answer=str(row["answer"]), + cot_content=( + row.get("cot_content") if pd.notna(row.get("cot_content")) else None + ), + metadata=metadata if metadata else None, + ) + questions.append(question) + return questions diff --git a/bench/router_reason_bench.py b/bench/router_reason_bench.py index 1bf666a4..f3567f98 100644 --- a/bench/router_reason_bench.py +++ b/bench/router_reason_bench.py @@ -17,8 +17,10 @@ # This benchmark supports two usage patterns: # 1) Router-transparent: send a single neutral prompt; router/model decides reasoning. -# 2) Policy evaluation: run NR (neutral), XC (explicit CoT), and optionally AR (automatic reasoning via extra_body) -# per question, then aggregate according to policies like Always-NR, Always-XC, CR-XC, Oracle, etc. +# 2) vLLM 3-case evaluation: run realistic scenarios that match router decision patterns: +# - NR: Plain prompt, no reasoning toggle (baseline/fast) +# - XC: CoT prompt, no reasoning toggle (prompt-based reasoning) +# - NR_REASONING: Plain prompt, reasoning toggle ON (model-based reasoning) ANSWER_PATTERN = re.compile(r"(?:answer(?:\sis)?:?\s*)([A-J])", re.IGNORECASE) @@ -76,7 +78,7 @@ def parse_args(): type=str, nargs="+", default=["NR", "XC"], - help="Prompt styles to run on vLLM: NR (neutral), XC (explicit CoT)", + help="DEPRECATED: vLLM now runs 3 fixed realistic modes: NR (plain), XC (CoT), NR_REASONING (plain+toggle)", ) parser.add_argument( "--run-router", @@ -340,7 +342,17 @@ def call_model( total_tokens = getattr(usage, "total_tokens", None) if usage else None return text, True, prompt_tokens, completion_tokens, total_tokens except Exception as e: - print(f"Model call failed: {e}") + print(f"❌ Model call failed: {e}") + print(f" Error type: {type(e).__name__}") + print(f" Model: {model}") + print(f" Endpoint: {getattr(client, '_base_url', 'unknown')}") + print(f" API key set: {'Yes' if getattr(client, 'api_key', None) else 'No'}") + if hasattr(e, "response"): + print(f" HTTP status: {getattr(e.response, 'status_code', 'unknown')}") + print(f" Response text: {getattr(e.response, 'text', 'unknown')}") + import traceback + + print(f" Full traceback: {traceback.format_exc()}") return "ERROR", False, None, None, None @@ -352,7 +364,7 @@ def build_extra_body_for_model( - DeepSeek v3.1: {"chat_template_kwargs": {"thinking": true/false}} - GPT-OSS: {"reasoning_effort": "low|medium|high"} when ON; if not provided, then low """ - # reasoning: True -> ON, False -> OFF, None -> base + # reasoning: True -> ON, False -> OFF, None -> base (default behavior) lower = model_name.lower() if (("ds" in lower) or ("deepseek" in lower)) and ( @@ -360,10 +372,11 @@ def build_extra_body_for_model( ): if reasoning is True: return {"chat_template_kwargs": {"thinking": True}} - if reasoning is None or reasoning is False: + elif reasoning is False: return {"chat_template_kwargs": {"thinking": False}} - # Base: do not set thinking for DeepSeek - return None + else: # reasoning is None (base mode) + # Base: do not set thinking for DeepSeek - let it use default behavior + return None # Qwen3 family if "qwen3" in lower: @@ -375,12 +388,13 @@ def build_extra_body_for_model( # GPT OSS family if "gpt-oss" in lower or "openai/gpt-oss" in lower or "gpt_oss" in lower: - # Base -> low effort, On -> provided effort (e.g., high) if reasoning is True: return {"reasoning_effort": "high"} - if reasoning is None or reasoning is False: + elif reasoning is False: return {"reasoning_effort": "low"} - return None + else: # reasoning is None (base mode) + # Base: do not set reasoning_effort - let it use default behavior + return None return None @@ -450,8 +464,17 @@ def evaluate_model_router_transparent( max_tokens: int, temperature: float, ) -> pd.DataFrame: + """ + Evaluate router in transparent mode - send plain prompts and let router decide reasoning. + + This represents the 'auto' mode where the router internally decides whether to use + reasoning or not based on the question complexity. + """ client = OpenAI(base_url=endpoint, api_key=api_key or None) print(f"Using model: {model}, endpoint: {endpoint}") + print( + f"API key provided: {'Yes' if api_key else 'No'} (length: {len(api_key) if api_key else 0})" + ) results: List[Dict[str, Any]] = [] questions_data = df.to_dict("records") @@ -491,37 +514,57 @@ def evaluate_model_vllm_multimode( temperature: float, exec_modes: List[str], ) -> pd.DataFrame: - """Run vLLM with NR/XC prompts and reasoning ON/OFF variants.""" - client = OpenAI(base_url=endpoint, api_key=api_key or None) + """Run vLLM with 3 realistic reasoning scenarios. + + The 3 scenarios represent real-world router decision patterns: + 1. NR - Plain prompt, no reasoning toggle (fast baseline) + 2. XC - CoT prompt, no reasoning toggle (prompt-based reasoning) + 3. NR_REASONING - Plain prompt, reasoning toggle ON (model-based reasoning) + """ + client = OpenAI(base_url=endpoint, api_key=api_key or "dummy-key") print(f"Using vLLM model: {model}, endpoint: {endpoint}") results: List[Dict[str, Any]] = [] questions_data = df.to_dict("records") - # Define mode variants: (label, prompt_mode, reasoning_flag) - mode_variants: List[Tuple[str, str, Optional[bool]]] = [] - for m in exec_modes: - if m.upper() == "NR": - mode_variants.extend( - [ - ("VLLM_NR_base", "NR", None), - ("VLLM_NR_reason_on", "NR", True), - ("VLLM_NR_reason_off", "NR", False), - ] - ) - elif m.upper() == "XC": - mode_variants.extend( - [ - ("VLLM_XC_base", "XC", None), - ("VLLM_XC_reason_on", "XC", True), - ("VLLM_XC_reason_off", "XC", False), - ] - ) + # Define 3 realistic mode variants: (label, prompt_mode, reasoning_flag) + # For DeepSeek and Qwen3 models, explicitly set reasoning flags for all modes + model_lower = model.lower() + is_deepseek_or_qwen = ( + (("ds" in model_lower) or ("deepseek" in model_lower)) + and ("v31" in model_lower or "v3.1" in model_lower or "v3" in model_lower) + ) or ("qwen3" in model_lower) + + if is_deepseek_or_qwen: + mode_variants: List[Tuple[str, str, Optional[bool]]] = [ + ("VLLM_NR", "NR", False), # Plain prompt, reasoning OFF (baseline) + ("VLLM_XC", "XC", False), # CoT prompt, reasoning OFF (prompt reasoning) + ( + "VLLM_NR_REASONING", + "NR", + True, + ), # Plain prompt, reasoning ON (model reasoning) + ] + else: + mode_variants: List[Tuple[str, str, Optional[bool]]] = [ + ("VLLM_NR", "NR", None), # Plain prompt, no toggle (baseline) + ("VLLM_XC", "XC", None), # CoT prompt, no toggle (prompt reasoning) + ( + "VLLM_NR_REASONING", + "NR", + True, + ), # Plain prompt, toggle ON (model reasoning) + ] def run_variants(q: Dict[str, Any]) -> List[Dict[str, Any]]: local_records: List[Dict[str, Any]] = [] for label, prompt_mode, reasoning_flag in mode_variants: extra_body = build_extra_body_for_model(model, reasoning_flag) + # Debug: print extra_body for first question to verify configuration + if q == questions_data[0]: + print( + f" {label}: reasoning_flag={reasoning_flag}, extra_body={extra_body}" + ) rec = process_question_single( client, model, diff --git a/bench/router_reason_bench_multi_dataset.py b/bench/router_reason_bench_multi_dataset.py new file mode 100644 index 00000000..30470aac --- /dev/null +++ b/bench/router_reason_bench_multi_dataset.py @@ -0,0 +1,850 @@ +""" +Multi-Dataset Reasoning Benchmark + +A comprehensive evaluation framework for comparing semantic router performance +against direct vLLM inference across various reasoning datasets. + +Features: +- Dataset-agnostic architecture supporting MMLU, ARC, GPQA, TruthfulQA, CommonsenseQA, HellaSwag +- Optimized token limits per dataset complexity +- Multiple reasoning modes (NR, XC, NR_REASONING) +- Structured response parsing with robust answer extraction +- Comprehensive metrics and visualization +""" + +import argparse +import json +import os +import random +import re +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, List, Optional, Tuple + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from dataset_factory import DatasetFactory, list_available_datasets +from dataset_interface import DatasetInfo, Question, questions_to_dataframe +from openai import OpenAI +from tqdm import tqdm + +# Robust answer extraction patterns for structured response parsing +ANSWER_PATTERN_PRIMARY = re.compile(r"(?:answer\s*:?\s*)([A-Z])", re.IGNORECASE) +ANSWER_PATTERN_FINAL = re.compile(r"(?:final\s*answer\s*:?\s*)([A-Z])", re.IGNORECASE) +ANSWER_PATTERN_CONCLUSION = re.compile( + r"(?:therefore|thus|so).*?([A-Z])", re.IGNORECASE +) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Multi-Dataset Reasoning Benchmark: Comprehensive evaluation framework for semantic router vs direct vLLM" + ) + + # Dataset selection + parser.add_argument( + "--dataset", + type=str, + default="mmlu", + help="Dataset to evaluate on. Use --list-datasets to see available options.", + ) + parser.add_argument( + "--list-datasets", + action="store_true", + help="List all available datasets and exit", + ) + + # Semantic router configuration + parser.add_argument( + "--router-endpoint", + type=str, + default=os.environ.get("ROUTER_ENDPOINT", "http://127.0.0.1:8801/v1"), + help="Semantic router endpoint URL", + ) + parser.add_argument( + "--router-api-key", + type=str, + default=os.environ.get( + "ROUTER_API_KEY", os.environ.get("OPENAI_API_KEY", "1234") + ), + help="API key for router endpoint", + ) + parser.add_argument( + "--router-models", + type=str, + nargs="+", + default=["auto"], + help="Router models to evaluate (default: auto).", + ) + + # Direct vLLM configuration + parser.add_argument( + "--vllm-endpoint", + type=str, + default=os.environ.get("VLLM_ENDPOINT", ""), + help="Direct vLLM endpoint URL", + ) + parser.add_argument( + "--vllm-api-key", + type=str, + default=os.environ.get("VLLM_API_KEY", os.environ.get("OPENAI_API_KEY", "")), + help="API key for vLLM endpoint", + ) + parser.add_argument( + "--vllm-models", + type=str, + nargs="+", + default=[], + help="Direct vLLM models to evaluate (leave empty to fetch from endpoint).", + ) + + # vLLM reasoning modes + parser.add_argument( + "--vllm-exec-modes", + type=str, + nargs="+", + default=["NR", "XC"], + help="vLLM reasoning modes: NR (neutral), XC (chain-of-thought), NR_REASONING (reasoning-enabled)", + ) + parser.add_argument( + "--run-router", + action="store_true", + help="Evaluate semantic router performance", + ) + parser.add_argument( + "--run-vllm", + action="store_true", + help="Evaluate direct vLLM performance across multiple reasoning modes", + ) + + # Dataset filtering options + parser.add_argument( + "--categories", + type=str, + nargs="+", + default=None, + help="List of categories to evaluate. If not provided, all available categories will be used.", + ) + parser.add_argument( + "--samples-per-category", + type=int, + default=5, + help="Number of questions to sample per category. If not provided, all questions will be used.", + ) + + # Execution options + parser.add_argument( + "--concurrent-requests", + type=int, + default=1, + help="Number of concurrent requests to make", + ) + parser.add_argument( + "--output-dir", + type=str, + default="results/reasonbench", + help="Directory to save results", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=None, + help="Maximum number of tokens to generate (default: dataset-optimal)", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.0, + help="Temperature for text generation", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducibility", + ) + parser.add_argument( + "--ar-extra-body", + type=str, + default="", + help=( + 'JSON string passed as extra_body for AR mode (e.g., \'{"reasoning":{"effort":"medium"}}\'). ' + "If empty, AR modes are disabled." + ), + ) + return parser.parse_args() + + +def get_dataset_optimal_tokens(dataset_info): + """ + Determine optimal token limit based on dataset complexity and reasoning requirements. + + Token limits are optimized for structured response generation while maintaining + efficiency across different reasoning complexity levels. + """ + dataset_name = dataset_info.name.lower() + difficulty = dataset_info.difficulty_level.lower() + + # Optimized token limits per dataset + dataset_tokens = { + "gpqa": 500, # Graduate-level scientific reasoning + "truthfulqa": 250, # Misconception analysis + "hellaswag": 250, # Natural continuation reasoning + "arc": 220, # Elementary/middle school science + "commonsenseqa": 300, # Common sense reasoning + "mmlu": 150 if difficulty == "undergraduate" else 200, # Academic knowledge + } + + # Find matching dataset + for dataset_key, tokens in dataset_tokens.items(): + if dataset_key in dataset_name: + return tokens + + # Default based on difficulty level + difficulty_tokens = {"graduate": 300, "hard": 300, "moderate": 200, "easy": 150} + + return difficulty_tokens.get(difficulty, 200) + + +def get_available_models(endpoint: str, api_key: str = "") -> List[str]: + """Get available models from an endpoint.""" + client = OpenAI(base_url=endpoint, api_key=api_key or None) + try: + models = client.models.list() + return [m.id for m in models.data] + except Exception as e: + print(f"Error communicating with endpoint to list models: {e}") + return [] + + +def extract_answer(response: Any) -> Optional[str]: + """Extract answer from model response.""" + # Normalize non-string responses into a string to be robust to providers + # that return structured content (e.g., lists of parts or dicts). + if response is None: + return None + + if not isinstance(response, str): + try: + # Handle list-of-parts shapes + if isinstance(response, list): + parts: List[str] = [] + for part in response: + if isinstance(part, dict): + if "text" in part and isinstance(part["text"], str): + parts.append(part["text"]) + elif "content" in part and isinstance(part["content"], str): + parts.append(part["content"]) + else: + parts.append(str(part)) + else: + parts.append(str(part)) + response = "\n".join(parts) + # Handle dict shapes + elif isinstance(response, dict): + for key in ("content", "text", "reasoning_content"): + val = response.get(key) if isinstance(response, dict) else None + if isinstance(val, str) and val: + response = val + break + else: + # Fallback to JSON stringification + response = json.dumps(response, ensure_ascii=False) + else: + response = str(response) + except Exception: + response = str(response) + + # Try multiple extraction patterns in order of preference + patterns = [ANSWER_PATTERN_PRIMARY, ANSWER_PATTERN_FINAL, ANSWER_PATTERN_CONCLUSION] + + for pattern in patterns: + match = pattern.search(response) + if match: + return match.group(1).upper() + + # Fallback 1: Look for standalone letters at end of response + lines = response.strip().split("\n") + for line in reversed(lines[-3:]): # Check last 3 lines + line = line.strip() + if len(line) == 1 and line.upper() in "ABCDEFGHIJKLMNOPQRSTUVWXYZ": + return line.upper() + + # Fallback 2: Find last letter in entire response + for char in reversed(response): + if char.upper() in "ABCDEFGHIJKLMNOPQRSTUVWXYZ": + return char.upper() + + return None + + +def call_model( + client: OpenAI, + model: str, + prompt: str, + max_tokens: int, + temperature: float, + extra_body: Optional[Dict[str, Any]] = None, +) -> Tuple[str, bool, Optional[int], Optional[int], Optional[int]]: + """Call model with given parameters.""" + try: + response = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": prompt}], + max_tokens=max_tokens, + temperature=temperature, + extra_body=extra_body if extra_body else None, + ) + # For reasoning models, content might be in reasoning_content instead of content + message = response.choices[0].message + text = message.content or getattr(message, "reasoning_content", None) or "" + usage = getattr(response, "usage", None) + prompt_tokens = getattr(usage, "prompt_tokens", None) if usage else None + completion_tokens = getattr(usage, "completion_tokens", None) if usage else None + total_tokens = getattr(usage, "total_tokens", None) if usage else None + return text, True, prompt_tokens, completion_tokens, total_tokens + except Exception as e: + print(f"Model call failed: {e}") + return "ERROR", False, None, None, None + + +def build_extra_body_for_model( + model_name: str, reasoning: Optional[bool] +) -> Optional[Dict[str, Any]]: + """Return an extra_body dict to toggle reasoning for a given model. + + - DeepSeek v3.1: {"chat_template_kwargs": {"thinking": true/false}} + - GPT-OSS: {"reasoning_effort": "low|medium|high"} when ON; if not provided, then low + """ + # reasoning: True -> ON, False -> OFF, None -> base (default behavior) + + lower = model_name.lower() + if (("ds" in lower) or ("deepseek" in lower)) and ( + "v31" in lower or "v3.1" in lower or "v3" in lower + ): + if reasoning is True: + return {"chat_template_kwargs": {"thinking": True}} + elif reasoning is False: + return {"chat_template_kwargs": {"thinking": False}} + else: # reasoning is None (base mode) + # Base: do not set thinking for DeepSeek - let it use default behavior + return None + + # Qwen3 family + if "qwen3" in lower: + if reasoning is True: + return {"chat_template_kwargs": {"enable_thinking": True}} + if reasoning is False: + return {"chat_template_kwargs": {"enable_thinking": False}} + return None + + # GPT OSS family + if "gpt-oss" in lower or "openai/gpt-oss" in lower or "gpt_oss" in lower: + if reasoning is True: + return {"reasoning_effort": "high"} + elif reasoning is False: + return {"reasoning_effort": "low"} + else: # reasoning is None (base mode) + # Base: do not set reasoning_effort - let it use default behavior + return None + + return None + + +def process_question_single( + client: OpenAI, + model: str, + question: Question, + dataset: Any, # DatasetInterface + prompt_mode: str, + max_tokens: int, + temperature: float, + ar_extra_body: Optional[Dict[str, Any]] = None, + mode_label: Optional[str] = None, +) -> Dict[str, Any]: + """Process a single question with the model.""" + # Format prompt based on mode + if prompt_mode == "XC": + prompt = dataset.format_prompt(question, "explicit_cot") + extra_body = None + elif prompt_mode == "AR": + prompt = dataset.format_prompt(question, "plain") + extra_body = ar_extra_body + else: # NR or Router-Transparent + prompt = dataset.format_prompt(question, "plain") + extra_body = None + + start_time = time.time() + response_text, success, prompt_tokens, completion_tokens, total_tokens = call_model( + client, model, prompt, max_tokens, temperature, extra_body=extra_body + ) + end_time = time.time() + + predicted_answer = extract_answer(response_text) if success else None + + # Compare predicted answer with correct answer (handle both letter and index formats) + if predicted_answer and predicted_answer in "ABCDEFGHIJKLMNOPQRSTUVWXYZ": + if isinstance(question.correct_answer, str): + # Dataset stores answer as letter (e.g., MMLU: "F") + is_correct = predicted_answer == question.correct_answer + elif isinstance(question.correct_answer, int): + # Dataset stores answer as index (e.g., CommonsenseQA: 1, ARC: 0) + predicted_idx = ord(predicted_answer) - ord("A") + is_correct = predicted_idx == question.correct_answer + else: + is_correct = False + else: + is_correct = False + + return { + "mode": prompt_mode, + "mode_label": mode_label or prompt_mode, + "question_id": question.question_id, + "category": question.category, + "question": question.question, + "options": question.options, + "correct_answer": question.correct_answer, + "model_response": response_text, + "predicted_answer": predicted_answer, + "is_correct": is_correct, + "response_time": end_time - start_time, + "success": success, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + + +def evaluate_model_router_transparent( + questions: List[Question], + dataset: Any, # DatasetInterface + model: str, + endpoint: str, + api_key: str, + concurrent_requests: int, + max_tokens: int, + temperature: float, +) -> pd.DataFrame: + """Evaluate model in router-transparent mode.""" + client = OpenAI(base_url=endpoint, api_key=api_key or None) + print(f"Using model: {model}, endpoint: {endpoint}") + + results: List[Dict[str, Any]] = [] + + with ThreadPoolExecutor(max_workers=concurrent_requests) as executor: + futures = [] + for question in questions: + futures.append( + executor.submit( + process_question_single, + client, + model, + question, + dataset, + "NR", + max_tokens, + temperature, + None, + mode_label="Router_NR", + ) + ) + + for future in tqdm( + futures, total=len(futures), desc=f"Evaluating {model} (Router-Transparent)" + ): + results.append(future.result()) + + return pd.DataFrame(results) + + +def evaluate_model_vllm_multimode( + questions: List[Question], + dataset: Any, # DatasetInterface + model: str, + endpoint: str, + api_key: str, + concurrent_requests: int, + max_tokens: int, + temperature: float, + exec_modes: List[str], +) -> pd.DataFrame: + """Run vLLM with 2-3 realistic reasoning scenarios. + + The scenarios represent real-world router decision patterns: + 1. NR - Plain prompt, no reasoning toggle (fast baseline) - ALWAYS included + 2. XC - CoT prompt, no reasoning toggle (prompt-based reasoning) - ONLY if dataset has CoT + 3. NR_REASONING - Plain prompt, reasoning toggle ON (model-based reasoning) - ALWAYS included + """ + client = OpenAI(base_url=endpoint, api_key=api_key or "dummy-key") + print(f"Using vLLM model: {model}, endpoint: {endpoint}") + + # Check if dataset has actual CoT content by examining sample questions + has_cot_content = any( + q.cot_content is not None and q.cot_content.strip() for q in questions[:10] + ) + + if has_cot_content: + print(f" Dataset has CoT content - using 3 modes: NR, XC, NR_REASONING") + else: + print( + f" Dataset lacks CoT content - using 2 modes: NR, NR_REASONING (skipping XC)" + ) + + results: List[Dict[str, Any]] = [] + + # Define mode variants based on model type and CoT availability + model_lower = model.lower() + is_deepseek_or_qwen = ( + (("ds" in model_lower) or ("deepseek" in model_lower)) + and ("v31" in model_lower or "v3.1" in model_lower or "v3" in model_lower) + ) or ("qwen3" in model_lower) + + # Base modes (always included) + if is_deepseek_or_qwen: + mode_variants: List[Tuple[str, str, Optional[bool]]] = [ + ("VLLM_NR", "NR", False), # Plain prompt, reasoning OFF (baseline) + ( + "VLLM_NR_REASONING", + "NR", + True, + ), # Plain prompt, reasoning ON (model reasoning) + ] + else: + mode_variants: List[Tuple[str, str, Optional[bool]]] = [ + ("VLLM_NR", "NR", None), # Plain prompt, no toggle (baseline) + ( + "VLLM_NR_REASONING", + "NR", + True, + ), # Plain prompt, reasoning toggle ON (model reasoning) + ] + + # Add XC mode only if dataset has CoT content + if has_cot_content: + if is_deepseek_or_qwen: + mode_variants.insert( + 1, ("VLLM_XC", "XC", False) + ) # Insert between NR and NR_REASONING + else: + mode_variants.insert( + 1, ("VLLM_XC", "XC", None) + ) # Insert between NR and NR_REASONING + + def run_variants(q: Question) -> List[Dict[str, Any]]: + local_records: List[Dict[str, Any]] = [] + for label, prompt_mode, reasoning_flag in mode_variants: + extra_body = build_extra_body_for_model(model, reasoning_flag) + # Debug: print extra_body for first question to verify configuration + if q == questions[0]: + print( + f" {label}: reasoning_flag={reasoning_flag}, extra_body={extra_body}" + ) + rec = process_question_single( + client, + model, + q, + dataset, + prompt_mode, + max_tokens, + temperature, + ar_extra_body=extra_body, + mode_label=label, + ) + local_records.append(rec) + return local_records + + with ThreadPoolExecutor(max_workers=concurrent_requests) as executor: + futures = [executor.submit(run_variants, q) for q in questions] + for future in tqdm( + futures, total=len(futures), desc=f"Evaluating {model} (vLLM modes)" + ): + results.extend(future.result()) + + return pd.DataFrame(results) + + +def analyze_results(results_df: pd.DataFrame) -> Dict[str, Any]: + """Analyze results and compute metrics.""" + valid = results_df[results_df["success"]] + overall_acc = valid["is_correct"].mean() if not valid.empty else 0.0 + + category_metrics: Dict[str, Dict[str, Any]] = {} + for category in valid["category"].unique(): + sub = valid[valid["category"] == category] + category_metrics[category] = { + "accuracy": float(sub["is_correct"].mean()) if not sub.empty else 0.0, + "avg_response_time": ( + float(sub["response_time"].mean()) if not sub.empty else 0.0 + ), + "avg_prompt_tokens": ( + float(sub["prompt_tokens"].dropna().mean()) + if not sub["prompt_tokens"].dropna().empty + else None + ), + "avg_completion_tokens": ( + float(sub["completion_tokens"].dropna().mean()) + if not sub["completion_tokens"].dropna().empty + else None + ), + "avg_total_tokens": ( + float(sub["total_tokens"].dropna().mean()) + if not sub["total_tokens"].dropna().empty + else None + ), + } + + avg_latency = valid["response_time"].mean() if not valid.empty else 0.0 + avg_prompt_tokens = ( + valid["prompt_tokens"].dropna().mean() if not valid.empty else None + ) + avg_completion_tokens = ( + valid["completion_tokens"].dropna().mean() if not valid.empty else None + ) + avg_total_tokens = ( + valid["total_tokens"].dropna().mean() if not valid.empty else None + ) + + # Optional: metrics by mode_label + by_mode: Dict[str, Dict[str, Any]] = {} + if "mode_label" in valid.columns: + for label in valid["mode_label"].unique(): + sub = valid[valid["mode_label"] == label] + by_mode[label] = { + "accuracy": float(sub["is_correct"].mean()) if not sub.empty else 0.0, + "avg_response_time": ( + float(sub["response_time"].mean()) if not sub.empty else 0.0 + ), + "avg_prompt_tokens": ( + float(sub["prompt_tokens"].dropna().mean()) + if not sub["prompt_tokens"].dropna().empty + else None + ), + "avg_completion_tokens": ( + float(sub["completion_tokens"].dropna().mean()) + if not sub["completion_tokens"].dropna().empty + else None + ), + "avg_total_tokens": ( + float(sub["total_tokens"].dropna().mean()) + if not sub["total_tokens"].dropna().empty + else None + ), + } + + return { + "overall_accuracy": float(overall_acc), + "category_metrics": category_metrics, + "avg_response_time": float(avg_latency) if avg_latency is not None else 0.0, + "avg_prompt_tokens": ( + float(avg_prompt_tokens) if avg_prompt_tokens is not None else None + ), + "avg_completion_tokens": ( + float(avg_completion_tokens) if avg_completion_tokens is not None else None + ), + "avg_total_tokens": ( + float(avg_total_tokens) if avg_total_tokens is not None else None + ), + "total_questions": int(len(results_df)), + "successful_queries": int(len(valid)), + "failed_queries": int(len(results_df) - len(valid)), + "by_mode": by_mode, + } + + +def save_results( + results_df: pd.DataFrame, + analysis: Dict[str, Any], + model: str, + dataset_name: str, + output_dir: str, +): + """Save results to files.""" + model_name = model.replace("/", "_") + model_dir = os.path.join(output_dir, f"{dataset_name}_{model_name}") + os.makedirs(model_dir, exist_ok=True) + + results_df.to_csv(os.path.join(model_dir, "detailed_results.csv"), index=False) + + with open(os.path.join(model_dir, "summary.json"), "w") as f: + json.dump( + { + "model": model, + "dataset": dataset_name, + **analysis, + }, + f, + indent=2, + ) + + print("\n" + "=" * 50) + print(f"Model: {model} | Dataset: {dataset_name}") + print(f"Overall Accuracy: {analysis['overall_accuracy']:.4f}") + print(f"Total Questions: {analysis['total_questions']}") + print(f"Successful Queries: {analysis['successful_queries']}") + print(f"Failed Queries: {analysis['failed_queries']}") + print( + f"Avg Latency: {analysis['avg_response_time']:.2f}s | Avg Total Tokens: {analysis['avg_total_tokens']}" + ) + print("=" * 50 + "\n") + + if "category_metrics" in analysis: + print("Category Metrics (acc | latency | total_tokens):") + printable = [] + for category, met in analysis["category_metrics"].items(): + printable.append((category, met.get("accuracy", 0.0))) + for category, acc in sorted(printable, key=lambda x: x[1], reverse=True): + m = analysis["category_metrics"][category] + print( + f" {category}: acc={m['accuracy']:.4f}, latency={m['avg_response_time']:.2f}s, tokens={m['avg_total_tokens']}" + ) + print() + + +def main(): + args = parse_args() + + # Handle dataset listing + if args.list_datasets: + list_available_datasets() + return + + # Set random seeds + random.seed(args.seed) + np.random.seed(args.seed) + + # Load dataset + print(f"Loading dataset: {args.dataset}") + try: + dataset = DatasetFactory.create_dataset(args.dataset) + questions, dataset_info = dataset.load_dataset( + categories=args.categories, + samples_per_category=args.samples_per_category, + seed=args.seed, + ) + print( + f"Dataset loaded: {len(questions)} questions across {len(dataset_info.categories)} categories" + ) + print(f"Categories: {', '.join(dataset_info.categories)}") + + # Check for empty dataset + if len(questions) == 0: + print(f"❌ No questions loaded from dataset '{args.dataset}'") + print("This could be due to:") + print(" - Dataset requiring authentication (gated dataset)") + print(" - Network connectivity issues") + print(" - Invalid dataset name or configuration") + print("\nTry a different dataset:") + list_available_datasets() + return + + except Exception as e: + print(f"Error loading dataset '{args.dataset}': {e}") + print("\nAvailable datasets:") + list_available_datasets() + return + + # Resolve endpoints and models + router_endpoint = ( + args.router_endpoint + or os.environ.get("ROUTER_ENDPOINT") + or "http://127.0.0.1:8801/v1" + ) + router_api_key = ( + args.router_api_key + or os.environ.get("ROUTER_API_KEY") + or os.environ.get("OPENAI_API_KEY") + or "1234" + ) + + vllm_endpoint = args.vllm_endpoint or os.environ.get("VLLM_ENDPOINT", "") + vllm_api_key = ( + args.vllm_api_key + or os.environ.get("VLLM_API_KEY") + or os.environ.get("OPENAI_API_KEY") + or "" + ) + + router_models = args.router_models + if router_models and len(router_models) == 1 and "," in router_models[0]: + router_models = router_models[0].split(",") + if not router_models or (len(router_models) == 1 and router_models[0] == "auto"): + print("Fetching available models from router endpoint...") + fetched_models = get_available_models(router_endpoint, router_api_key) + if fetched_models: + router_models = fetched_models + else: + print("No models returned from endpoint, using 'auto' as fallback") + router_models = ["auto"] + + vllm_models = args.vllm_models + if vllm_models and len(vllm_models) == 1 and "," in vllm_models[0]: + vllm_models = vllm_models[0].split(",") + if not vllm_models and vllm_endpoint: + print("Fetching available models from vLLM endpoint...") + vllm_models = get_available_models(vllm_endpoint, vllm_api_key) + + print(f"Router models: {router_models}") + print(f"vLLM models: {vllm_models}") + + # Determine optimal token limit for this dataset + if args.max_tokens: + optimal_tokens = args.max_tokens + print(f"Using user-specified max_tokens: {optimal_tokens}") + else: + optimal_tokens = get_dataset_optimal_tokens(dataset_info) + print( + f"Using dataset-optimal max_tokens: {optimal_tokens} (for {dataset_info.name})" + ) + + # Router evaluation (NR-only) + if args.run_router and router_endpoint and router_models: + for model in router_models: + print(f"\nEvaluating router model: {model}") + rt_df = evaluate_model_router_transparent( + questions=questions, + dataset=dataset, + model=model, + endpoint=router_endpoint, + api_key=router_api_key, + concurrent_requests=args.concurrent_requests, + max_tokens=optimal_tokens, + temperature=args.temperature, + ) + analysis = analyze_results(rt_df) + save_results( + results_df=rt_df, + analysis=analysis, + model=f"router::{model}", + dataset_name=dataset_info.name, + output_dir=args.output_dir, + ) + + # Direct vLLM evaluation (NR/XC with reasoning ON/OFF) + if args.run_vllm and vllm_endpoint and vllm_models: + for model in vllm_models: + print(f"\nEvaluating vLLM model: {model}") + vdf = evaluate_model_vllm_multimode( + questions=questions, + dataset=dataset, + model=model, + endpoint=vllm_endpoint, + api_key=vllm_api_key, + concurrent_requests=args.concurrent_requests, + max_tokens=optimal_tokens, + temperature=args.temperature, + exec_modes=args.vllm_exec_modes, + ) + analysis = analyze_results(vdf) + save_results( + results_df=vdf, + analysis=analysis, + model=f"vllm::{model}", + dataset_name=dataset_info.name, + output_dir=args.output_dir, + ) + + +if __name__ == "__main__": + main()