|
| 1 | +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +""" |
| 6 | +Generate reference outputs for LLM accuracy testing. |
| 7 | +
|
| 8 | +This script loads a HuggingFace model, runs it on the "Tale of Two Cities" text corpus, |
| 9 | +and generates a .refpt file containing reference tokens and top-5 predictions for each position. |
| 10 | +
|
| 11 | +The .refpt files are used by the TokenAccuracy class for measuring TOP1 and TOP5 accuracy |
| 12 | +during model inference testing. |
| 13 | +
|
| 14 | +Usage: |
| 15 | + python3 benchmark/tt-xla/generate_reference_outputs.py \\ |
| 16 | + --model "meta-llama/Llama-3.2-1B-Instruct" \\ |
| 17 | + --output_file "benchmark/tt-xla/reference_outputs/Llama-3.2-1B-Instruct.refpt" \\ |
| 18 | + --total_length 1024 |
| 19 | +
|
| 20 | +Output format (.refpt file): |
| 21 | + { |
| 22 | + 'reference_tokens': torch.Tensor, # Shape: [1, total_length] |
| 23 | + 'top5_tokens': torch.Tensor, # Shape: [total_length, 5] |
| 24 | + } |
| 25 | +""" |
| 26 | + |
| 27 | +import argparse |
| 28 | +import bz2 |
| 29 | +import os |
| 30 | + |
| 31 | +import torch |
| 32 | +from loguru import logger |
| 33 | +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
| 34 | + |
| 35 | + |
| 36 | +def generate_reference_outputs(total_length, output_file, model_name): |
| 37 | + """ |
| 38 | + Generate reference outputs for accuracy testing using HuggingFace models. |
| 39 | +
|
| 40 | + Args: |
| 41 | + total_length: Number of tokens to process from Tale of Two Cities |
| 42 | + output_file: Path to save .refpt file |
| 43 | + model_name: HuggingFace model name (e.g., 'meta-llama/Llama-3.2-1B-Instruct') |
| 44 | + """ |
| 45 | + # Set device |
| 46 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 47 | + logger.info(f"Using device: {device}") |
| 48 | + |
| 49 | + # Load model and tokenizer from HuggingFace |
| 50 | + config = AutoConfig.from_pretrained(model_name) |
| 51 | + |
| 52 | + # Qwen only: add rope scaling to the config, for long context support. |
| 53 | + # https://huggingface.co/Qwen/Qwen2.5-7B-Instruct#processing-long-texts |
| 54 | + if "Qwen" in model_name: |
| 55 | + config.rope_scaling = {"factor": 4.0, "original_max_position_embeddings": 32768, "type": "yarn"} |
| 56 | + |
| 57 | + tokenizer = AutoTokenizer.from_pretrained(model_name) |
| 58 | + model = AutoModelForCausalLM.from_pretrained(model_name, config=config, device_map="auto") |
| 59 | + model.eval() |
| 60 | + |
| 61 | + # Load the book text |
| 62 | + current_file_path = os.path.abspath(__file__) |
| 63 | + current_file_dir = os.path.dirname(current_file_path) |
| 64 | + prompt_file = os.path.join(current_file_dir, "reference_outputs", "tale-of-two-cities.txt.bz2") |
| 65 | + |
| 66 | + if not os.path.exists(prompt_file): |
| 67 | + raise FileNotFoundError( |
| 68 | + f"Tale of Two Cities text file not found: {prompt_file}\n" |
| 69 | + f"Please ensure the file exists in the reference_outputs directory." |
| 70 | + ) |
| 71 | + |
| 72 | + logger.info(f"Loading text from {prompt_file}") |
| 73 | + with bz2.open(prompt_file, "rt", encoding="utf-8") as f: |
| 74 | + text = f.read() |
| 75 | + |
| 76 | + # Encode text to tokens |
| 77 | + encoded_tokens = tokenizer.encode(text, add_special_tokens=True)[:total_length] |
| 78 | + encoded_tokens_tensor = torch.tensor(encoded_tokens, device=device).unsqueeze(0) # Shape [1, seq_len] on device |
| 79 | + |
| 80 | + logger.info(f"Processing {len(encoded_tokens)} tokens") |
| 81 | + logger.info(f"Model: {model_name}") |
| 82 | + logger.info(f"Output file: {output_file}") |
| 83 | + |
| 84 | + print(f"{'Progress':<15}{'Correct':<8}{'Actual':<15}{'Top 5 Predictions':<75}") |
| 85 | + print("-" * 113) |
| 86 | + |
| 87 | + # Initialize lists to store results |
| 88 | + all_top1_correct = [] |
| 89 | + all_top5_correct = [] |
| 90 | + all_top5_tokens = [] |
| 91 | + segment_accuracies = [] |
| 92 | + chunk_size = 1024 |
| 93 | + |
| 94 | + with torch.no_grad(): |
| 95 | + for chunk_start in range(0, total_length - 1, chunk_size): |
| 96 | + chunk_end = min(chunk_start + chunk_size, total_length) |
| 97 | + # Get input and target chunks |
| 98 | + chunk_tokens = encoded_tokens_tensor[:, chunk_start:chunk_end] |
| 99 | + chunk_next_tokens = encoded_tokens[chunk_start + 1 : chunk_end + 1] |
| 100 | + actual_chunk_size = min(len(chunk_tokens[0]), len(chunk_next_tokens)) |
| 101 | + |
| 102 | + # Trim input chunk if needed |
| 103 | + chunk_tokens = chunk_tokens[:, :actual_chunk_size] |
| 104 | + |
| 105 | + # Process chunk using HuggingFace model |
| 106 | + outputs = model(chunk_tokens.to(device)) |
| 107 | + logits = outputs.logits |
| 108 | + |
| 109 | + # Compute top-5 predictions |
| 110 | + probs = torch.softmax(logits, dim=-1) |
| 111 | + _, chunk_top5_tokens = torch.topk(probs, k=5, dim=-1) # Shape: [1, chunk_size, 5] |
| 112 | + chunk_top5_tokens = chunk_top5_tokens.squeeze(0) # Shape: [chunk_size, 5] |
| 113 | + |
| 114 | + # Get next tokens tensor |
| 115 | + chunk_next_tokens_tensor = torch.tensor( |
| 116 | + chunk_next_tokens[:actual_chunk_size], device=device |
| 117 | + ) # Move to same device |
| 118 | + |
| 119 | + # Calculate correctness |
| 120 | + chunk_top1_correct = chunk_top5_tokens[:, 0] == chunk_next_tokens_tensor |
| 121 | + chunk_top5_correct = torch.any(chunk_top5_tokens == chunk_next_tokens_tensor.unsqueeze(1), dim=1) |
| 122 | + |
| 123 | + # Store results |
| 124 | + all_top1_correct.extend(chunk_top1_correct.tolist()) |
| 125 | + all_top5_correct.extend(chunk_top5_correct.tolist()) |
| 126 | + all_top5_tokens.append(chunk_top5_tokens) |
| 127 | + |
| 128 | + # Print predictions for this chunk |
| 129 | + for i in range(len(chunk_next_tokens)): |
| 130 | + global_pos = chunk_start + i |
| 131 | + next_token = chunk_next_tokens[i] |
| 132 | + |
| 133 | + sanitize = lambda x: x.replace("\n", "").replace("\r", "").replace("\x0c", "") |
| 134 | + actual_token = sanitize(tokenizer.decode([next_token])) |
| 135 | + top5_tokens = [sanitize(tokenizer.decode([t.item()])) for t in chunk_top5_tokens[i]] |
| 136 | + correct = "x" if chunk_top1_correct[i] else ("-" if chunk_top5_correct[i] else " ") |
| 137 | + top5_str = " ".join(f"{t:<14}" for t in top5_tokens) |
| 138 | + |
| 139 | + progress_str = f"{global_pos+1}/{total_length-1}" |
| 140 | + print(f"{progress_str:<15}{correct:<8}{actual_token:<15}{top5_str}") |
| 141 | + |
| 142 | + # Calculate and store segment accuracies every 100 tokens |
| 143 | + if (global_pos + 1) % 100 == 0 or global_pos == total_length - 2: |
| 144 | + start_idx = (global_pos // 100) * 100 |
| 145 | + end_idx = min(start_idx + 100, len(all_top1_correct)) |
| 146 | + segment_top1_acc = sum(all_top1_correct[start_idx:end_idx]) / (end_idx - start_idx) * 100 |
| 147 | + segment_top5_acc = sum(all_top5_correct[start_idx:end_idx]) / (end_idx - start_idx) * 100 |
| 148 | + if len(segment_accuracies) <= global_pos // 100: |
| 149 | + segment_accuracies.append((segment_top1_acc, segment_top5_acc)) |
| 150 | + |
| 151 | + # Save the data - ensure tensors are concatenated and on CPU |
| 152 | + data = { |
| 153 | + "top5_tokens": torch.cat(all_top5_tokens, dim=0).cpu(), |
| 154 | + "reference_tokens": encoded_tokens_tensor[:, :total_length].clone().cpu(), |
| 155 | + } |
| 156 | + |
| 157 | + torch.save(data, output_file) |
| 158 | + logger.info(f"Saved reference outputs to {output_file}") |
| 159 | + |
| 160 | + # Print all segment accuracy summaries as a table |
| 161 | + print("\nSegment Accuracy Summaries:") |
| 162 | + print(f"{'Tokens':<15}{'Top-1 Accuracy':<20}{'Top-5 Accuracy':<20}") |
| 163 | + print("-" * 55) |
| 164 | + for i, (top1_acc, top5_acc) in enumerate(segment_accuracies): |
| 165 | + start_token = i * 100 + 1 |
| 166 | + end_token = min((i + 1) * 100, total_length) |
| 167 | + print(f"{f'{start_token}-{end_token}':<15}{f'{top1_acc:.2f}%':<20}{f'{top5_acc:.2f}%':<20}") |
| 168 | + |
| 169 | + # Calculate overall accuracy |
| 170 | + overall_top1_acc = sum(acc[0] for acc in segment_accuracies) / len(segment_accuracies) |
| 171 | + overall_top5_acc = sum(acc[1] for acc in segment_accuracies) / len(segment_accuracies) |
| 172 | + print("-" * 55) |
| 173 | + print(f"{'Overall':<15}{f'{overall_top1_acc:.2f}%':<20}{f'{overall_top5_acc:.2f}%':<20}") |
| 174 | + |
| 175 | + |
| 176 | +def main(): |
| 177 | + parser = argparse.ArgumentParser( |
| 178 | + description="Generate reference outputs for LLM accuracy testing using HuggingFace models.", |
| 179 | + epilog=""" |
| 180 | +Examples: |
| 181 | + # Generate reference for Llama 3.2 1B |
| 182 | + python3 benchmark/tt-xla/generate_reference_outputs.py \\ |
| 183 | + --model "meta-llama/Llama-3.2-1B-Instruct" \\ |
| 184 | + --output_file "benchmark/tt-xla/reference_outputs/Llama-3.2-1B-Instruct.refpt" |
| 185 | +
|
| 186 | + # Generate with custom length |
| 187 | + python3 benchmark/tt-xla/generate_reference_outputs.py \\ |
| 188 | + --model "mistralai/Mistral-7B-Instruct-v0.3" \\ |
| 189 | + --output_file "benchmark/tt-xla/reference_outputs/Mistral-7B-Instruct-v0.3.refpt" \\ |
| 190 | + --total_length 2048 |
| 191 | + """, |
| 192 | + formatter_class=argparse.RawDescriptionHelpFormatter, |
| 193 | + ) |
| 194 | + parser.add_argument( |
| 195 | + "--total_length", type=int, default=1024, help="Total length of tokens to process (default: 1024)" |
| 196 | + ) |
| 197 | + parser.add_argument( |
| 198 | + "--output_file", |
| 199 | + type=str, |
| 200 | + required=True, |
| 201 | + help="Output file path for reference data (e.g., 'benchmark/tt-xla/reference_outputs/ModelName.refpt')", |
| 202 | + ) |
| 203 | + parser.add_argument( |
| 204 | + "--model", type=str, required=True, help="HuggingFace model name (e.g., 'meta-llama/Llama-3.2-1B-Instruct')" |
| 205 | + ) |
| 206 | + args = parser.parse_args() |
| 207 | + |
| 208 | + generate_reference_outputs(total_length=args.total_length, output_file=args.output_file, model_name=args.model) |
| 209 | + |
| 210 | + |
| 211 | +if __name__ == "__main__": |
| 212 | + main() |
0 commit comments