Skip to content

Commit 5784490

Browse files
authored
[NFC] Prepare utilities for warp specialization lowering on AMD (#8968)
Exposes additional utilities needed for warp specialization lowering on AMD.
1 parent 7cd4e29 commit 5784490

File tree

4 files changed

+62
-32
lines changed

4 files changed

+62
-32
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,10 @@ Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale,
468468
// Hardware Indices
469469
// -----------------------------------------------------------------------
470470

471+
// If an operation is contained within a warp specialize region, this returns
472+
// the warp ID offset of that warpgroup.
473+
std::optional<int> getWarpGroupStartWarpId(Block *block);
474+
471475
// If an operation is contained within a warp specialize region, this returns
472476
// the thread ID offset of that warpgroup.
473477
std::optional<int> getWarpGroupStartThreadId(Block *block);

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
301301
return outIndices;
302302
}
303303

304-
std::optional<int> getWarpGroupStartThreadId(Block *block) {
304+
std::optional<int> getWarpGroupStartWarpId(Block *block) {
305305
using namespace triton::gpu;
306306

307307
// Look for an enclosing `ttg.warp_specialize` op.
@@ -317,9 +317,19 @@ std::optional<int> getWarpGroupStartThreadId(Block *block) {
317317
std::optional<ArrayRef<int32_t>> startIds = ws.getWarpGroupStartIds();
318318
assert(startIds && "cannot get warp group ID before warp group allocation");
319319
int32_t warpStartId = (*startIds)[idx];
320-
int threadsPerWarp =
321-
TritonGPUDialect::getThreadsPerWarp(ws->getParentOfType<ModuleOp>());
322-
return warpStartId * threadsPerWarp;
320+
return warpStartId;
321+
}
322+
323+
std::optional<int> getWarpGroupStartThreadId(Block *block) {
324+
using namespace triton::gpu;
325+
326+
std::optional<int> warpStartId = getWarpGroupStartWarpId(block);
327+
if (!warpStartId)
328+
return {};
329+
330+
int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(
331+
block->getParentOp()->getParentOfType<ModuleOp>());
332+
return *warpStartId * threadsPerWarp;
323333
}
324334

325335
Value getThreadId(OpBuilder &rewriter, Location loc) {
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#ifndef TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_TYPECONVERTER_H
2+
#define TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_TYPECONVERTER_H
3+
4+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
5+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
6+
#include "triton/Conversion/MLIRTypes.h"
7+
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
8+
#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h"
9+
#include "triton/Dialect/Triton/IR/Types.h"
10+
#include "triton/Dialect/TritonGPU/IR/Types.h"
11+
12+
using namespace mlir;
13+
using namespace mlir::triton;
14+
15+
class TritonAMDGPUToLLVMTypeConverter : public TritonGPUToLLVMTypeConverter {
16+
public:
17+
TritonAMDGPUToLLVMTypeConverter(MLIRContext *ctx,
18+
const LowerToLLVMOptions &options,
19+
const TargetInfoBase &targetInfo,
20+
const DataLayoutAnalysis *analysis = nullptr)
21+
: TritonGPUToLLVMTypeConverter(ctx, options, targetInfo, analysis) {
22+
addConversion([&](TensorDescType type) -> std::optional<Type> {
23+
return convertTensorDescType(type);
24+
});
25+
}
26+
27+
Type convertTensorDescType(triton::TensorDescType type) {
28+
auto ctx = type.getContext();
29+
auto blockType = type.getBlockType();
30+
auto shape = blockType.getShape();
31+
32+
// Determine the number of dwords based on tensor dimensions
33+
// 2D tensors: group0 (4) + group1 (8) = 12 dwords
34+
// 3D-5D tensors: group0 (4) + group1 (8) + group2 (4) + group3 (4) = 20
35+
// dwords
36+
int numDwords = (shape.size() > 2) ? (4 + 8 + 4 + 4) : (4 + 8);
37+
38+
auto types = SmallVector<Type>(numDwords, IntegerType::get(ctx, 32));
39+
return LLVM::LLVMStructType::getLiteral(ctx, types);
40+
}
41+
};
42+
43+
#endif

third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "PatternTritonGPUOpToLLVM.h"
55
#include "TargetInfo.h"
66
#include "TritonAMDGPUToLLVM/MembarUtility.h"
7+
#include "TritonAMDGPUToLLVM/TypeConverter.h"
78
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
89
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
910
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
@@ -63,34 +64,6 @@ class TritonLLVMConversionTarget : public ConversionTarget {
6364
}
6465
};
6566

66-
class TritonAMDGPUToLLVMTypeConverter : public TritonGPUToLLVMTypeConverter {
67-
public:
68-
TritonAMDGPUToLLVMTypeConverter(MLIRContext *ctx,
69-
const LowerToLLVMOptions &options,
70-
const TargetInfoBase &targetInfo,
71-
const DataLayoutAnalysis *analysis = nullptr)
72-
: TritonGPUToLLVMTypeConverter(ctx, options, targetInfo, analysis) {
73-
addConversion([&](TensorDescType type) -> std::optional<Type> {
74-
return convertTensorDescType(type);
75-
});
76-
}
77-
78-
Type convertTensorDescType(triton::TensorDescType type) {
79-
auto ctx = type.getContext();
80-
auto blockType = type.getBlockType();
81-
auto shape = blockType.getShape();
82-
83-
// Determine the number of dwords based on tensor dimensions
84-
// 2D tensors: group0 (4) + group1 (8) = 12 dwords
85-
// 3D-5D tensors: group0 (4) + group1 (8) + group2 (4) + group3 (4) = 20
86-
// dwords
87-
int numDwords = (shape.size() > 2) ? (4 + 8 + 4 + 4) : (4 + 8);
88-
89-
auto types = SmallVector<Type>(numDwords, IntegerType::get(ctx, 32));
90-
return LLVM::LLVMStructType::getLiteral(ctx, types);
91-
}
92-
};
93-
9467
struct ConvertTritonAMDGPUToLLVM
9568
: public triton::impl::ConvertTritonAMDGPUToLLVMBase<
9669
ConvertTritonAMDGPUToLLVM> {

0 commit comments

Comments
 (0)