Skip to content

Add 3 ops: gather_dim, all_reduce_async, paged_flash_multi_latent_attention_decode#7491

Open
svuckovicTT wants to merge 1 commit intomainfrom
svuckovic/add-3-ops
Open

Add 3 ops: gather_dim, all_reduce_async, paged_flash_multi_latent_attention_decode#7491
svuckovicTT wants to merge 1 commit intomainfrom
svuckovic/add-3-ops

Conversation

@svuckovicTT
Copy link
Contributor

Ticket

#7388, #7389, #7390

Problem description

Various LLM ops exist in TTNN lib but aren't modelled thru forge stack, enumerated here: #7341

What's changed

Add 3 ops from the list:

  • ttnn.gather
  • ttnn.all_reduce_async
  • ttnn.paged_flash_multi_latent_attention_decode

Ops are added to dialects but no fusion patterns are included, to be added subsequently.

Checklist

  • New/Existing tests provide coverage for changes

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds missing TTNN/LLM operations to the Forge/TTMLIR stack by introducing new TTIR/TTNN ops and wiring them through lowering, serialization, runtime execution, and tests.

Changes:

  • Introduce new ops: ttir.gather_dimttnn.gather, ttir.all_reduce_asyncttnn.all_reduce_async, and ttir.paged_flash_multi_latent_attention_decodettnn.paged_flash_multi_latent_attention_decode.
  • Extend TTNN flatbuffer schema + TTNNToFlatbuffer emission + runtime executor/IO plumbing to support the new ops.
  • Add golden mappings and MLIR + Python golden tests for the new ops.

Reviewed changes

Copilot reviewed 40 out of 40 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
tools/ttnn-standalone/ttnn-precompiled.hpp Adds TTNN header includes for new ops in standalone build context.
tools/golden/mapping.py Adds golden implementations and mappings for new ops.
tools/builder/ttir/ttir_builder.py Adds TTIRBuilder construction/parsing/splitting for new TTIR ops.
test/ttmlir/EmitC/TTNN/transformer/paged_flash_multi_latent_attention_decode.mlir EmitC translation coverage for new transformer op.
test/ttmlir/EmitC/TTNN/gather/gather_dim.mlir EmitC translation coverage for ttir.gather_dim.
test/ttmlir/EmitC/TTNN/gather/gather.mlir EmitC translation coverage for gather case.
test/ttmlir/EmitC/TTNN/ccl/all_reduce_async.mlir EmitC translation coverage for ttir.all_reduce_async.
test/ttmlir/Dialect/TTNN/paged_flash_multi_latent_attention_decode/simple_paged_flash_mla_decode.mlir Checks TTIR→TTNN lowering emits ttnn.paged_flash_multi_latent_attention_decode.
test/ttmlir/Dialect/TTNN/gather/simple_gather_dim.mlir Checks TTIR→TTNN lowering emits ttnn.gather.
test/ttmlir/Dialect/TTNN/gather/simple_gather.mlir More FileCheck coverage for gather lowering across dims/ranks.
test/ttmlir/Dialect/TTNN/all_reduce_async/simple_all_reduce_async.mlir Checks TTIR→TTNN lowering emits ttnn.all_reduce_async.
test/python/golden/test_ttir_ops.py Adds Python golden tests for gather_dim, all_reduce_async, and paged flash MLA decode.
runtime/lib/ttnn/runtime.cpp Adds runtime tensor input/output ref extraction for new OpTypes.
runtime/lib/ttnn/program_executor.cpp Dispatches new ops to their runtime implementations.
runtime/lib/ttnn/operations/transformer/paged_flash_multi_latent_attention_decode.h Declares runtime runner for new transformer op.
runtime/lib/ttnn/operations/transformer/paged_flash_multi_latent_attention_decode.cpp Implements runtime invocation of TTNN paged flash MLA decode.
runtime/lib/ttnn/operations/data_movement/gather.h Declares runtime runner for gather.
runtime/lib/ttnn/operations/data_movement/gather.cpp Implements runtime invocation of ttnn::gather.
runtime/lib/ttnn/operations/ccl/all_reduce_async.h Declares runtime runner for all_reduce_async.
runtime/lib/ttnn/operations/ccl/all_reduce_async.cpp Implements runtime invocation of ttnn::experimental::all_reduce_async.
runtime/lib/ttnn/operations/CMakeLists.txt Adds new runtime op sources to build.
runtime/include/tt/runtime/detail/ttnn/ttnn.h Adds TTNN gather include to runtime umbrella header.
lib/Target/TTNN/TTNNToFlatbuffer.cpp Emits new TTNN ops into flatbuffer program format.
lib/OpModel/TTNN/TTNNOpModel.cpp Adds OpModel stubs for new TTNN ops.
lib/Dialect/TTNN/Interfaces/TTNNOpModelInterface.cpp Adds OpModel interface hooks for new TTNN ops.
lib/Dialect/TTNN/IR/TTNNWorkaroundsPass.cpp Adds operand layout workarounds for MLA decode inputs.
lib/Dialect/TTNN/IR/TTNNOps.cpp Adds verifiers for TTNN GatherOp, AllReduceAsyncOp, and MLA decode op.
lib/Dialect/TTIR/IR/TTIROps.cpp Adds verifiers for TTIR GatherDimOp, AllReduceAsyncOp, and MLA decode op.
lib/Conversion/TTNNToEmitPy/TTNNToEmitPy.cpp Adds EmitPy conversions for new TTNN ops.
lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp Adds EmitC conversions for new TTNN ops.
lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp Lowers new TTIR ops to corresponding TTNN ops.
include/ttmlir/Target/TTNN/program.fbs Adds new ops to OpType union for flatbuffer schema.
include/ttmlir/Target/TTNN/operations/transformer.fbs Adds flatbuffer table for MLA decode op.
include/ttmlir/Target/TTNN/operations/data_movement.fbs Adds flatbuffer table for Gather op.
include/ttmlir/Target/TTNN/operations/ccl.fbs Adds flatbuffer table for AllReduceAsync op.
include/ttmlir/OpModel/TTNN/TTNNOpModel.h Declares OpModel specializations for new TTNN ops.
include/ttmlir/OpModel/TTNN/MetalHeaders.h Adds metal header includes for new TTNN ops (opmodel build).
include/ttmlir/Dialect/TTNN/IR/TTNNWorkaroundsPass.h Declares new workaround factory method.
include/ttmlir/Dialect/TTNN/IR/TTNNOps.td Declares new TTNN ops in ODS.
include/ttmlir/Dialect/TTIR/IR/TTIROps.td Declares new TTIR ops in ODS.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

