Skip to content

Introduce accuracy testing for LLMs in tt-forge.#859

Open
dgolubovicTT wants to merge 2 commits intomainfrom
dgolubovic/add-accuracy-llm-tests
Open

Introduce accuracy testing for LLMs in tt-forge.#859
dgolubovicTT wants to merge 2 commits intomainfrom
dgolubovic/add-accuracy-llm-tests

Conversation

@dgolubovicTT
Copy link
Contributor

@dgolubovicTT dgolubovicTT commented Feb 4, 2026

Closes tenstorrent/tt-xla#3355

Overview: LLM Accuracy Testing with TOP1/TOP5 Metrics

This PR introduces token-level accuracy testing for LLM benchmarks, measuring how well TT device predictions match CPU reference outputs.

Accuracy Metrics:

  • TOP1 Accuracy: Percentage of tokens where the device's prediction matches the CPU reference model's top prediction (argmax)
  • TOP5 Accuracy: Percentage of tokens where the device's prediction appears in the CPU reference model's top-5 predictions

Test Corpus:
We use "A Tale of Two Cities" by Charles Dickens as our reference text. The text is split into two parts:

  • Prefill phase (first half): Input context fed to the model
  • Decode phase (second half): Ground truth tokens for validation with teacher forcing

Critical Requirement - Sequence Length Matching:
The CPU reference outputs (.refpt files) must be generated with the same total_length as the input_sequence_length used during device testing. Mismatched sequence lengths cause accuracy degradation even with teacher forcing, because CPU model and compiled model need to have same input sequence size (context). If you change
input_sequence_length in tests, you must regenerate all reference outputs.

NOTE: We currently use almost identical approach to accuracy testing in tt-metal for 1-1 comparison. Same text corpus (Tale of Two Cities) and same top1/top5 metrics. However, since we test accuracy on llms only on n150, we had to use total sequence length 128 (for memory constraints). That means 64 tokens in prefil, and 64 for decode.

Introduce accuracy testing to test_llms.py

  • Add --accuracy-testing argument to test_llms.py tests that we track in benchmark testing.
  • Add --batch-size argument to accuracy tests in test_llms.py because we can't fit default batch 32 on device
    due to larger input sequence length required in accuracy testing.
    With batch size 32 and input sequence length, 7B and 8B models failed with OOM issues.
  • Run accuracy tests in separate job called: run-n150-accuracy-benchmarks of the perf-benchmark-experimental workflow

Generating ground truth .refpt files (generate_reference_outputs.py)

Add generate_reference_outputs.py script that loads Huggingface model, runs it on "Tale of Two Cities" text corpus, and generates a .refpt file containing reference tokens and top-5 predictions for each position.
Generate reference top1/top5 token predictions for LLM accuracy benchmarking:

  • Loads HuggingFace models on CPU
  • Processes "Tale of Two Cities" text corpus with teacher forcing
  • Outputs .refpt files containing reference tokens and top-k predictions
  • Used by TokenAccuracy class to validate TOP1/TOP5 accuracy during benchmarks
  • Added directory with reference .refpt files - /reference_outputs with README that explains how reference files are created and used
  • Added generate_all_reference_outputs.sh that runs generate_reference_outputs.py for a list of models and dumps reference outputs to benchmark/tt-xla/reference_outputs

Ensures reproducibility through eval mode, disabled dropout, greedy decoding,
and StaticCache matching the benchmark environment. Reference files must be
regenerated if input_sequence_length changes.

Usage:
python3 scripts/generate_reference_outputs.py
--model "meta-llama/Llama-3.2-1B-Instruct"
--output_file "reference_outputs/Llama-3.2-1B-Instruct.refpt"
--total_length 128

Adding shared utility for decode (decode_utils.py)

Centralize LLM decode operations used by reference output generation and accuracy testing:

  • Teacher forcing generation with ground truth tokens
  • Reference top-k prediction generation for .refpt files
  • Static cache and accuracy testing initialization
  • Logits extraction and top-k token utilities

Prevents implementation drift between reference generation and benchmark paths
by sharing the same decode logic, tokenization, and cache semantics.

