Skip to content

Commit e6c3c9f

Browse files
dgolubovicTTvkovacevicTT
authored andcommitted
Introduce accuracy testing to test_llms.py:
- Add --accuracy-testing argument to test_llms.py tests that we track in benchmark testing. - Add --batch-size argument to accuracy tests in test_llms.py because we can't fit default batch 32 on device due to larger input sequence length required in accuracy testing. With batch size 32 and input sequence length, 7B and 8B models failed with OOM issues. - Run accuracy tests in separate job called: run-n150-accuracy-benchmarks of the perf-benchmark-experimental workflow Teacher forcing for accuracy testing: - Add teacher forcing support to generate_and_benchmark() for accuracy testing mode - Route to teacher_forced_generate() when ground_truth_tokens provided - Update construct_inputs() to support pre-tokenized input and custom prompts Generating ground truth .refpt files (generate_reference_outputs.py): Add generate_reference_outputs.py: create ground truth .refpt files for accuracy testing Generate reference top1/top5 token predictions for LLM accuracy benchmarking: - Loads HuggingFace models on CPU for deterministic inference - Processes "Tale of Two Cities" text corpus with teacher forcing - Outputs .refpt files containing reference tokens and top-k predictions - Used by TokenAccuracy class to validate TOP1/TOP5 accuracy during benchmarks Ensures reproducibility through eval mode, disabled dropout, greedy decoding, and StaticCache matching the benchmark environment. Reference files must be regenerated if input_sequence_length changes. Usage: python3 scripts/generate_reference_outputs.py \ --model "meta-llama/Llama-3.2-1B-Instruct" \ --output_file "reference_outputs/Llama-3.2-1B-Instruct.refpt" \ --total_length 128 Adding shared utility for decode (decode_utils.py): Centralize LLM decode operations used by reference output generation and accuracy testing: - Teacher forcing generation with ground truth tokens - Reference top-k prediction generation for .refpt files - Static cache and accuracy testing initialization - Logits extraction and top-k token utilities Prevents implementation drift between reference generation and benchmark paths by sharing the same decode logic, tokenization, and cache semantics. TokenAccuracy class for validating LLM inference quality(token_accuracy.py): - Loads precomputed reference data from .refpt files (tokens, top1/top5 predictions) - Validates torch/transformers versions match reference file for reproducibility - Splits reference tokens into prefill (input) and decode (ground truth) windows - Computes TOP1/TOP5 accuracy by comparing model predictions against reference - Provides teacher forcing tokens for deterministic decode loops Slight refactoring - Simplify static cache initialization using init_static_cache helper - Remove unused variables is_multichip and mesh from generate_and_benchmark function.
1 parent 4db9918 commit e6c3c9f

34 files changed

+1518
-103
lines changed

.github/workflows/call-perf-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ jobs:
197197
python benchmark/benchmark.py -p ${{ matrix.build.project}} -m ${{ matrix.build.name }} -bs ${{ matrix.build.bs }} -df ${{ matrix.build.df }} -lp ${{ matrix.build.lp }} ${{ matrix.build.input_sequence_length && format('-isl {0}', matrix.build.input_sequence_length) }} -ts ${{ matrix.build.ts }} -o ${{ steps.strings.outputs.perf_report_json_file }} ${{ inputs.run_id_source && format('-r {0}', inputs.run_id_source) }}
198198
else
199199
# Run with pytest
200-
pytest -svv "${{ matrix.build.pytest }}" --output-file=${{ steps.strings.outputs.perf_report_json_file }}
200+
pytest -svv "${{ matrix.build.pytest }}" ${{ matrix.build.accuracy-testing && '--accuracy-testing true' || '' }} ${{ matrix.build['batch-size'] && format('--batch-size {0}', matrix.build['batch-size']) || '' }} --output-file=${{ steps.strings.outputs.perf_report_json_file }}
201201
fi
202202
203203
- name: Dump stablehlo to report

