Linalg in Triton #1842
Replies: 11 comments 10 replies
-
@ingomueller-net sorry for misspelling your screenname. |
Beta Was this translation helpful? Give feedback.
-
@manbearian Sure, that would be great! We're a bit confused about the program you mentioned in #1797 . Could you please explain it in more detail?
|
Beta Was this translation helpful? Give feedback.
-
I think both of your questions are related to the model we're proposing. In the proposed model we differentiate between Triton pointers and Triton values/tensors/blocks. We translate the pointers into unranked memrefs and the values into proper tensors. We haven't encountered any function signature mismatch, unranked memrefs are just lowered as raw pointers, so exactly match the original kernel. are tiling and fusion code runs on the linalg operations and ignores the load/stores. Since the load stores represent the boundary between data of unknown size in shared memory and data of known size in local memory this works out for us.
|
Beta Was this translation helpful? Give feedback.
-
Yes, thank you for your explanations, make sense, but I have a minor question regarding the function signature. In memref-to-llvm conversion pass, unranked memref will be converted to {size, raw_ptr}, that's why i questions about function signature, I don't know if I missed something. |
Beta Was this translation helpful? Give feedback.
-
Regarding to the pointer continuity analysis, just as the We obtained the analysis algorithm by expanding Triton's own
The meaning of
We use
By doing this, we think there are the following advantages:
I am really looking forward to hearing your feedback or response regarding the proposal. Could you please provide some feedback or suggestions? |
Beta Was this translation helpful? Give feedback.
-
Hi: Thanks Ian, Nhat and folks at Microsoft for this triton-to-linalg contribution. Very useful. MEMREF For instance, in MLIR it is possible to lower ranked memref to bare pointer using e.g. MEMCOPY/ALLOC |
Beta Was this translation helpful? Give feedback.
-
Hi, folks. @javedabsar1 , let me address your first point:
Is it that the model doesn't make sense for your HW target? Or that doesn't work with your optimizer (e.g., copies might be inserted later in your optimizer flow)? Thanks! And please keep the conversation going! |
Beta Was this translation helpful? Give feedback.
-
We have three things planned to update "soon"-ish
i don't have a time-frame for these yet, but i hope to land something in August with some of this. |
Beta Was this translation helpful? Give feedback.
-
@manbearian How are the Microsoft folks getting Triton IR to feed into the My solution is a bit general purpose, so if the Microsoft folks have a more elegant integration of Triton IR compilation to |
Beta Was this translation helpful? Give feedback.
-
With what is in this branch i think working form the triton IR (ttir) makes the most sense. If you look at the tests we've checked-in this is what they're doing (none of them invoke the python code path). In order to generate the TTIR we do something like this: ret = triton.compile(kernel, signature="*fp32,i32,*fp32,i32", constants={"BLOCK_M": 64, "BLOCK_N": 64})
print(ret.asm["ttir"]) But, as you mentioned i believe this invokes the full pipeline. We have an end-to-end test compiler that we're using for this work internally and it creates its own pass pipeline. We pass an extra argument to the |
Beta Was this translation helpful? Give feedback.
-
FYI. We've republished the work as a plug-in: PR redone as a plug-in: #2374 |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
@sethbrin and @ingomueller-net and @nhat-nguyen
Hi folks,
Nhat and i work on ML compilation at Microsoft. Our team is very interested in adding a lowering from Triton IR to Linalg IR to the Triton compiler. Our goal is to create a lowering to a common MLIR dialect that teams can use to build, or leverage existing, code generators and analysis for Triton. Our goal is not to leverage the Llnalg dialect in cases where the TritonGPU dialect is used, but rather as an alternative path.
i'd like to use this thread talk about how to converge our efforts to add the linalg dialect on Triton.
For reference here are three approaches from our respective groups:
#1797
#1542
https://github.com/iree-org/iree-llvm-sandbox/tree/main/lib/Conversion/TritonToLLVM
The Microsoft contribution in PR #1797 has two parts. First, a pointer analysis pass that identifies contiguous loads/stores and the actual conversion between dialects. The pointer analysis is someone orthogonal to the Linalg dialect lowering, however i believe it is beneficial to non-GPU architectures; this pointers analysis can fail (since not all loads/stores are contiguous) and using a loop, scatter/gather is required, but not yet implemented in our approach.
Can we collaborate using this branch: https://github.com/openai/triton/tree/triton-to-linalg ?
Beta Was this translation helpful? Give feedback.
All reactions