Skip to content

Commit 715f6b1

Browse files
authored
[NFC] std::move bases into LLs and LLs into LinearEncoding (#8921)
1 parent 14373ae commit 715f6b1

File tree

22 files changed

+84
-76
lines changed

22 files changed

+84
-76
lines changed

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
9494
bases_inv[d][i] = {0};
9595
}
9696
}
97-
auto invBroadcast =
98-
LinearLayout(bases_inv, invReg.getOutDims(), /*isSurjective=*/false);
97+
auto invBroadcast = LinearLayout(std::move(bases_inv), invReg.getOutDims(),
98+
/*isSurjective=*/false);
9999
auto cvt = llReg.compose(invBroadcast);
100100

101101
// Deduplicate the result values

include/triton/Tools/LinearLayout.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ class LinearLayout {
582582
auto value = std::move(it->second);
583583
bases.erase(it);
584584
bases.insert({newDim, std::move(value)});
585-
return LinearLayout(bases, getOutDims(),
585+
return LinearLayout(std::move(bases), getOutDims(),
586586
/*requireSurjective=*/isSurjective());
587587
}
588588

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,8 +1224,6 @@ delinearize(RewriterBase &rewriter, Location loc,
12241224
ArrayRef<int64_t> shape, StringAttr dimName, Value linear) {
12251225
auto b = TritonLLVMOpBuilder(loc, rewriter);
12261226
auto ll = triton::gpu::toLinearLayout(shape, layout);
1227-
auto linearLayout =
1228-
triton::gpu::LinearEncodingAttr::get(rewriter.getContext(), ll);
12291227
assert(ll.hasInDim(dimName));
12301228
int32_t freeVarMask = ll.getFreeVariableMasks()[dimName];
12311229
auto isRepresentative = b.true_val();
@@ -1237,6 +1235,8 @@ delinearize(RewriterBase &rewriter, Location loc,
12371235
linear = pext_i32(rewriter, loc, linear, nonFreeVarMask);
12381236
}
12391237

1238+
auto linearLayout = triton::gpu::LinearEncodingAttr::get(
1239+
rewriter.getContext(), std::move(ll));
12401240
auto orderDim = linearLayout.orderPerDim(dimName, linearLayout.getOrder());
12411241
auto shapeDim = linearLayout.basesPerDim(dimName);
12421242
auto multiDim = delinearize(rewriter, loc, linear, shapeDim, orderDim);

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ struct TritonSplitOpPattern : public OpConversionPattern<triton::SplitOp> {
389389
append(defaultEnc.getThreadsPerWarp(), 1),
390390
append(defaultEnc.getWarpsPerCTA(), 1),
391391
prepend(defaultEnc.getOrder(), rank - 1),
392-
CGAEncodingAttr::get(getContext(), layout));
392+
CGAEncodingAttr::get(getContext(), std::move(layout)));
393393
srcTy = srcTy.cloneWithEncoding(srcEnc);
394394
src = ConvertLayoutOp::create(rewriter, op.getLoc(), srcTy, src);
395395
}

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ CGAEncodingAttr CGAEncodingAttr::getDefault(MLIRContext *ctx, int rank) {
413413
LinearLayout::BasesT bases;
414414
bases[kBlock] = {};
415415
auto dims = standardOutDimNames(ctx, rank);
416-
return get(ctx, LinearLayout(bases, dims));
416+
return get(ctx, LinearLayout(std::move(bases), dims));
417417
}
418418

419419
CGAEncodingAttr CGAEncodingAttr::fromSplitParams(MLIRContext *ctx,
@@ -438,18 +438,18 @@ CGAEncodingAttr CGAEncodingAttr::fromSplitParams(MLIRContext *ctx,
438438
}
439439

440440
layout = layout.transposeOuts(outDimNames);
441-
return CGAEncodingAttr::get(ctx, layout);
441+
return CGAEncodingAttr::get(ctx, std::move(layout));
442442
}
443443

444444
SmallVector<unsigned> CGAEncodingAttr::getCTAsPerCGA() const {
445-
auto ll = getLinearLayout();
445+
const auto &ll = getLinearLayout();
446446
auto rank = ll.getNumOutDims();
447447
return basesPerDimImpl(ll.getBases(), StringAttr::get(getContext(), "block"),
448448
rank, /*skipBroadcast=*/false);
449449
}
450450

