Skip to content

Commit e1c8e5f

Browse files
committed
Introduce accuracy testing for LLMs in tt-forge.
We currently use identical approach to accuracy testing in tt-metal for 1-1 comparison. - Edit llm_benchmark.py to add option "accuracy_testing". When this is true, top1 and top5 metrics are calculated using precomputed CPU model outputs stored in reference_outputs folder (e.g. for benchmark/tt-xla/reference_outputs/Qwen2.5-0.5B-Instruct.refpt). When accuracy_testing is true, we don't perform PCC checks. - Implement TokenAccuracy class that manages computing top1 and top5 metrics from reference_outputs. - Add accuracy tests in test_llms.py - Add tests to perf-bench-matrix.json Generating reference outputs: - Add generate_reference_outputs.py script that loads Huggingface model, runs it on "Tale of Two Cities" text corpus, and generates a .refpt file containing reference tokens and top-5 predictions for each position. - Added directory with reference .refpt files - /reference_outputs with README that explains how reference files are created and used
1 parent 3fe9490 commit e1c8e5f

28 files changed

+1032
-24
lines changed

.github/workflows/perf-bench-matrix.json

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,111 @@
281281
"name": "unet_for_conditional_generation",
282282
"pyreq": "accelerate datasets diffusers==0.36.0 loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1",
283283
"pytest": "benchmark/tt-xla/test_encoders.py::test_unet_for_conditional_generation"
284+
},
285+
{
286+
"name": "llama_3_2_1b_instruct_accuracy",
287+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 tqdm transformers==4.57.1",
288+
"pytest": "benchmark/tt-xla/test_llms.py::test_llama_3_2_1b_accuracy"
289+
},
290+
{
291+
"name": "llama_3_2_3b_instruct_accuracy",
292+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 tqdm transformers==4.57.1",
293+
"pytest": "benchmark/tt-xla/test_llms.py::test_llama_3_2_3b_accuracy"
294+
},
295+
{
296+
"name": "llama_3_1_8b_instruct_accuracy",
297+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 tqdm transformers==4.57.1",
298+
"pytest": "benchmark/tt-xla/test_llms.py::test_llama_3_1_8b_accuracy"
299+
},
300+
{
301+
"name": "mistral_7b_accuracy",
302+
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1 protobuf sentencepiece",
303+
"pytest": "benchmark/tt-xla/test_llms.py::test_mistral_7b_accuracy"
304+
},
305+
{
306+
"name": "qwen_2_5_7b_instruct_accuracy",
307+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
308+
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_2_5_7b_accuracy"
309+
},
310+
{
311+
"name": "google_gemma-1.1-2b-it_accuracy",
312+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 tqdm transformers==4.57.1",
313+
"pytest": "benchmark/tt-xla/test_llms.py::test_gemma_1_1_2b_accuracy"
314+
},
315+
{
316+
"name": "google_gemma-2-2b-it_accuracy",
317+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 tqdm transformers==4.57.1",
318+
"pytest": "benchmark/tt-xla/test_llms.py::test_gemma_2_2b_accuracy"
319+
},
320+
{
321+
"name": "microsoft_phi-1_accuracy",
322+
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1",
323+
"pytest": "benchmark/tt-xla/test_llms.py::test_phi1_accuracy"
324+
},
325+
{
326+
"name": "microsoft_phi-1_5_accuracy",
327+
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1",
328+
"pytest": "benchmark/tt-xla/test_llms.py::test_phi1_5_accuracy"
329+
},
330+
{
331+
"name": "microsoft_phi-2_accuracy",
332+
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1",
333+
"pytest": "benchmark/tt-xla/test_llms.py::test_phi2_accuracy"
334+
},
335+
{
336+
"name": "tiiuae_falcon3-1b-base_accuracy",
337+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
338+
"pytest": "benchmark/tt-xla/test_llms.py::test_falcon3_1b_accuracy"
339+
},
340+
{
341+
"name": "tiiuae_falcon3-3b-base_accuracy",
342+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
343+
"pytest": "benchmark/tt-xla/test_llms.py::test_falcon3_3b_accuracy"
344+
},
345+
{
346+
"name": "tiiuae_falcon3-7b-base_accuracy",
347+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
348+
"pytest": "benchmark/tt-xla/test_llms.py::test_falcon3_7b_accuracy"
349+
},
350+
{
351+
"name": "qwen_2_5_0_5b_instruct_accuracy",
352+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
353+
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_2_5_0_5b_accuracy"
354+
},
355+
{
356+
"name": "qwen_2_5_1_5b_instruct_accuracy",
357+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
358+
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_2_5_1_5b_accuracy"
359+
},
360+
{
361+
"name": "qwen_2_5_3b_instruct_accuracy",
362+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
363+
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_2_5_3b_accuracy"
364+
},
365+
{
366+
"name": "qwen_3_0_6b_accuracy",
367+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
368+
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_3_0_6b_accuracy"
369+
},
370+
{
371+
"name": "qwen_3_1_7b_accuracy",
372+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
373+
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_3_1_7b_accuracy"
374+
},
375+
{
376+
"name": "qwen_3_4b_accuracy",
377+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
378+
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_3_4b_accuracy"
379+
},
380+
{
381+
"name": "qwen_3_8b_accuracy",
382+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
383+
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_3_8b_accuracy"
384+
},
385+
{
386+
"name": "ministral_8b_accuracy",
387+
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1",
388+
"pytest": "benchmark/tt-xla/test_llms.py::test_ministral_8b_accuracy"
284389
}
285390
]
286391
}
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
Generate reference outputs for LLM accuracy testing.
7+
8+
This script loads a HuggingFace model, runs it on the "Tale of Two Cities" text corpus,
9+
and generates a .refpt file containing reference tokens and top-5 predictions for each position.
10+
11+
The .refpt files are used by the TokenAccuracy class for measuring TOP1 and TOP5 accuracy
12+
during model inference testing.
13+
14+
Usage:
15+
python3 benchmark/tt-xla/generate_reference_outputs.py \\
16+
--model "meta-llama/Llama-3.2-1B-Instruct" \\
17+
--output_file "benchmark/tt-xla/reference_outputs/Llama-3.2-1B-Instruct.refpt" \\
18+
--total_length 1024
19+
20+
Output format (.refpt file):
21+
{
22+
'reference_tokens': torch.Tensor, # Shape: [1, total_length]
23+
'top5_tokens': torch.Tensor, # Shape: [total_length, 5]
24+
}
25+
"""
26+
27+
import argparse
28+
import bz2
29+
import os
30+
31+
import torch
32+
from loguru import logger
33+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
34+
35+
36+
def generate_reference_outputs(total_length, output_file, model_name):
37+
"""
38+
Generate reference outputs for accuracy testing using HuggingFace models.
39+
40+
Args:
41+
total_length: Number of tokens to process from Tale of Two Cities
42+
output_file: Path to save .refpt file
43+
model_name: HuggingFace model name (e.g., 'meta-llama/Llama-3.2-1B-Instruct')
44+
"""
45+
# Set device
46+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47+
logger.info(f"Using device: {device}")
48+
49+
# Load model and tokenizer from HuggingFace
50+
config = AutoConfig.from_pretrained(model_name)
51+
52+
# Qwen only: add rope scaling to the config, for long context support.
53+
# https://huggingface.co/Qwen/Qwen2.5-7B-Instruct#processing-long-texts
54+
if "Qwen" in model_name:
55+
config.rope_scaling = {"factor": 4.0, "original_max_position_embeddings": 32768, "type": "yarn"}
56+
57+
tokenizer = AutoTokenizer.from_pretrained(model_name)
58+
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, device_map="auto")
59+
model.eval()
60+
61+
# Load the book text
62+
current_file_path = os.path.abspath(__file__)
63+
current_file_dir = os.path.dirname(current_file_path)
64+
prompt_file = os.path.join(current_file_dir, "reference_outputs", "tale-of-two-cities.txt.bz2")
65+
66+
if not os.path.exists(prompt_file):
67+
raise FileNotFoundError(
68+
f"Tale of Two Cities text file not found: {prompt_file}\n"
69+
f"Please ensure the file exists in the reference_outputs directory."
70+
)
71+
72+
logger.info(f"Loading text from {prompt_file}")
73+
with bz2.open(prompt_file, "rt", encoding="utf-8") as f:
74+
text = f.read()
75+
76+
# Encode text to tokens
77+
encoded_tokens = tokenizer.encode(text, add_special_tokens=True)[:total_length]
78+
encoded_tokens_tensor = torch.tensor(encoded_tokens, device=device).unsqueeze(0) # Shape [1, seq_len] on device
79+
80+
logger.info(f"Processing {len(encoded_tokens)} tokens")
81+
logger.info(f"Model: {model_name}")
82+
logger.info(f"Output file: {output_file}")
83+
84+
print(f"{'Progress':<15}{'Correct':<8}{'Actual':<15}{'Top 5 Predictions':<75}")
85+
print("-" * 113)
86+
87+
# Initialize lists to store results
88+
all_top1_correct = []
89+
all_top5_correct = []
90+
all_top5_tokens = []
91+
segment_accuracies = []
92+
chunk_size = 1024
93+
94+
with torch.no_grad():
95+
for chunk_start in range(0, total_length - 1, chunk_size):
96+
chunk_end = min(chunk_start + chunk_size, total_length)
97+
# Get input and target chunks
98+
chunk_tokens = encoded_tokens_tensor[:, chunk_start:chunk_end]
99+
chunk_next_tokens = encoded_tokens[chunk_start + 1 : chunk_end + 1]
100+
actual_chunk_size = min(len(chunk_tokens[0]), len(chunk_next_tokens))
101+
102+
# Trim input chunk if needed
103+
chunk_tokens = chunk_tokens[:, :actual_chunk_size]
104+
105+
# Process chunk using HuggingFace model
106+
outputs = model(chunk_tokens.to(device))
107+
logits = outputs.logits
108+
109+
# Compute top-5 predictions
110+
probs = torch.softmax(logits, dim=-1)
111+
_, chunk_top5_tokens = torch.topk(probs, k=5, dim=-1) # Shape: [1, chunk_size, 5]
112+
chunk_top5_tokens = chunk_top5_tokens.squeeze(0) # Shape: [chunk_size, 5]
113+
114+
# Get next tokens tensor
115+
chunk_next_tokens_tensor = torch.tensor(
116+
chunk_next_tokens[:actual_chunk_size], device=device
117+
) # Move to same device
118+
119+
# Calculate correctness
120+
chunk_top1_correct = chunk_top5_tokens[:, 0] == chunk_next_tokens_tensor
121+
chunk_top5_correct = torch.any(chunk_top5_tokens == chunk_next_tokens_tensor.unsqueeze(1), dim=1)
122+
123+
# Store results
124+
all_top1_correct.extend(chunk_top1_correct.tolist())
125+
all_top5_correct.extend(chunk_top5_correct.tolist())
126+
all_top5_tokens.append(chunk_top5_tokens)
127+
128+
# Print predictions for this chunk
129+
for i in range(len(chunk_next_tokens)):
130+
global_pos = chunk_start + i
131+
next_token = chunk_next_tokens[i]
132+
133+
sanitize = lambda x: x.replace("\n", "").replace("\r", "").replace("\x0c", "")
134+
actual_token = sanitize(tokenizer.decode([next_token]))
135+
top5_tokens = [sanitize(tokenizer.decode([t.item()])) for t in chunk_top5_tokens[i]]
136+
correct = "x" if chunk_top1_correct[i] else ("-" if chunk_top5_correct[i] else " ")
137+
top5_str = " ".join(f"{t:<14}" for t in top5_tokens)
138+
139+
progress_str = f"{global_pos+1}/{total_length-1}"
140+
print(f"{progress_str:<15}{correct:<8}{actual_token:<15}{top5_str}")
141+
142+
# Calculate and store segment accuracies every 100 tokens
143+
if (global_pos + 1) % 100 == 0 or global_pos == total_length - 2:
144+
start_idx = (global_pos // 100) * 100
145+
end_idx = min(start_idx + 100, len(all_top1_correct))
146+
segment_top1_acc = sum(all_top1_correct[start_idx:end_idx]) / (end_idx - start_idx) * 100
147+
segment_top5_acc = sum(all_top5_correct[start_idx:end_idx]) / (end_idx - start_idx) * 100
148+
if len(segment_accuracies) <= global_pos // 100:
149+
segment_accuracies.append((segment_top1_acc, segment_top5_acc))
150+
151+
# Save the data - ensure tensors are concatenated and on CPU
152+
data = {
153+
"top5_tokens": torch.cat(all_top5_tokens, dim=0).cpu(),
154+
"reference_tokens": encoded_tokens_tensor[:, :total_length].clone().cpu(),
155+
}
156+
157+
torch.save(data, output_file)
158+
logger.info(f"Saved reference outputs to {output_file}")
159+
160+
# Print all segment accuracy summaries as a table
161+
print("\nSegment Accuracy Summaries:")
162+
print(f"{'Tokens':<15}{'Top-1 Accuracy':<20}{'Top-5 Accuracy':<20}")
163+
print("-" * 55)
164+
for i, (top1_acc, top5_acc) in enumerate(segment_accuracies):
165+
start_token = i * 100 + 1
166+
end_token = min((i + 1) * 100, total_length)
167+
print(f"{f'{start_token}-{end_token}':<15}{f'{top1_acc:.2f}%':<20}{f'{top5_acc:.2f}%':<20}")
168+
169+
# Calculate overall accuracy
170+
overall_top1_acc = sum(acc[0] for acc in segment_accuracies) / len(segment_accuracies)
171+
overall_top5_acc = sum(acc[1] for acc in segment_accuracies) / len(segment_accuracies)
172+
print("-" * 55)
173+
print(f"{'Overall':<15}{f'{overall_top1_acc:.2f}%':<20}{f'{overall_top5_acc:.2f}%':<20}")
174+
175+
176+
def main():
177+
parser = argparse.ArgumentParser(
178+
description="Generate reference outputs for LLM accuracy testing using HuggingFace models.",
179+
epilog="""
180+
Examples:
181+
# Generate reference for Llama 3.2 1B
182+
python3 benchmark/tt-xla/generate_reference_outputs.py \\
183+
--model "meta-llama/Llama-3.2-1B-Instruct" \\
184+
--output_file "benchmark/tt-xla/reference_outputs/Llama-3.2-1B-Instruct.refpt"
185+
186+
# Generate with custom length
187+
python3 benchmark/tt-xla/generate_reference_outputs.py \\
188+
--model "mistralai/Mistral-7B-Instruct-v0.3" \\
189+
--output_file "benchmark/tt-xla/reference_outputs/Mistral-7B-Instruct-v0.3.refpt" \\
190+
--total_length 2048
191+
""",
192+
formatter_class=argparse.RawDescriptionHelpFormatter,
193+
)
194+
parser.add_argument(
195+
"--total_length", type=int, default=1024, help="Total length of tokens to process (default: 1024)"
196+
)
197+
parser.add_argument(
198+
"--output_file",
199+
type=str,
200+
required=True,
201+
help="Output file path for reference data (e.g., 'benchmark/tt-xla/reference_outputs/ModelName.refpt')",
202+
)
203+
parser.add_argument(
204+
"--model", type=str, required=True, help="HuggingFace model name (e.g., 'meta-llama/Llama-3.2-1B-Instruct')"
205+
)
206+
args = parser.parse_args()
207+
208+
generate_reference_outputs(total_length=args.total_length, output_file=args.output_file, model_name=args.model)
209+
210+
211+
if __name__ == "__main__":
212+
main()

0 commit comments

Comments
 (0)