.github/workflows/perf-bench-matrix.json

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,137 @@
305305
"name": "unet_for_conditional_generation",
306306
"pyreq": "accelerate datasets diffusers==0.36.0 loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1",
307307
"pytest": "benchmark/tt-xla/test_encoders.py::test_unet_for_conditional_generation"
308+
},
309+
{
310+
"name": "llama_3_2_1b_instruct_accuracy",
311+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 tqdm transformers==4.57.1",
312+
"pytest": "benchmark/tt-xla/test_llms.py::test_llama_3_2_1b",
313+
"accuracy-testing": true
314+
},
315+
{
316+
"name": "llama_3_2_3b_instruct_accuracy",
317+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 tqdm transformers==4.57.1",
318+
"pytest": "benchmark/tt-xla/test_llms.py::test_llama_3_2_3b",
319+
"accuracy-testing": true
320+
},
321+
{
322+
"name": "llama_3_1_8b_instruct_accuracy",
323+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 tqdm transformers==4.57.1",
324+
"pytest": "benchmark/tt-xla/test_llms.py::test_llama_3_1_8b",
325+
"accuracy-testing": true,
326+
"batch-size": 16
327+
},
328+
{
329+
"name": "mistral_7b_accuracy",
330+
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1 protobuf sentencepiece",
331+
"pytest": "benchmark/tt-xla/test_llms.py::test_mistral_7b",
332+
"accuracy-testing": true,
333+
"batch-size": 8
334+
},
335+
{
336+
"name": "qwen_2_5_7b_instruct_accuracy",
337+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
338+
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_2_5_7b",
339+
"accuracy-testing": true
340+
},
341+
{
342+
"name": "google_gemma-1.1-2b-it_accuracy",
343+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 tqdm transformers==4.57.1",
344+
"pytest": "benchmark/tt-xla/test_llms.py::test_gemma_1_1_2b",
345+
"accuracy-testing": true
346+
},
347+
{
348+
"name": "google_gemma-2-2b-it_accuracy",
349+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 tqdm transformers==4.57.1",
350+
"pytest": "benchmark/tt-xla/test_llms.py::test_gemma_2_2b",
351+
"accuracy-testing": true
352+
},
353+
{
354+
"name": "microsoft_phi-1_accuracy",
355+
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1",
356+
"pytest": "benchmark/tt-xla/test_llms.py::test_phi1",
357+
"accuracy-testing": true
358+
},
359+
{
360+
"name": "microsoft_phi-1_5_accuracy",
361+
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1",
362+
"pytest": "benchmark/tt-xla/test_llms.py::test_phi1_5",
363+
"accuracy-testing": true
364+
},
365+
{
366+
"name": "microsoft_phi-2_accuracy",
367+
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1",
368+
"pytest": "benchmark/tt-xla/test_llms.py::test_phi2",
369+
"accuracy-testing": true
370+
},
371+
{
372+
"name": "tiiuae_falcon3-1b-base_accuracy",
373+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
374+
"pytest": "benchmark/tt-xla/test_llms.py::test_falcon3_1b",
375+
"accuracy-testing": true
376+
},
377+
{
378+
"name": "tiiuae_falcon3-3b-base_accuracy",
379+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
380+
"pytest": "benchmark/tt-xla/test_llms.py::test_falcon3_3b",
381+
"accuracy-testing": true
382+
},
383+
{
384+
"name": "tiiuae_falcon3-7b-base_accuracy",
385+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
386+
"pytest": "benchmark/tt-xla/test_llms.py::test_falcon3_7b",
387+
"accuracy-testing": true,
388+
"batch-size": 4
389+
},
390+
{
391+
"name": "qwen_2_5_0_5b_instruct_accuracy",
392+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
393+
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_2_5_0_5b",
394+
"accuracy-testing": true
395+
},
396+
{
397+
"name": "qwen_2_5_1_5b_instruct_accuracy",
398+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
399+
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_2_5_1_5b",
400+
"accuracy-testing": true
401+
},
402+
{
403+
"name": "qwen_2_5_3b_instruct_accuracy",
404+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
405+
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_2_5_3b",
406+
"accuracy-testing": true,
407+
"batch-size": 16
408+
},
409+
{
410+
"name": "qwen_3_0_6b_accuracy",
411+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
412+
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_3_0_6b",
413+
"accuracy-testing": true
414+
},
415+
{
416+
"name": "qwen_3_1_7b_accuracy",
417+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
418+
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_3_1_7b",
419+
"accuracy-testing": true
420+
},
421+
{
422+
"name": "qwen_3_4b_accuracy",
423+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
424+
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_3_4b",
425+
"accuracy-testing": true
426+
},
427+
{
428+
"name": "qwen_3_8b_accuracy",
429+
"pyreq": "datasets loguru pytest requests tabulate timm torch==2.9.0 torchvision==0.24.0 tqdm transformers==4.57.1",
430+
"pytest": "benchmark/tt-xla/test_llms.py::test_qwen_3_8b",
431+
"accuracy-testing": true
432+
},
433+
{
434+
"name": "ministral_8b_accuracy",
435+
"pyreq": "datasets loguru pytest requests torch==2.9.0 tqdm transformers==4.57.1",
436+
"pytest": "benchmark/tt-xla/test_llms.py::test_ministral_8b",
437+
"accuracy-testing": true,
438+
"batch-size": 16
308439
}
309440
]
310441
}

