66#include " mlir/Conversion/LLVMCommon/Pattern.h"
77#include " mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
88#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
9- #include " mlir/IR/ImplicitLocOpBuilder.h"
109#include " mlir/Interfaces/FunctionInterfaces.h"
1110#include " triton/Analysis/Utility.h"
1211#include " triton/Conversion/MLIRTypes.h"
@@ -31,243 +30,247 @@ using namespace mlir;
3130using namespace mlir ::triton;
3231
3332namespace mlir ::triton {
33+
34+ // Returns CTA level thread idx
35+ inline Value getThreadId (OpBuilder &rewriter, Location loc) {
36+ Value tid =
37+ rewriter.create <::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x);
38+ Type i32_ty = rewriter.getIntegerType (32 );
39+ return rewriter.create <arith::IndexCastOp>(loc, i32_ty, tid);
40+ }
41+
3442struct TritonLLVMOpBuilder {
35- TritonLLVMOpBuilder (const Location &loc, RewriterBase &builder)
36- : loc(loc), builder(builder) {}
43+ TritonLLVMOpBuilder (const Location &loc, OpBuilder &builder)
44+ : loc(loc), builder(& builder) {}
3745 // Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
3846 // Operators
3947 template <typename ... Args> LLVM::SIToFPOp inttofloat (Args &&...args) {
40- return builder. create <LLVM::SIToFPOp>(loc, std::forward<Args>(args)...);
48+ return builder-> create <LLVM::SIToFPOp>(loc, std::forward<Args>(args)...);
4149 }
4250 template <typename ... Args> LLVM::IntToPtrOp inttoptr (Args &&...args) {
43- return builder. create <LLVM::IntToPtrOp>(loc, std::forward<Args>(args)...);
51+ return builder-> create <LLVM::IntToPtrOp>(loc, std::forward<Args>(args)...);
4452 }
4553 template <typename ... Args> LLVM::PtrToIntOp ptrtoint (Args &&...args) {
46- return builder. create <LLVM::PtrToIntOp>(loc, std::forward<Args>(args)...);
54+ return builder-> create <LLVM::PtrToIntOp>(loc, std::forward<Args>(args)...);
4755 }
4856 template <typename ... Args> LLVM::ZExtOp zext (Args &&...args) {
49- return builder. create <LLVM::ZExtOp>(loc, std::forward<Args>(args)...);
57+ return builder-> create <LLVM::ZExtOp>(loc, std::forward<Args>(args)...);
5058 }
5159 template <typename ... Args> LLVM::SExtOp sext (Args &&...args) {
52- return builder. create <LLVM::SExtOp>(loc, std::forward<Args>(args)...);
60+ return builder-> create <LLVM::SExtOp>(loc, std::forward<Args>(args)...);
5361 }
5462 template <typename ... Args> LLVM::FPExtOp fpext (Args &&...args) {
55- return builder. create <LLVM::FPExtOp>(loc, std::forward<Args>(args)...);
63+ return builder-> create <LLVM::FPExtOp>(loc, std::forward<Args>(args)...);
5664 }
5765 template <typename ... Args> LLVM::FPTruncOp fptrunc (Args &&...args) {
58- return builder. create <LLVM::FPTruncOp>(loc, std::forward<Args>(args)...);
66+ return builder-> create <LLVM::FPTruncOp>(loc, std::forward<Args>(args)...);
5967 }
6068 template <typename ... Args> LLVM::TruncOp trunc (Args &&...args) {
61- return builder. create <LLVM::TruncOp>(loc, std::forward<Args>(args)...);
69+ return builder-> create <LLVM::TruncOp>(loc, std::forward<Args>(args)...);
6270 }
6371 template <typename ... Args> LLVM::UDivOp udiv (Args &&...args) {
64- return builder. create <LLVM::UDivOp>(loc, std::forward<Args>(args)...);
72+ return builder-> create <LLVM::UDivOp>(loc, std::forward<Args>(args)...);
6573 }
6674 template <typename ... Args> LLVM::SDivOp sdiv (Args &&...args) {
67- return builder. create <LLVM::SDivOp>(loc, std::forward<Args>(args)...);
75+ return builder-> create <LLVM::SDivOp>(loc, std::forward<Args>(args)...);
6876 }
6977 template <typename ... Args> LLVM::URemOp urem (Args &&...args) {
70- return builder. create <LLVM::URemOp>(loc, std::forward<Args>(args)...);
78+ return builder-> create <LLVM::URemOp>(loc, std::forward<Args>(args)...);
7179 }
7280 template <typename ... Args> LLVM::AddOp add (Args &&...args) {
73- return builder. create <LLVM::AddOp>(loc, std::forward<Args>(args)...);
81+ return builder-> create <LLVM::AddOp>(loc, std::forward<Args>(args)...);
7482 }
7583 template <typename ... Args> LLVM::SubOp sub (Args &&...args) {
76- return builder. create <LLVM::SubOp>(loc, std::forward<Args>(args)...);
84+ return builder-> create <LLVM::SubOp>(loc, std::forward<Args>(args)...);
7785 }
7886 template <typename ... Args> LLVM::FAddOp fadd (Args &&...args) {
79- return builder. create <LLVM::FAddOp>(loc, std::forward<Args>(args)...);
87+ return builder-> create <LLVM::FAddOp>(loc, std::forward<Args>(args)...);
8088 }
8189 template <typename ... Args> LLVM::MulOp mul (Args &&...args) {
82- return builder. create <LLVM::MulOp>(loc, std::forward<Args>(args)...);
90+ return builder-> create <LLVM::MulOp>(loc, std::forward<Args>(args)...);
8391 }
8492 template <typename ... Args> LLVM::FMulOp fmul (Args &&...args) {
85- return builder. create <LLVM::FMulOp>(loc, std::forward<Args>(args)...);
93+ return builder-> create <LLVM::FMulOp>(loc, std::forward<Args>(args)...);
8694 }
8795 template <typename ... Args> LLVM::FMAOp fma (Args &&...args) {
88- return builder. create <LLVM::FMAOp>(loc, std::forward<Args>(args)...);
96+ return builder-> create <LLVM::FMAOp>(loc, std::forward<Args>(args)...);
8997 }
9098 template <typename ... Args> LLVM::FNegOp neg (Args &&...args) {
91- return builder. create <LLVM::FNegOp>(loc, std::forward<Args>(args)...);
99+ return builder-> create <LLVM::FNegOp>(loc, std::forward<Args>(args)...);
92100 }
93101 template <typename ... Args> LLVM::SMaxOp smax (Args &&...args) {
94- return builder. create <LLVM::SMaxOp>(loc, std::forward<Args>(args)...);
102+ return builder-> create <LLVM::SMaxOp>(loc, std::forward<Args>(args)...);
95103 }
96104 template <typename ... Args> LLVM::UMaxOp umax (Args &&...args) {
97- return builder. create <LLVM::UMaxOp>(loc, std::forward<Args>(args)...);
105+ return builder-> create <LLVM::UMaxOp>(loc, std::forward<Args>(args)...);
98106 }
99107 template <typename ... Args> LLVM::MaxNumOp fmax (Args &&...args) {
100- return builder. create <LLVM::MaxNumOp>(loc, std::forward<Args>(args)...);
108+ return builder-> create <LLVM::MaxNumOp>(loc, std::forward<Args>(args)...);
101109 }
102110 template <typename ... Args> LLVM::SMinOp smin (Args &&...args) {
103- return builder. create <LLVM::SMinOp>(loc, std::forward<Args>(args)...);
111+ return builder-> create <LLVM::SMinOp>(loc, std::forward<Args>(args)...);
104112 }
105113 template <typename ... Args> LLVM::UMinOp umin (Args &&...args) {
106- return builder. create <LLVM::UMinOp>(loc, std::forward<Args>(args)...);
114+ return builder-> create <LLVM::UMinOp>(loc, std::forward<Args>(args)...);
107115 }
108116 template <typename ... Args> LLVM::MinNumOp fmin (Args &&...args) {
109- return builder. create <LLVM::MinNumOp>(loc, std::forward<Args>(args)...);
117+ return builder-> create <LLVM::MinNumOp>(loc, std::forward<Args>(args)...);
110118 }
111119 template <typename ... Args> LLVM::ShlOp shl (Args &&...args) {
112- return builder. create <LLVM::ShlOp>(loc, std::forward<Args>(args)...);
120+ return builder-> create <LLVM::ShlOp>(loc, std::forward<Args>(args)...);
113121 }
114122 template <typename ... Args> LLVM::LShrOp lshr (Args &&...args) {
115- return builder. create <LLVM::LShrOp>(loc, std::forward<Args>(args)...);
123+ return builder-> create <LLVM::LShrOp>(loc, std::forward<Args>(args)...);
116124 }
117125 template <typename ... Args> LLVM::AShrOp ashr (Args &&...args) {
118- return builder. create <LLVM::AShrOp>(loc, std::forward<Args>(args)...);
126+ return builder-> create <LLVM::AShrOp>(loc, std::forward<Args>(args)...);
119127 }
120128 template <typename ... Args> LLVM::AndOp and_ (Args &&...args) {
121- return builder. create <LLVM::AndOp>(loc, std::forward<Args>(args)...);
129+ return builder-> create <LLVM::AndOp>(loc, std::forward<Args>(args)...);
122130 }
123131 template <typename ... Args> LLVM::XOrOp xor_ (Args &&...args) {
124- return builder. create <LLVM::XOrOp>(loc, std::forward<Args>(args)...);
132+ return builder-> create <LLVM::XOrOp>(loc, std::forward<Args>(args)...);
125133 }
126134 template <typename ... Args> LLVM::OrOp or_ (Args &&...args) {
127- return builder. create <LLVM::OrOp>(loc, std::forward<Args>(args)...);
135+ return builder-> create <LLVM::OrOp>(loc, std::forward<Args>(args)...);
128136 }
129137 LLVM::BitcastOp bitcast (Value val, Type type) {
130- return builder. create <LLVM::BitcastOp>(loc, type, val);
138+ return builder-> create <LLVM::BitcastOp>(loc, type, val);
131139 }
132140 template <typename ... Args>
133141 LLVM::AddrSpaceCastOp addrspacecast (Args &&...args) {
134- return builder. create <LLVM::AddrSpaceCastOp>(loc,
135- std::forward<Args>(args)...);
142+ return builder-> create <LLVM::AddrSpaceCastOp>(loc,
143+ std::forward<Args>(args)...);
136144 }
137145 template <typename ... Args> LLVM::GEPOp gep (Args &&...args) {
138- return builder. create <LLVM::GEPOp>(loc, std::forward<Args>(args)...);
146+ return builder-> create <LLVM::GEPOp>(loc, std::forward<Args>(args)...);
139147 }
140148 template <typename ... Args> LLVM::InsertValueOp insert_val (Args &&...args) {
141- return builder. create <LLVM::InsertValueOp>(loc,
142- std::forward<Args>(args)...);
149+ return builder-> create <LLVM::InsertValueOp>(loc,
150+ std::forward<Args>(args)...);
143151 }
144152 template <typename ... Args> LLVM::ExtractValueOp extract_val (Args &&...args) {
145- return builder. create <LLVM::ExtractValueOp>(loc,
146- std::forward<Args>(args)...);
153+ return builder-> create <LLVM::ExtractValueOp>(loc,
154+ std::forward<Args>(args)...);
147155 }
148156 template <typename ... Args>
149157 LLVM::InsertElementOp insert_element (Args &&...args) {
150- return builder. create <LLVM::InsertElementOp>(loc,
151- std::forward<Args>(args)...);
158+ return builder-> create <LLVM::InsertElementOp>(loc,
159+ std::forward<Args>(args)...);
152160 }
153161 template <typename ... Args>
154162 LLVM::ExtractElementOp extract_element (Args &&...args) {
155- return builder. create <LLVM::ExtractElementOp>(loc,
156- std::forward<Args>(args)...);
163+ return builder-> create <LLVM::ExtractElementOp>(loc,
164+ std::forward<Args>(args)...);
157165 }
158166 template <typename ... Args> LLVM::LoadOp load (Args &&...args) {
159- return builder. create <LLVM::LoadOp>(loc, std::forward<Args>(args)...);
167+ return builder-> create <LLVM::LoadOp>(loc, std::forward<Args>(args)...);
160168 }
161169 template <typename ... Args> LLVM::StoreOp store (Args &&...args) {
162- return builder. create <LLVM::StoreOp>(loc, std::forward<Args>(args)...);
170+ return builder-> create <LLVM::StoreOp>(loc, std::forward<Args>(args)...);
163171 }
164- template < typename ... Args> LLVM::FCmpOp fcmp_ogt (Value lhs, Value rhs) {
165- return builder. create <LLVM::FCmpOp>(loc, builder. getI1Type (),
166- LLVM::FCmpPredicate::ogt, lhs, rhs);
172+ LLVM::FCmpOp fcmp_ogt (Value lhs, Value rhs) {
173+ return builder-> create <LLVM::FCmpOp>(loc, builder-> getI1Type (),
174+ LLVM::FCmpPredicate::ogt, lhs, rhs);
167175 }
168- template < typename ... Args> LLVM::FCmpOp fcmp_olt (Value lhs, Value rhs) {
169- return builder. create <LLVM::FCmpOp>(loc, builder. getI1Type (),
170- LLVM::FCmpPredicate::olt, lhs, rhs);
176+ LLVM::FCmpOp fcmp_olt (Value lhs, Value rhs) {
177+ return builder-> create <LLVM::FCmpOp>(loc, builder-> getI1Type (),
178+ LLVM::FCmpPredicate::olt, lhs, rhs);
171179 }
172- template < typename ... Args> LLVM::FCmpOp fcmp_eq (Value lhs, Value rhs) {
173- return builder. create <LLVM::FCmpOp>(loc, builder. getI1Type (),
174- LLVM::FCmpPredicate::oeq, lhs, rhs);
180+ LLVM::FCmpOp fcmp_eq (Value lhs, Value rhs) {
181+ return builder-> create <LLVM::FCmpOp>(loc, builder-> getI1Type (),
182+ LLVM::FCmpPredicate::oeq, lhs, rhs);
175183 }
176184 template <typename ... Args> LLVM::ICmpOp icmp_eq (Args &&...args) {
177- return builder. create <LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq,
178- std::forward<Args>(args)...);
185+ return builder-> create <LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq,
186+ std::forward<Args>(args)...);
179187 }
180188 template <typename ... Args> LLVM::ICmpOp icmp_ne (Args &&...args) {
181- return builder. create <LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne,
182- std::forward<Args>(args)...);
189+ return builder-> create <LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne,
190+ std::forward<Args>(args)...);
183191 }
184192 template <typename ... Args> LLVM::ICmpOp icmp_slt (Args &&...args) {
185- return builder. create <LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt,
186- std::forward<Args>(args)...);
193+ return builder-> create <LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt,
194+ std::forward<Args>(args)...);
187195 }
188196 template <typename ... Args> LLVM::ICmpOp icmp_sle (Args &&...args) {
189- return builder. create <LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sle,
190- std::forward<Args>(args)...);
197+ return builder-> create <LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sle,
198+ std::forward<Args>(args)...);
191199 }
192200 template <typename ... Args> LLVM::ICmpOp icmp_sgt (Args &&...args) {
193- return builder. create <LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt,
194- std::forward<Args>(args)...);
201+ return builder-> create <LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt,
202+ std::forward<Args>(args)...);
195203 }
196204 template <typename ... Args> LLVM::ICmpOp icmp_sge (Args &&...args) {
197- return builder. create <LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sge,
198- std::forward<Args>(args)...);
205+ return builder-> create <LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sge,
206+ std::forward<Args>(args)...);
199207 }
200208 template <typename ... Args> LLVM::ICmpOp icmp_ult (Args &&...args) {
201- return builder. create <LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ult,
202- std::forward<Args>(args)...);
209+ return builder-> create <LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ult,
210+ std::forward<Args>(args)...);
203211 }
204212 template <typename ... Args> LLVM::ICmpOp icmp_ule (Args &&...args) {
205- return builder. create <LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ule,
206- std::forward<Args>(args)...);
213+ return builder-> create <LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ule,
214+ std::forward<Args>(args)...);
207215 }
208216 template <typename ... Args> LLVM::ICmpOp icmp_ugt (Args &&...args) {
209- return builder. create <LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ugt,
210- std::forward<Args>(args)...);
217+ return builder-> create <LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ugt,
218+ std::forward<Args>(args)...);
211219 }
212220 template <typename ... Args> LLVM::ICmpOp icmp_uge (Args &&...args) {
213- return builder. create <LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::uge,
214- std::forward<Args>(args)...);
221+ return builder-> create <LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::uge,
222+ std::forward<Args>(args)...);
215223 }
216224 template <typename ... Args> LLVM::SelectOp select (Args &&...args) {
217- return builder. create <LLVM::SelectOp>(loc, std::forward<Args>(args)...);
225+ return builder-> create <LLVM::SelectOp>(loc, std::forward<Args>(args)...);
218226 }
219227 template <typename ... Args> LLVM::AddressOfOp address_of (Args &&...args) {
220- return builder. create <LLVM::AddressOfOp>(loc, std::forward<Args>(args)...);
228+ return builder-> create <LLVM::AddressOfOp>(loc, std::forward<Args>(args)...);
221229 }
222230 mlir::gpu::BarrierOp barrier () {
223- return builder. create <mlir::gpu::BarrierOp>(loc);
231+ return builder-> create <mlir::gpu::BarrierOp>(loc);
224232 }
225233 template <typename ... Args> LLVM::UndefOp undef (Args &&...args) {
226- return builder. create <LLVM::UndefOp>(loc, std::forward<Args>(args)...);
234+ return builder-> create <LLVM::UndefOp>(loc, std::forward<Args>(args)...);
227235 }
228236 template <typename ... Args> LLVM::ZeroOp null (Args &&...args) {
229- return builder. create <LLVM::ZeroOp>(loc, std::forward<Args>(args)...);
237+ return builder-> create <LLVM::ZeroOp>(loc, std::forward<Args>(args)...);
230238 }
231239 template <typename ... Args> LLVM::CallOp call (Args &&...args) {
232- return builder. create <LLVM::CallOp>(loc, std::forward<Args>(args)...);
240+ return builder-> create <LLVM::CallOp>(loc, std::forward<Args>(args)...);
233241 }
234242 // Constants
235243 Value int_val (short bitwidth, int64_t val) {
236- Type ty = builder. getIntegerType (bitwidth);
237- return builder. create <LLVM::ConstantOp>(loc, ty,
238- builder. getIntegerAttr (ty, val));
244+ Type ty = builder-> getIntegerType (bitwidth);
245+ return builder-> create <LLVM::ConstantOp>(loc, ty,
246+ builder-> getIntegerAttr (ty, val));
239247 }
240248 Value i1_val (int64_t val) { return int_val (1 , val); }
241249 Value true_val () { return int_val (1 , true ); }
242250 Value false_val () { return int_val (1 , false ); }
243251 Value f16_val (float v) {
244- auto type = type::f16Ty (builder. getContext ());
245- return builder. create <LLVM::ConstantOp>(loc, type,
246- builder. getF16FloatAttr (v));
252+ auto type = type::f16Ty (builder-> getContext ());
253+ return builder-> create <LLVM::ConstantOp>(loc, type,
254+ builder-> getF16FloatAttr (v));
247255 }
248256 Value f32_val (float v) {
249- auto type = type::f32Ty (builder. getContext ());
250- return builder. create <LLVM::ConstantOp>(loc, type,
251- builder. getF32FloatAttr (v));
257+ auto type = type::f32Ty (builder-> getContext ());
258+ return builder-> create <LLVM::ConstantOp>(loc, type,
259+ builder-> getF32FloatAttr (v));
252260 }
253261 Value f64_val (double v) {
254- auto type = type::f64Ty (builder. getContext ());
255- return builder. create <LLVM::ConstantOp>(loc, type,
256- builder. getF64FloatAttr (v));
262+ auto type = type::f64Ty (builder-> getContext ());
263+ return builder-> create <LLVM::ConstantOp>(loc, type,
264+ builder-> getF64FloatAttr (v));
257265 }
258266 Value i8_val (int64_t val) { return int_val (8 , val); }
259267 Value i16_val (int64_t val) { return int_val (16 , val); }
260268 Value i32_val (int64_t val) { return int_val (32 , val); }
261269 Value i64_val (int64_t val) { return int_val (64 , val); }
262- Value tid_val () {
263- Value tid =
264- builder.create <::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x);
265- Type i32_ty = builder.getIntegerType (32 );
266- return builder.create <arith::IndexCastOp>(loc, i32_ty, tid);
267- }
270+ Value tid_val () { return getThreadId (*this ->builder , loc); }
268271
269272 Location loc;
270- RewriterBase & builder;
273+ OpBuilder * builder;
271274};
272275} // namespace mlir::triton
273276
@@ -657,14 +660,6 @@ Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale,
657660
658661} // namespace LLVM
659662
660- /* ------------------------------------ */
661- // Returns CTA level thread idx
662- inline Value getThreadId (RewriterBase &rewriter, Location loc) {
663- Value tid =
664- rewriter.create <::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x);
665- return rewriter.create <arith::IndexCastOp>(loc, i32_ty, tid);
666- }
667-
668663// -----------------------------------------------------------------------
669664// Shared memory utilities
670665// -----------------------------------------------------------------------
0 commit comments