10
10
import re
11
11
import time
12
12
from concurrent .futures import ThreadPoolExecutor
13
- from typing import Any , Dict , List , Optional
13
+ from typing import Any , Dict , List , Optional , Tuple
14
14
15
15
import numpy as np
16
16
import pandas as pd
21
21
# Constants
22
22
ANSWER_PATTERN = re .compile (r"(?:answer(?:\\sis)?:?\\s*)(A|B|C|D)" , re .IGNORECASE )
23
23
TIMEOUT_SECONDS = 120
24
- MAX_RETRIES = 1
24
+ MAX_RETRIES = 1 # No retries
25
25
26
26
27
27
def parse_args ():
@@ -64,7 +64,7 @@ def parse_args():
64
64
parser .add_argument (
65
65
"--max-tokens" ,
66
66
type = int ,
67
- default = 2048 ,
67
+ default = 2048 , # Make it sufficient for the model to answer the question
68
68
help = "Maximum number of tokens to generate" ,
69
69
)
70
70
parser .add_argument (
@@ -77,6 +77,7 @@ def parse_args():
77
77
78
78
79
79
def get_available_models (endpoint : str , api_key : str = "" ) -> List [str ]:
80
+ """Get the list of available models from the vLLM OpenAI API endpoint."""
80
81
client = OpenAI (
81
82
base_url = endpoint ,
82
83
api_key = api_key ,
@@ -92,43 +93,55 @@ def get_available_models(endpoint: str, api_key: str = "") -> List[str]:
92
93
def load_arc_challenge_dataset (
93
94
samples : Optional [int ] = None , seed : int = 42
94
95
) -> pd .DataFrame :
96
+ """Load the ARC Challenge dataset"""
95
97
dataset = load_dataset ("allenai/ai2_arc" , "ARC-Challenge" , split = "train" )
96
98
df = pd .DataFrame (dataset )
99
+
97
100
if samples :
98
101
random .seed (seed )
99
102
np .random .seed (seed )
100
103
if len (df ) > samples :
101
104
df = df .sample (samples , random_state = seed )
105
+
102
106
return df
103
107
104
108
105
109
def format_cot_prompt_arc (
106
110
question : str , choices : Dict [str , List [str ]], use_cot : bool = False
107
111
) -> str :
112
+ """Format the prompt for the model with or without Chain-of-Thought."""
108
113
formatted_options = ""
114
+
109
115
for label , text in zip (choices ["label" ], choices ["text" ]):
110
116
formatted_options += f"{ label } ) { text } \n "
117
+
111
118
if use_cot :
112
119
prompt = f"Question: { question } \n \n Options:\n { formatted_options } \n \n Please solve this step-by-step, then provide your final answer in the format 'Answer: [letter]'."
113
120
else :
114
121
prompt = f"Question: { question } \n \n Options:\n { formatted_options } \n \n Please choose the correct answer from the options above. Provide your answer in the format 'Answer: [letter]'."
122
+
115
123
return prompt
116
124
117
125
118
126
def extract_answer_arc (response : str ) -> Optional [str ]:
127
+ """Extract the answer letter from the model's response."""
128
+ # Try to find the answer using regex pattern
119
129
match = ANSWER_PATTERN .search (response )
120
130
if match :
121
131
return match .group (1 ).upper ()
132
+
122
133
# fallback: last occurrence of A/B/C/D
123
134
for char in reversed (response ):
124
135
if char .upper () in "ABCD" :
125
136
return char .upper ()
137
+
126
138
return None
127
139
128
140
129
141
def call_model_with_retry (
130
142
client : OpenAI , model : str , prompt : str , max_tokens : int , temperature : float
131
- ) -> (str , bool ):
143
+ ) -> Tuple [str , bool ]:
144
+ """Call the model with retry logic for handling timeouts and errors."""
132
145
for attempt in range (MAX_RETRIES ):
133
146
try :
134
147
response = client .chat .completions .create (
@@ -139,7 +152,7 @@ def call_model_with_retry(
139
152
)
140
153
return response .choices [0 ].message .content , True
141
154
except Exception as e :
142
- if attempt < MAX_RETRIES - 1 :
155
+ if attempt < MAX_RETRIES - 1 : # Exponential backoff
143
156
delay = 2 ** attempt
144
157
print (
145
158
f"Error calling model (attempt { attempt + 1 } /{ MAX_RETRIES } ), retrying in { delay } s: { e } "
@@ -158,21 +171,28 @@ def process_question_arc(
158
171
max_tokens : int ,
159
172
temperature : float ,
160
173
) -> Dict [str , Any ]:
174
+ """Process a single question and return the results."""
161
175
question = question_data ["question" ]
162
176
choices = question_data ["choices" ]
163
177
correct_answer = question_data ["answerKey" ]
178
+
164
179
prompt = format_cot_prompt_arc (question , choices , use_cot )
180
+
181
+ # append the prompt and correct answer to a file
165
182
with open ("arc_challenge_vllm_eval.txt" , "a" ) as f :
166
183
f .write (f"Prompt: { prompt } \n " )
167
184
f .write (f"Correct answer: { correct_answer } \n \n " )
185
+
168
186
start_time = time .time ()
169
187
response_text , success = call_model_with_retry (
170
188
client , model , prompt , max_tokens , temperature
171
189
)
172
190
end_time = time .time ()
191
+
173
192
predicted_answer = extract_answer_arc (response_text ) if success else None
174
193
is_correct = (predicted_answer == correct_answer ) if predicted_answer else False
175
194
print (f"Predicted answer: { predicted_answer } , Correct answer: { correct_answer } " )
195
+
176
196
return {
177
197
"id" : question_data ["id" ],
178
198
"question" : question ,
@@ -196,10 +216,14 @@ def evaluate_model_arc(
196
216
max_tokens : int ,
197
217
temperature : float ,
198
218
) -> pd .DataFrame :
219
+ """Evaluate a model on the ARC Challenge dataset."""
199
220
client = OpenAI (base_url = endpoint , api_key = api_key if api_key else "dummy" )
200
221
print (f"Using model: { model } , endpoint: { endpoint } , api_key: { api_key } " )
201
222
results = []
223
+
224
+ # Convert DataFrame rows to dictionaries for processing
202
225
questions_data = df .to_dict ("records" )
226
+
203
227
with ThreadPoolExecutor (max_workers = concurrent_requests ) as executor :
204
228
futures = []
205
229
for question_data in questions_data :
@@ -213,21 +237,30 @@ def evaluate_model_arc(
213
237
temperature ,
214
238
)
215
239
futures .append (future )
240
+
216
241
for future in tqdm (futures , total = len (futures ), desc = f"Evaluating { model } " ):
217
242
result = future .result ()
218
243
results .append (result )
244
+
219
245
results_df = pd .DataFrame (results )
220
246
return results_df
221
247
222
248
223
249
def analyze_results_arc (results_df : pd .DataFrame ) -> Dict [str , float ]:
250
+ """Analyze the results and compute statistics."""
251
+ # Skip failed requests in the analysis
224
252
valid_results = results_df [results_df ["success" ]]
253
+
254
+ # Overall accuracy
225
255
overall_accuracy = (
226
256
valid_results ["is_correct" ].mean () if not valid_results .empty else 0.0
227
257
)
258
+
259
+ # Compute average response time
228
260
avg_response_time = (
229
261
valid_results ["response_time" ].mean () if not valid_results .empty else 0.0
230
262
)
263
+
231
264
return {
232
265
"overall_accuracy" : overall_accuracy ,
233
266
"avg_response_time" : avg_response_time ,
@@ -244,14 +277,23 @@ def save_results_arc(
244
277
output_dir : str ,
245
278
use_cot : bool ,
246
279
):
280
+ """Save the results and analysis to files."""
247
281
model_name = model .replace ("/" , "_" )
248
282
cot_suffix = "cot" if use_cot else "direct"
283
+
284
+ # Create output directory if it doesn't exist
249
285
os .makedirs (output_dir , exist_ok = True )
250
286
model_dir = os .path .join (output_dir , f"{ model_name } _{ cot_suffix } " )
251
287
os .makedirs (model_dir , exist_ok = True )
288
+
289
+ # Save detailed results
252
290
results_df .to_csv (os .path .join (model_dir , "detailed_results.csv" ), index = False )
291
+
292
+ # Save analysis
253
293
with open (os .path .join (model_dir , "analysis.json" ), "w" ) as f :
254
294
json .dump (analysis , f , indent = 2 )
295
+
296
+ # Save summary
255
297
summary = {
256
298
"model" : model ,
257
299
"approach" : "Chain-of-Thought" if use_cot else "Direct" ,
@@ -261,8 +303,11 @@ def save_results_arc(
261
303
"failed_queries" : analysis ["failed_queries" ],
262
304
"avg_response_time" : analysis ["avg_response_time" ],
263
305
}
306
+
264
307
with open (os .path .join (model_dir , "summary.json" ), "w" ) as f :
265
308
json .dump (summary , f , indent = 2 )
309
+
310
+ # Print summary
266
311
print ("\n " + "=" * 50 )
267
312
print (f"Model: { model } " )
268
313
print (f"Approach: { 'Chain-of-Thought' if use_cot else 'Direct' } " )
@@ -276,8 +321,12 @@ def save_results_arc(
276
321
277
322
def main ():
278
323
args = parse_args ()
324
+
325
+ # Set random seed for reproducibility
279
326
random .seed (args .seed )
280
327
np .random .seed (args .seed )
328
+
329
+ # Get available models if not specified
281
330
if not args .models :
282
331
print ("Fetching available models from vLLM endpoint..." )
283
332
models = get_available_models (args .endpoint , args .api_key )
@@ -287,12 +336,19 @@ def main():
287
336
)
288
337
return
289
338
args .models = models
339
+
290
340
if args .models and len (args .models ) == 1 and "," in args .models [0 ]:
291
341
args .models = args .models [0 ].split ("," )
342
+
292
343
print (f"Models to evaluate: { args .models } " )
344
+
345
+ # Load dataset
293
346
print ("Loading ARC Challenge dataset..." )
294
347
df = load_arc_challenge_dataset (samples = args .samples , seed = args .seed )
348
+
295
349
print (f"Dataset loaded: { len (df )} questions" )
350
+
351
+ # Evaluate each model
296
352
for model in args .models :
297
353
print (f"\n Evaluating model: { model } " )
298
354
results_df = evaluate_model_arc (
@@ -305,6 +361,8 @@ def main():
305
361
max_tokens = args .max_tokens ,
306
362
temperature = args .temperature ,
307
363
)
364
+
365
+ # Analyze and save results
308
366
analysis = analyze_results_arc (results_df )
309
367
save_results_arc (
310
368
results_df = results_df ,
0 commit comments