TokenAccuracy class for validating LLM inference quality(token_accuracy.py)

  • Loads precomputed reference data from .refpt files (tokens, top1/top5 predictions)
  • Validates torch/transformers versions match reference file for reproducibility
  • Splits reference tokens into prefill (input) and decode (ground truth) windows
  • Computes TOP1/TOP5 accuracy by comparing model predictions against reference
  • Provides teacher forcing tokens for deterministic decode loops

Slight refactoring

  • Simplify static cache initialization using init_static_cache helper
  • Remove unused variables is_multichip and mesh from generate_and_benchmark function.

@dgolubovicTT dgolubovicTT force-pushed the dgolubovic/add-accuracy-llm-tests branch from f76d178 to e1c8e5f Compare February 5, 2026 08:53
Copy link
Collaborator

@odjuricicTT odjuricicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! A few minor changes requested.

Comment on lines 284 to 388
},
{
"name": "llama_3_2_1b_instruct_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_llama_3_2_1b_accuracy"
},
{
"name": "llama_3_2_3b_instruct_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_llama_3_2_3b_accuracy"
},
{
"name": "llama_3_1_8b_instruct_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_llama_3_1_8b_accuracy"
},
{
"name": "mistral_7b_accuracy",
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1 protobuf sentencepiece",
"pytest": "benchmark/tt-xla/test_llms.py::test_mistral_7b_accuracy"
},
{
"name": "qwen_2_5_7b_instruct_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_2_5_7b_accuracy"
},
{
"name": "google_gemma-1.1-2b-it_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_gemma_1_1_2b_accuracy"
},
{
"name": "google_gemma-2-2b-it_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_gemma_2_2b_accuracy"
},
{
"name": "microsoft_phi-1_accuracy",
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_phi1_accuracy"
},
{
"name": "microsoft_phi-1_5_accuracy",
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_phi1_5_accuracy"
},
{
"name": "microsoft_phi-2_accuracy",
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_phi2_accuracy"
},
{
"name": "tiiuae_falcon3-1b-base_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_falcon3_1b_accuracy"
},
{
"name": "tiiuae_falcon3-3b-base_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_falcon3_3b_accuracy"
},
{
"name": "tiiuae_falcon3-7b-base_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_falcon3_7b_accuracy"
},
{
"name": "qwen_2_5_0_5b_instruct_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_2_5_0_5b_accuracy"
},
{
"name": "qwen_2_5_1_5b_instruct_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_2_5_1_5b_accuracy"
},
{
"name": "qwen_2_5_3b_instruct_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_2_5_3b_accuracy"
},
{
"name": "qwen_3_0_6b_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_3_0_6b_accuracy"
},
{
"name": "qwen_3_1_7b_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_3_1_7b_accuracy"
},
{
"name": "qwen_3_4b_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_3_4b_accuracy"
},
{
"name": "qwen_3_8b_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_3_8b_accuracy"
},
{
"name": "ministral_8b_accuracy",
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_ministral_8b_accuracy"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reduce this to a one model size per architecture? (e.g. 1 llama, 1 qwen etc). This will be run on perf runners by default which are overloaded currently.

Going forward we might want to have this as a seperate job and run it on shared runners by default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For start, I really want to track Accuracy on all models that we want to improve and do mixed precision on. I don't think we should jump to reducing reasonable tests before trying to solve it other way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed @odjuricicTT
We should make it a seperate run like p150 and llmbox and default it to shared runners.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. I propose the following:

  • Remove all of these additional tests.
  • Implement accuracy testing as a parameter to all existing tests (e.g. like num_layers)
  • Add a new CI job that will call all tests and pass the e.g. --accuracy param
  • Make sure that this job is run on the n150 shared runners

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, with slight corrections from @vvukomanTT and @vkovacevicTT

arch,
required_pcc,
accuracy_testing: bool = False,
model_name_for_accuracy: str = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some refactors for naming saved files related to models in this PR: https://github.com/tenstorrent/tt-forge/pull/847/changes

Please take a look and try to reuse the same logic / variables for naming.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talked with @mvasiljevicTT offline. Her model name comes from pytest test ID. My model name is real Huggingface model name. function that generates these files generate_reference_outputs has huggingface model name as attribute and loads it and creates .refpt file with the same name. So it is completely independent on what test name is... Therefore, it would be hacky to force display_name (from Marina's PR).

@dgolubovicTT dgolubovicTT force-pushed the dgolubovic/add-accuracy-llm-tests branch from e1c8e5f to bea32e9 Compare February 6, 2026 14:56
@dgolubovicTT dgolubovicTT force-pushed the dgolubovic/add-accuracy-llm-tests branch from e3e9925 to 054bb95 Compare February 9, 2026 11:05
Comment on lines 228 to 229
predicted_token = next_token_ids[0].item() # Assuming batch_size=1
predicted_tokens.append(predicted_token)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems redundant, i'd rather recompute this outside of this function using the logits, would that work for you case?

The reasons for this are that it's not ok to assume batch here and we want to minimize what is measured into e2e time if it is not needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added support for batch_size != 1 so I think we are good now

Comment on lines 284 to 388
},
{
"name": "llama_3_2_1b_instruct_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_llama_3_2_1b_accuracy"
},
{
"name": "llama_3_2_3b_instruct_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_llama_3_2_3b_accuracy"
},
{
"name": "llama_3_1_8b_instruct_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_llama_3_1_8b_accuracy"
},
{
"name": "mistral_7b_accuracy",
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1 protobuf sentencepiece",
"pytest": "benchmark/tt-xla/test_llms.py::test_mistral_7b_accuracy"
},
{
"name": "qwen_2_5_7b_instruct_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_2_5_7b_accuracy"
},
{
"name": "google_gemma-1.1-2b-it_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_gemma_1_1_2b_accuracy"
},
{
"name": "google_gemma-2-2b-it_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_gemma_2_2b_accuracy"
},
{
"name": "microsoft_phi-1_accuracy",
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_phi1_accuracy"
},
{
"name": "microsoft_phi-1_5_accuracy",
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_phi1_5_accuracy"
},
{
"name": "microsoft_phi-2_accuracy",
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_phi2_accuracy"
},
{
"name": "tiiuae_falcon3-1b-base_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_falcon3_1b_accuracy"
},
{
"name": "tiiuae_falcon3-3b-base_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_falcon3_3b_accuracy"
},
{
"name": "tiiuae_falcon3-7b-base_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_falcon3_7b_accuracy"
},
{
"name": "qwen_2_5_0_5b_instruct_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_2_5_0_5b_accuracy"
},
{
"name": "qwen_2_5_1_5b_instruct_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_2_5_1_5b_accuracy"
},
{
"name": "qwen_2_5_3b_instruct_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_2_5_3b_accuracy"
},
{
"name": "qwen_3_0_6b_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_3_0_6b_accuracy"
},
{
"name": "qwen_3_1_7b_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_3_1_7b_accuracy"
},
{
"name": "qwen_3_4b_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_3_4b_accuracy"
},
{
"name": "qwen_3_8b_accuracy",
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_3_8b_accuracy"
},
{
"name": "ministral_8b_accuracy",
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1",
"pytest": "benchmark/tt-xla/test_llms.py::test_ministral_8b_accuracy"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. I propose the following:

  • Remove all of these additional tests.
  • Implement accuracy testing as a parameter to all existing tests (e.g. like num_layers)
  • Add a new CI job that will call all tests and pass the e.g. --accuracy param
  • Make sure that this job is run on the n150 shared runners

@dgolubovicTT dgolubovicTT force-pushed the dgolubovic/add-accuracy-llm-tests branch 5 times, most recently from 1cf2179 to 29c9c3b Compare February 12, 2026 10:32
Copy link
Contributor

@vkovacevicTT vkovacevicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good now!
Here is the run with latest changes: https://github.com/tenstorrent/tt-forge/actions/runs/21955566618

@dgolubovicTT dgolubovicTT force-pushed the dgolubovic/add-accuracy-llm-tests branch 3 times, most recently from 4bc2f39 to 60750d0 Compare February 13, 2026 22:25
@vkovacevicTT vkovacevicTT force-pushed the dgolubovic/add-accuracy-llm-tests branch from 60750d0 to a89fa14 Compare February 18, 2026 15:30
- Add --accuracy-testing argument to test_llms.py tests that we track in benchmark testing.
- Add --batch-size argument to accuracy tests in test_llms.py because we can't fit default batch 32 on device
  due to larger input sequence length required in accuracy testing.
  With batch size 32 and input sequence length, 7B and 8B models failed with OOM issues.
- Run accuracy tests in separate job called: run-n150-accuracy-benchmarks of the perf-benchmark-experimental workflow

Teacher forcing for accuracy testing:
- Add teacher forcing support to generate_and_benchmark() for accuracy testing mode
- Route to teacher_forced_generate() when ground_truth_tokens provided
- Update construct_inputs() to support pre-tokenized input and custom prompts

Generating ground truth .refpt files (generate_reference_outputs.py):
Add generate_reference_outputs.py: create ground truth .refpt files for accuracy testing
Generate reference top1/top5 token predictions for LLM accuracy benchmarking:
  - Loads HuggingFace models on CPU for deterministic inference
  - Processes "Tale of Two Cities" text corpus with teacher forcing
  - Outputs .refpt files containing reference tokens and top-k predictions
  - Used by TokenAccuracy class to validate TOP1/TOP5 accuracy during benchmarks

 Ensures reproducibility through eval mode, disabled dropout, greedy decoding,
 and StaticCache matching the benchmark environment. Reference files must be
 regenerated if input_sequence_length changes.

 Usage:
    python3 scripts/generate_reference_outputs.py \
        --model "meta-llama/Llama-3.2-1B-Instruct" \
        --output_file "reference_outputs/Llama-3.2-1B-Instruct.refpt" \
        --total_length 128

Adding shared utility for decode (decode_utils.py):
Centralize LLM decode operations used by reference output generation and accuracy testing:
  - Teacher forcing generation with ground truth tokens
  - Reference top-k prediction generation for .refpt files
  - Static cache and accuracy testing initialization
  - Logits extraction and top-k token utilities

  Prevents implementation drift between reference generation and benchmark paths
  by sharing the same decode logic, tokenization, and cache semantics.

TokenAccuracy class for validating LLM inference quality(token_accuracy.py):
  - Loads precomputed reference data from .refpt files (tokens, top1/top5 predictions)
  - Validates torch/transformers versions match reference file for reproducibility
  - Splits reference tokens into prefill (input) and decode (ground truth) windows
  - Computes TOP1/TOP5 accuracy by comparing model predictions against reference
  - Provides teacher forcing tokens for deterministic decode loops

Slight refactoring
  - Simplify static cache initialization using init_static_cache helper
  - Remove unused variables is_multichip and mesh from generate_and_benchmark function.
@dgolubovicTT dgolubovicTT force-pushed the dgolubovic/add-accuracy-llm-tests branch from a89fa14 to 6c3215f Compare February 23, 2026 16:37
@dgolubovicTT dgolubovicTT force-pushed the dgolubovic/add-accuracy-llm-tests branch from 6c3215f to 7aa4058 Compare February 23, 2026 16:55
…and cache_position layout drift

Teacher forcing was feeding a per-step scalar token (ground_truth_tokens[step].to(device)).
On XLA-style backends this commonly takes the scalar-constant path, which can specialize the
compiled program on the token value. In decode this produces many unique programs (one per
token) and can blow instruction/L1 caches.

Fix by slicing on CPU to a stable-shaped tensor [1,1] each step and transferring it as runtime
data. Expand to [batch,1] and materialize a contiguous buffer to avoid broadcast/stride issues.

cache_position updates done on-device produced an si32 buffer with a different (non-tiled)
layout than the compiled model expects (tiled si32), leading to TTIR to TTNN compilation failure
on Gemma. Fix by round-tripping cache_position through CPU: normalize to shape [1] via
reshape(-1)[-1:], increment on host, then re-upload so the device import path restores the
expected layout.
@dgolubovicTT dgolubovicTT force-pushed the dgolubovic/add-accuracy-llm-tests branch from 7aa4058 to 64ce899 Compare February 23, 2026 16:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add top1/top5 metrics tests in benchmark LLMs

6 participants