Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ba17a9e
Basic lowering AsyncCommitGroup and AsyncWait
AlexAUT Jan 17, 2025
2327587
WIP lowering of AsyncCopy
AlexAUT Jan 17, 2025
1d4edf6
Added layout checks for asynccopy lowering
AlexAUT Jan 24, 2025
ead4915
Support direct to lds
AlexAUT Jan 27, 2025
3141ba4
Enable non working masking
AlexAUT Jan 27, 2025
644aa1e
Add support to enable disable direct to lds with env var AMDGCN_USE_D…
AlexAUT Jan 28, 2025
7c9bab1
Fix masking and others for direct to lds
AlexAUT Jan 28, 2025
cb823d0
Fix when AsycCopy is lowered without a mask
AlexAUT Jan 28, 2025
c097616
Use ROCDL instead of intrinsics
AlexAUT Jan 28, 2025
1a9f1e0
Cleanup and simplify AsyncCopy lowering
AlexAUT Jan 28, 2025
a20b686
CacheModifiers for AsyncCopy
AlexAUT Jan 28, 2025
97d677d
Add lit test for AsyncCopy
AlexAUT Jan 28, 2025
30352ad
Split AsyncCopy Lit for gfx950
AlexAUT Jan 28, 2025
fe8619d
Add const to getCtrlBitsForCacheModifierOnTarget
AlexAUT Jan 28, 2025
7941a30
Cleanup StreamPipeliner changes
AlexAUT Jan 28, 2025
def9313
Revert stream pipeline related changes
AlexAUT Jan 28, 2025
318caa2
Add missing CDNA1 to AsyncCopy support list
AlexAUT Jan 28, 2025
6600138
Cleanup
AlexAUT Jan 28, 2025
ea02c3c
Replace macros for llvm ops with TritonLLVMOpBuilder
AlexAUT Jan 29, 2025
13419bb
Fix wrong value in supported bit width for global.to.lds
AlexAUT Jan 30, 2025
ca8b441
Addressing review comments
AlexAUT Jan 31, 2025
6aa3554
Unified async ops lit tests
AlexAUT Jan 31, 2025
04fad93
Emit correct wmcnt wait instead of waiting on all cnts
AlexAUT Jan 31, 2025
f6cbe22
Add tests for AsyncWait/AsyncCommitGroup
AlexAUT Jan 31, 2025
3d30f43
Limit AsyncWait conversion to gfx9
AlexAUT Feb 3, 2025
0c382db
Add AsyncOpy lowering lit test with masking and other values
AlexAUT Feb 3, 2025
f560aeb
Added async copy lit tests with cache modifiers
AlexAUT Feb 5, 2025
d6b0d02
Merge branch 'main' into global_to_lds_lowering
AlexAUT Feb 5, 2025
d90ffbe
Adjust to shared encoding changes
AlexAUT Feb 5, 2025
5356802
Fix a few small issues
antiagainst Feb 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 217 additions & 0 deletions test/Conversion/amd/async_ops_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s --check-prefix=GFX950
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --verify-diagnostics | FileCheck %s --check-prefix=GFX942