451451
SmallVector<unsigned> CGAEncodingAttr::getCTASplitNum() const {
452-
auto ll = getLinearLayout();
452+
const auto &ll = getLinearLayout();
453453
auto rank = ll.getNumOutDims();
454454
return basesPerDimImpl(ll.getBases(), StringAttr::get(getContext(), "block"),
455455
rank);
@@ -996,7 +996,7 @@ basesPerDimImpl(const LinearLayout::BasesT &namedBases, StringAttr dimName,
996996

997997
SmallVector<unsigned>
998998
LinearEncodingAttr::basesPerDim(StringAttr dimName, bool skipBroadcast) const {
999-
auto ll = getLinearLayout();
999+
const auto &ll = getLinearLayout();
10001000
auto rank = ll.getNumOutDims();
10011001
return basesPerDimImpl(ll.getBases(), dimName, rank, skipBroadcast);
10021002
}
@@ -1066,7 +1066,7 @@ SmallVector<unsigned> LinearEncodingAttr::getThreadOrder() const {
10661066

10671067
SmallVector<unsigned> LinearEncodingAttr::getSizePerThread() const {
10681068
auto rank = getOrder().size();
1069-
auto ll = getLinearLayout();
1069+
const auto &ll = getLinearLayout();
10701070
auto ctx = getContext();
10711071
auto kRegister = StringAttr::get(ctx, "register");
10721072
auto splitNum = getCGALayout().getCTASplitNum();
@@ -1144,7 +1144,7 @@ LinearEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
11441144
SmallVector<unsigned>
11451145
LinearEncodingAttr::getContig(const char *inDim,
11461146
SmallVector<unsigned int> lowerContig) const {
1147-
auto ll = getLinearLayout();
1147+
const auto &ll = getLinearLayout();
11481148
const auto &bases =
11491149
ll.getBases().find(StringAttr::get(getContext(), inDim))->second;
11501150
auto order = getOrder();
@@ -1517,7 +1517,7 @@ SmallVector<unsigned> SliceEncodingAttr::getRepOrder() const {
15171517
CGAEncodingAttr SliceEncodingAttr::getCGALayout() const {
15181518
auto layout = ::getCGALayout(getParent()).getLinearLayout();
15191519
layout = removeStandardDim(layout, getDim());
1520-
return CGAEncodingAttr::get(getContext(), layout);
1520+
return CGAEncodingAttr::get(getContext(), std::move(layout));
15211521
}
15221522

15231523
template <class T>
@@ -1749,7 +1749,7 @@ Attribute SharedLinearEncodingAttr::parse(AsmParser &parser, Type type) {
17491749
SmallVector<unsigned>
17501750
SharedLinearEncodingAttr::basesPerDim(StringAttr dimName,
17511751
bool skipBroadcast) const {
1752-
auto ll = getLinearLayout();
1752+
const auto &ll = getLinearLayout();
17531753
auto rank = ll.getNumOutDims();
17541754
return basesPerDimImpl(ll.getBases(), dimName, rank, skipBroadcast);
17551755
}
@@ -1761,7 +1761,7 @@ SharedLinearEncodingAttr::orderPerDim(StringAttr dimName,
17611761
}
17621762

17631763
SmallVector<unsigned> SharedLinearEncodingAttr::getOrder() const {
1764-
auto ll = getLinearLayout();
1764+
const auto &ll = getLinearLayout();
17651765
auto rank = ll.getNumOutDims();
17661766
SmallVector<unsigned> defaultOrder(rank);
17671767
std::iota(defaultOrder.rbegin(), defaultOrder.rend(), 0);
@@ -1774,7 +1774,7 @@ CGAEncodingAttr SharedLinearEncodingAttr::getCGALayout() const {
17741774
}
17751775
LinearLayout
17761776
SharedLinearEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
1777-
auto ll = getLinearLayout();
1777+
const auto &ll = getLinearLayout();
17781778
auto outDimNames = llvm::to_vector(ll.getOutDimNames());
17791779
assert(shape.size() == outDimNames.size());
17801780
// We don't support automatic broadcasting for shared linear layouts
@@ -1997,7 +1997,7 @@ PaddedSharedEncodingAttr PaddedSharedEncodingAttr::get(
19971997
identityStandardND(kOffset, SmallVector<unsigned>(shape), order);
19981998
linearComponent = combineCtaCgaWithShape(linearComponent, cgaLayout, shape);
19991999

2000-
return get(context, intervalPads, linearComponent);
2000+
return get(context, intervalPads, std::move(linearComponent));
20012001
}
20022002

20032003
PaddedSharedEncodingAttr PaddedSharedEncodingAttr::get(
@@ -2010,7 +2010,7 @@ PaddedSharedEncodingAttr PaddedSharedEncodingAttr::get(
20102010
intervals.push_back(interval);
20112011
paddings.push_back(padding);
20122012
}
2013-
return get(context, intervals, paddings, linearComponent);
2013+
return get(context, intervals, paddings, std::move(linearComponent));
20142014
}
20152015

20162016
SmallVector<unsigned>
@@ -2454,7 +2454,7 @@ SmallVector<unsigned> DotOperandEncodingAttr::getRepOrder() const {
24542454
}
24552455

24562456
CGAEncodingAttr DotOperandEncodingAttr::getCGALayout() const {
2457-
auto layout = ::getCGALayout(getParent()).getLinearLayout();
2457+
const auto &layout = ::getCGALayout(getParent()).getLinearLayout();
24582458
auto bases = layout.getBases();
24592459
auto kBlock = StringAttr::get(getContext(), "block");
24602460
auto &blockBases = bases[kBlock];
@@ -2465,7 +2465,8 @@ CGAEncodingAttr DotOperandEncodingAttr::getCGALayout() const {
24652465
}
24662466
auto dims = layout.getOutDims();
24672467
dims[kDim].second = 1;
2468-
return CGAEncodingAttr::get(getContext(), LinearLayout(bases, dims, true));
2468+
return CGAEncodingAttr::get(getContext(),
2469+
LinearLayout(std::move(bases), dims, true));
24692470
}
24702471
LogicalResult DotOperandEncodingAttr::verify(
24712472
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
@@ -3071,7 +3072,7 @@ struct TritonGPUInferLayoutInterface
30713072
LinearLayout ll =
30723073
inferReshapeLinearLayout(cast<TensorOrMemDesc>(srcTy), dstShape);
30733074

3074-
dstEnc = LinearEncodingAttr::get(srcEnc.getContext(), ll);
3075+
dstEnc = LinearEncodingAttr::get(srcEnc.getContext(), std::move(ll));
30753076
return success();
30763077
}
30773078

@@ -3119,7 +3120,7 @@ struct TritonGPUInferLayoutInterface
31193120
enc.getContext(), append(enc.getSizePerThread(), 2),
31203121
append(enc.getThreadsPerWarp(), 1), append(enc.getWarpsPerCTA(), 1),
31213122
appendMajorDim(enc.getOrder()),
3122-
CGAEncodingAttr::get(enc.getContext(), ctall));
3123+
CGAEncodingAttr::get(enc.getContext(), std::move(ctall)));
31233124
return success();
31243125
}
31253126

@@ -3136,7 +3137,7 @@ struct TritonGPUInferLayoutInterface
31363137
tryJoinOnAxis(ctx, ll, newLl, /*fwdInference=*/true, axis, loc);
31373138

31383139
assert(result.succeeded());
3139-
dstEnc = LinearEncodingAttr::get(ctx, newLl);
3140+
dstEnc = LinearEncodingAttr::get(ctx, std::move(newLl));
31403141
return success();
31413142
}
31423143

@@ -3167,7 +3168,7 @@ struct TritonGPUInferLayoutInterface
31673168
ArrayRef(enc.getSizePerThread()).drop_back(1),
31683169
ArrayRef(enc.getThreadsPerWarp()).drop_back(1),
31693170
ArrayRef(enc.getWarpsPerCTA()).drop_back(1), ArrayRef(newOrder),
3170-
CGAEncodingAttr::get(enc.getContext(), ctall));
3171+
CGAEncodingAttr::get(enc.getContext(), std::move(ctall)));
31713172
return success();
31723173
}
31733174

@@ -3191,7 +3192,7 @@ struct TritonGPUInferLayoutInterface
31913192
SmallVector<int64_t> dstShape(shape.begin(), shape.end());
31923193
dstShape.pop_back();
31933194
newLl = newLl.reshapeOuts(standardOutDimPairs(ctx, dstShape));
3194-
dstEnc = LinearEncodingAttr::get(ctx, newLl);
3195+
dstEnc = LinearEncodingAttr::get(ctx, std::move(newLl));
31953196
return success();
31963197
}
31973198

