Skip to content

Commit fac50b1

Browse files
authored
Merge pull request #137 from rootfs/fix-max-token-bench
chore: add just max token for different models in router bench
2 parents 0a3af6e + 9d22297 commit fac50b1

File tree

1 file changed

+72
-24
lines changed

1 file changed

+72
-24
lines changed

bench/vllm_semantic_router_bench/router_reason_bench_multi_dataset.py

Lines changed: 72 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -178,35 +178,62 @@ def parse_args():
178178
return parser.parse_args()
179179

180180

181-
def get_dataset_optimal_tokens(dataset_info):
181+
def get_dataset_optimal_tokens(dataset_info, model_name=None):
182182
"""
183-
Determine optimal token limit based on dataset complexity and reasoning requirements.
183+
Determine optimal token limit based on dataset complexity, reasoning requirements, and model capabilities.
184184
185185
Token limits are optimized for structured response generation while maintaining
186-
efficiency across different reasoning complexity levels.
186+
efficiency across different reasoning complexity levels and model architectures.
187+
188+
Args:
189+
dataset_info: Dataset information object
190+
model_name: Model identifier (e.g., "openai/gpt-oss-20b", "Qwen/Qwen3-30B-A3B")
187191
"""
188192
dataset_name = dataset_info.name.lower()
189193
difficulty = dataset_info.difficulty_level.lower()
190194

191-
# Optimized token limits per dataset (increased for reasoning mode support)
192-
dataset_tokens = {
193-
"gpqa": 1500, # Graduate-level scientific reasoning
195+
# Determine model type and capabilities
196+
model_multiplier = 1.0
197+
if model_name:
198+
model_lower = model_name.lower()
199+
if "qwen" in model_lower:
200+
# Qwen models are more efficient and can handle longer contexts
201+
model_multiplier = 1.5
202+
elif "deepseek" in model_lower:
203+
# DeepSeek models (e.g., V3.1) are capable and can handle longer contexts
204+
model_multiplier = 1.5
205+
elif "gpt-oss" in model_lower:
206+
# GPT-OSS models use baseline token limits
207+
model_multiplier = 1.0
208+
# Default to baseline for unknown models
209+
210+
# Base token limits per dataset (optimized for gpt-oss20b baseline)
211+
base_dataset_tokens = {
212+
"gpqa": 3000, # Graduate-level scientific reasoning (increased for complex multi-step reasoning)
194213
"truthfulqa": 800, # Misconception analysis
195214
"hellaswag": 800, # Natural continuation reasoning
196215
"arc": 800, # Elementary/middle school science
197216
"commonsenseqa": 1000, # Common sense reasoning
198-
"mmlu": 600 if difficulty == "undergraduate" else 800, # Academic knowledge
217+
"mmlu": 3000, # Academic knowledge (increased for complex technical domains like engineering/chemistry)
199218
}
200219

201-
# Find matching dataset
202-
for dataset_key, tokens in dataset_tokens.items():
220+
# Find matching dataset and apply model multiplier
221+
base_tokens = None
222+
for dataset_key, tokens in base_dataset_tokens.items():
203223
if dataset_key in dataset_name:
204-
return tokens
224+
base_tokens = tokens
225+
break
226+
227+
# Fallback to difficulty-based tokens if dataset not found
228+
if base_tokens is None:
229+
difficulty_tokens = {"graduate": 300, "hard": 300, "moderate": 200, "easy": 150}
230+
base_tokens = difficulty_tokens.get(difficulty, 200)
205231

206-
# Default based on difficulty level
207-
difficulty_tokens = {"graduate": 300, "hard": 300, "moderate": 200, "easy": 150}
232+
# Apply model-specific multiplier and round to nearest 50
233+
final_tokens = int(base_tokens * model_multiplier)
234+
final_tokens = ((final_tokens + 25) // 50) * 50 # Round to nearest 50
208235

209-
return difficulty_tokens.get(difficulty, 200)
236+
return final_tokens
210237

211238

212239
def get_available_models(endpoint: str, api_key: str = "") -> List[str]:
@@ -507,6 +534,20 @@ def evaluate_model_vllm_multimode(
507534
q.cot_content is not None and q.cot_content.strip() for q in questions[:10]
508535
)
509536

537+
# Debug: Show CoT content status for first few questions
538+
print(f" CoT Debug - Checking first 10 questions:")
539+
for i, q in enumerate(questions[:10]):
540+
cot_status = (
541+
"None"
542+
if q.cot_content is None
543+
else (
544+
f"'{q.cot_content[:50]}...'"
545+
if len(q.cot_content) > 50
546+
else f"'{q.cot_content}'"
547+
)
548+
)
549+
print(f" Q{i+1}: CoT = {cot_status}")
550+
510551
if has_cot_content:
511552
print(f" Dataset has CoT content - using 3 modes: NR, XC, NR_REASONING")
512553
else:
@@ -827,28 +868,31 @@ def main():
827868
print(f"Router models: {router_models}")
828869
print(f"vLLM models: {vllm_models}")
829870

830-
# Determine optimal token limit for this dataset
831-
if args.max_tokens:
832-
optimal_tokens = args.max_tokens
833-
print(f"Using user-specified max_tokens: {optimal_tokens}")
834-
else:
835-
optimal_tokens = get_dataset_optimal_tokens(dataset_info)
836-
print(
837-
f"Using dataset-optimal max_tokens: {optimal_tokens} (for {dataset_info.name})"
838-
)
871+
# Function to get optimal tokens for a specific model
872+
# For fair comparison, use consistent token limits regardless of model name
873+
def get_model_optimal_tokens(model_name):
874+
if args.max_tokens:
875+
return args.max_tokens
876+
else:
877+
# Use base dataset tokens without model-specific multipliers for fair comparison
878+
return get_dataset_optimal_tokens(dataset_info, model_name=None)
839879

840880
# Router evaluation (NR-only)
841881
if args.run_router and router_endpoint and router_models:
842882
for model in router_models:
883+
model_tokens = get_model_optimal_tokens(model)
843884
print(f"\nEvaluating router model: {model}")
885+
print(
886+
f"Using max_tokens: {model_tokens} (dataset-optimized for fair comparison)"
887+
)
844888
rt_df = evaluate_model_router_transparent(
845889
questions=questions,
846890
dataset=dataset,
847891
model=model,
848892
endpoint=router_endpoint,
849893
api_key=router_api_key,
850894
concurrent_requests=args.concurrent_requests,
851-
max_tokens=optimal_tokens,
895+
max_tokens=model_tokens,
852896
temperature=args.temperature,
853897
)
854898
analysis = analyze_results(rt_df)
@@ -863,15 +907,19 @@ def main():
863907
# Direct vLLM evaluation (NR/XC with reasoning ON/OFF)
864908
if args.run_vllm and vllm_endpoint and vllm_models:
865909
for model in vllm_models:
910+
model_tokens = get_model_optimal_tokens(model)
866911
print(f"\nEvaluating vLLM model: {model}")
912+
print(
913+
f"Using max_tokens: {model_tokens} (dataset-optimized for fair comparison)"
914+
)
867915
vdf = evaluate_model_vllm_multimode(
868916
questions=questions,
869917
dataset=dataset,
870918
model=model,
871919
endpoint=vllm_endpoint,
872920
api_key=vllm_api_key,
873921
concurrent_requests=args.concurrent_requests,
874-
max_tokens=optimal_tokens,
922+
max_tokens=model_tokens,
875923
temperature=args.temperature,
876924
exec_modes=args.vllm_exec_modes,
877925
)

0 commit comments

Comments
 (0)