#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: async_copy
tt.func public @async_copy(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: i32 {tt.divisibility = 16 : i32},
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
// We need the splat to allow the AxisAnalysis to work during lowering
%1 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
// Each thread needs to load 8 elements and we load 1 (sizePerThread) per global.load.lds
// CHECK-COUNT-8: rocdl.global.load.lds
// CHECK-NOT: rocdl.global.load.lds
%2 = ttg.async_copy_global_to_local %1, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: async_copy_vectorized_2xf16
tt.func public @async_copy_vectorized_2xf16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: i32 {tt.divisibility = 16 : i32},
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
// We need the index calculation so AxisAnalysis sees that we can vectorize the load
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
%5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>

// Each thread needs to load 8 elements and we load 2 (sizePerThread) per global.load.lds
// CHECK-COUNT-4: rocdl.global.load.lds
// CHECK-NOT: rocdl.global.load.lds
%6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
// GFX950-LABEL: async_copy_vectorized_8xf16
tt.func public @async_copy_vectorized_8xf16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: i32 {tt.divisibility = 16 : i32},
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
// We need the index calculation so AxisAnalysis sees that we can vectorize the load
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
%5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>

// Each thread needs to load 8 elements and we load 8 (sizePerThread) per global.load.lds
// GFX950: rocdl.global.load.lds
// GFX950-next: llvm.return

// GFX942 does not support vectorization > 4bytes
// expected-error@+1 {{failed to legalize operation 'ttg.async_copy_global_to_local' that was explicitly marked illegal}}
%6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: async_wait
tt.func public @async_wait(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: i32 {tt.divisibility = 16 : i32},
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
// The waitcnt stores all counters in one i32 bits 15:14 and 3:0 store the vmcnt we have to wait on
// CHECK: rocdl.waitcnt -49168
// CHECK: rocdl.barrier
ttg.async_wait {num = 0 : i32}
// CHECK: rocdl.waitcnt -49167
// CHECK: rocdl.barrier
ttg.async_wait {num = 1 : i32}
// CHECK: rocdl.waitcnt -2
// CHECK: rocdl.barrier
ttg.async_wait {num = 62 : i32}
// CHECK: rocdl.waitcnt -1
// CHECK: rocdl.barrier
ttg.async_wait {num = 63 : i32}
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: async_commit_group
tt.func public @async_commit_group(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: i32 {tt.divisibility = 16 : i32},
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
// CHECK-NEXT: llvm.return
ttg.async_commit_group
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: async_copy_mask_other
tt.func public @async_copy_mask_other(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: i32 {tt.divisibility = 16 : i32},
%arg2: !ttg.memdesc<32x32xf16, #shared, #smem, mutable>,
%arg3: i32 {tt.divisibility = 16 : i32}) {
// We need the splat to allow the AxisAnalysis to work during lowering
%cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #blocked>
%c0_i32 = arith.constant 0 : i32
%c32_i32 = arith.constant 32 : i32
%c31_i32 = arith.constant 31 : i32
%1 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x32x!tt.ptr<f16>, #blocked>
%29 = arith.addi %arg3, %c31_i32 : i32
%30 = arith.divsi %29, %c32_i32 : i32
%31 = arith.cmpi sgt, %30, %c0_i32 : i32

%51 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%52 = tt.expand_dims %51 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked>
%65 = tt.splat %arg3 : i32 -> tensor<32x1xi32, #blocked>
%66 = arith.cmpi slt, %52, %65 : tensor<32x1xi32, #blocked>
%67 = tt.broadcast %66 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked>

%70 = tt.splat %31 : i1 -> tensor<32x32xi1, #blocked>
%71 = arith.andi %70, %67 : tensor<32x32xi1, #blocked>

// Each thread needs to load 4 elements and we load 1 (sizePerThread) per global.load.lds
// Note that mask/other alignment is 1 so we need 4 conditionals

// CHECK: llvm.cond_br
// CHECK: rocdl.global.load.lds
// CHECK-NEXT: llvm.br
// CHECK: _predicated_store

// CHECK: llvm.cond_br
// CHECK: rocdl.global.load.lds
// CHECK-NEXT: llvm.br
// CHECK: _predicated_store

// CHECK: llvm.cond_br
// CHECK: rocdl.global.load.lds
// CHECK-NEXT: llvm.br
// CHECK: _predicated_store

// CHECK: llvm.cond_br
// CHECK: rocdl.global.load.lds
// CHECK-NEXT: llvm.br
// CHECK: _predicated_store

%2 = ttg.async_copy_global_to_local %1, %arg2 mask %67 other %cst_0 : tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 32], warpsPerCTA = [16, 1], order = [1, 0]}>
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
// GFX942-LABEL: async_copy_cache_mods
tt.func public @async_copy_cache_mods(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: i32 {tt.divisibility = 16 : i32},
%arg2: !ttg.memdesc<32x32xf16, #shared, #smem, mutable>) {
// We need the splat to allow the AxisAnalysis to work during lowering
%1 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x32x!tt.ptr<f16>, #blocked>
// Each thread needs to load 1 element and we load 1 (sizePerThread) per global.load.lds

// GFX942: llvm.getelementptr
// GFX942: %[[aux_ca:.*]] = llvm.mlir.constant(0 : i32) : i32
// GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_ca]]
%2 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = ca: tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
// GFX942: llvm.getelementptr
// GFX942: %[[aux_cg:.*]] = llvm.mlir.constant(0 : i32) : i32
// GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cg]]
%3 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = cg: tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
// GFX942: llvm.getelementptr
// GFX942: %[[aux_cs:.*]] = llvm.mlir.constant(3 : i32) : i32
// GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cs]]
%5 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = cs: tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
// GFX942: llvm.getelementptr
// GFX942: %[[aux_cv:.*]] = llvm.mlir.constant(9 : i32) : i32
// GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cv]]
%6 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = cv: tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
// GFX942: llvm.getelementptr
// GFX942: %[[aux_wb:.*]] = llvm.mlir.constant(0 : i32) : i32
// GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_wb]]
%7 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = wb: tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
// GFX942: llvm.getelementptr
// GFX942: %[[aux_wt:.*]] = llvm.mlir.constant(8 : i32) : i32
// GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_wt]]
%8 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = wt: tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
tt.return
}
}
Loading
Loading