|
| 1 | +// RUN: triton-opt %s --tritongpu-hoist-tmem-alloc --tritongpu-partition-scheduling -allow-unregistered-dialect | FileCheck %s |
| 2 | + |
| 3 | +#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> |
| 4 | +#load_blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> |
| 5 | + |
| 6 | +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> |
| 7 | +#shared_T = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> |
| 8 | + |
| 9 | +#smem = #ttg.shared_memory |
| 10 | +#tmem_acc = #ttng.tensor_memory_encoding<blockM = 128, blockN = 64, unpacked = true> |
| 11 | +#tmem_lhs = #ttng.tensor_memory_encoding<blockM = 128, blockN = 64, unpacked = false> |
| 12 | +module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { |
| 13 | + |
| 14 | +// CHECK-LABEL: @attention_forward |
| 15 | +tt.func public @attention_forward( |
| 16 | + %Q_shared: !ttg.memdesc<256x64xf16, #shared, #smem>, |
| 17 | + %K_desc: !tt.tensordesc<tensor<64x64xf16, #shared>>, |
| 18 | + %V_desc: !tt.tensordesc<tensor<64x64xf16, #shared>>, |
| 19 | + %qk_scale: f32, |
| 20 | + %n_tiles: i32 |
| 21 | +) { |
| 22 | + %true = arith.constant true |
| 23 | + %false = arith.constant false |
| 24 | + %c0_i32 = arith.constant 0 : i32 |
| 25 | + %c64_i32 = arith.constant 64 : i32 |
| 26 | + |
| 27 | + %neg_inf = arith.constant dense<0xFF800000> : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> |
| 28 | + %zero = arith.constant dense<0.0> : tensor<256x64xf32, #blocked> |
| 29 | + %one = arith.constant dense<1.0> : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> |
| 30 | + |
| 31 | + %QK_tmem, %QK_tok = ttng.tmem_alloc : () -> (!ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable>, !ttg.async.token) |
| 32 | + |
| 33 | + %loop_outs:4 = scf.for %i = %c0_i32 to %n_tiles step %c64_i32 iter_args( |
| 34 | + %l_i = %one, |
| 35 | + %acc = %zero, |
| 36 | + %m_i = %neg_inf, |
| 37 | + %e_i = %one |
| 38 | + ) -> ( |
| 39 | + tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, |
| 40 | + tensor<256x64xf32, #blocked>, |
| 41 | + tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, |
| 42 | + tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> |
| 43 | + ) : i32 { |
| 44 | + |
| 45 | + %K = tt.descriptor_load %K_desc[%i, %c0_i32] : !tt.tensordesc<tensor<64x64xf16, #shared>> -> tensor<64x64xf16, #load_blocked> |
| 46 | + %K_shared = ttg.local_alloc %K : (tensor<64x64xf16, #load_blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem> |
| 47 | + |
| 48 | + %K_trans = ttg.memdesc_trans %K_shared {order = array<i32: 1, 0>} : !ttg.memdesc<64x64xf16, #shared, #smem> -> !ttg.memdesc<64x64xf16, #shared_T, #smem> |
| 49 | + %QK_mma_tok = ttng.tc_gen5_mma %Q_shared, %K_trans, %QK_tmem[%QK_tok], %false, %true : !ttg.memdesc<256x64xf16, #shared, #smem>, !ttg.memdesc<64x64xf16, #shared_T, #smem>, !ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable> |
| 50 | + |
| 51 | + %QK, %QK_load_tok = ttng.tmem_load %QK_tmem[%QK_mma_tok] : !ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable> -> tensor<256x64xf32, #blocked> |
| 52 | + %row_max = "compute_row_max"(%QK, %qk_scale) : (tensor<256x64xf32, #blocked>, f32) -> tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> |
| 53 | + %QK_adj = "sub_row_max"(%QK, %row_max, %qk_scale) : (tensor<256x64xf32, #blocked>, tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, f32) -> tensor<256x64xf32, #blocked> |
| 54 | + // CHECK: [[SOFTMAX:%.*]] = math.exp2 {{.*}} {ttg.partition = 0 : i32} : tensor<256x64xf32 |
| 55 | + %softmax = math.exp2 %QK_adj : tensor<256x64xf32, #blocked> |
| 56 | + %diff = arith.subf %m_i, %row_max : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> |
| 57 | + %alpha = math.exp2 %diff : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> |
| 58 | + |
| 59 | + %l_ij = "tt.reduce"(%softmax) <{axis = 1 : i32}> ({ |
| 60 | + ^bb0(%arg29: f32, %arg30: f32): |
| 61 | + %68 = arith.addf %arg29, %arg30 : f32 |
| 62 | + tt.reduce.return %68 : f32 |
| 63 | + }) : (tensor<256x64xf32, #blocked>) -> tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> |
| 64 | + %l_i_scaled = arith.mulf %l_i, %alpha : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> |
| 65 | + %next_l_i = arith.addf %l_i_scaled, %l_ij : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> |
| 66 | + |
| 67 | + %alpha_0 = tt.expand_dims %alpha {axis = 1 : i32} : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xf32, #blocked> |
| 68 | + %alpha_1 = tt.broadcast %alpha_0 : tensor<256x1xf32, #blocked> -> tensor<256x64xf32, #blocked> |
| 69 | + |
| 70 | + %acc_corrected = arith.mulf %acc, %alpha_1 : tensor<256x64xf32, #blocked> |
| 71 | + |
| 72 | + // CHECK: [[X:%.*]] = arith.addf [[SOFTMAX]], [[SOFTMAX]] {ttg.partition = 0 : i32} |
| 73 | + %x = arith.addf %softmax, %softmax : tensor<256x64xf32, #blocked> |
| 74 | + // CHECK-NEXT: [[ACC_X:%.*]] = arith.addf %{{.*}}, [[X]] {ttg.partition = 3 : i32} |
| 75 | + %acc_x = arith.addf %acc, %x : tensor<256x64xf32, #blocked> |
| 76 | + %e = "sum"(%acc_x) : (tensor<256x64xf32, #blocked>) -> tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> |
| 77 | + %next_e_i = arith.addf %e_i, %e : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> |
| 78 | + |
| 79 | + %V = tt.descriptor_load %V_desc[%i, %c0_i32] : !tt.tensordesc<tensor<64x64xf16, #shared>> -> tensor<64x64xf16, #load_blocked> |
| 80 | + %V_shared = ttg.local_alloc %V : (tensor<64x64xf16, #load_blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem> |
| 81 | + %P = arith.truncf %softmax : tensor<256x64xf32, #blocked> to tensor<256x64xf16, #blocked> |
| 82 | + |
| 83 | + %P_tmem = ttng.tmem_alloc %P : (tensor<256x64xf16, #blocked>) -> !ttg.memdesc<256x64xf16, #tmem_lhs, #ttng.tensor_memory> |
| 84 | + %acc_tmem, %acc_tok = ttng.tmem_alloc %acc_corrected : (tensor<256x64xf32, #blocked>) -> (!ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable>, !ttg.async.token) |
| 85 | + %PV_mma_tok = ttng.tc_gen5_mma %P_tmem, %V_shared, %acc_tmem[%acc_tok], %true, %true : !ttg.memdesc<256x64xf16, #tmem_lhs, #ttng.tensor_memory>, !ttg.memdesc<64x64xf16, #shared, #smem>, !ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable> |
| 86 | + %O, %O_tok = ttng.tmem_load %acc_tmem[%PV_mma_tok] : !ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable> -> tensor<256x64xf32, #blocked> |
| 87 | + |
| 88 | + scf.yield %next_l_i, %O, %row_max, %next_e_i : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256x64xf32, #blocked>, tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>> |
| 89 | + } {tt.warp_specialize} |
| 90 | + |
| 91 | + "use"(%loop_outs#0, %loop_outs#1, %loop_outs#2) : (tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<256x64xf32, #blocked>, tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) -> () |
| 92 | + |
| 93 | + tt.return |
| 94 | +} |
| 95 | + |
| 96 | +} |
0 commit comments