|
8 | 8 |
|
9 | 9 | #include <algorithm> |
10 | 10 | #include <limits> |
| 11 | +#include <numeric> |
| 12 | + |
| 13 | +using ::mlir::triton::gpu::BlockedEncodingAttr; |
| 14 | +using ::mlir::triton::gpu::MmaEncodingAttr; |
| 15 | +using ::mlir::triton::gpu::SharedEncodingAttr; |
11 | 16 |
|
12 | 17 | namespace mlir { |
13 | 18 |
|
14 | 19 | //===----------------------------------------------------------------------===// |
15 | 20 | // Shared Memory Allocation Analysis |
16 | 21 | //===----------------------------------------------------------------------===// |
17 | 22 | namespace triton { |
| 23 | + |
| 24 | +SmallVector<unsigned> |
| 25 | +getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, |
| 26 | + unsigned &outVec) { |
| 27 | + auto srcTy = op.src().getType().cast<RankedTensorType>(); |
| 28 | + auto dstTy = op.result().getType().cast<RankedTensorType>(); |
| 29 | + Attribute srcLayout = srcTy.getEncoding(); |
| 30 | + Attribute dstLayout = dstTy.getEncoding(); |
| 31 | + assert(srcLayout && dstLayout && |
| 32 | + "Unexpect layout in getScratchConfigForCvtLayout()"); |
| 33 | + unsigned rank = dstTy.getRank(); |
| 34 | + SmallVector<unsigned> paddedRepShape(rank); |
| 35 | + // TODO: move to TritonGPUAttrDefs.h.inc |
| 36 | + auto getShapePerCTA = [&](const Attribute &layout, unsigned d) -> unsigned { |
| 37 | + if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) { |
| 38 | + return blockedLayout.getSizePerThread()[d] * |
| 39 | + blockedLayout.getThreadsPerWarp()[d] * |
| 40 | + blockedLayout.getWarpsPerCTA()[d]; |
| 41 | + } else { |
| 42 | + assert(0 && "Unimplemented usage of getShapePerCTA"); |
| 43 | + return 0; |
| 44 | + } |
| 45 | + }; |
| 46 | + if (srcLayout.isa<BlockedEncodingAttr>() && |
| 47 | + dstLayout.isa<BlockedEncodingAttr>()) { |
| 48 | + auto srcBlockedLayout = srcLayout.cast<BlockedEncodingAttr>(); |
| 49 | + auto dstBlockedLayout = dstLayout.cast<BlockedEncodingAttr>(); |
| 50 | + auto inOrd = srcBlockedLayout.getOrder(); |
| 51 | + auto outOrd = dstBlockedLayout.getOrder(); |
| 52 | + // TODO: Fix the legacy issue that ourOrd[0] == 0 always means |
| 53 | + // that we cannot do vectorization. |
| 54 | + inVec = outOrd[0] == 0 ? 1 |
| 55 | + : inOrd[0] == 0 ? 1 |
| 56 | + : srcBlockedLayout.getSizePerThread()[inOrd[0]]; |
| 57 | + outVec = |
| 58 | + outOrd[0] == 0 ? 1 : dstBlockedLayout.getSizePerThread()[outOrd[0]]; |
| 59 | + unsigned pad = std::max(inVec, outVec); |
| 60 | + for (unsigned d = 0; d < rank; ++d) { |
| 61 | + paddedRepShape[d] = std::max( |
| 62 | + std::min<unsigned>(srcTy.getShape()[d], getShapePerCTA(srcLayout, d)), |
| 63 | + std::min<unsigned>(dstTy.getShape()[d], |
| 64 | + getShapePerCTA(dstLayout, d))); |
| 65 | + } |
| 66 | + paddedRepShape[outOrd[0]] += pad; |
| 67 | + } |
| 68 | + return paddedRepShape; |
| 69 | +} |
| 70 | + |
18 | 71 | class AllocationAnalysis { |
19 | 72 | public: |
20 | 73 | AllocationAnalysis(Operation *operation, Allocation *allocation) |
@@ -73,6 +126,27 @@ class AllocationAnalysis { |
73 | 126 | tensorType.getElementTypeBitWidth() / 8; |
74 | 127 | allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes); |
75 | 128 | } |
| 129 | + } else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) { |
| 130 | + auto srcTy = cvtLayout.src().getType().cast<RankedTensorType>(); |
| 131 | + auto dstTy = cvtLayout.result().getType().cast<RankedTensorType>(); |
| 132 | + auto srcEncoding = srcTy.getEncoding(); |
| 133 | + auto dstEncoding = dstTy.getEncoding(); |
| 134 | + if (srcEncoding.isa<SharedEncodingAttr>() || |
| 135 | + dstEncoding.isa<SharedEncodingAttr>()) { |
| 136 | + // Only blocked -> blocked conversion requires for scratch allocation |
| 137 | + return; |
| 138 | + } |
| 139 | + // ConvertLayoutOp with both input/output non-shared_layout |
| 140 | + // TODO: Besides of implementing ConvertLayoutOp via shared memory, it's |
| 141 | + // also possible to realize it with other approaches in restricted |
| 142 | + // conditions, such as warp-shuffle |
| 143 | + unsigned inVec = 0; |
| 144 | + unsigned outVec = 0; |
| 145 | + auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec); |
| 146 | + unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, |
| 147 | + std::multiplies{}); |
| 148 | + auto bytes = elems * srcTy.getElementTypeBitWidth() / 8; |
| 149 | + allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes); |
76 | 150 | } |
77 | 151 | } |
78 | 152 |
|
|
0 commit comments