Add 3 ops: gather_dim, all_reduce_async, paged_flash_multi_latent_attention_decode#7491
Add 3 ops: gather_dim, all_reduce_async, paged_flash_multi_latent_attention_decode#7491svuckovicTT wants to merge 1 commit intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds missing TTNN/LLM operations to the Forge/TTMLIR stack by introducing new TTIR/TTNN ops and wiring them through lowering, serialization, runtime execution, and tests.
Changes:
- Introduce new ops:
ttir.gather_dim→ttnn.gather,ttir.all_reduce_async→ttnn.all_reduce_async, andttir.paged_flash_multi_latent_attention_decode→ttnn.paged_flash_multi_latent_attention_decode. - Extend TTNN flatbuffer schema + TTNNToFlatbuffer emission + runtime executor/IO plumbing to support the new ops.
- Add golden mappings and MLIR + Python golden tests for the new ops.
Reviewed changes
Copilot reviewed 40 out of 40 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| tools/ttnn-standalone/ttnn-precompiled.hpp | Adds TTNN header includes for new ops in standalone build context. |
| tools/golden/mapping.py | Adds golden implementations and mappings for new ops. |
| tools/builder/ttir/ttir_builder.py | Adds TTIRBuilder construction/parsing/splitting for new TTIR ops. |
| test/ttmlir/EmitC/TTNN/transformer/paged_flash_multi_latent_attention_decode.mlir | EmitC translation coverage for new transformer op. |
| test/ttmlir/EmitC/TTNN/gather/gather_dim.mlir | EmitC translation coverage for ttir.gather_dim. |
| test/ttmlir/EmitC/TTNN/gather/gather.mlir | EmitC translation coverage for gather case. |
| test/ttmlir/EmitC/TTNN/ccl/all_reduce_async.mlir | EmitC translation coverage for ttir.all_reduce_async. |
| test/ttmlir/Dialect/TTNN/paged_flash_multi_latent_attention_decode/simple_paged_flash_mla_decode.mlir | Checks TTIR→TTNN lowering emits ttnn.paged_flash_multi_latent_attention_decode. |
| test/ttmlir/Dialect/TTNN/gather/simple_gather_dim.mlir | Checks TTIR→TTNN lowering emits ttnn.gather. |
| test/ttmlir/Dialect/TTNN/gather/simple_gather.mlir | More FileCheck coverage for gather lowering across dims/ranks. |
| test/ttmlir/Dialect/TTNN/all_reduce_async/simple_all_reduce_async.mlir | Checks TTIR→TTNN lowering emits ttnn.all_reduce_async. |
| test/python/golden/test_ttir_ops.py | Adds Python golden tests for gather_dim, all_reduce_async, and paged flash MLA decode. |
| runtime/lib/ttnn/runtime.cpp | Adds runtime tensor input/output ref extraction for new OpTypes. |
| runtime/lib/ttnn/program_executor.cpp | Dispatches new ops to their runtime implementations. |
| runtime/lib/ttnn/operations/transformer/paged_flash_multi_latent_attention_decode.h | Declares runtime runner for new transformer op. |
| runtime/lib/ttnn/operations/transformer/paged_flash_multi_latent_attention_decode.cpp | Implements runtime invocation of TTNN paged flash MLA decode. |
| runtime/lib/ttnn/operations/data_movement/gather.h | Declares runtime runner for gather. |
| runtime/lib/ttnn/operations/data_movement/gather.cpp | Implements runtime invocation of ttnn::gather. |
| runtime/lib/ttnn/operations/ccl/all_reduce_async.h | Declares runtime runner for all_reduce_async. |
| runtime/lib/ttnn/operations/ccl/all_reduce_async.cpp | Implements runtime invocation of ttnn::experimental::all_reduce_async. |
| runtime/lib/ttnn/operations/CMakeLists.txt | Adds new runtime op sources to build. |
| runtime/include/tt/runtime/detail/ttnn/ttnn.h | Adds TTNN gather include to runtime umbrella header. |
| lib/Target/TTNN/TTNNToFlatbuffer.cpp | Emits new TTNN ops into flatbuffer program format. |
| lib/OpModel/TTNN/TTNNOpModel.cpp | Adds OpModel stubs for new TTNN ops. |
| lib/Dialect/TTNN/Interfaces/TTNNOpModelInterface.cpp | Adds OpModel interface hooks for new TTNN ops. |
| lib/Dialect/TTNN/IR/TTNNWorkaroundsPass.cpp | Adds operand layout workarounds for MLA decode inputs. |
| lib/Dialect/TTNN/IR/TTNNOps.cpp | Adds verifiers for TTNN GatherOp, AllReduceAsyncOp, and MLA decode op. |
| lib/Dialect/TTIR/IR/TTIROps.cpp | Adds verifiers for TTIR GatherDimOp, AllReduceAsyncOp, and MLA decode op. |
| lib/Conversion/TTNNToEmitPy/TTNNToEmitPy.cpp | Adds EmitPy conversions for new TTNN ops. |
| lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp | Adds EmitC conversions for new TTNN ops. |
| lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | Lowers new TTIR ops to corresponding TTNN ops. |
| include/ttmlir/Target/TTNN/program.fbs | Adds new ops to OpType union for flatbuffer schema. |
| include/ttmlir/Target/TTNN/operations/transformer.fbs | Adds flatbuffer table for MLA decode op. |
| include/ttmlir/Target/TTNN/operations/data_movement.fbs | Adds flatbuffer table for Gather op. |
| include/ttmlir/Target/TTNN/operations/ccl.fbs | Adds flatbuffer table for AllReduceAsync op. |
| include/ttmlir/OpModel/TTNN/TTNNOpModel.h | Declares OpModel specializations for new TTNN ops. |
| include/ttmlir/OpModel/TTNN/MetalHeaders.h | Adds metal header includes for new TTNN ops (opmodel build). |
| include/ttmlir/Dialect/TTNN/IR/TTNNWorkaroundsPass.h | Declares new workaround factory method. |
| include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | Declares new TTNN ops in ODS. |
| include/ttmlir/Dialect/TTIR/IR/TTIROps.td | Declares new TTIR ops in ODS. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| srcOp, adaptor, rewriter, this->isGoldenModeEnabled()); | ||
|
|
||
| llvm::SmallVector<mlir::Attribute> args{ | ||
| emitter.emit(srcOp.getInput(), "input_tensor"), |
| llvm::SmallVector<mlir::Attribute> args{ | ||
| emitter.emit(srcOp.getInput()), | ||
| emitter.emit(srcOp.getClusterAxis()), | ||
| emitter.emitSubDeviceId(srcOp.getSubDeviceId()), | ||
| emitter.emit(srcOp.getMemoryConfig()), | ||
| emitter.emit(srcOp.getNumLinks()), |
| value: Optional[GoldenMapTensor] = None, | ||
| page_table: Optional[GoldenMapTensor] = None, | ||
| attention_mask: Optional[GoldenMapTensor] = None, | ||
| cur_pos_tensor: Optional[GoldenMapTensor] = None, | ||
| attention_sink: Optional[GoldenMapTensor] = None, |
| # V is derived from K's first head_dim_v dimensions if not provided. | ||
| if v is not None: | ||
| v_unpaged = v # TODO: unpage V similarly if provided | ||
| else: | ||
| v_unpaged = k_unpaged[..., :head_dim_v_val] # (B, nkv, seq_len, head_dim_v) |
1f6a781 to
a659ee9
Compare
a659ee9 to
e354eac
Compare
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #7491 +/- ##
==========================================
- Coverage 69.82% 69.80% -0.03%
==========================================
Files 419 419
Lines 74115 74479 +364
==========================================
+ Hits 51753 51990 +237
- Misses 22362 22489 +127 ☔ View full report in Codecov by Sentry. |
Ticket
#7388, #7389, #7390
Problem description
Various LLM ops exist in TTNN lib but aren't modelled thru forge stack, enumerated here: #7341
What's changed
Add 3 ops from the list:
ttnn.gatherttnn.all_reduce_asyncttnn.paged_flash_multi_latent_attention_decodeOps are added to dialects but no fusion patterns are included, to be added subsequently.
Checklist