@@ -178,35 +178,62 @@ def parse_args():
178178    return  parser .parse_args ()
179179
180180
181- def  get_dataset_optimal_tokens (dataset_info ):
181+ def  get_dataset_optimal_tokens (dataset_info ,  model_name = None ):
182182    """ 
183-     Determine optimal token limit based on dataset complexity and  reasoning requirements. 
183+     Determine optimal token limit based on dataset complexity,  reasoning requirements, and model capabilities . 
184184
185185    Token limits are optimized for structured response generation while maintaining 
186-     efficiency across different reasoning complexity levels. 
186+     efficiency across different reasoning complexity levels and model architectures. 
187+ 
188+     Args: 
189+         dataset_info: Dataset information object 
190+         model_name: Model identifier (e.g., "openai/gpt-oss-20b", "Qwen/Qwen3-30B-A3B") 
187191    """ 
188192    dataset_name  =  dataset_info .name .lower ()
189193    difficulty  =  dataset_info .difficulty_level .lower ()
190194
191-     # Optimized token limits per dataset (increased for reasoning mode support) 
192-     dataset_tokens  =  {
193-         "gpqa" : 1500 ,  # Graduate-level scientific reasoning 
195+     # Determine model type and capabilities 
196+     model_multiplier  =  1.0 
197+     if  model_name :
198+         model_lower  =  model_name .lower ()
199+         if  "qwen"  in  model_lower :
200+             # Qwen models are more efficient and can handle longer contexts 
201+             model_multiplier  =  1.5 
202+         elif  "deepseek"  in  model_lower :
203+             # DeepSeek models (e.g., V3.1) are capable and can handle longer contexts 
204+             model_multiplier  =  1.5 
205+         elif  "gpt-oss"  in  model_lower :
206+             # GPT-OSS models use baseline token limits 
207+             model_multiplier  =  1.0 
208+         # Default to baseline for unknown models 
209+ 
210+     # Base token limits per dataset (optimized for gpt-oss20b baseline) 
211+     base_dataset_tokens  =  {
212+         "gpqa" : 3000 ,  # Graduate-level scientific reasoning (increased for complex multi-step reasoning) 
194213        "truthfulqa" : 800 ,  # Misconception analysis 
195214        "hellaswag" : 800 ,  # Natural continuation reasoning 
196215        "arc" : 800 ,  # Elementary/middle school science 
197216        "commonsenseqa" : 1000 ,  # Common sense reasoning 
198-         "mmlu" : 600   if   difficulty   ==   "undergraduate"   else   800 ,   # Academic knowledge 
217+         "mmlu" : 3000 ,   # Academic knowledge (increased for complex technical domains like engineering/chemistry) 
199218    }
200219
201-     # Find matching dataset 
202-     for  dataset_key , tokens  in  dataset_tokens .items ():
220+     # Find matching dataset and apply model multiplier 
221+     base_tokens  =  None 
222+     for  dataset_key , tokens  in  base_dataset_tokens .items ():
203223        if  dataset_key  in  dataset_name :
204-             return  tokens 
224+             base_tokens  =  tokens 
225+             break 
226+ 
227+     # Fallback to difficulty-based tokens if dataset not found 
228+     if  base_tokens  is  None :
229+         difficulty_tokens  =  {"graduate" : 300 , "hard" : 300 , "moderate" : 200 , "easy" : 150 }
230+         base_tokens  =  difficulty_tokens .get (difficulty , 200 )
205231
206-     # Default based on difficulty level 
207-     difficulty_tokens  =  {"graduate" : 300 , "hard" : 300 , "moderate" : 200 , "easy" : 150 }
232+     # Apply model-specific multiplier and round to nearest 50 
233+     final_tokens  =  int (base_tokens  *  model_multiplier )
234+     final_tokens  =  ((final_tokens  +  25 ) //  50 ) *  50   # Round to nearest 50 
208235
209-     return  difficulty_tokens . get ( difficulty ,  200 ) 
236+     return  final_tokens 
210237
211238
212239def  get_available_models (endpoint : str , api_key : str  =  "" ) ->  List [str ]:
@@ -507,6 +534,20 @@ def evaluate_model_vllm_multimode(
507534        q .cot_content  is  not   None  and  q .cot_content .strip () for  q  in  questions [:10 ]
508535    )
509536
537+     # Debug: Show CoT content status for first few questions 
538+     print (f"  CoT Debug - Checking first 10 questions:" )
539+     for  i , q  in  enumerate (questions [:10 ]):
540+         cot_status  =  (
541+             "None" 
542+             if  q .cot_content  is  None 
543+             else  (
544+                 f"'{ q .cot_content [:50 ]}  ...'" 
545+                 if  len (q .cot_content ) >  50 
546+                 else  f"'{ q .cot_content }  '" 
547+             )
548+         )
549+         print (f"    Q{ i + 1 }  : CoT = { cot_status }  " )
550+ 
510551    if  has_cot_content :
511552        print (f"  Dataset has CoT content - using 3 modes: NR, XC, NR_REASONING" )
512553    else :
@@ -827,28 +868,31 @@ def main():
827868    print (f"Router models: { router_models }  " )
828869    print (f"vLLM models: { vllm_models }  " )
829870
830-     # Determine optimal token limit for this dataset 
831-     if  args .max_tokens :
832-         optimal_tokens  =  args .max_tokens 
833-         print (f"Using user-specified max_tokens: { optimal_tokens }  " )
834-     else :
835-         optimal_tokens  =  get_dataset_optimal_tokens (dataset_info )
836-         print (
837-             f"Using dataset-optimal max_tokens: { optimal_tokens }   (for { dataset_info .name }  )" 
838-         )
871+     # Function to get optimal tokens for a specific model 
872+     # For fair comparison, use consistent token limits regardless of model name 
873+     def  get_model_optimal_tokens (model_name ):
874+         if  args .max_tokens :
875+             return  args .max_tokens 
876+         else :
877+             # Use base dataset tokens without model-specific multipliers for fair comparison 
878+             return  get_dataset_optimal_tokens (dataset_info , model_name = None )
839879
840880    # Router evaluation (NR-only) 
841881    if  args .run_router  and  router_endpoint  and  router_models :
842882        for  model  in  router_models :
883+             model_tokens  =  get_model_optimal_tokens (model )
843884            print (f"\n Evaluating router model: { model }  " )
885+             print (
886+                 f"Using max_tokens: { model_tokens }   (dataset-optimized for fair comparison)" 
887+             )
844888            rt_df  =  evaluate_model_router_transparent (
845889                questions = questions ,
846890                dataset = dataset ,
847891                model = model ,
848892                endpoint = router_endpoint ,
849893                api_key = router_api_key ,
850894                concurrent_requests = args .concurrent_requests ,
851-                 max_tokens = optimal_tokens ,
895+                 max_tokens = model_tokens ,
852896                temperature = args .temperature ,
853897            )
854898            analysis  =  analyze_results (rt_df )
@@ -863,15 +907,19 @@ def main():
863907    # Direct vLLM evaluation (NR/XC with reasoning ON/OFF) 
864908    if  args .run_vllm  and  vllm_endpoint  and  vllm_models :
865909        for  model  in  vllm_models :
910+             model_tokens  =  get_model_optimal_tokens (model )
866911            print (f"\n Evaluating vLLM model: { model }  " )
912+             print (
913+                 f"Using max_tokens: { model_tokens }   (dataset-optimized for fair comparison)" 
914+             )
867915            vdf  =  evaluate_model_vllm_multimode (
868916                questions = questions ,
869917                dataset = dataset ,
870918                model = model ,
871919                endpoint = vllm_endpoint ,
872920                api_key = vllm_api_key ,
873921                concurrent_requests = args .concurrent_requests ,
874-                 max_tokens = optimal_tokens ,
922+                 max_tokens = model_tokens ,
875923                temperature = args .temperature ,
876924                exec_modes = args .vllm_exec_modes ,
877925            )
0 commit comments