srcOp, adaptor, rewriter, this->isGoldenModeEnabled());

llvm::SmallVector<mlir::Attribute> args{
emitter.emit(srcOp.getInput(), "input_tensor"),
Comment on lines +2964 to +2969
llvm::SmallVector<mlir::Attribute> args{
emitter.emit(srcOp.getInput()),
emitter.emit(srcOp.getClusterAxis()),
emitter.emitSubDeviceId(srcOp.getSubDeviceId()),
emitter.emit(srcOp.getMemoryConfig()),
emitter.emit(srcOp.getNumLinks()),
Comment on lines +6257 to +6261
value: Optional[GoldenMapTensor] = None,
page_table: Optional[GoldenMapTensor] = None,
attention_mask: Optional[GoldenMapTensor] = None,
cur_pos_tensor: Optional[GoldenMapTensor] = None,
attention_sink: Optional[GoldenMapTensor] = None,
Comment on lines +6295 to +6299
# V is derived from K's first head_dim_v dimensions if not provided.
if v is not None:
v_unpaged = v # TODO: unpage V similarly if provided
else:
v_unpaged = k_unpaged[..., :head_dim_v_val] # (B, nkv, seq_len, head_dim_v)
@svuckovicTT svuckovicTT force-pushed the svuckovic/add-3-ops branch from 1f6a781 to a659ee9 Compare March 13, 2026 20:18
@svuckovicTT svuckovicTT force-pushed the svuckovic/add-3-ops branch from a659ee9 to e354eac Compare March 13, 2026 20:20
@codecov-commenter
Copy link

Codecov Report

❌ Patch coverage is 60.92896% with 143 lines in your changes missing coverage. Please review.
✅ Project coverage is 69.80%. Comparing base (6943de1) to head (e354eac).
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
lib/Conversion/TTNNToEmitPy/TTNNToEmitPy.cpp 3.57% 54 Missing ⚠️
lib/Dialect/TTNN/IR/TTNNOps.cpp 57.14% 21 Missing ⚠️
lib/Dialect/TTIR/IR/TTIROps.cpp 61.22% 19 Missing ⚠️
...b/Dialect/TTNN/Interfaces/TTNNOpModelInterface.cpp 0.00% 18 Missing ⚠️
lib/OpModel/TTNN/TTNNOpModel.cpp 0.00% 12 Missing ⚠️
lib/Target/TTNN/TTNNToFlatbuffer.cpp 84.05% 11 Missing ⚠️
lib/Dialect/TTNN/IR/TTNNWorkaroundsPass.cpp 71.42% 8 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #7491      +/-   ##
==========================================
- Coverage   69.82%   69.80%   -0.03%     
==========================================
  Files         419      419              
  Lines       74115    74479     +364     
==========================================
+ Hits        51753    51990     +237     
- Misses      22362    22489     +127     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants