Skip to content

Commit 14cb752

Browse files
authored
fix: add comments for readability (#135)
* add comments for readability Signed-off-by: JaredforReal <[email protected]> * Fix trailing whitespace in ARC eval script Signed-off-by: JaredforReal <[email protected]> --------- Signed-off-by: JaredforReal <[email protected]>
1 parent 4a6c433 commit 14cb752

File tree

2 files changed

+64
-7
lines changed

2 files changed

+64
-7
lines changed

src/training/model_eval/arc_challenge_vllm_eval.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import re
1111
import time
1212
from concurrent.futures import ThreadPoolExecutor
13-
from typing import Any, Dict, List, Optional
13+
from typing import Any, Dict, List, Optional, Tuple
1414

1515
import numpy as np
1616
import pandas as pd
@@ -21,7 +21,7 @@
2121
# Constants
2222
ANSWER_PATTERN = re.compile(r"(?:answer(?:\\sis)?:?\\s*)(A|B|C|D)", re.IGNORECASE)
2323
TIMEOUT_SECONDS = 120
24-
MAX_RETRIES = 1
24+
MAX_RETRIES = 1 # No retries
2525

2626

2727
def parse_args():
@@ -64,7 +64,7 @@ def parse_args():
6464
parser.add_argument(
6565
"--max-tokens",
6666
type=int,
67-
default=2048,
67+
default=2048, # Make it sufficient for the model to answer the question
6868
help="Maximum number of tokens to generate",
6969
)
7070
parser.add_argument(
@@ -77,6 +77,7 @@ def parse_args():
7777

7878

7979
def get_available_models(endpoint: str, api_key: str = "") -> List[str]:
80+
"""Get the list of available models from the vLLM OpenAI API endpoint."""
8081
client = OpenAI(
8182
base_url=endpoint,
8283
api_key=api_key,
@@ -92,43 +93,55 @@ def get_available_models(endpoint: str, api_key: str = "") -> List[str]:
9293
def load_arc_challenge_dataset(
9394
samples: Optional[int] = None, seed: int = 42
9495
) -> pd.DataFrame:
96+
"""Load the ARC Challenge dataset"""
9597
dataset = load_dataset("allenai/ai2_arc", "ARC-Challenge", split="train")
9698
df = pd.DataFrame(dataset)
99+
97100
if samples:
98101
random.seed(seed)
99102
np.random.seed(seed)
100103
if len(df) > samples:
101104
df = df.sample(samples, random_state=seed)
105+
102106
return df
103107

104108

105109
def format_cot_prompt_arc(
106110
question: str, choices: Dict[str, List[str]], use_cot: bool = False
107111
) -> str:
112+
"""Format the prompt for the model with or without Chain-of-Thought."""
108113
formatted_options = ""
114+
109115
for label, text in zip(choices["label"], choices["text"]):
110116
formatted_options += f"{label}) {text}\n"
117+
111118
if use_cot:
112119
prompt = f"Question: {question}\n\nOptions:\n{formatted_options}\n\nPlease solve this step-by-step, then provide your final answer in the format 'Answer: [letter]'."
113120
else:
114121
prompt = f"Question: {question}\n\nOptions:\n{formatted_options}\n\nPlease choose the correct answer from the options above. Provide your answer in the format 'Answer: [letter]'."
122+
115123
return prompt
116124

117125

118126
def extract_answer_arc(response: str) -> Optional[str]:
127+
"""Extract the answer letter from the model's response."""
128+
# Try to find the answer using regex pattern
119129
match = ANSWER_PATTERN.search(response)
120130
if match:
121131
return match.group(1).upper()
132+
122133
# fallback: last occurrence of A/B/C/D
123134
for char in reversed(response):
124135
if char.upper() in "ABCD":
125136
return char.upper()
137+
126138
return None
127139

128140

129141
def call_model_with_retry(
130142
client: OpenAI, model: str, prompt: str, max_tokens: int, temperature: float
131-
) -> (str, bool):
143+
) -> Tuple[str, bool]:
144+
"""Call the model with retry logic for handling timeouts and errors."""
132145
for attempt in range(MAX_RETRIES):
133146
try:
134147
response = client.chat.completions.create(
@@ -139,7 +152,7 @@ def call_model_with_retry(
139152
)
140153
return response.choices[0].message.content, True
141154
except Exception as e:
142-
if attempt < MAX_RETRIES - 1:
155+
if attempt < MAX_RETRIES - 1: # Exponential backoff
143156
delay = 2**attempt
144157
print(
145158
f"Error calling model (attempt {attempt+1}/{MAX_RETRIES}), retrying in {delay}s: {e}"
@@ -158,21 +171,28 @@ def process_question_arc(
158171
max_tokens: int,
159172
temperature: float,
160173
) -> Dict[str, Any]:
174+
"""Process a single question and return the results."""
161175
question = question_data["question"]
162176
choices = question_data["choices"]
163177
correct_answer = question_data["answerKey"]
178+
164179
prompt = format_cot_prompt_arc(question, choices, use_cot)
180+
181+
# append the prompt and correct answer to a file
165182
with open("arc_challenge_vllm_eval.txt", "a") as f:
166183
f.write(f"Prompt: {prompt}\n")
167184
f.write(f"Correct answer: {correct_answer}\n\n")
185+
168186
start_time = time.time()
169187
response_text, success = call_model_with_retry(
170188
client, model, prompt, max_tokens, temperature
171189
)
172190
end_time = time.time()
191+
173192
predicted_answer = extract_answer_arc(response_text) if success else None
174193
is_correct = (predicted_answer == correct_answer) if predicted_answer else False
175194
print(f"Predicted answer: {predicted_answer}, Correct answer: {correct_answer}")
195+
176196
return {
177197
"id": question_data["id"],
178198
"question": question,
@@ -196,10 +216,14 @@ def evaluate_model_arc(
196216
max_tokens: int,
197217
temperature: float,
198218
) -> pd.DataFrame:
219+
"""Evaluate a model on the ARC Challenge dataset."""
199220
client = OpenAI(base_url=endpoint, api_key=api_key if api_key else "dummy")
200221
print(f"Using model: {model}, endpoint: {endpoint}, api_key: {api_key}")
201222
results = []
223+
224+
# Convert DataFrame rows to dictionaries for processing
202225
questions_data = df.to_dict("records")
226+
203227
with ThreadPoolExecutor(max_workers=concurrent_requests) as executor:
204228
futures = []
205229
for question_data in questions_data:
@@ -213,21 +237,30 @@ def evaluate_model_arc(
213237
temperature,
214238
)
215239
futures.append(future)
240+
216241
for future in tqdm(futures, total=len(futures), desc=f"Evaluating {model}"):
217242
result = future.result()
218243
results.append(result)
244+
219245
results_df = pd.DataFrame(results)
220246
return results_df
221247

222248

223249
def analyze_results_arc(results_df: pd.DataFrame) -> Dict[str, float]:
250+
"""Analyze the results and compute statistics."""
251+
# Skip failed requests in the analysis
224252
valid_results = results_df[results_df["success"]]
253+
254+
# Overall accuracy
225255
overall_accuracy = (
226256
valid_results["is_correct"].mean() if not valid_results.empty else 0.0
227257
)
258+
259+
# Compute average response time
228260
avg_response_time = (
229261
valid_results["response_time"].mean() if not valid_results.empty else 0.0
230262
)
263+
231264
return {
232265
"overall_accuracy": overall_accuracy,
233266
"avg_response_time": avg_response_time,
@@ -244,14 +277,23 @@ def save_results_arc(
244277
output_dir: str,
245278
use_cot: bool,
246279
):
280+
"""Save the results and analysis to files."""
247281
model_name = model.replace("/", "_")
248282
cot_suffix = "cot" if use_cot else "direct"
283+
284+
# Create output directory if it doesn't exist
249285
os.makedirs(output_dir, exist_ok=True)
250286
model_dir = os.path.join(output_dir, f"{model_name}_{cot_suffix}")
251287
os.makedirs(model_dir, exist_ok=True)
288+
289+
# Save detailed results
252290
results_df.to_csv(os.path.join(model_dir, "detailed_results.csv"), index=False)
291+
292+
# Save analysis
253293
with open(os.path.join(model_dir, "analysis.json"), "w") as f:
254294
json.dump(analysis, f, indent=2)
295+
296+
# Save summary
255297
summary = {
256298
"model": model,
257299
"approach": "Chain-of-Thought" if use_cot else "Direct",
@@ -261,8 +303,11 @@ def save_results_arc(
261303
"failed_queries": analysis["failed_queries"],
262304
"avg_response_time": analysis["avg_response_time"],
263305
}
306+
264307
with open(os.path.join(model_dir, "summary.json"), "w") as f:
265308
json.dump(summary, f, indent=2)
309+
310+
# Print summary
266311
print("\n" + "=" * 50)
267312
print(f"Model: {model}")
268313
print(f"Approach: {'Chain-of-Thought' if use_cot else 'Direct'}")
@@ -276,8 +321,12 @@ def save_results_arc(
276321

277322
def main():
278323
args = parse_args()
324+
325+
# Set random seed for reproducibility
279326
random.seed(args.seed)
280327
np.random.seed(args.seed)
328+
329+
# Get available models if not specified
281330
if not args.models:
282331
print("Fetching available models from vLLM endpoint...")
283332
models = get_available_models(args.endpoint, args.api_key)
@@ -287,12 +336,19 @@ def main():
287336
)
288337
return
289338
args.models = models
339+
290340
if args.models and len(args.models) == 1 and "," in args.models[0]:
291341
args.models = args.models[0].split(",")
342+
292343
print(f"Models to evaluate: {args.models}")
344+
345+
# Load dataset
293346
print("Loading ARC Challenge dataset...")
294347
df = load_arc_challenge_dataset(samples=args.samples, seed=args.seed)
348+
295349
print(f"Dataset loaded: {len(df)} questions")
350+
351+
# Evaluate each model
296352
for model in args.models:
297353
print(f"\nEvaluating model: {model}")
298354
results_df = evaluate_model_arc(
@@ -305,6 +361,8 @@ def main():
305361
max_tokens=args.max_tokens,
306362
temperature=args.temperature,
307363
)
364+
365+
# Analyze and save results
308366
analysis = analyze_results_arc(results_df)
309367
save_results_arc(
310368
results_df=results_df,

src/training/model_eval/mmlu_pro_vllm_eval.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import random
1010
import re
1111
import time
12-
from collections import defaultdict
1312
from concurrent.futures import ThreadPoolExecutor
1413
from typing import Any, Dict, List, Optional, Tuple
1514

@@ -230,6 +229,7 @@ def process_question(
230229
correct_answer = question_data["answer"]
231230

232231
prompt = format_cot_prompt(question, options, use_cot)
232+
233233
# append the prompt, category and correct answer to a file
234234
with open("mmlu_pro_vllm_eval.txt", "a") as f:
235235
f.write(f"Category: {question_data['category']}\n")
@@ -240,7 +240,6 @@ def process_question(
240240
response_text, success = call_model_with_retry(
241241
client, model, prompt, max_tokens, temperature
242242
)
243-
# print(f"Response: {response_text}")
244243
end_time = time.time()
245244

246245
predicted_answer = extract_answer(response_text) if success else None

0 commit comments

Comments
 (0)