Skip to content

Commit df66eb5

Browse files
authored
[AMD][BACKEND] Adjust cache modifier mappings (triton-lang#5852)
Adjust mappings to better represent the expected caching behavior and improve the misleading parameter name. `.cg` loads should be different for all types of loads and not just `BufferLoads`.
1 parent 6afc767 commit df66eb5

File tree

3 files changed

+15
-24
lines changed

3 files changed

+15
-24
lines changed

test/Conversion/amd/async_ops_to_llvm.mlir

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -193,25 +193,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, ttg.sha
193193
// CHECK: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_ca]]
194194
%2 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = ca: tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
195195
// CHECK: llvm.getelementptr
196-
// CHECK: %[[aux_cg:.*]] = llvm.mlir.constant(0 : i32) : i32
196+
// CHECK: %[[aux_cg:.*]] = llvm.mlir.constant(3 : i32) : i32
197197
// CHECK: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cg]]
198198
%3 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = cg: tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
199199
// CHECK: llvm.getelementptr
200-
// CHECK: %[[aux_cs:.*]] = llvm.mlir.constant(3 : i32) : i32
201-
// CHECK: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cs]]
202-
%5 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = cs: tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
203-
// CHECK: llvm.getelementptr
204-
// CHECK: %[[aux_cv:.*]] = llvm.mlir.constant(9 : i32) : i32
200+
// CHECK: %[[aux_cv:.*]] = llvm.mlir.constant(11 : i32) : i32
205201
// CHECK: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cv]]
206-
%6 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = cv: tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
207-
// CHECK: llvm.getelementptr
208-
// CHECK: %[[aux_wb:.*]] = llvm.mlir.constant(0 : i32) : i32
209-
// CHECK: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_wb]]
210-
%7 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = wb: tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
211-
// CHECK: llvm.getelementptr
212-
// CHECK: %[[aux_wt:.*]] = llvm.mlir.constant(8 : i32) : i32
213-
// CHECK: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_wt]]
214-
%8 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = wt: tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
202+
%4 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = cv: tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
215203
tt.return
216204
}
217205
}

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ struct AsyncCopyGlobalToLocalOpConversion
511511

512512
Value cacheModifiers =
513513
b.i32_val(mlir::LLVM::AMD::getCtrlBitsForCacheModifierOnTarget(
514-
op.getCache(), false, targetInfo));
514+
op.getCache(), /*isLoad=*/true, targetInfo));
515515

516516
Value llMask = adaptor.getMask();
517517
SmallVector<Value> maskElems;

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -455,40 +455,43 @@ getCacheModifierFlagsForPredicatedCall(LLVM::CallOp callOp) {
455455
// Load | .ca | 0 | 0 | 0 |
456456
// | .cg | 0 | 1 | 1 |
457457
// | .cs | 0 | 1 | 1 |
458-
// | .cv | 1 | 1 | x |
458+
// | .cv | 1 | 1 | 1 |
459459
// -------+-----+-----+-----+----+--
460460
// Store | .wb | 0 | 0 | 0 |
461461
// | .cg | 0 | 0 | 0 |
462462
// | .cs | 0 | 1 | 1 |
463-
// | .wt | 1 | x | x |
463+
// | .wt | 1 | 1 | 1 |
464464
// -------+-----+-----+-----+----+--
465465
// Atomic | N/A | 0 | 1 | x | Setting sc0 returns the pre-op value
466466
// | N/A | 1 | 0 | x | Setting sc1 performs a system-scope atomic
467467
// -------+-----+-----+-----+----+--
468468
static int32_t
469469
getCtrlBitsForCacheModifierOnGFX_942_950(triton::CacheModifier cm,
470-
bool isBufferLoad) {
470+
bool isLoad) {
471471
const int sc0Bit = 0b1, ntBit = 0b10, sc1Bit = 0b1000;
472472
int32_t aux = 0;
473473
switch (cm) {
474474
case triton::CacheModifier::CA:
475475
aux = 0;
476476
break;
477477
case triton::CacheModifier::CG:
478-
if (isBufferLoad)
478+
if (isLoad)
479479
aux |= sc0Bit | ntBit;
480480
break;
481481
case triton::CacheModifier::CS:
482482
aux |= sc0Bit | ntBit;
483483
break;
484484
case triton::CacheModifier::CV:
485-
aux |= sc0Bit | sc1Bit;
485+
assert(isLoad);
486+
aux |= sc0Bit | sc1Bit | ntBit;
486487
break;
487488
case triton::CacheModifier::WB:
489+
assert(!isLoad);
488490
aux = 0;
489491
break;
490492
case triton::CacheModifier::WT:
491-
aux |= sc1Bit;
493+
assert(!isLoad);
494+
aux |= sc0Bit | sc1Bit | ntBit;
492495
break;
493496
default:
494497
aux = 0;
@@ -521,12 +524,12 @@ static int32_t getDefaultCtrlBitsForCacheModifier(triton::CacheModifier cm) {
521524
// .wb: write-back, writes back data at all cache levels
522525
// .wt: write-through, write data directly to system memory
523526
int32_t getCtrlBitsForCacheModifierOnTarget(
524-
triton::CacheModifier cm, bool isBufferLoad,
527+
triton::CacheModifier cm, bool isLoad,
525528
const mlir::triton::AMD::TargetInfo &targetInfo) {
526529
switch (targetInfo.getGPUKind()) {
527530
case llvm::AMDGPU::GK_GFX942:
528531
case llvm::AMDGPU::GK_GFX950:
529-
return getCtrlBitsForCacheModifierOnGFX_942_950(cm, isBufferLoad);
532+
return getCtrlBitsForCacheModifierOnGFX_942_950(cm, isLoad);
530533
default:
531534
return getDefaultCtrlBitsForCacheModifier(cm);
532535
}

0 commit comments

Comments
 (0)