Skip to content

Commit f29d8c7

Browse files
authored
[AMD] Properly categorize gfx950 in ISA family (triton-lang#5873)
While also here, audit features and enable them for gfx950 when sutiable.
1 parent 7685e96 commit f29d8c7

File tree

8 files changed

+41
-34
lines changed

8 files changed

+41
-34
lines changed

third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ enum class ISAFamily {
1111
CDNA1,
1212
CDNA2,
1313
CDNA3,
14+
CDNA4,
1415
RDNA1,
1516
RDNA2,
1617
RDNA3,

third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,11 @@
22
#include "TargetInfo.h"
33
#include "Utility.h"
44
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
5-
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
65
#include "mlir/IR/PatternMatch.h"
7-
#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h"
86
#include "triton/Dialect/Triton/IR/Dialect.h"
97

108
#include "BufferOpsEmitter.h"
119

12-
using mlir::triton::gpu::appendOrGetExternFuncOp;
13-
using mlir::triton::gpu::getFunctionType;
1410
using namespace triton::AMD;
1511

1612
namespace {
@@ -66,7 +62,8 @@ Value BufferEmitter::createResourceDescriptor(Value basePtr,
6662
}
6763

6864
Value stride = b.int_val(16, 0);
69-
if (targetInfo.getISAFamily() == ISAFamily::CDNA3) {
65+
if (llvm::is_contained({ISAFamily::CDNA3, ISAFamily::CDNA4},
66+
targetInfo.getISAFamily())) {
7067
if (blockStride) { // TODO: BufferAtomicRMWOp is unsupported
7168
Value enableSwizzle = b.int_val(16, 16384);
7269
Value mask14b = b.int_val(16, 16383);

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ using ::mlir::LLVM::getSharedMemoryBase;
2323
using ::mlir::LLVM::AMD::getVectorSize;
2424
using ::mlir::LLVM::AMD::llLoad;
2525
using ::mlir::LLVM::AMD::llStore;
26+
using ::mlir::triton::AMD::ISAFamily;
2627
using ::mlir::triton::gpu::getTotalElemsPerThread;
2728

2829
namespace {
@@ -408,25 +409,18 @@ struct AsyncCopyGlobalToLocalOpConversion
408409

409410
bool supportsLoadWidth(unsigned bits,
410411
const AMD::TargetInfo &targetInfo) const {
411-
llvm::SmallSetVector<unsigned, 10> supportedWidths;
412-
using mlir::triton::AMD::ISAFamily;
413412
switch (targetInfo.getISAFamily()) {
414413
case ISAFamily::CDNA1:
415414
case ISAFamily::CDNA2:
416415
case ISAFamily::CDNA3:
417-
supportedWidths.insert(8);
418-
supportedWidths.insert(16);
419-
supportedWidths.insert(32);
420-
if (targetInfo.getGPUKind() == llvm::AMDGPU::GPUKind::GK_GFX950) {
421-
supportedWidths.insert(96);
422-
supportedWidths.insert(128);
423-
}
424-
break;
416+
return llvm::is_contained({32, 16, 8}, bits);
417+
case ISAFamily::CDNA4:
418+
return llvm::is_contained({128, 96, 32, 16, 8}, bits);
425419
default:
426-
return false;
420+
break;
427421
}
428422

429-
return supportedWidths.contains(bits);
423+
return false;
430424
}
431425

432426
LogicalResult
@@ -1120,10 +1114,17 @@ struct AtomicCASOpConversion
11201114
}
11211115
};
11221116

1123-
bool supportsGlobalAtomicF16PackedAndDpp(triton::AMD::ISAFamily isaFamily) {
1124-
return isaFamily == triton::AMD::ISAFamily::CDNA1 ||
1125-
isaFamily == triton::AMD::ISAFamily::CDNA2 ||
1126-
isaFamily == triton::AMD::ISAFamily::CDNA3;
1117+
bool supportsGlobalAtomicF16PackedAndDpp(ISAFamily isaFamily) {
1118+
switch (isaFamily) {
1119+
case ISAFamily::CDNA1:
1120+
case ISAFamily::CDNA2:
1121+
case ISAFamily::CDNA3:
1122+
case ISAFamily::CDNA4:
1123+
return true;
1124+
default:
1125+
break;
1126+
}
1127+
return false;
11271128
}
11281129

11291130
Value generateI32DppMove(PatternRewriter &rewriter, Value val, int dppCtrl) {
@@ -1284,11 +1285,12 @@ struct AtomicRMWOpConversion
12841285
int numElems = 1;
12851286
Type packF16Ty = vec_ty(valueElemTy, 2);
12861287

1287-
// CDNA3 arch allows to accelerate its atomics with LDS reduction algorithm,
1288-
// which is only applicable for atomics with no return. Otherwise we have to
1289-
// deal with an additional overhead.
1288+
// CDNA3/CDNA4 arch allows to accelerate its atomics with LDS reduction
1289+
// algorithm, which is only applicable for atomics with no return. Otherwise
1290+
// we have to deal with an additional overhead.
12901291
bool enableIntraWaveReduce =
1291-
targetInfo.getISAFamily() == triton::AMD::ISAFamily::CDNA3 &&
1292+
llvm::is_contained({ISAFamily::CDNA3, ISAFamily::CDNA4},
1293+
targetInfo.getISAFamily()) &&
12921294
tensorTy && opResult.use_empty();
12931295

12941296
// TODO: support data types less than 32 bits
@@ -1648,17 +1650,15 @@ struct AsyncWaitOpConversion : public ConvertOpToLLVMPattern<AsyncWaitOp> {
16481650
LogicalResult
16491651
matchAndRewrite(AsyncWaitOp op, OpAdaptor adaptor,
16501652
ConversionPatternRewriter &rewriter) const override {
1651-
1652-
using mlir::triton::AMD::ISAFamily;
1653-
16541653
switch (targetInfo.getISAFamily()) {
16551654
case ISAFamily::CDNA1:
16561655
case ISAFamily::CDNA2:
16571656
case ISAFamily::CDNA3:
1657+
case ISAFamily::CDNA4:
16581658
break;
16591659
default:
16601660
return rewriter.notifyMatchFailure(
1661-
op, "Only supported on target architecture");
1661+
op, "Only supported on CDNA target architecture");
16621662
}
16631663

16641664
auto loc = op->getLoc();

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "TargetInfo.h"
22
#include "SchedInstructions.h"
33
#include "TritonAMDGPUToLLVM/GCNAsmFormat.h"
4+
#include "TritonAMDGPUToLLVM/TargetUtils.h"
45
#include "Utility.h"
56
#include "mlir/Dialect/Arith/IR/Arith.h"
67
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -193,8 +194,9 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
193194
if (numLaneToReduce != 64)
194195
return false;
195196

196-
if (auto family = getISAFamily();
197-
family != ISAFamily::CDNA3 && family != ISAFamily::CDNA2) {
197+
if (!llvm::is_contained(
198+
{ISAFamily::CDNA2, ISAFamily::CDNA3, ISAFamily::CDNA4},
199+
getISAFamily())) {
198200
return false;
199201
}
200202

third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ ISAFamily deduceISAFamily(llvm::StringRef arch) {
1212
// CDNA ISA cases
1313
switch (kind) {
1414
case llvm::AMDGPU::GK_GFX950:
15+
return ISAFamily::CDNA4;
1516
case llvm::AMDGPU::GK_GFX942:
1617
case llvm::AMDGPU::GK_GFX941:
1718
case llvm::AMDGPU::GK_GFX940:
@@ -40,6 +41,7 @@ bool supportsVDot(llvm::StringRef arch) {
4041
case AMD::ISAFamily::CDNA1:
4142
case AMD::ISAFamily::CDNA2:
4243
case AMD::ISAFamily::CDNA3:
44+
case AMD::ISAFamily::CDNA4:
4345
case AMD::ISAFamily::RDNA2:
4446
case AMD::ISAFamily::RDNA3:
4547
return true;

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,11 @@ static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter,
145145
Value offset = b.i32_val(0x401F);
146146
return rewriter.create<ROCDL::DsSwizzleOp>(loc, valType, val, offset);
147147
} else {
148-
if (isaFamily != ISAFamily::CDNA2 && isaFamily != ISAFamily::CDNA3) {
149-
// DPP is only supportted for CDNA2 and CDNA3 right now, so we fallback
150-
// to ds_swizzle for other archs.
148+
if (!llvm::is_contained(
149+
{ISAFamily::CDNA2, ISAFamily::CDNA3, ISAFamily::CDNA4},
150+
isaFamily)) {
151+
// DPP is only supported for CDNA2/CDNA3/CDNA4 right now, so we fallback
152+
// to ds_swizzle for other architectures.
151153
//
152154
// This map facilates the butterfly shuffle pattern for a stride less
153155
// than 16. The pattern stride is the key of the map.

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ int getMfmaVersion(ISAFamily isaFamily) {
2424
return 2;
2525
case ISAFamily::CDNA3:
2626
return 3;
27+
case ISAFamily::CDNA4:
28+
return 4;
2729
default:
2830
break;
2931
}

third_party/amd/python/triton_amd.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ void init_triton_amd(py::module &&m) {
273273
m.def("has_matrix_core_feature", [](const std::string &arch) {
274274
using mlir::triton::AMD::ISAFamily;
275275
switch (mlir::triton::AMD::deduceISAFamily(arch)) {
276+
case ISAFamily::CDNA4:
276277
case ISAFamily::CDNA3:
277278
case ISAFamily::CDNA2:
278279
case ISAFamily::CDNA1:

0 commit comments

Comments
 (0)