Skip to content

Commit 57f742b

Browse files
committed
comments
1 parent 98efa53 commit 57f742b

File tree

1 file changed

+104
-109
lines changed
  • include/triton/Conversion/TritonGPUToLLVM

1 file changed

+104
-109
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 104 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include "mlir/Conversion/LLVMCommon/Pattern.h"
77
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
88
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
9-
#include "mlir/IR/ImplicitLocOpBuilder.h"
109
#include "mlir/Interfaces/FunctionInterfaces.h"
1110
#include "triton/Analysis/Utility.h"
1211
#include "triton/Conversion/MLIRTypes.h"
@@ -31,243 +30,247 @@ using namespace mlir;
3130
using namespace mlir::triton;
3231

3332
namespace mlir::triton {
33+
34+
// Returns CTA level thread idx
35+
inline Value getThreadId(OpBuilder &rewriter, Location loc) {
36+
Value tid =
37+
rewriter.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x);
38+
Type i32_ty = rewriter.getIntegerType(32);
39+
return rewriter.create<arith::IndexCastOp>(loc, i32_ty, tid);
40+
}
41+
3442
struct TritonLLVMOpBuilder {
35-
TritonLLVMOpBuilder(const Location &loc, RewriterBase &builder)
36-
: loc(loc), builder(builder) {}
43+
TritonLLVMOpBuilder(const Location &loc, OpBuilder &builder)
44+
: loc(loc), builder(&builder) {}
3745
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
3846
// Operators
3947
template <typename... Args> LLVM::SIToFPOp inttofloat(Args &&...args) {
40-
return builder.create<LLVM::SIToFPOp>(loc, std::forward<Args>(args)...);
48+
return builder->create<LLVM::SIToFPOp>(loc, std::forward<Args>(args)...);
4149
}
4250
template <typename... Args> LLVM::IntToPtrOp inttoptr(Args &&...args) {
43-
return builder.create<LLVM::IntToPtrOp>(loc, std::forward<Args>(args)...);
51+
return builder->create<LLVM::IntToPtrOp>(loc, std::forward<Args>(args)...);
4452
}
4553
template <typename... Args> LLVM::PtrToIntOp ptrtoint(Args &&...args) {
46-
return builder.create<LLVM::PtrToIntOp>(loc, std::forward<Args>(args)...);
54+
return builder->create<LLVM::PtrToIntOp>(loc, std::forward<Args>(args)...);
4755
}
4856
template <typename... Args> LLVM::ZExtOp zext(Args &&...args) {
49-
return builder.create<LLVM::ZExtOp>(loc, std::forward<Args>(args)...);
57+
return builder->create<LLVM::ZExtOp>(loc, std::forward<Args>(args)...);
5058
}
5159
template <typename... Args> LLVM::SExtOp sext(Args &&...args) {
52-
return builder.create<LLVM::SExtOp>(loc, std::forward<Args>(args)...);
60+
return builder->create<LLVM::SExtOp>(loc, std::forward<Args>(args)...);
5361
}
5462
template <typename... Args> LLVM::FPExtOp fpext(Args &&...args) {
55-
return builder.create<LLVM::FPExtOp>(loc, std::forward<Args>(args)...);
63+
return builder->create<LLVM::FPExtOp>(loc, std::forward<Args>(args)...);
5664
}
5765
template <typename... Args> LLVM::FPTruncOp fptrunc(Args &&...args) {
58-
return builder.create<LLVM::FPTruncOp>(loc, std::forward<Args>(args)...);
66+
return builder->create<LLVM::FPTruncOp>(loc, std::forward<Args>(args)...);
5967
}
6068
template <typename... Args> LLVM::TruncOp trunc(Args &&...args) {
61-
return builder.create<LLVM::TruncOp>(loc, std::forward<Args>(args)...);
69+
return builder->create<LLVM::TruncOp>(loc, std::forward<Args>(args)...);
6270
}
6371
template <typename... Args> LLVM::UDivOp udiv(Args &&...args) {
64-
return builder.create<LLVM::UDivOp>(loc, std::forward<Args>(args)...);
72+
return builder->create<LLVM::UDivOp>(loc, std::forward<Args>(args)...);
6573
}
6674
template <typename... Args> LLVM::SDivOp sdiv(Args &&...args) {
67-
return builder.create<LLVM::SDivOp>(loc, std::forward<Args>(args)...);
75+
return builder->create<LLVM::SDivOp>(loc, std::forward<Args>(args)...);
6876
}
6977
template <typename... Args> LLVM::URemOp urem(Args &&...args) {
70-
return builder.create<LLVM::URemOp>(loc, std::forward<Args>(args)...);
78+
return builder->create<LLVM::URemOp>(loc, std::forward<Args>(args)...);
7179
}
7280
template <typename... Args> LLVM::AddOp add(Args &&...args) {
73-
return builder.create<LLVM::AddOp>(loc, std::forward<Args>(args)...);
81+
return builder->create<LLVM::AddOp>(loc, std::forward<Args>(args)...);
7482
}
7583
template <typename... Args> LLVM::SubOp sub(Args &&...args) {
76-
return builder.create<LLVM::SubOp>(loc, std::forward<Args>(args)...);
84+
return builder->create<LLVM::SubOp>(loc, std::forward<Args>(args)...);
7785
}
7886
template <typename... Args> LLVM::FAddOp fadd(Args &&...args) {
79-
return builder.create<LLVM::FAddOp>(loc, std::forward<Args>(args)...);
87+
return builder->create<LLVM::FAddOp>(loc, std::forward<Args>(args)...);
8088
}
8189
template <typename... Args> LLVM::MulOp mul(Args &&...args) {
82-
return builder.create<LLVM::MulOp>(loc, std::forward<Args>(args)...);
90+
return builder->create<LLVM::MulOp>(loc, std::forward<Args>(args)...);
8391
}
8492
template <typename... Args> LLVM::FMulOp fmul(Args &&...args) {
85-
return builder.create<LLVM::FMulOp>(loc, std::forward<Args>(args)...);
93+
return builder->create<LLVM::FMulOp>(loc, std::forward<Args>(args)...);
8694
}
8795
template <typename... Args> LLVM::FMAOp fma(Args &&...args) {
88-
return builder.create<LLVM::FMAOp>(loc, std::forward<Args>(args)...);
96+
return builder->create<LLVM::FMAOp>(loc, std::forward<Args>(args)...);
8997
}
9098
template <typename... Args> LLVM::FNegOp neg(Args &&...args) {
91-
return builder.create<LLVM::FNegOp>(loc, std::forward<Args>(args)...);
99+
return builder->create<LLVM::FNegOp>(loc, std::forward<Args>(args)...);
92100
}
93101
template <typename... Args> LLVM::SMaxOp smax(Args &&...args) {
94-
return builder.create<LLVM::SMaxOp>(loc, std::forward<Args>(args)...);
102+
return builder->create<LLVM::SMaxOp>(loc, std::forward<Args>(args)...);
95103
}
96104
template <typename... Args> LLVM::UMaxOp umax(Args &&...args) {
97-
return builder.create<LLVM::UMaxOp>(loc, std::forward<Args>(args)...);
105+
return builder->create<LLVM::UMaxOp>(loc, std::forward<Args>(args)...);
98106
}
99107
template <typename... Args> LLVM::MaxNumOp fmax(Args &&...args) {
100-
return builder.create<LLVM::MaxNumOp>(loc, std::forward<Args>(args)...);
108+
return builder->create<LLVM::MaxNumOp>(loc, std::forward<Args>(args)...);
101109
}
102110
template <typename... Args> LLVM::SMinOp smin(Args &&...args) {
103-
return builder.create<LLVM::SMinOp>(loc, std::forward<Args>(args)...);
111+
return builder->create<LLVM::SMinOp>(loc, std::forward<Args>(args)...);
104112
}
105113
template <typename... Args> LLVM::UMinOp umin(Args &&...args) {
106-
return builder.create<LLVM::UMinOp>(loc, std::forward<Args>(args)...);
114+
return builder->create<LLVM::UMinOp>(loc, std::forward<Args>(args)...);
107115
}
108116
template <typename... Args> LLVM::MinNumOp fmin(Args &&...args) {
109-
return builder.create<LLVM::MinNumOp>(loc, std::forward<Args>(args)...);
117+
return builder->create<LLVM::MinNumOp>(loc, std::forward<Args>(args)...);
110118
}
111119
template <typename... Args> LLVM::ShlOp shl(Args &&...args) {
112-
return builder.create<LLVM::ShlOp>(loc, std::forward<Args>(args)...);
120+
return builder->create<LLVM::ShlOp>(loc, std::forward<Args>(args)...);
113121
}
114122
template <typename... Args> LLVM::LShrOp lshr(Args &&...args) {
115-
return builder.create<LLVM::LShrOp>(loc, std::forward<Args>(args)...);
123+
return builder->create<LLVM::LShrOp>(loc, std::forward<Args>(args)...);
116124
}
117125
template <typename... Args> LLVM::AShrOp ashr(Args &&...args) {
118-
return builder.create<LLVM::AShrOp>(loc, std::forward<Args>(args)...);
126+
return builder->create<LLVM::AShrOp>(loc, std::forward<Args>(args)...);
119127
}
120128
template <typename... Args> LLVM::AndOp and_(Args &&...args) {
121-
return builder.create<LLVM::AndOp>(loc, std::forward<Args>(args)...);
129+
return builder->create<LLVM::AndOp>(loc, std::forward<Args>(args)...);
122130
}
123131
template <typename... Args> LLVM::XOrOp xor_(Args &&...args) {
124-
return builder.create<LLVM::XOrOp>(loc, std::forward<Args>(args)...);
132+
return builder->create<LLVM::XOrOp>(loc, std::forward<Args>(args)...);
125133
}
126134
template <typename... Args> LLVM::OrOp or_(Args &&...args) {
127-
return builder.create<LLVM::OrOp>(loc, std::forward<Args>(args)...);
135+
return builder->create<LLVM::OrOp>(loc, std::forward<Args>(args)...);
128136
}
129137
LLVM::BitcastOp bitcast(Value val, Type type) {
130-
return builder.create<LLVM::BitcastOp>(loc, type, val);
138+
return builder->create<LLVM::BitcastOp>(loc, type, val);
131139
}
132140
template <typename... Args>
133141
LLVM::AddrSpaceCastOp addrspacecast(Args &&...args) {
134-
return builder.create<LLVM::AddrSpaceCastOp>(loc,
135-
std::forward<Args>(args)...);
142+
return builder->create<LLVM::AddrSpaceCastOp>(loc,
143+
std::forward<Args>(args)...);
136144
}
137145
template <typename... Args> LLVM::GEPOp gep(Args &&...args) {
138-
return builder.create<LLVM::GEPOp>(loc, std::forward<Args>(args)...);
146+
return builder->create<LLVM::GEPOp>(loc, std::forward<Args>(args)...);
139147
}
140148
template <typename... Args> LLVM::InsertValueOp insert_val(Args &&...args) {
141-
return builder.create<LLVM::InsertValueOp>(loc,
142-
std::forward<Args>(args)...);
149+
return builder->create<LLVM::InsertValueOp>(loc,
150+
std::forward<Args>(args)...);
143151
}
144152
template <typename... Args> LLVM::ExtractValueOp extract_val(Args &&...args) {
145-
return builder.create<LLVM::ExtractValueOp>(loc,
146-
std::forward<Args>(args)...);
153+
return builder->create<LLVM::ExtractValueOp>(loc,
154+
std::forward<Args>(args)...);
147155
}
148156
template <typename... Args>
149157
LLVM::InsertElementOp insert_element(Args &&...args) {
150-
return builder.create<LLVM::InsertElementOp>(loc,
151-
std::forward<Args>(args)...);
158+
return builder->create<LLVM::InsertElementOp>(loc,
159+
std::forward<Args>(args)...);
152160
}
153161
template <typename... Args>
154162
LLVM::ExtractElementOp extract_element(Args &&...args) {
155-
return builder.create<LLVM::ExtractElementOp>(loc,
156-
std::forward<Args>(args)...);
163+
return builder->create<LLVM::ExtractElementOp>(loc,
164+
std::forward<Args>(args)...);
157165
}
158166
template <typename... Args> LLVM::LoadOp load(Args &&...args) {
159-
return builder.create<LLVM::LoadOp>(loc, std::forward<Args>(args)...);
167+
return builder->create<LLVM::LoadOp>(loc, std::forward<Args>(args)...);
160168
}
161169
template <typename... Args> LLVM::StoreOp store(Args &&...args) {
162-
return builder.create<LLVM::StoreOp>(loc, std::forward<Args>(args)...);
170+
return builder->create<LLVM::StoreOp>(loc, std::forward<Args>(args)...);
163171
}
164-
template <typename... Args> LLVM::FCmpOp fcmp_ogt(Value lhs, Value rhs) {
165-
return builder.create<LLVM::FCmpOp>(loc, builder.getI1Type(),
166-
LLVM::FCmpPredicate::ogt, lhs, rhs);
172+
LLVM::FCmpOp fcmp_ogt(Value lhs, Value rhs) {
173+
return builder->create<LLVM::FCmpOp>(loc, builder->getI1Type(),
174+
LLVM::FCmpPredicate::ogt, lhs, rhs);
167175
}
168-
template <typename... Args> LLVM::FCmpOp fcmp_olt(Value lhs, Value rhs) {
169-
return builder.create<LLVM::FCmpOp>(loc, builder.getI1Type(),
170-
LLVM::FCmpPredicate::olt, lhs, rhs);
176+
LLVM::FCmpOp fcmp_olt(Value lhs, Value rhs) {
177+
return builder->create<LLVM::FCmpOp>(loc, builder->getI1Type(),
178+
LLVM::FCmpPredicate::olt, lhs, rhs);
171179
}
172-
template <typename... Args> LLVM::FCmpOp fcmp_eq(Value lhs, Value rhs) {
173-
return builder.create<LLVM::FCmpOp>(loc, builder.getI1Type(),
174-
LLVM::FCmpPredicate::oeq, lhs, rhs);
180+
LLVM::FCmpOp fcmp_eq(Value lhs, Value rhs) {
181+
return builder->create<LLVM::FCmpOp>(loc, builder->getI1Type(),
182+
LLVM::FCmpPredicate::oeq, lhs, rhs);
175183
}
176184
template <typename... Args> LLVM::ICmpOp icmp_eq(Args &&...args) {
177-
return builder.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq,
178-
std::forward<Args>(args)...);
185+
return builder->create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq,
186+
std::forward<Args>(args)...);
179187
}
180188
template <typename... Args> LLVM::ICmpOp icmp_ne(Args &&...args) {
181-
return builder.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne,
182-
std::forward<Args>(args)...);
189+
return builder->create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne,
190+
std::forward<Args>(args)...);
183191
}
184192
template <typename... Args> LLVM::ICmpOp icmp_slt(Args &&...args) {
185-
return builder.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt,
186-
std::forward<Args>(args)...);
193+
return builder->create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt,
194+
std::forward<Args>(args)...);
187195
}
188196
template <typename... Args> LLVM::ICmpOp icmp_sle(Args &&...args) {
189-
return builder.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sle,
190-
std::forward<Args>(args)...);
197+
return builder->create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sle,
198+
std::forward<Args>(args)...);
191199
}
192200
template <typename... Args> LLVM::ICmpOp icmp_sgt(Args &&...args) {
193-
return builder.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt,
194-
std::forward<Args>(args)...);
201+
return builder->create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt,
202+
std::forward<Args>(args)...);
195203
}
196204
template <typename... Args> LLVM::ICmpOp icmp_sge(Args &&...args) {
197-
return builder.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sge,
198-
std::forward<Args>(args)...);
205+
return builder->create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sge,
206+
std::forward<Args>(args)...);
199207
}
200208
template <typename... Args> LLVM::ICmpOp icmp_ult(Args &&...args) {
201-
return builder.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ult,
202-
std::forward<Args>(args)...);
209+
return builder->create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ult,
210+
std::forward<Args>(args)...);
203211
}
204212
template <typename... Args> LLVM::ICmpOp icmp_ule(Args &&...args) {
205-
return builder.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ule,
206-
std::forward<Args>(args)...);
213+
return builder->create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ule,
214+
std::forward<Args>(args)...);
207215
}
208216
template <typename... Args> LLVM::ICmpOp icmp_ugt(Args &&...args) {
209-
return builder.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ugt,
210-
std::forward<Args>(args)...);
217+
return builder->create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ugt,
218+
std::forward<Args>(args)...);
211219
}
212220
template <typename... Args> LLVM::ICmpOp icmp_uge(Args &&...args) {
213-
return builder.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::uge,
214-
std::forward<Args>(args)...);
221+
return builder->create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::uge,
222+
std::forward<Args>(args)...);
215223
}
216224
template <typename... Args> LLVM::SelectOp select(Args &&...args) {
217-
return builder.create<LLVM::SelectOp>(loc, std::forward<Args>(args)...);
225+
return builder->create<LLVM::SelectOp>(loc, std::forward<Args>(args)...);
218226
}
219227
template <typename... Args> LLVM::AddressOfOp address_of(Args &&...args) {
220-
return builder.create<LLVM::AddressOfOp>(loc, std::forward<Args>(args)...);
228+
return builder->create<LLVM::AddressOfOp>(loc, std::forward<Args>(args)...);
221229
}
222230
mlir::gpu::BarrierOp barrier() {
223-
return builder.create<mlir::gpu::BarrierOp>(loc);
231+
return builder->create<mlir::gpu::BarrierOp>(loc);
224232
}
225233
template <typename... Args> LLVM::UndefOp undef(Args &&...args) {
226-
return builder.create<LLVM::UndefOp>(loc, std::forward<Args>(args)...);
234+
return builder->create<LLVM::UndefOp>(loc, std::forward<Args>(args)...);
227235
}
228236
template <typename... Args> LLVM::ZeroOp null(Args &&...args) {
229-
return builder.create<LLVM::ZeroOp>(loc, std::forward<Args>(args)...);
237+
return builder->create<LLVM::ZeroOp>(loc, std::forward<Args>(args)...);
230238
}
231239
template <typename... Args> LLVM::CallOp call(Args &&...args) {
232-
return builder.create<LLVM::CallOp>(loc, std::forward<Args>(args)...);
240+
return builder->create<LLVM::CallOp>(loc, std::forward<Args>(args)...);
233241
}
234242
// Constants
235243
Value int_val(short bitwidth, int64_t val) {
236-
Type ty = builder.getIntegerType(bitwidth);
237-
return builder.create<LLVM::ConstantOp>(loc, ty,
238-
builder.getIntegerAttr(ty, val));
244+
Type ty = builder->getIntegerType(bitwidth);
245+
return builder->create<LLVM::ConstantOp>(loc, ty,
246+
builder->getIntegerAttr(ty, val));
239247
}
240248
Value i1_val(int64_t val) { return int_val(1, val); }
241249
Value true_val() { return int_val(1, true); }
242250
Value false_val() { return int_val(1, false); }
243251
Value f16_val(float v) {
244-
auto type = type::f16Ty(builder.getContext());
245-
return builder.create<LLVM::ConstantOp>(loc, type,
246-
builder.getF16FloatAttr(v));
252+
auto type = type::f16Ty(builder->getContext());
253+
return builder->create<LLVM::ConstantOp>(loc, type,
254+
builder->getF16FloatAttr(v));
247255
}
248256
Value f32_val(float v) {
249-
auto type = type::f32Ty(builder.getContext());
250-
return builder.create<LLVM::ConstantOp>(loc, type,
251-
builder.getF32FloatAttr(v));
257+
auto type = type::f32Ty(builder->getContext());
258+
return builder->create<LLVM::ConstantOp>(loc, type,
259+
builder->getF32FloatAttr(v));
252260
}
253261
Value f64_val(double v) {
254-
auto type = type::f64Ty(builder.getContext());
255-
return builder.create<LLVM::ConstantOp>(loc, type,
256-
builder.getF64FloatAttr(v));
262+
auto type = type::f64Ty(builder->getContext());
263+
return builder->create<LLVM::ConstantOp>(loc, type,
264+
builder->getF64FloatAttr(v));
257265
}
258266
Value i8_val(int64_t val) { return int_val(8, val); }
259267
Value i16_val(int64_t val) { return int_val(16, val); }
260268
Value i32_val(int64_t val) { return int_val(32, val); }
261269
Value i64_val(int64_t val) { return int_val(64, val); }
262-
Value tid_val() {
263-
Value tid =
264-
builder.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x);
265-
Type i32_ty = builder.getIntegerType(32);
266-
return builder.create<arith::IndexCastOp>(loc, i32_ty, tid);
267-
}
270+
Value tid_val() { return getThreadId(*this->builder, loc); }
268271

269272
Location loc;
270-
RewriterBase &builder;
273+
OpBuilder *builder;
271274
};
272275
} // namespace mlir::triton
273276

@@ -657,14 +660,6 @@ Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale,
657660

658661
} // namespace LLVM
659662

660-
/* ------------------------------------ */
661-
// Returns CTA level thread idx
662-
inline Value getThreadId(RewriterBase &rewriter, Location loc) {
663-
Value tid =
664-
rewriter.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x);
665-
return rewriter.create<arith::IndexCastOp>(loc, i32_ty, tid);
666-
}
667-
668663
// -----------------------------------------------------------------------
669664
// Shared memory utilities
670665
// -----------------------------------------------------------------------

0 commit comments

Comments
 (0)