Skip to content

Commit e3e9925

Browse files
committed
Adressing PR comments
- Add benchmark/tt-xla/scripts/generate_reference_outputs.py script - Remove ground truth from warmup - Added utility function for initializing accuracy testing: benchmark/tt-xla/utils.py::initialize_accuracy_testing
1 parent bea32e9 commit e3e9925

File tree

4 files changed

+271
-25
lines changed

4 files changed

+271
-25
lines changed

benchmark/tt-xla/llm_benchmark.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
create_benchmark_result,
3333
compute_pcc,
3434
build_xla_export_name,
35+
initialize_accuracy_testing,
3536
)
36-
from token_accuracy import TokenAccuracy
3737

3838
xr.set_device_type("TT")
3939

@@ -388,27 +388,10 @@ def benchmark_llm_torch_xla(
388388
token_accuracy = None
389389
custom_input_prompt = None
390390
if accuracy_testing:
391-
if model_name_for_accuracy is None:
392-
raise ValueError("model_name_for_accuracy must be provided when accuracy_testing=True")
393-
394-
# Use half the cache for prefill, half for decode
395-
# This ensures we fit within hardware constraints
396-
max_prefill = max_cache_len // 2
397-
max_decode = max_cache_len // 2
398-
399-
token_accuracy = TokenAccuracy(
400-
model_name=model_name_for_accuracy,
401-
max_prefill_tokens=max_prefill,
402-
max_decode_tokens=max_decode,
403-
)
404-
405-
# Get Tale of Two Cities text from reference data
406-
custom_input_prompt = token_accuracy.prepare_ref_tokens(tokenizer)
407-
print(
408-
f"Using reference text for accuracy testing:"
409-
f"\n Max prefill: {max_prefill} tokens"
410-
f"\n Max decode: {max_decode} tokens"
411-
f"\n Text preview: {custom_input_prompt[:100]}..."
391+
token_accuracy, custom_input_prompt = initialize_accuracy_testing(
392+
model_name_for_accuracy=model_name_for_accuracy,
393+
max_cache_len=max_cache_len,
394+
tokenizer=tokenizer,
412395
)
413396

414397
# Construct inputs, including static cache
@@ -477,7 +460,6 @@ def benchmark_llm_torch_xla(
477460
# Warmup run
478461
print("Warming up...")
479462
warmup_tokens = min(MIN_STEPS, max_tokens_to_generate)
480-
ground_truth_for_warmup = token_accuracy.reference_tokens[:warmup_tokens] if accuracy_testing else None
481463
_, _, _ = generate_and_benchmark(
482464
compiled_model,
483465
input_args,
@@ -488,7 +470,6 @@ def benchmark_llm_torch_xla(
488470
verbose=False,
489471
is_multichip=is_multichip,
490472
mesh=mesh,
491-
ground_truth_tokens=ground_truth_for_warmup,
492473
)
493474

494475
# Reconstruct inputs for the actual benchmark run
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# SPDX-FileCopyrightText: (c) 2026 Tenstorrent AI ULC
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
Generate reference outputs for LLM accuracy testing.
7+
8+
This script loads a HuggingFace model, runs it on the "Tale of Two Cities" text corpus,
9+
and generates a .refpt file containing reference tokens and top-5 predictions for each position.
10+
11+
The .refpt files are used by the TokenAccuracy class for measuring TOP1 and TOP5 accuracy
12+
during model inference testing.
13+
14+
Usage:
15+
python3 <path-to-script>/generate_reference_outputs.py \\
16+
--model "meta-llama/Llama-3.2-1B-Instruct" \\
17+
--output_file "<output-dir>/Llama-3.2-1B-Instruct.refpt" \\
18+
--total_length 1024
19+
20+
Output format (.refpt file):
21+
{
22+
'reference_tokens': torch.Tensor, # Shape: [1, total_length]
23+
'top5_tokens': torch.Tensor, # Shape: [total_length, 5]
24+
}
25+
"""
26+
27+
import argparse
28+
import bz2
29+
import os
30+
31+
import torch
32+
import transformers
33+
from loguru import logger
34+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
35+
36+
37+
def generate_reference_outputs(total_length, output_file, model_name):
38+
"""
39+
Generate reference outputs for accuracy testing using HuggingFace models.
40+
41+
Args:
42+
total_length: Number of tokens to process from Tale of Two Cities
43+
output_file: Path to save .refpt file
44+
model_name: HuggingFace model name (e.g., 'meta-llama/Llama-3.2-1B-Instruct')
45+
"""
46+
# Set device
47+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48+
logger.info(f"Using device: {device}")
49+
50+
# Load model and tokenizer from HuggingFace
51+
config = AutoConfig.from_pretrained(model_name)
52+
53+
# Qwen only: add rope scaling to the config, for long context support.
54+
# https://huggingface.co/Qwen/Qwen2.5-7B-Instruct#processing-long-texts
55+
if "Qwen" in model_name:
56+
config.rope_scaling = {"factor": 4.0, "original_max_position_embeddings": 32768, "type": "yarn"}
57+
58+
tokenizer = AutoTokenizer.from_pretrained(model_name)
59+
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, device_map="auto")
60+
model.eval()
61+
62+
# Load the book text - look in ../reference_outputs relative to script location
63+
current_file_path = os.path.abspath(__file__)
64+
current_file_dir = os.path.dirname(current_file_path)
65+
# Navigate up to parent directory and look for reference_outputs
66+
parent_dir = os.path.dirname(current_file_dir)
67+
prompt_file = os.path.join(parent_dir, "reference_outputs", "tale-of-two-cities.txt.bz2")
68+
69+
if not os.path.exists(prompt_file):
70+
raise FileNotFoundError(
71+
f"Tale of Two Cities text file not found: {prompt_file}\n"
72+
f"Please ensure tale-of-two-cities.txt.bz2 exists in the reference_outputs directory."
73+
)
74+
75+
logger.info(f"Loading text from {prompt_file}")
76+
with bz2.open(prompt_file, "rt", encoding="utf-8") as f:
77+
text = f.read()
78+
79+
# Encode text to tokens
80+
encoded_tokens = tokenizer.encode(text, add_special_tokens=True)[:total_length]
81+
encoded_tokens_tensor = torch.tensor(encoded_tokens, device=device).unsqueeze(0) # Shape [1, seq_len] on device
82+
83+
logger.info(f"Processing {len(encoded_tokens)} tokens")
84+
logger.info(f"Model: {model_name}")
85+
logger.info(f"Output file: {output_file}")
86+
87+
print(f"{'Progress':<15}{'Correct':<8}{'Actual':<15}{'Top 5 Predictions':<75}")
88+
print("-" * 113)
89+
90+
# Initialize lists to store results
91+
all_top1_correct = []
92+
all_top5_correct = []
93+
all_top5_tokens = []
94+
segment_accuracies = []
95+
chunk_size = 1024
96+
97+
with torch.no_grad():
98+
for chunk_start in range(0, total_length - 1, chunk_size):
99+
chunk_end = min(chunk_start + chunk_size, total_length)
100+
# Get input and target chunks
101+
chunk_tokens = encoded_tokens_tensor[:, chunk_start:chunk_end]
102+
chunk_next_tokens = encoded_tokens[chunk_start + 1 : chunk_end + 1]
103+
actual_chunk_size = min(len(chunk_tokens[0]), len(chunk_next_tokens))
104+
105+
# Trim input chunk if needed
106+
chunk_tokens = chunk_tokens[:, :actual_chunk_size]
107+
108+
# Process chunk using HuggingFace model
109+
outputs = model(chunk_tokens.to(device))
110+
logits = outputs.logits
111+
112+
# Compute top-5 predictions
113+
probs = torch.softmax(logits, dim=-1)
114+
_, chunk_top5_tokens = torch.topk(probs, k=5, dim=-1) # Shape: [1, chunk_size, 5]
115+
chunk_top5_tokens = chunk_top5_tokens.squeeze(0) # Shape: [chunk_size, 5]
116+
117+
# Get next tokens tensor
118+
chunk_next_tokens_tensor = torch.tensor(
119+
chunk_next_tokens[:actual_chunk_size], device=device
120+
) # Move to same device
121+
122+
# Calculate correctness
123+
chunk_top1_correct = chunk_top5_tokens[:, 0] == chunk_next_tokens_tensor
124+
chunk_top5_correct = torch.any(chunk_top5_tokens == chunk_next_tokens_tensor.unsqueeze(1), dim=1)
125+
126+
# Store results
127+
all_top1_correct.extend(chunk_top1_correct.tolist())
128+
all_top5_correct.extend(chunk_top5_correct.tolist())
129+
all_top5_tokens.append(chunk_top5_tokens)
130+
131+
# Print predictions for this chunk
132+
for i in range(len(chunk_next_tokens)):
133+
global_pos = chunk_start + i
134+
next_token = chunk_next_tokens[i]
135+
136+
sanitize = lambda x: x.replace("\n", "").replace("\r", "").replace("\x0c", "")
137+
actual_token = sanitize(tokenizer.decode([next_token]))
138+
top5_tokens = [sanitize(tokenizer.decode([t.item()])) for t in chunk_top5_tokens[i]]
139+
correct = "x" if chunk_top1_correct[i] else ("-" if chunk_top5_correct[i] else " ")
140+
top5_str = " ".join(f"{t:<14}" for t in top5_tokens)
141+
142+
progress_str = f"{global_pos+1}/{total_length-1}"
143+
print(f"{progress_str:<15}{correct:<8}{actual_token:<15}{top5_str}")
144+
145+
# Calculate and store segment accuracies every 100 tokens
146+
if (global_pos + 1) % 100 == 0 or global_pos == total_length - 2:
147+
start_idx = (global_pos // 100) * 100
148+
end_idx = min(start_idx + 100, len(all_top1_correct))
149+
segment_top1_acc = sum(all_top1_correct[start_idx:end_idx]) / (end_idx - start_idx) * 100
150+
segment_top5_acc = sum(all_top5_correct[start_idx:end_idx]) / (end_idx - start_idx) * 100
151+
if len(segment_accuracies) <= global_pos // 100:
152+
segment_accuracies.append((segment_top1_acc, segment_top5_acc))
153+
154+
# Save the data - ensure tensors are concatenated and on CPU
155+
data = {
156+
"top5_tokens": torch.cat(all_top5_tokens, dim=0).cpu(),
157+
"reference_tokens": encoded_tokens_tensor[:, :total_length].clone().cpu(),
158+
"library_versions": {
159+
"torch": torch.__version__,
160+
"transformers": transformers.__version__,
161+
},
162+
}
163+
164+
torch.save(data, output_file)
165+
logger.info(f"Saved reference outputs to {output_file}")
166+
logger.info(f"Library versions: torch={torch.__version__}, transformers={transformers.__version__}")
167+
168+
# Print all segment accuracy summaries as a table
169+
print("\nSegment Accuracy Summaries:")
170+
print(f"{'Tokens':<15}{'Top-1 Accuracy':<20}{'Top-5 Accuracy':<20}")
171+
print("-" * 55)
172+
for i, (top1_acc, top5_acc) in enumerate(segment_accuracies):
173+
start_token = i * 100 + 1
174+
end_token = min((i + 1) * 100, total_length)
175+
print(f"{f'{start_token}-{end_token}':<15}{f'{top1_acc:.2f}%':<20}{f'{top5_acc:.2f}%':<20}")
176+
177+
# Calculate overall accuracy
178+
overall_top1_acc = sum(acc[0] for acc in segment_accuracies) / len(segment_accuracies)
179+
overall_top5_acc = sum(acc[1] for acc in segment_accuracies) / len(segment_accuracies)
180+
print("-" * 55)
181+
print(f"{'Overall':<15}{f'{overall_top1_acc:.2f}%':<20}{f'{overall_top5_acc:.2f}%':<20}")
182+
183+
184+
def main():
185+
parser = argparse.ArgumentParser(
186+
description="Generate reference outputs for LLM accuracy testing using HuggingFace models.",
187+
epilog="""
188+
Examples:
189+
# Generate reference for Llama 3.2 1B
190+
python3 generate_reference_outputs.py \\
191+
--model "meta-llama/Llama-3.2-1B-Instruct" \\
192+
--output_file "../reference_outputs/Llama-3.2-1B-Instruct.refpt"
193+
194+
# Generate with custom length
195+
python3 generate_reference_outputs.py \\
196+
--model "mistralai/Mistral-7B-Instruct-v0.3" \\
197+
--output_file "../reference_outputs/Mistral-7B-Instruct-v0.3.refpt" \\
198+
--total_length 2048
199+
""",
200+
formatter_class=argparse.RawDescriptionHelpFormatter,
201+
)
202+
parser.add_argument(
203+
"--total_length", type=int, default=1024, help="Total length of tokens to process (default: 1024)"
204+
)
205+
parser.add_argument(
206+
"--output_file",
207+
type=str,
208+
required=True,
209+
help="Output file path for reference data (e.g., '../reference_outputs/ModelName.refpt')",
210+
)
211+
parser.add_argument(
212+
"--model", type=str, required=True, help="HuggingFace model name (e.g., 'meta-llama/Llama-3.2-1B-Instruct')"
213+
)
214+
args = parser.parse_args()
215+
216+
generate_reference_outputs(total_length=args.total_length, output_file=args.output_file, model_name=args.model)
217+
218+
219+
if __name__ == "__main__":
220+
main()

benchmark/tt-xla/token_accuracy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
1+
# SPDX-FileCopyrightText: (c) 2026 Tenstorrent AI ULC
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

benchmark/tt-xla/utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,3 +464,48 @@ def move_to_cpu(data):
464464
moved = [move_to_cpu(item) for item in data]
465465
return type(data)(moved)
466466
return data
467+
468+
469+
def initialize_accuracy_testing(model_name_for_accuracy: str, max_cache_len: int, tokenizer):
470+
"""
471+
Initialize token accuracy testing for LLM benchmarks.
472+
473+
Args:
474+
model_name_for_accuracy: Model name for .refpt file lookup
475+
max_cache_len: Maximum cache length (determines prefill and decode splits)
476+
tokenizer: HuggingFace tokenizer instance
477+
478+
Returns:
479+
Tuple of (token_accuracy, custom_input_prompt)
480+
- token_accuracy: TokenAccuracy instance
481+
- custom_input_prompt: Reference text string for benchmarking
482+
483+
Raises:
484+
ValueError: If model_name_for_accuracy is None
485+
"""
486+
from token_accuracy import TokenAccuracy
487+
488+
if model_name_for_accuracy is None:
489+
raise ValueError("model_name_for_accuracy must be provided when accuracy_testing=True")
490+
491+
# Use half the cache for prefill, half for decode
492+
# This ensures we fit within hardware constraints
493+
max_prefill = max_cache_len // 2
494+
max_decode = max_cache_len // 2
495+
496+
token_accuracy = TokenAccuracy(
497+
model_name=model_name_for_accuracy,
498+
max_prefill_tokens=max_prefill,
499+
max_decode_tokens=max_decode,
500+
)
501+
502+
# Get Tale of Two Cities text from reference data
503+
custom_input_prompt = token_accuracy.prepare_ref_tokens(tokenizer)
504+
print(
505+
f"Using reference text for accuracy testing:"
506+
f"\n Max prefill: {max_prefill} tokens"
507+
f"\n Max decode: {max_decode} tokens"
508+
f"\n Text preview: {custom_input_prompt[:100]}..."
509+
)
510+
511+
return token_accuracy, custom_input_prompt

0 commit comments

Comments
 (0)