Skip to content

Commit 69cd4af

Browse files
committed
comments
1 parent 98efa53 commit 69cd4af

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+327
-332
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 120 additions & 125 deletions
Large diffs are not rendered by default.

lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
1818
matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor,
1919
ConversionPatternRewriter &rewriter) const override {
2020
auto loc = op.getLoc();
21-
auto b = TritonLLVMOpBuilder(loc, rewriter);
21+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
2222
auto ctx = rewriter.getContext();
2323
auto typeConverter = getTypeConverter();
2424
auto elems = unpackLLElements(loc, adaptor.getCondition(), rewriter);

lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
1414
ConversionPatternRewriter &rewriter) const override {
1515
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
1616
auto loc = op.getLoc();
17-
auto b = TritonLLVMOpBuilder(loc, rewriter);
17+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
1818
if (funcOp->hasAttr("nvvm.kernel")) {
1919
// A GPU kernel
2020
if (op.getNumOperands() > 0) {
@@ -79,7 +79,7 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
7979
// Get the last argument of the caller, which is the current stack pointer
8080
// of shared memory and append it to the operands of the callOp.
8181
auto loc = callOp.getLoc();
82-
auto b = TritonLLVMOpBuilder(loc, rewriter);
82+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
8383
auto caller = callOp->getParentOfType<FunctionOpInterface>();
8484
auto promotedOperands = this->getTypeConverter()->promoteOperands(
8585
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ struct ConvertLayoutOpConversion
6262
ArrayRef<unsigned> origRepShape,
6363
ArrayRef<unsigned> outOrd, SmallVector<Value> &vals,
6464
Value smemBase) const {
65-
auto b = TritonLLVMOpBuilder(loc, rewriter);
65+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
6666
auto accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
6767
auto layout = type.getEncoding();
6868
auto rank = type.getRank();
@@ -147,7 +147,7 @@ struct ConvertLayoutOpConversion
147147
ConversionPatternRewriter &rewriter,
148148
const TargetInfoBase &targetInfo) const {
149149
auto loc = op.getLoc();
150-
auto b = TritonLLVMOpBuilder(loc, rewriter);
150+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
151151
auto typeConverter = getTypeConverter();
152152
RankedTensorType srcTy = op.getSrc().getType();
153153
RankedTensorType dstTy = op.getType();
@@ -357,7 +357,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
357357
ConversionPatternRewriter &rewriter) const {
358358
MLIRContext *ctx = op.getContext();
359359
auto loc = op.getLoc();
360-
auto b = TritonLLVMOpBuilder(loc, rewriter);
360+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
361361
auto srcTy = op.getSrc().getType();
362362
auto dstTy = op.getType();
363363

@@ -446,7 +446,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
446446
ConversionPatternRewriter &rewriter) const {
447447
MLIRContext *ctx = op.getContext();
448448
auto loc = op.getLoc();
449-
auto b = TritonLLVMOpBuilder(loc, rewriter);
449+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
450450

451451
StringAttr kRegister = str_attr("register");
452452
StringAttr kLane = str_attr("lane");
@@ -651,7 +651,7 @@ void ConvertLayoutOpUsingLinearLayoutsConversion::transferWithinWarp(
651651
ConversionPatternRewriter &rewriter) const {
652652
MLIRContext *ctx = op.getContext();
653653
Location loc = op.getLoc();
654-
auto b = TritonLLVMOpBuilder(loc, rewriter);
654+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
655655
StringAttr kRegister = str_attr("register");
656656
StringAttr kLane = str_attr("lane");
657657
assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ bool isSwizzled(SharedEncodingAttr layout) { return layout.getMaxPhase() != 1; }
3737
SmallVector<Value> swizzleIndices(ConversionPatternRewriter &rewriter,
3838
Location loc, SmallVector<Value> rawIndices,
3939
SharedEncodingAttr layout) {
40-
auto b = TritonLLVMOpBuilder(loc, rewriter);
40+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
4141
const auto &order = layout.getOrder();
4242
auto rank = order.size();
4343

@@ -81,7 +81,7 @@ void storeValuesInLinearVector(PatternRewriter &rewriter, Location loc,
8181
unsigned kIdx, unsigned nonKIdx, unsigned bIdx,
8282
const DimIdx &dim, int vecDim,
8383
ArrayRef<unsigned> opOrder) {
84-
auto b = TritonLLVMOpBuilder(loc, rewriter);
84+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
8585
auto vecTy = cast<VectorType>(vec.getType());
8686
auto vectorSize = vecTy.getNumElements();
8787
auto elemTy = vecTy.getElementType();
@@ -118,7 +118,7 @@ Value getUnswizzledFirstElemOffset(ConversionPatternRewriter &rewriter,
118118
Location loc, unsigned B, unsigned NonK,
119119
Value bTileOffset, Value nonKTileOffset,
120120
Value bStride, Value nonKStride) {
121-
auto b = TritonLLVMOpBuilder(loc, rewriter);
121+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
122122
auto bOffset = b.mul(b.urem(bTileOffset, b.i32_val(B)), bStride);
123123
auto nonKOffset = b.mul(b.urem(nonKTileOffset, b.i32_val(NonK)), nonKStride);
124124
Value threadIdDependantOffset = b.add(bOffset, nonKOffset);
@@ -157,7 +157,7 @@ Value computeSwizzledOffset(ConversionPatternRewriter &rewriter, Location loc,
157157
SharedEncodingAttr sharedLayout,
158158
ArrayRef<int64_t> opTensorShape,
159159
ArrayRef<Value> strides) {
160-
auto b = TritonLLVMOpBuilder(loc, rewriter);
160+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
161161
Value offset = b.i32_val(0);
162162
// Compute unswizzled multi dim coordinates in shared memory object
163163
SmallVector<Value> elemMultiDimIndices(3);
@@ -190,7 +190,7 @@ Value computeNonSwizzledOffset(ConversionPatternRewriter &rewriter,
190190
unsigned shapePerCTABTile,
191191
unsigned shapePerCTANonKTile,
192192
ArrayRef<Value> strides) {
193-
auto b = TritonLLVMOpBuilder(loc, rewriter);
193+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
194194
SmallVector<Value> offsetIndices(3);
195195
offsetIndices[dim.batch] =
196196
b.i32_val((i.bTile * shapePerCTABTile + i.b) % tensorShape[dim.batch]);
@@ -219,7 +219,7 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout,
219219
Value thread, Location loc,
220220
const LLVMTypeConverter *typeConverter,
221221
ConversionPatternRewriter &rewriter, const int dotOpNo) {
222-
auto tb = TritonLLVMOpBuilder(loc, rewriter);
222+
auto tb = TritonLLVMOpBuilder(loc, &rewriter);
223223
if (!verifyCTALayout(dLayout.getCTALayout()))
224224
return Value();
225225

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ struct AddPtrOpConversion : public ConvertOpToLLVMPattern<AddPtrOp> {
4040
matchAndRewrite(AddPtrOp op, OpAdaptor adaptor,
4141
ConversionPatternRewriter &rewriter) const override {
4242
Location loc = op->getLoc();
43-
auto b = TritonLLVMOpBuilder(loc, rewriter);
43+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
4444
auto resultTy = op.getType();
4545
auto typeConverter = getTypeConverter();
4646
auto resultTensorTy = dyn_cast<RankedTensorType>(resultTy);
@@ -248,7 +248,7 @@ struct ElementwiseInlineAsmOpConversion
248248
MultipleOperandsRange operands,
249249
ConversionPatternRewriter &rewriter,
250250
Location loc) const {
251-
auto b = TritonLLVMOpBuilder(loc, rewriter);
251+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
252252
SmallVector<Value> packedOperands;
253253
unsigned numPackedElements = op.getPackedElement();
254254
for (int i = 0, e = op.getNumOperands(); i < e; i++) {
@@ -279,7 +279,7 @@ struct ElementwiseInlineAsmOpConversion
279279
ConversionPatternRewriter &rewriter,
280280
MultipleOperandsRange operands, Location loc) const {
281281
auto ctx = op->getContext();
282-
auto b = TritonLLVMOpBuilder(loc, rewriter);
282+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
283283

284284
if (operands.size() % op.getPackedElement() != 0)
285285
llvm::report_fatal_error("Inline asm op has more packed elements than "
@@ -354,7 +354,7 @@ struct ElementwiseInlineAsmOpConversion
354354
matchAndRewrite(ElementwiseInlineAsmOp op, OpAdaptor adaptor,
355355
ConversionPatternRewriter &rewriter) const override {
356356
Location loc = op->getLoc();
357-
auto b = TritonLLVMOpBuilder(loc, rewriter);
357+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
358358

359359
// Layout is unpackedOperands[operand][elem].
360360
SmallVector<SmallVector<Value>> unpackedOperands;
@@ -448,7 +448,7 @@ struct AbsFOpConversion
448448
ConversionPatternRewriter &rewriter,
449449
Type elemTy, MultipleOperandsRange operands,
450450
Location loc) const {
451-
auto b = TritonLLVMOpBuilder(loc, rewriter);
451+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
452452
if (llvm::isa<IntegerType>(elemTy)) {
453453
// Mask out the sign bit
454454
auto num_bits =

lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor,
5050

5151
static Value convertIndexToI32(Location loc, Value index,
5252
ConversionPatternRewriter &rewriter) {
53-
auto b = TritonLLVMOpBuilder(loc, rewriter);
53+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
5454
unsigned idxWidth = index.getType().getIntOrFloatBitWidth();
5555
// The LL index computations are performed with 32 bit integers. If the
5656
// indices are something else, cast them to i32.
@@ -66,7 +66,7 @@ static Value convertIndexToI32(Location loc, Value index,
6666
void GatherOpConversion::emitGatherInShared(
6767
GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const {
6868
Location loc = op.getLoc();
69-
auto b = TritonLLVMOpBuilder(loc, rewriter);
69+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
7070
RankedTensorType srcType = op.getSrc().getType();
7171

7272
// Compute the src subtensor shape owned by this CTA.
@@ -190,7 +190,7 @@ void GatherOpConversion::emitWarpLocalGather(
190190
GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const {
191191
MLIRContext *ctx = op.getContext();
192192
Location loc = op.getLoc();
193-
auto b = TritonLLVMOpBuilder(loc, rewriter);
193+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
194194
RankedTensorType srcType = op.getSrc().getType();
195195
RankedTensorType idxType = op.getIndices().getType();
196196

lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ static SmallVector<Value> computeWarpLevelHistogram(
1818
Location loc, RankedTensorType srcType, SmallVector<Value> &srcValues,
1919
int numBins, int numThreadPerWarp, Value threadId,
2020
ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo) {
21-
auto b = TritonLLVMOpBuilder(loc, rewriter);
21+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
2222
assert(numBins % numThreadPerWarp == 0 &&
2323
"numBins must be divisible by numThreadPerWarp");
2424
Value zero = b.i32_val(0);
@@ -88,7 +88,7 @@ static SmallVector<Value> computeCrossWarpHistogram(
8888
Value baseSharedMemPtr, const SmallVector<Value> &warpLevelHistogram,
8989
int numBins, int numThreadPerWarp, const SmallVector<Value> &indices,
9090
Value threadId, int numWarps) {
91-
auto b = TritonLLVMOpBuilder(loc, rewriter);
91+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
9292
SmallVector<Value> histogramValues;
9393
unsigned numWarpsWithUniqueData =
9494
mlir::triton::gpu::getWarpsPerCTAWithUniqueData(srcType.getEncoding(),

lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ struct MakeRangeOpConversion
1818
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
1919
ConversionPatternRewriter &rewriter) const override {
2020
Location loc = op->getLoc();
21-
auto b = TritonLLVMOpBuilder(loc, rewriter);
21+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
2222
RankedTensorType ty = op.getType();
2323
auto shape = ty.getShape();
2424
auto layout = ty.getEncoding();

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ struct GlobalScratchAllocOpConversion
4040
matchAndRewrite(triton::gpu::GlobalScratchAllocOp op, OpAdaptor adaptor,
4141
ConversionPatternRewriter &rewriter) const override {
4242
Location loc = op.getLoc();
43-
auto b = TritonLLVMOpBuilder(loc, rewriter);
43+
auto b = TritonLLVMOpBuilder(loc, &rewriter);
4444

4545
auto opOffsetAttr = op->getAttrOfType<mlir::IntegerAttr>(
4646
"ttg.global_scratch_memory_offset");

0 commit comments

Comments
 (0)