Conversation
ca1eac3 to
134230d
Compare
134230d to
638453b
Compare
688fdd4 to
250c8c8
Compare
250c8c8 to
5e92be5
Compare
d545675 to
8dac492
Compare
8dac492 to
0abe32f
Compare
637ea4a to
0b7494b
Compare
There was a problem hiding this comment.
Pull request overview
Adds Automatic Prefix Caching (APC) support to the Llama3 70B Galaxy optimized model path, enabling reuse of cached KV blocks during prefill and introducing tracing/runtime heuristics to avoid performance regressions when caching is not beneficial.
Changes:
- Extend Galaxy Llama3-70B generator/model to accept
start_posand run prefix-cached prefill via flexible chunked SDPA + column replication. - Add SDPA program-config updates (including fixed flexible-chunk config) and new demo/CI test entries for prefix-caching scenarios.
- Adjust vLLM nightly workflow settings (trace region size, benchmark args, mount mode) for the new tracing/caching behavior.
Reviewed changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp |
Improves TT_FATAL message clarity for page-table stick size alignment. |
tests/pipeline_reorg/galaxy_demo_tests.yaml |
Adds a Galaxy demo pytest selection for prefix-caching PCC coverage. |
models/demos/llama3_70b_galaxy/tt/qwen_model_config.py |
Updates SDPA program-config lambdas to support chunk-start-dependent constraints and adds flexible-chunk config. |
models/demos/llama3_70b_galaxy/tt/model_config.py |
Introduces SDPA_CHUNK_ALIGN and a fixed flexible-chunk SDPA config for prefix caching traces. |
models/demos/llama3_70b_galaxy/tt/llama_rope.py |
Adds a helper to build prefill rotary matrices via embedding lookup. |
models/demos/llama3_70b_galaxy/tt/llama_model.py |
Adds device-side constants for slicing RoPE mats in-trace; extends prefill IO and output processing for paged/prefix paths. |
models/demos/llama3_70b_galaxy/tt/llama_decoder.py |
Threads chunk_start_idx_tensor through decoder → attention. |
models/demos/llama3_70b_galaxy/tt/llama_ccl.py |
Adds ATTN_REPLICATE buffers needed for column replication in prefix-cached prefill. |
models/demos/llama3_70b_galaxy/tt/llama_attention.py |
Implements prefix-cached prefill path using chunked SDPA with chunk_start_idx_tensor and column replication via line all-reduce. |
models/demos/llama3_70b_galaxy/tt/generator_vllm.py |
Advertises supports_prefix_caching=True for vLLM integration. |
models/demos/llama3_70b_galaxy/tt/generator.py |
Implements start_pos handling, alignment + heuristic skipping, trace capture/replay changes, and prefix-caching page-table shaping. |
models/demos/llama3_70b_galaxy/demo/text_demo.py |
Adds prefix-caching demo modes/tests, updates profiling flow, and adjusts trace-region sizing. |
models/demos/llama3_70b_galaxy/demo/demo_decode.py |
Updates test signature to accept device_params. |
.github/workflows/vllm-nightly-tests-impl.yaml |
Updates Galaxy vLLM config (trace size), adds --random-prefix-len, and changes /mnt/MLPerf mount mode. |
|
Oh nice. This is a really cool addition. I haven't started reviewing the code yet, but this should be easily extend to TT-Transformers codebase as well right? |
mtairum
left a comment
There was a problem hiding this comment.
Really awesome addition!
Left a couple of comments.
Do you think there's a way, moving forward, that we can have the prefix caching functionality roomed together into it's own class or abstract template or something of sorts?
I understand that it touches both the generator + the attn module (and by hierarchy the decoder/model), but I do wonder if we could make it a bit more modular for future models.
For the new added tests, how is this testing the prefix caching working? You do a warmup caching and then run prefill a second time?
From my understanding of the feature, this works best when a user loads a document and keeps asking questions on it. Since that's pretty much how are longer seqlen tests work (we load the Frankenstein Novel from Project Gutenberg, Trim it to #tokens==seqlen to prefill, then append an instruction to the LLM), could do the test with repeat_batches > 1 instead? This parameter resends the prompts in the input_prompts file, meaning that we could reuse pretty much the full prefilled cache again.
Also, make sure to test these changes exhaustively, and make sure it doesn't break accuracy nor performance in the old non-prefixed tests.
cc773d9 to
196da91
Compare
196da91 to
11643c0
Compare
Ticket
tenstorrent/vllm#268
Problem description
Automatic prefix caching (https://docs.vllm.ai/en/latest/features/automatic_prefix_caching/#enabling-apc-in-vllm) allows for re-use of cached KV entries when a new prompt arrives which shares its prefix with a previous prompt which was already processed. It will help reduce time to first token (TTFT) significantly when there are multiple users with shared prefixes and when a conversation with a user is continued over many turns.
What's changed
After previous changes:
This PR adds support for Llama 70B Galaxy optimized model.
Related VLLM PR: tenstorrent/vllm#335
Performance
Depends on total seq len and the ratio of cached tokens. The higher the ratio the better, of course. For longer sequences, the benefits are lower, because prefix caching cannot use ring SDPA, which is faster. The implementation includes a heuristic and ignores prefix caching when the performance benefit would be negative. Thus the runtime ratio cached vs non-cached is never over 100%:

CI
Checklist
Model tests
If your changes cover model-related code, you should run tests corresponding to affected models and platforms (Single card, T3K, Galaxy). "Choose your pipeline" workflows facilitate running multiple kinds of tests in a single run. Each offers
models-mandatoryandmodels-extendedpresets.The former includes a minimal set of tests, to be run always. The latter extends that with additional ones - use your best judgement in deciding which is the most appropriate for your PR.
models-mandatorypreset (runs: Device perf regressions and Frequent model and ttnn tests)models-extendedpreset (runs: the mandatory tests, plus Demo and Model perf tests)models-mandatorypreset (runs: Unit tests)models-extendedpreset (runs: the mandatory tests, plus Demo and Model perf tests)models-mandatorypreset (runs: Quick tests)models-extendedpreset (runs: the mandatory tests, plus Demo and Model perf tests)