Skip to content

Commit 2bc24ff

Browse files
committed
chore: add just max token for different models in router bench
Signed-off-by: Huamin Chen <[email protected]>
1 parent 14cb752 commit 2bc24ff

File tree

1 file changed

+52
-24
lines changed

1 file changed

+52
-24
lines changed

bench/vllm_semantic_router_bench/router_reason_bench_multi_dataset.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -178,35 +178,62 @@ def parse_args():
178178
return parser.parse_args()
179179

180180

181-
def get_dataset_optimal_tokens(dataset_info):
181+
def get_dataset_optimal_tokens(dataset_info, model_name=None):
182182
"""
183-
Determine optimal token limit based on dataset complexity and reasoning requirements.
183+
Determine optimal token limit based on dataset complexity, reasoning requirements, and model capabilities.
184184
185185
Token limits are optimized for structured response generation while maintaining
186-
efficiency across different reasoning complexity levels.
186+
efficiency across different reasoning complexity levels and model architectures.
187+
188+
Args:
189+
dataset_info: Dataset information object
190+
model_name: Model identifier (e.g., "openai/gpt-oss-20b", "Qwen/Qwen3-30B-A3B")
187191
"""
188192
dataset_name = dataset_info.name.lower()
189193
difficulty = dataset_info.difficulty_level.lower()
190194

191-
# Optimized token limits per dataset (increased for reasoning mode support)
192-
dataset_tokens = {
193-
"gpqa": 1500, # Graduate-level scientific reasoning
195+
# Determine model type and capabilities
196+
model_multiplier = 1.0
197+
if model_name:
198+
model_lower = model_name.lower()
199+
if "qwen" in model_lower:
200+
# Qwen models are more efficient and can handle longer contexts
201+
model_multiplier = 1.5
202+
elif "deepseek" in model_lower:
203+
# DeepSeek models (e.g., V3.1) are capable and can handle longer contexts
204+
model_multiplier = 1.5
205+
elif "gpt-oss" in model_lower:
206+
# GPT-OSS models use baseline token limits
207+
model_multiplier = 1.0
208+
# Default to baseline for unknown models
209+
210+
# Base token limits per dataset (optimized for gpt-oss20b baseline)
211+
base_dataset_tokens = {
212+
"gpqa": 2000, # Graduate-level scientific reasoning (increased from 1500 due to truncation)
194213
"truthfulqa": 800, # Misconception analysis
195214
"hellaswag": 800, # Natural continuation reasoning
196215
"arc": 800, # Elementary/middle school science
197216
"commonsenseqa": 1000, # Common sense reasoning
198-
"mmlu": 600 if difficulty == "undergraduate" else 800, # Academic knowledge
217+
"mmlu": 1600, # Academic knowledge (increased from 1200 due to continued truncation at 150 tokens)
199218
}
200219

201-
# Find matching dataset
202-
for dataset_key, tokens in dataset_tokens.items():
220+
# Find matching dataset and apply model multiplier
221+
base_tokens = None
222+
for dataset_key, tokens in base_dataset_tokens.items():
203223
if dataset_key in dataset_name:
204-
return tokens
224+
base_tokens = tokens
225+
break
226+
227+
# Fallback to difficulty-based tokens if dataset not found
228+
if base_tokens is None:
229+
difficulty_tokens = {"graduate": 300, "hard": 300, "moderate": 200, "easy": 150}
230+
base_tokens = difficulty_tokens.get(difficulty, 200)
205231

206-
# Default based on difficulty level
207-
difficulty_tokens = {"graduate": 300, "hard": 300, "moderate": 200, "easy": 150}
232+
# Apply model-specific multiplier and round to nearest 50
233+
final_tokens = int(base_tokens * model_multiplier)
234+
final_tokens = ((final_tokens + 25) // 50) * 50 # Round to nearest 50
208235

209-
return difficulty_tokens.get(difficulty, 200)
236+
return final_tokens
210237

211238

212239
def get_available_models(endpoint: str, api_key: str = "") -> List[str]:
@@ -827,28 +854,27 @@ def main():
827854
print(f"Router models: {router_models}")
828855
print(f"vLLM models: {vllm_models}")
829856

830-
# Determine optimal token limit for this dataset
831-
if args.max_tokens:
832-
optimal_tokens = args.max_tokens
833-
print(f"Using user-specified max_tokens: {optimal_tokens}")
834-
else:
835-
optimal_tokens = get_dataset_optimal_tokens(dataset_info)
836-
print(
837-
f"Using dataset-optimal max_tokens: {optimal_tokens} (for {dataset_info.name})"
838-
)
857+
# Function to get optimal tokens for a specific model
858+
def get_model_optimal_tokens(model_name):
859+
if args.max_tokens:
860+
return args.max_tokens
861+
else:
862+
return get_dataset_optimal_tokens(dataset_info, model_name)
839863

840864
# Router evaluation (NR-only)
841865
if args.run_router and router_endpoint and router_models:
842866
for model in router_models:
867+
model_tokens = get_model_optimal_tokens(model)
843868
print(f"\nEvaluating router model: {model}")
869+
print(f"Using max_tokens: {model_tokens} (model-optimized for {model})")
844870
rt_df = evaluate_model_router_transparent(
845871
questions=questions,
846872
dataset=dataset,
847873
model=model,
848874
endpoint=router_endpoint,
849875
api_key=router_api_key,
850876
concurrent_requests=args.concurrent_requests,
851-
max_tokens=optimal_tokens,
877+
max_tokens=model_tokens,
852878
temperature=args.temperature,
853879
)
854880
analysis = analyze_results(rt_df)
@@ -863,15 +889,17 @@ def main():
863889
# Direct vLLM evaluation (NR/XC with reasoning ON/OFF)
864890
if args.run_vllm and vllm_endpoint and vllm_models:
865891
for model in vllm_models:
892+
model_tokens = get_model_optimal_tokens(model)
866893
print(f"\nEvaluating vLLM model: {model}")
894+
print(f"Using max_tokens: {model_tokens} (model-optimized for {model})")
867895
vdf = evaluate_model_vllm_multimode(
868896
questions=questions,
869897
dataset=dataset,
870898
model=model,
871899
endpoint=vllm_endpoint,
872900
api_key=vllm_api_key,
873901
concurrent_requests=args.concurrent_requests,
874-
max_tokens=optimal_tokens,
902+
max_tokens=model_tokens,
875903
temperature=args.temperature,
876904
exec_modes=args.vllm_exec_modes,
877905
)

0 commit comments

Comments
 (0)