Skip to content

Commit 15bfd0c

Browse files
authored
[BACKEND] Support of ConvertLayoutOp from blocked to blocked and SliceLayout with blocked parent (#658)
1 parent 13669b4 commit 15bfd0c

File tree

17 files changed

+1025
-191
lines changed

17 files changed

+1025
-191
lines changed

bin/triton-translate.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,6 @@ OwningOpRef<ModuleOp> loadMLIRModule(llvm::StringRef inputFilename,
6464
return nullptr;
6565
}
6666

67-
mlir::PassManager pm(module->getContext());
68-
applyPassManagerCLOptions(pm);
69-
70-
pm.addPass(createConvertTritonGPUToLLVMPass());
71-
72-
if (failed(pm.run(module->getOperation()))) {
73-
llvm::errs() << "Pass execution failed";
74-
return nullptr;
75-
}
76-
7767
return module;
7868
}
7969

include/triton/Analysis/Allocation.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@ namespace mlir {
1414

1515
namespace triton {
1616
class AllocationAnalysis;
17-
}
17+
18+
SmallVector<unsigned>
19+
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
20+
unsigned &outVec);
21+
22+
} // namespace triton
1823

1924
/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h
2025
/// A class that represents an interval, specified using a start and an end

include/triton/Analysis/Utility.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
#define TRITON_ANALYSIS_UTILITY_H
33

44
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
5+
#include <algorithm>
6+
#include <numeric>
57
#include <string>
8+
69
namespace mlir {
710

811
bool isSharedEncoding(Value value);
@@ -11,6 +14,12 @@ bool maybeSharedAllocationOp(Operation *op);
1114

1215
std::string getValueOperandName(Value value, AsmState &state);
1316

17+
template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
18+
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
19+
}
20+
21+
template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
22+
1423
} // namespace mlir
1524

1625
#endif // TRITON_ANALYSIS_UTILITY_H

include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ class TritonLLVMConversionTarget : public ConversionTarget {
1818
mlir::LLVMTypeConverter &typeConverter);
1919
};
2020

21+
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
22+
mlir::LLVMTypeConverter &typeConverter;
23+
24+
public:
25+
explicit TritonLLVMFunctionConversionTarget(
26+
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter);
27+
};
28+
2129
namespace triton {
2230

2331
// Names for identifying different NVVM annotations. It is used as attribute

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,16 @@
1616
#define GET_OP_CLASSES
1717
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"
1818

19+
namespace mlir {
20+
namespace triton {
21+
namespace gpu {
22+
23+
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape);
24+
25+
unsigned getShapePerCTA(const Attribute &layout, unsigned d);
26+
27+
} // namespace gpu
28+
} // namespace triton
29+
} // namespace mlir
30+
1931
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ Then, attaching $\mathcal{L} to a tensor $T$ would mean that:
3131

3232
Right now, Triton implements two classes of layouts: shared, and distributed.
3333
}];
34+
35+
code extraBaseClassDeclaration = [{
36+
unsigned getElemsPerThread(ArrayRef<int64_t> shape) const;
37+
}];
3438
}
3539

3640
//===----------------------------------------------------------------------===//
@@ -64,6 +68,8 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
6468
"unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase,
6569
ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order
6670
);
71+
72+
let extraClassDeclaration = extraBaseClassDeclaration;
6773
}
6874

6975
//===----------------------------------------------------------------------===//
@@ -93,6 +99,8 @@ Then the data of A would be distributed as follow between the 16 CUDA threads:
9399
L(A) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
94100
{4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ]
95101
}];
102+
103+
let extraClassDeclaration = extraBaseClassDeclaration;
96104
}
97105

98106
//===----------------------------------------------------------------------===//
@@ -171,11 +179,10 @@ for
171179
}]>
172180
];
173181

174-
let extraClassDeclaration = [{
182+
let extraClassDeclaration = extraBaseClassDeclaration # [{
175183
SliceEncodingAttr squeeze(int axis);
176184
}];
177185

178-
179186
let parameters = (
180187
ins
181188
ArrayRefParameter<"unsigned">:$sizePerThread,
@@ -282,6 +289,8 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
282289
"unsigned":$version,
283290
ArrayRefParameter<"unsigned">:$warpsPerCTA
284291
);
292+
293+
let extraClassDeclaration = extraBaseClassDeclaration;
285294
}
286295

287296
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
@@ -311,6 +320,8 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
311320
// TODO: constraint here to only take distributed encodings
312321
"Attribute":$parent
313322
);
323+
324+
let extraClassDeclaration = extraBaseClassDeclaration;
314325
}
315326

316327

include/triton/tools/sys/getenv.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#ifndef TDL_TOOLS_SYS_GETENV_HPP
2323
#define TDL_TOOLS_SYS_GETENV_HPP
2424

25+
#include <algorithm>
2526
#include <cstdlib>
2627
#include <string>
2728

@@ -37,6 +38,14 @@ inline std::string getenv(const char *name) {
3738
return result;
3839
}
3940

41+
inline bool getBoolEnv(const std::string &env) {
42+
const char *s = std::getenv(env.c_str());
43+
std::string str(s ? s : "");
44+
std::transform(str.begin(), str.end(), str.begin(),
45+
[](unsigned char c) { return std::tolower(c); });
46+
return (str == "on" || str == "true" || str == "1");
47+
}
48+
4049
} // namespace tools
4150

4251
} // namespace triton

lib/Analysis/Allocation.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,66 @@
88

