-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Expand file tree
/
Copy pathTargetInfo.cpp
More file actions
141 lines (127 loc) · 5.25 KB
/
TargetInfo.cpp
File metadata and controls
141 lines (127 loc) · 5.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#include "Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/TargetInfo.h"
#include "Dialect/ProtonGPU/IR/Dialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "third_party/amd/include/TritonAMDGPUToLLVM/GCNAsmFormat.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "llvm/Support/MathExtras.h"
namespace mlir::triton::proton::gpu::AMD {
Value TargetInfo::globalTime(ConversionPatternRewriter &rewriter,
Location loc) const {
auto b = TritonLLVMOpBuilder(loc, rewriter);
StringRef globalTimeIntrinsicName = "llvm.amdgcn.s.memrealtime";
Value globalTimeVal = LLVM::createLLVMIntrinsicCallOp(
rewriter, loc, globalTimeIntrinsicName, i64_ty, {})
.getResult(0);
// The clock-generator runs at 100 MHz ==> 10 ns per clock.
// Reference: Section 3.4.11 in the RDNA4 ISA manual
// https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/instruction-set-architectures/rdna4-instruction-set-architecture.pdf
return b.mul(globalTimeVal, b.i64_val(10));
}
// https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/include/hip/amd_detail/amd_device_functions.h#L898
// XCC_ID Register bit structure for gfx940-942, gfx950
// XCC_ID 3:0 XCC the wave is assigned to.
static Value getXCCID(ConversionPatternRewriter &rewriter, Location loc) {
GCNBuilder builder;
auto &gethwid = *builder.create("s_getreg_b32");
auto xcc_id = builder.newOperand("=s");
// HW_REG_XCC_ID_OFFSET=0, HW_REG_XCC_ID_SIZE=4
auto xcc_reg = builder.newConstantOperand("hwreg(HW_REG_XCC_ID, 0, 4)");
gethwid(xcc_id, xcc_reg);
return builder.launch(rewriter, loc, i32_ty, false);
}
// HW_ID Register bit structure for GCN and CDNA
// CU_ID 11:8 Compute Unit the wave is assigned to.
static Value getCUID(ConversionPatternRewriter &rewriter, Location loc) {
GCNBuilder builder;
auto &gethwid = *builder.create("s_getreg_b32");
auto cu_id = builder.newOperand("=s");
// HW_ID_CU_ID_OFFSET=8, HW_ID_CU_ID_SIZE=4
auto hwreg = builder.newConstantOperand("hwreg(HW_REG_HW_ID, 8, 4)");
gethwid(cu_id, hwreg);
return builder.launch(rewriter, loc, i32_ty, false);
}
// SE_ID 15:13 Shader Engine the wave is assigned to for gfx940-942,
// gfx950
static Value getSEID(ConversionPatternRewriter &rewriter, Location loc) {
GCNBuilder builder;
auto &gethwid = *builder.create("s_getreg_b32");
auto se_id = builder.newOperand("=s");
// HW_ID_SE_ID_OFFSET=13, HW_ID_SE_ID_SIZE=3
auto hwreg = builder.newConstantOperand("hwreg(HW_REG_HW_ID, 13, 3)");
gethwid(se_id, hwreg);
return builder.launch(rewriter, loc, i32_ty, false);
}
// gfx942 has 8 XCDs, each XCD contains 40 CUs per XCD but only 38/40 are active
// (total of 304 CUs) gfx950 has 8 XCDs, each XCD contains 36 CUs per XCD but
// only 32/36 active CUs (total 256 CUs)
static uint32_t getCU_PER_XCD(llvm::AMDGPU::GPUKind GPUKind) {
switch (GPUKind) {
case llvm::AMDGPU::GK_GFX942:
return 38;
case llvm::AMDGPU::GK_GFX950:
return 32;
default:
llvm_unreachable("unsupported arch");
}
}
static uint32_t getCU_PER_SE(llvm::AMDGPU::GPUKind GPUKind) {
switch (GPUKind) {
case llvm::AMDGPU::GK_GFX942:
return 10;
case llvm::AMDGPU::GK_GFX950:
return 10;
default:
llvm_unreachable("unsupported arch");
}
}
Value TargetInfo::processorId(ConversionPatternRewriter &rewriter,
Location loc) const {
GCNBuilder builder;
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto &gethwid = *builder.create("s_getreg_b32");
Value xcc_id = b.i32_val(0);
llvm::AMDGPU::GPUKind GPUKind = llvm::AMDGPU::parseArchAMDGCN(this->arch);
// For now only support gfx942, and gfx950
switch (GPUKind) {
case llvm::AMDGPU::GK_GFX942:
case llvm::AMDGPU::GK_GFX950:
xcc_id = getXCCID(rewriter, loc);
break;
default:
llvm::report_fatal_error("unsupported arch");
}
Value cu_id = getCUID(rewriter, loc); // local CU ID
Value se_id = getSEID(rewriter, loc);
builder.create<>("s_waitcnt lgkmcnt(0)")->operator()();
// For XCC based architectures to get a unique CU id for a wave:
// global_cu_id = xcc_id * CU_PER_XCD + se_id * CU_PER_SE + cu_id (local)
if (GPUKind == llvm::AMDGPU::GK_GFX942 ||
GPUKind == llvm::AMDGPU::GK_GFX950) {
uint32_t CU_PER_XCD = getCU_PER_XCD(GPUKind);
uint32_t CU_PER_SE = getCU_PER_SE(GPUKind);
cu_id = b.add(b.add(b.mul(xcc_id, b.i32_val(CU_PER_XCD)),
b.mul(se_id, b.i32_val(CU_PER_SE))),
cu_id);
}
return cu_id;
}
int TargetInfo::getAddressSpace(Attribute addressSpace) const {
int spaceId = 0;
if (mlir::isa<triton::gpu::SharedMemorySpaceAttr>(addressSpace)) {
spaceId = 3;
} else if (mlir::isa<proton::gpu::GlobalMemorySpaceAttr>(addressSpace)) {
spaceId = 1;
} else {
llvm::report_fatal_error("Only support SharedMemorySpace, "
"and GlobalMemorySpace for now");
}
return spaceId;
}
int TargetInfo::getIndexPtrAddrSpace() const {
// Internal buffer index is private to each thread, we use thread local
// address space for AMD GPUs. See detail discussion:
// https://llvm.org/docs/AMDGPUUsage.html#address-spaces
return 5;
}
} // namespace mlir::triton::proton::gpu::AMD