.github/workflows/perf-benchmark-experimental.yml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ jobs:
1515
outputs:
1616
matrix_p150: ${{ steps.set-perf-benchmarks.outputs.matrix_p150 }}
1717
matrix_p150_skip: ${{ steps.set-perf-benchmarks.outputs.matrix_p150_skip }}
18+
matrix_n150_accuracy: ${{ steps.set-perf-benchmarks.outputs.matrix_n150_accuracy }}
19+
matrix_n150_accuracy_skip: ${{ steps.set-perf-benchmarks.outputs.matrix_n150_accuracy_skip }}
1820
steps:
1921
- name: Checkout repository
2022
uses: actions/checkout@v4
@@ -28,6 +30,7 @@ jobs:
2830
id: set-perf-benchmarks
2931
shell: bash
3032
run: |
33+
# Filter for regular p150 tests
3134
result=$(python .github/workflows/filter-test-matrix.py \
3235
.github/workflows/perf-bench-matrix.json \
3336
"tt-forge")
@@ -44,6 +47,25 @@ jobs:
4447
echo "matrix_p150=$matrix_p150" >> $GITHUB_OUTPUT
4548
echo "matrix_p150_skip=$matrix_p150_skip" >> $GITHUB_OUTPUT
4649
50+
# Filter for n150 accuracy tests
51+
# Call filter-test-matrix.py with --sh-runner flag to map n150 to shared runners
52+
result_sh=$(python .github/workflows/filter-test-matrix.py \
53+
.github/workflows/perf-bench-matrix.json \
54+
"tt-forge" \
55+
--sh-runner)
56+
57+
# Filter by: runs-on contains "n150" AND accuracy-testing == true
58+
matrix_n150_accuracy=$(echo $result_sh | jq -r -c '.matrix | map(select((."runs-on" | contains("n150")) and (.["accuracy-testing"] == true)))')
59+
60+
matrix_n150_accuracy_skip="false"
61+
62+
if [ "$matrix_n150_accuracy" == "[]" ]; then
63+
matrix_n150_accuracy_skip="true"
64+
fi
65+
66+
echo "matrix_n150_accuracy=$matrix_n150_accuracy" >> $GITHUB_OUTPUT
67+
echo "matrix_n150_accuracy_skip=$matrix_n150_accuracy_skip" >> $GITHUB_OUTPUT
68+
4769
run-p150-perf-benchmarks:
4870
needs: filter-tests
4971
if: ${{ needs.filter-tests.outputs.matrix_p150_skip == 'false' }}
@@ -53,9 +75,19 @@ jobs:
5375
matrix: ${{ needs.filter-tests.outputs.matrix_p150 }}
5476
docker-image: "ghcr.io/tenstorrent/tt-xla-slim:nightly-latest"
5577

78+
run-n150-accuracy-benchmarks:
79+
needs: filter-tests
80+
if: ${{ needs.filter-tests.outputs.matrix_n150_accuracy_skip == 'false' }}
81+
secrets: inherit
82+
uses: ./.github/workflows/call-perf-test.yml
83+
with:
84+
matrix: ${{ needs.filter-tests.outputs.matrix_n150_accuracy }}
85+
docker-image: "ghcr.io/tenstorrent/tt-xla-slim:nightly-latest"
86+
5687
produce-data:
5788
needs:
5889
- run-p150-perf-benchmarks
90+
- run-n150-accuracy-benchmarks
5991
if: always()
6092
runs-on: ubuntu-latest
6193
steps:

benchmark/tt-xla/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,13 @@ def pytest_addoption(parser):
162162
type=make_validator_boolean("--experimental-compile"),
163163
help="Enable experimental compile flag (true/false). Overrides config value.",
164164
)
165+
parser.addoption(
166+
"--accuracy-testing",
167+
action="store",
168+
default=None,
169+
type=make_validator_boolean("--accuracy-testing"),
170+
help="Enable accuracy testing mode (true/false). Uses reference data for TOP1/TOP5 accuracy.",
171+
)
165172

166173

167174
@pytest.fixture
@@ -217,3 +224,9 @@ def task(request):
217224
@pytest.fixture
218225
def experimental_compile(request):
219226
return request.config.getoption("--experimental-compile")
227+
228+
229+
@pytest.fixture
230+
def accuracy_testing(request):
231+
value = request.config.getoption("--accuracy-testing")
232+
return value if value is not None else False

0 commit comments

Comments
 (0)