99
#include <algorithm>
1010
#include <limits>
11+
#include <numeric>
12+
13+
using ::mlir::triton::gpu::BlockedEncodingAttr;
14+
using ::mlir::triton::gpu::MmaEncodingAttr;
15+
using ::mlir::triton::gpu::SharedEncodingAttr;
1116

1217
namespace mlir {
1318

1419
//===----------------------------------------------------------------------===//
1520
// Shared Memory Allocation Analysis
1621
//===----------------------------------------------------------------------===//
1722
namespace triton {
23+
24+
SmallVector<unsigned>
25+
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
26+
unsigned &outVec) {
27+
auto srcTy = op.src().getType().cast<RankedTensorType>();
28+
auto dstTy = op.result().getType().cast<RankedTensorType>();
29+
Attribute srcLayout = srcTy.getEncoding();
30+
Attribute dstLayout = dstTy.getEncoding();
31+
assert(srcLayout && dstLayout &&
32+
"Unexpect layout in getScratchConfigForCvtLayout()");
33+
unsigned rank = dstTy.getRank();
34+
SmallVector<unsigned> paddedRepShape(rank);
35+
// TODO: move to TritonGPUAttrDefs.h.inc
36+
auto getShapePerCTA = [&](const Attribute &layout, unsigned d) -> unsigned {
37+
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
38+
return blockedLayout.getSizePerThread()[d] *
39+
blockedLayout.getThreadsPerWarp()[d] *
40+
blockedLayout.getWarpsPerCTA()[d];
41+
} else {
42+
assert(0 && "Unimplemented usage of getShapePerCTA");
43+
return 0;
44+
}
45+
};
46+
if (srcLayout.isa<BlockedEncodingAttr>() &&
47+
dstLayout.isa<BlockedEncodingAttr>()) {
48+
auto srcBlockedLayout = srcLayout.cast<BlockedEncodingAttr>();
49+
auto dstBlockedLayout = dstLayout.cast<BlockedEncodingAttr>();
50+
auto inOrd = srcBlockedLayout.getOrder();
51+
auto outOrd = dstBlockedLayout.getOrder();
52+
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
53+
// that we cannot do vectorization.
54+
inVec = outOrd[0] == 0 ? 1
55+
: inOrd[0] == 0 ? 1
56+
: srcBlockedLayout.getSizePerThread()[inOrd[0]];
57+
outVec =
58+
outOrd[0] == 0 ? 1 : dstBlockedLayout.getSizePerThread()[outOrd[0]];
59+
unsigned pad = std::max(inVec, outVec);
60+
for (unsigned d = 0; d < rank; ++d) {
61+
paddedRepShape[d] = std::max(
62+
std::min<unsigned>(srcTy.getShape()[d], getShapePerCTA(srcLayout, d)),
63+
std::min<unsigned>(dstTy.getShape()[d],
64+
getShapePerCTA(dstLayout, d)));
65+
}
66+
paddedRepShape[outOrd[0]] += pad;
67+
}
68+
return paddedRepShape;
69+
}
70+
1871
class AllocationAnalysis {
1972
public:
2073
AllocationAnalysis(Operation *operation, Allocation *allocation)
@@ -73,6 +126,27 @@ class AllocationAnalysis {
73126
tensorType.getElementTypeBitWidth() / 8;
74127
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
75128
}
129+
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
130+
auto srcTy = cvtLayout.src().getType().cast<RankedTensorType>();
131+
auto dstTy = cvtLayout.result().getType().cast<RankedTensorType>();
132+
auto srcEncoding = srcTy.getEncoding();
133+
auto dstEncoding = dstTy.getEncoding();
134+
if (srcEncoding.isa<SharedEncodingAttr>() ||
135+
dstEncoding.isa<SharedEncodingAttr>()) {
136+
// Only blocked -> blocked conversion requires for scratch allocation
137+
return;
138+
}
139+
// ConvertLayoutOp with both input/output non-shared_layout
140+
// TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
141+
// also possible to realize it with other approaches in restricted
142+
// conditions, such as warp-shuffle
143+
unsigned inVec = 0;
144+
unsigned outVec = 0;
145+
auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec);
146+
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
147+
std::multiplies{});
148+
auto bytes = elems * srcTy.getElementTypeBitWidth() / 8;
149+
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
76150
}
77151
}
78152

lib/Analysis/AxisInfo.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "mlir/Analysis/DataFlowAnalysis.h"
2+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23
#include "llvm/Support/raw_ostream.h"
34
#include <iostream>
45

@@ -46,6 +47,11 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
4647
fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
4748
if (attr)
4849
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
50+
} else if (auto fun = dyn_cast<LLVM::LLVMFuncOp>(op)) {
51+
Attribute attr =
52+
fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility");
53+
if (attr)
54+
divHint = attr.cast<IntegerAttr>().getValue().getZExtValue();
4955
}
5056
}
5157
DimVectorT contiguity(rank, 1);
@@ -203,6 +209,13 @@ ChangeResult AxisInfoAnalysis::visitOperation(
203209
}
204210
curr = AxisInfo(contiguity, divisibility, constancy);
205211
}
212+
// UnrealizedConversionCast
213+
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
214+
// in the process of a PartialConversion, where UnrealizedConversionCast
215+
// may exist
216+
if (llvm::isa<mlir::UnrealizedConversionCastOp>(op)) {
217+
curr = operands[0]->getValue();
218+
}
206219
if (curr.getRank() == 0) {
207220
return markAllPessimisticFixpoint(op->getResults());
208221
}

0 commit comments

Comments
 (0)