1010import re
1111import time
1212from concurrent .futures import ThreadPoolExecutor
13- from typing import Any , Dict , List , Optional
13+ from typing import Any , Dict , List , Optional , Tuple
1414
1515import numpy as np
1616import pandas as pd
2121# Constants
2222ANSWER_PATTERN = re .compile (r"(?:answer(?:\\sis)?:?\\s*)(A|B|C|D)" , re .IGNORECASE )
2323TIMEOUT_SECONDS = 120
24- MAX_RETRIES = 1
24+ MAX_RETRIES = 1 # No retries
2525
2626
2727def parse_args ():
@@ -64,7 +64,7 @@ def parse_args():
6464 parser .add_argument (
6565 "--max-tokens" ,
6666 type = int ,
67- default = 2048 ,
67+ default = 2048 , # Make it sufficient for the model to answer the question
6868 help = "Maximum number of tokens to generate" ,
6969 )
7070 parser .add_argument (
@@ -77,6 +77,7 @@ def parse_args():
7777
7878
7979def get_available_models (endpoint : str , api_key : str = "" ) -> List [str ]:
80+ """Get the list of available models from the vLLM OpenAI API endpoint."""
8081 client = OpenAI (
8182 base_url = endpoint ,
8283 api_key = api_key ,
@@ -92,43 +93,55 @@ def get_available_models(endpoint: str, api_key: str = "") -> List[str]:
9293def load_arc_challenge_dataset (
9394 samples : Optional [int ] = None , seed : int = 42
9495) -> pd .DataFrame :
96+ """Load the ARC Challenge dataset"""
9597 dataset = load_dataset ("allenai/ai2_arc" , "ARC-Challenge" , split = "train" )
9698 df = pd .DataFrame (dataset )
99+
97100 if samples :
98101 random .seed (seed )
99102 np .random .seed (seed )
100103 if len (df ) > samples :
101104 df = df .sample (samples , random_state = seed )
105+
102106 return df
103107
104108
105109def format_cot_prompt_arc (
106110 question : str , choices : Dict [str , List [str ]], use_cot : bool = False
107111) -> str :
112+ """Format the prompt for the model with or without Chain-of-Thought."""
108113 formatted_options = ""
114+
109115 for label , text in zip (choices ["label" ], choices ["text" ]):
110116 formatted_options += f"{ label } ) { text } \n "
117+
111118 if use_cot :
112119 prompt = f"Question: { question } \n \n Options:\n { formatted_options } \n \n Please solve this step-by-step, then provide your final answer in the format 'Answer: [letter]'."
113120 else :
114121 prompt = f"Question: { question } \n \n Options:\n { formatted_options } \n \n Please choose the correct answer from the options above. Provide your answer in the format 'Answer: [letter]'."
122+
115123 return prompt
116124
117125
118126def extract_answer_arc (response : str ) -> Optional [str ]:
127+ """Extract the answer letter from the model's response."""
128+ # Try to find the answer using regex pattern
119129 match = ANSWER_PATTERN .search (response )
120130 if match :
121131 return match .group (1 ).upper ()
132+
122133 # fallback: last occurrence of A/B/C/D
123134 for char in reversed (response ):
124135 if char .upper () in "ABCD" :
125136 return char .upper ()
137+
126138 return None
127139
128140
129141def call_model_with_retry (
130142 client : OpenAI , model : str , prompt : str , max_tokens : int , temperature : float
131- ) -> (str , bool ):
143+ ) -> Tuple [str , bool ]:
144+ """Call the model with retry logic for handling timeouts and errors."""
132145 for attempt in range (MAX_RETRIES ):
133146 try :
134147 response = client .chat .completions .create (
@@ -139,7 +152,7 @@ def call_model_with_retry(
139152 )
140153 return response .choices [0 ].message .content , True
141154 except Exception as e :
142- if attempt < MAX_RETRIES - 1 :
155+ if attempt < MAX_RETRIES - 1 : # Exponential backoff
143156 delay = 2 ** attempt
144157 print (
145158 f"Error calling model (attempt { attempt + 1 } /{ MAX_RETRIES } ), retrying in { delay } s: { e } "
@@ -158,21 +171,28 @@ def process_question_arc(
158171 max_tokens : int ,
159172 temperature : float ,
160173) -> Dict [str , Any ]:
174+ """Process a single question and return the results."""
161175 question = question_data ["question" ]
162176 choices = question_data ["choices" ]
163177 correct_answer = question_data ["answerKey" ]
178+
164179 prompt = format_cot_prompt_arc (question , choices , use_cot )
180+
181+ # append the prompt and correct answer to a file
165182 with open ("arc_challenge_vllm_eval.txt" , "a" ) as f :
166183 f .write (f"Prompt: { prompt } \n " )
167184 f .write (f"Correct answer: { correct_answer } \n \n " )
185+
168186 start_time = time .time ()
169187 response_text , success = call_model_with_retry (
170188 client , model , prompt , max_tokens , temperature
171189 )
172190 end_time = time .time ()
191+
173192 predicted_answer = extract_answer_arc (response_text ) if success else None
174193 is_correct = (predicted_answer == correct_answer ) if predicted_answer else False
175194 print (f"Predicted answer: { predicted_answer } , Correct answer: { correct_answer } " )
195+
176196 return {
177197 "id" : question_data ["id" ],
178198 "question" : question ,
@@ -196,10 +216,14 @@ def evaluate_model_arc(
196216 max_tokens : int ,
197217 temperature : float ,
198218) -> pd .DataFrame :
219+ """Evaluate a model on the ARC Challenge dataset."""
199220 client = OpenAI (base_url = endpoint , api_key = api_key if api_key else "dummy" )
200221 print (f"Using model: { model } , endpoint: { endpoint } , api_key: { api_key } " )
201222 results = []
223+
224+ # Convert DataFrame rows to dictionaries for processing
202225 questions_data = df .to_dict ("records" )
226+
203227 with ThreadPoolExecutor (max_workers = concurrent_requests ) as executor :
204228 futures = []
205229 for question_data in questions_data :
@@ -213,21 +237,30 @@ def evaluate_model_arc(
213237 temperature ,
214238 )
215239 futures .append (future )
240+
216241 for future in tqdm (futures , total = len (futures ), desc = f"Evaluating { model } " ):
217242 result = future .result ()
218243 results .append (result )
244+
219245 results_df = pd .DataFrame (results )
220246 return results_df
221247
222248
223249def analyze_results_arc (results_df : pd .DataFrame ) -> Dict [str , float ]:
250+ """Analyze the results and compute statistics."""
251+ # Skip failed requests in the analysis
224252 valid_results = results_df [results_df ["success" ]]
253+
254+ # Overall accuracy
225255 overall_accuracy = (
226256 valid_results ["is_correct" ].mean () if not valid_results .empty else 0.0
227257 )
258+
259+ # Compute average response time
228260 avg_response_time = (
229261 valid_results ["response_time" ].mean () if not valid_results .empty else 0.0
230262 )
263+
231264 return {
232265 "overall_accuracy" : overall_accuracy ,
233266 "avg_response_time" : avg_response_time ,
@@ -244,14 +277,23 @@ def save_results_arc(
244277 output_dir : str ,
245278 use_cot : bool ,
246279):
280+ """Save the results and analysis to files."""
247281 model_name = model .replace ("/" , "_" )
248282 cot_suffix = "cot" if use_cot else "direct"
283+
284+ # Create output directory if it doesn't exist
249285 os .makedirs (output_dir , exist_ok = True )
250286 model_dir = os .path .join (output_dir , f"{ model_name } _{ cot_suffix } " )
251287 os .makedirs (model_dir , exist_ok = True )
288+
289+ # Save detailed results
252290 results_df .to_csv (os .path .join (model_dir , "detailed_results.csv" ), index = False )
291+
292+ # Save analysis
253293 with open (os .path .join (model_dir , "analysis.json" ), "w" ) as f :
254294 json .dump (analysis , f , indent = 2 )
295+
296+ # Save summary
255297 summary = {
256298 "model" : model ,
257299 "approach" : "Chain-of-Thought" if use_cot else "Direct" ,
@@ -261,8 +303,11 @@ def save_results_arc(
261303 "failed_queries" : analysis ["failed_queries" ],
262304 "avg_response_time" : analysis ["avg_response_time" ],
263305 }
306+
264307 with open (os .path .join (model_dir , "summary.json" ), "w" ) as f :
265308 json .dump (summary , f , indent = 2 )
309+
310+ # Print summary
266311 print ("\n " + "=" * 50 )
267312 print (f"Model: { model } " )
268313 print (f"Approach: { 'Chain-of-Thought' if use_cot else 'Direct' } " )
@@ -276,8 +321,12 @@ def save_results_arc(
276321
277322def main ():
278323 args = parse_args ()
324+
325+ # Set random seed for reproducibility
279326 random .seed (args .seed )
280327 np .random .seed (args .seed )
328+
329+ # Get available models if not specified
281330 if not args .models :
282331 print ("Fetching available models from vLLM endpoint..." )
283332 models = get_available_models (args .endpoint , args .api_key )
@@ -287,12 +336,19 @@ def main():
287336 )
288337 return
289338 args .models = models
339+
290340 if args .models and len (args .models ) == 1 and "," in args .models [0 ]:
291341 args .models = args .models [0 ].split ("," )
342+
292343 print (f"Models to evaluate: { args .models } " )
344+
345+ # Load dataset
293346 print ("Loading ARC Challenge dataset..." )
294347 df = load_arc_challenge_dataset (samples = args .samples , seed = args .seed )
348+
295349 print (f"Dataset loaded: { len (df )} questions" )
350+
351+ # Evaluate each model
296352 for model in args .models :
297353 print (f"\n Evaluating model: { model } " )
298354 results_df = evaluate_model_arc (
@@ -305,6 +361,8 @@ def main():
305361 max_tokens = args .max_tokens ,
306362 temperature = args .temperature ,
307363 )
364+
365+ # Analyze and save results
308366 analysis = analyze_results_arc (results_df )
309367 save_results_arc (
310368 results_df = results_df ,
0 commit comments