@@ -3254,7 +3255,7 @@ struct TritonGPUInferLayoutInterface
32543255
auto result = tryJoinOnAxis(ctx, ll, newLl, fwdInference, axis, loc);
32553256
if (!result.succeeded())
32563257
return result;
3257-
outEnc = LinearEncodingAttr::get(ctx, newLl);
3258+
outEnc = LinearEncodingAttr::get(ctx, std::move(newLl));
32583259
return success();
32593260
}
32603261
};

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,8 @@ LinearLayout tensorMemoryToLinearLayout(ArrayRef<int64_t> shape,
10771077
if (isM64TwoCTA) {
10781078
auto bases = ret.getBases();
10791079
std::swap(bases[kRow].back(), bases[kCol].back());
1080-
ret = LinearLayout(bases, ret.getOutDims(), ret.isSurjective());
1080+
ret =
1081+
LinearLayout(std::move(bases), ret.getOutDims(), ret.isSurjective());
10811082
}
10821083
auto split = LinearLayout::identity1D(splitM, kCol, dims[0]);
10831084
return ret * split;
@@ -1103,7 +1104,7 @@ LinearLayout tensorMemoryToLinearLayout(ArrayRef<int64_t> shape,
11031104
}
11041105
bases[kRow].push_back({16, 0});
11051106
bases[kRow].push_back({32, 0});
1106-
tile = LinearLayout(bases, dims);
1107+
tile = LinearLayout(std::move(bases), dims);
11071108
} else {
11081109
tile *= LinearLayout::identity1D(blockM, kRow, dims[0]) *
11091110
LinearLayout::identity1D(blockN, kCol, dims[1]);
@@ -1251,7 +1252,8 @@ LinearLayout getLayoutWithinBlock(const LinearLayout &layout) {
12511252
assert(layout.hasInDim(kBlock));
12521253
auto bases = layout.getBases();
12531254
bases[kBlock] = {};
1254-
return LinearLayout(bases, llvm::to_vector<4>(layout.getOutDimNames()));
1255+
return LinearLayout(std::move(bases),
1256+
llvm::to_vector<4>(layout.getOutDimNames()));
12551257
}
12561258

12571259
LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -611,15 +611,16 @@ static LogicalResult inferMemDescReshapeOpEncoding(ArrayRef<int64_t> srcShape,
611611
for (auto [interval, padding] : llvm::zip(intervals, paddings)) {
612612
intervalPads.emplace_back(interval, padding);
613613
}
614-
dstEnc = PaddedSharedEncodingAttr::get(ctx, intervalPads, dst);
614+
dstEnc = PaddedSharedEncodingAttr::get(ctx, intervalPads, std::move(dst));
615615
return success();
616616
}
617617

618618
// Generic LL case
619619
auto sharedEnc = cast<SharedEncodingTrait>(srcEnc);
620620
auto srcLL = toLinearLayout(srcShape, srcEnc);
621621
auto dstLL = reshapeLayout(ctx, srcLL, dstShape);
622-
dstEnc = SharedLinearEncodingAttr::get(ctx, dstLL, sharedEnc.getAlignment());
622+
dstEnc = SharedLinearEncodingAttr::get(ctx, std::move(dstLL),
623+
sharedEnc.getAlignment());
623624
return success();
624625
}
625626

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ class ScaledBlockedToMMA : public mlir::OpRewritePattern<triton::DotScaledOp> {
731731

732732
auto ll = triton::gpu::getSM120DotScaledScaleLayout(
733733
ctx, shape, opIdx, mmaWarps, blocked.getCGALayout());
734-
auto newEnc = triton::gpu::LinearEncodingAttr::get(ctx, ll);
734+
auto newEnc = triton::gpu::LinearEncodingAttr::get(ctx, std::move(ll));
735735
auto newTy = RankedTensorType::get(shape, ty.getElementType(), newEnc);
736736
return ConvertLayoutOp::create(rewriter, scale.getLoc(), newTy, scale);
737737
};

lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ class TritonGPUOptimizeThreadLocalityPass
555555
auto *ctx = kBlocked.getContext();
556556
auto dim = standardOutDimNames(ctx, rank + 1)[rank];
557557
ctaLl *= LinearLayout::identity1D(1, kBlocked, dim);
558-
auto ctaLayout3d = CGAEncodingAttr::get(ctx, ctaLl);
558+
auto ctaLayout3d = CGAEncodingAttr::get(ctx, std::move(ctaLl));
559559
auto blocked3d = triton::gpu::BlockedEncodingAttr::get(
560560
reduce.getContext(), sizePerThread3d, threadsPerWarp3d, warsPerCTA3d,
561561
order3d, ctaLayout3d);

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ Value mlir::triton::createScalarAlloc(ImplicitLocOpBuilder &rewriter, Type type,
449449
std::vector<std::vector<int32_t>>(llvm::Log2_32(numCTAs), {0});
450450
auto dims = standardOutDimNames(ctx, 1);
451451
auto barrierCGALayout =
452-
ttg::CGAEncodingAttr::get(ctx, LinearLayout(bases, dims));
452+
ttg::CGAEncodingAttr::get(ctx, LinearLayout(std::move(bases), dims));
453453
auto barrierEncoding =
454454
ttg::SwizzledSharedEncodingAttr::get(ctx, 1, 1, 1, {0}, barrierCGALayout);
455455
ttg::MemDescType memDescType = ttg::MemDescType::get(

0 commit comments

Comments
 (0)