@@ -178,35 +178,62 @@ def parse_args():
178
178
return parser .parse_args ()
179
179
180
180
181
- def get_dataset_optimal_tokens (dataset_info ):
181
+ def get_dataset_optimal_tokens (dataset_info , model_name = None ):
182
182
"""
183
- Determine optimal token limit based on dataset complexity and reasoning requirements.
183
+ Determine optimal token limit based on dataset complexity, reasoning requirements, and model capabilities .
184
184
185
185
Token limits are optimized for structured response generation while maintaining
186
- efficiency across different reasoning complexity levels.
186
+ efficiency across different reasoning complexity levels and model architectures.
187
+
188
+ Args:
189
+ dataset_info: Dataset information object
190
+ model_name: Model identifier (e.g., "openai/gpt-oss-20b", "Qwen/Qwen3-30B-A3B")
187
191
"""
188
192
dataset_name = dataset_info .name .lower ()
189
193
difficulty = dataset_info .difficulty_level .lower ()
190
194
191
- # Optimized token limits per dataset (increased for reasoning mode support)
192
- dataset_tokens = {
193
- "gpqa" : 1500 , # Graduate-level scientific reasoning
195
+ # Determine model type and capabilities
196
+ model_multiplier = 1.0
197
+ if model_name :
198
+ model_lower = model_name .lower ()
199
+ if "qwen" in model_lower :
200
+ # Qwen models are more efficient and can handle longer contexts
201
+ model_multiplier = 1.5
202
+ elif "deepseek" in model_lower :
203
+ # DeepSeek models (e.g., V3.1) are capable and can handle longer contexts
204
+ model_multiplier = 1.5
205
+ elif "gpt-oss" in model_lower :
206
+ # GPT-OSS models use baseline token limits
207
+ model_multiplier = 1.0
208
+ # Default to baseline for unknown models
209
+
210
+ # Base token limits per dataset (optimized for gpt-oss20b baseline)
211
+ base_dataset_tokens = {
212
+ "gpqa" : 3000 , # Graduate-level scientific reasoning (increased for complex multi-step reasoning)
194
213
"truthfulqa" : 800 , # Misconception analysis
195
214
"hellaswag" : 800 , # Natural continuation reasoning
196
215
"arc" : 800 , # Elementary/middle school science
197
216
"commonsenseqa" : 1000 , # Common sense reasoning
198
- "mmlu" : 600 if difficulty == "undergraduate" else 800 , # Academic knowledge
217
+ "mmlu" : 3000 , # Academic knowledge (increased for complex technical domains like engineering/chemistry)
199
218
}
200
219
201
- # Find matching dataset
202
- for dataset_key , tokens in dataset_tokens .items ():
220
+ # Find matching dataset and apply model multiplier
221
+ base_tokens = None
222
+ for dataset_key , tokens in base_dataset_tokens .items ():
203
223
if dataset_key in dataset_name :
204
- return tokens
224
+ base_tokens = tokens
225
+ break
226
+
227
+ # Fallback to difficulty-based tokens if dataset not found
228
+ if base_tokens is None :
229
+ difficulty_tokens = {"graduate" : 300 , "hard" : 300 , "moderate" : 200 , "easy" : 150 }
230
+ base_tokens = difficulty_tokens .get (difficulty , 200 )
205
231
206
- # Default based on difficulty level
207
- difficulty_tokens = {"graduate" : 300 , "hard" : 300 , "moderate" : 200 , "easy" : 150 }
232
+ # Apply model-specific multiplier and round to nearest 50
233
+ final_tokens = int (base_tokens * model_multiplier )
234
+ final_tokens = ((final_tokens + 25 ) // 50 ) * 50 # Round to nearest 50
208
235
209
- return difficulty_tokens . get ( difficulty , 200 )
236
+ return final_tokens
210
237
211
238
212
239
def get_available_models (endpoint : str , api_key : str = "" ) -> List [str ]:
@@ -507,6 +534,20 @@ def evaluate_model_vllm_multimode(
507
534
q .cot_content is not None and q .cot_content .strip () for q in questions [:10 ]
508
535
)
509
536
537
+ # Debug: Show CoT content status for first few questions
538
+ print (f" CoT Debug - Checking first 10 questions:" )
539
+ for i , q in enumerate (questions [:10 ]):
540
+ cot_status = (
541
+ "None"
542
+ if q .cot_content is None
543
+ else (
544
+ f"'{ q .cot_content [:50 ]} ...'"
545
+ if len (q .cot_content ) > 50
546
+ else f"'{ q .cot_content } '"
547
+ )
548
+ )
549
+ print (f" Q{ i + 1 } : CoT = { cot_status } " )
550
+
510
551
if has_cot_content :
511
552
print (f" Dataset has CoT content - using 3 modes: NR, XC, NR_REASONING" )
512
553
else :
@@ -827,28 +868,31 @@ def main():
827
868
print (f"Router models: { router_models } " )
828
869
print (f"vLLM models: { vllm_models } " )
829
870
830
- # Determine optimal token limit for this dataset
831
- if args .max_tokens :
832
- optimal_tokens = args .max_tokens
833
- print (f"Using user-specified max_tokens: { optimal_tokens } " )
834
- else :
835
- optimal_tokens = get_dataset_optimal_tokens (dataset_info )
836
- print (
837
- f"Using dataset-optimal max_tokens: { optimal_tokens } (for { dataset_info .name } )"
838
- )
871
+ # Function to get optimal tokens for a specific model
872
+ # For fair comparison, use consistent token limits regardless of model name
873
+ def get_model_optimal_tokens (model_name ):
874
+ if args .max_tokens :
875
+ return args .max_tokens
876
+ else :
877
+ # Use base dataset tokens without model-specific multipliers for fair comparison
878
+ return get_dataset_optimal_tokens (dataset_info , model_name = None )
839
879
840
880
# Router evaluation (NR-only)
841
881
if args .run_router and router_endpoint and router_models :
842
882
for model in router_models :
883
+ model_tokens = get_model_optimal_tokens (model )
843
884
print (f"\n Evaluating router model: { model } " )
885
+ print (
886
+ f"Using max_tokens: { model_tokens } (dataset-optimized for fair comparison)"
887
+ )
844
888
rt_df = evaluate_model_router_transparent (
845
889
questions = questions ,
846
890
dataset = dataset ,
847
891
model = model ,
848
892
endpoint = router_endpoint ,
849
893
api_key = router_api_key ,
850
894
concurrent_requests = args .concurrent_requests ,
851
- max_tokens = optimal_tokens ,
895
+ max_tokens = model_tokens ,
852
896
temperature = args .temperature ,
853
897
)
854
898
analysis = analyze_results (rt_df )
@@ -863,15 +907,19 @@ def main():
863
907
# Direct vLLM evaluation (NR/XC with reasoning ON/OFF)
864
908
if args .run_vllm and vllm_endpoint and vllm_models :
865
909
for model in vllm_models :
910
+ model_tokens = get_model_optimal_tokens (model )
866
911
print (f"\n Evaluating vLLM model: { model } " )
912
+ print (
913
+ f"Using max_tokens: { model_tokens } (dataset-optimized for fair comparison)"
914
+ )
867
915
vdf = evaluate_model_vllm_multimode (
868
916
questions = questions ,
869
917
dataset = dataset ,
870
918
model = model ,
871
919
endpoint = vllm_endpoint ,
872
920
api_key = vllm_api_key ,
873
921
concurrent_requests = args .concurrent_requests ,
874
- max_tokens = optimal_tokens ,
922
+ max_tokens = model_tokens ,
875
923
temperature = args .temperature ,
876
924
exec_modes = args .vllm_exec_modes ,
877
925
)
0 commit comments