|
| 1 | +// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx950 matrix-instruction-size=0' | FileCheck %s --check-prefixes CHECK |
| 2 | + |
| 3 | +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> |
| 4 | +// CHECK{LITERAL}: #ttg.linear<{register = [[0, 2], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[0, 0], [32, 0]], block = []}> |
| 5 | +// CHECK{LITERAL}: #ttg.linear<{register = [[0, 2], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[32, 0], [0, 0]], block = []}> |
| 6 | +// CHECK-LABEL: mfma_dot_scaled_mxfp4_mxfp4 |
| 7 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { |
| 8 | + tt.func public @mfma_dot_scaled_mxfp4_mxfp4( |
| 9 | + %arg0: tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, |
| 10 | + %arg1: tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, |
| 11 | + %arg2: tensor<128x4xi8>, |
| 12 | + %arg3: tensor<128x4xi8>, |
| 13 | + %arg4: tensor<128x128x!tt.ptr<f32>, #blocked> |
| 14 | + ) { |
| 15 | + // CHECK-NOT: arith.constant dense<127> : tensor<128x4xi8, #linear> |
| 16 | + // CHECK-NOT: arith.constant dense<127> : tensor<128x4xi8, #linear1> |
| 17 | + // CHECK-NOT: tt.fp_to_fp |
| 18 | + // CHECK: %[[C:.+]] = ttg.convert_layout {{.*}} : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma> |
| 19 | + // CHECK: %[[A:.+]] = ttg.convert_layout {{.*}} : tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> |
| 20 | + // CHECK: %[[B:.+]] = ttg.convert_layout {{.*}} : tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> |
| 21 | + // CHECK: %[[SCALE0:.+]] = ttg.convert_layout {{.*}} : {{.*}} -> tensor<128x4xi8, #linear> |
| 22 | + // CHECK: %[[SCALE1:.+]] = ttg.convert_layout {{.*}} : {{.*}} -> tensor<128x4xi8, #linear1> |
| 23 | + // CHECK: tt.dot_scaled %[[A]] scale %[[SCALE0]], %[[B]] scale %[[SCALE1]], %[[C]] lhs = e2m1 rhs = e2m1 |
| 24 | + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> |
| 25 | + %1 = tt.dot_scaled %arg0 scale %arg2, %arg1 scale %arg3, %cst lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, tensor<128x4xi8> * tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, tensor<128x4xi8> -> tensor<128x128xf32, #blocked> |
| 26 | + tt.store %arg4, %1 : tensor<128x128x!tt.ptr<f32>, #blocked> |
| 27 | + tt.return |
| 28 | + } |
| 29 | +} |
| 30 | + |
| 31 | +// ----- |
| 32 | + |
| 33 | +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> |
| 34 | +// CHECK-LABEL: mfma_dot_scaled_mxfp4_fp4 |
| 35 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { |
| 36 | + tt.func public @mfma_dot_scaled_mxfp4_fp4( |
| 37 | + %arg0: tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, |
| 38 | + %arg1: tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, |
| 39 | + %arg2: tensor<128x4xi8>, |
| 40 | + %arg3: tensor<128x128x!tt.ptr<f32>, #blocked> |
| 41 | + ) { |
| 42 | + // CHECK-NOT: tt.fp_to_fp |
| 43 | + // CHECK: %[[CST1:.+]] = arith.constant dense<127> : tensor<128x4xi8, #linear> |
| 44 | + // CHECK: %[[SCALE0:.+]] = ttg.convert_layout {{.*}} : {{.*}} -> tensor<128x4xi8, #linear1> |
| 45 | + // CHECK: tt.dot_scaled {{.*}} scale %[[SCALE0]], {{.*}} scale %[[CST1]], {{.*}} lhs = e2m1 rhs = e2m1 |
| 46 | + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> |
| 47 | + %1 = tt.dot_scaled %arg0 scale %arg2, %arg1, %cst lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, tensor<128x4xi8> * tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked> |
| 48 | + tt.store %arg3, %1 : tensor<128x128x!tt.ptr<f32>, #blocked> |
| 49 | + tt.return |
| 50 | + } |
| 51 | +} |
| 52 | + |
| 53 | +// ----- |
| 54 | + |
| 55 | +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> |
| 56 | +// CHECK-LABEL: mfma_dot_scaled_fp4_mxfp4 |
| 57 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { |
| 58 | + tt.func public @mfma_dot_scaled_fp4_mxfp4( |
| 59 | + %arg0: tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, |
| 60 | + %arg1: tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, |
| 61 | + %arg2: tensor<128x4xi8>, |
| 62 | + %arg3: tensor<128x128x!tt.ptr<f32>, #blocked> |
| 63 | + ) { |
| 64 | + // CHECK-NOT: tt.fp_to_fp |
| 65 | + // CHECK: %[[CST0:.+]] = arith.constant dense<127> : tensor<128x4xi8, #linear> |
| 66 | + // CHECK: %[[SCALE1:.+]] = ttg.convert_layout {{.*}} : {{.*}} -> tensor<128x4xi8, #linear1> |
| 67 | + // CHECK: tt.dot_scaled {{.*}} scale %[[CST0]], {{.*}} scale %[[SCALE1]], {{.*}} lhs = e2m1 rhs = e2m1 |
| 68 | + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> |
| 69 | + %1 = tt.dot_scaled %arg0, %arg1 scale %arg2, %cst lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, tensor<128x4xi8> -> tensor<128x128xf32, #blocked> |
| 70 | + tt.store %arg3, %1 : tensor<128x128x!tt.ptr<f32>, #blocked> |
| 71 | + tt.return |
| 72 | + } |
| 73 | +} |
| 74 | + |
| 75 | +// ----- |
| 76 | + |
| 77 | +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> |
| 78 | +// CHECK-LABEL: mfma_dot_scaled_fp4_fp4 |
| 79 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { |
| 80 | + tt.func public @mfma_dot_scaled_fp4_fp4( |
| 81 | + %arg0: tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, |
| 82 | + %arg1: tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, |
| 83 | + %arg2: tensor<128x128x!tt.ptr<f32>, #blocked> |
| 84 | + ) { |
| 85 | + // CHECK-NOT: tt.fp_to_fp |
| 86 | + // CHECK-DAG: %[[CST0:.+]] = arith.constant dense<127> : tensor<128x4xi8, #linear> |
| 87 | + // CHECK-DAG: %[[CST1:.+]] = arith.constant dense<127> : tensor<128x4xi8, #linear1> |
| 88 | + // CHECK: tt.dot_scaled {{.*}} scale %[[CST1]], {{.*}} scale %[[CST0]], {{.*}} lhs = e2m1 rhs = e2m1 |
| 89 | + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> |
| 90 | + %1 = tt.dot_scaled %arg0, %arg1, %cst lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked> |
| 91 | + tt.store %arg2, %1 : tensor<128x128x!tt.ptr<f32>, #blocked> |
| 92 | + tt.return |
| 93 | + } |
| 94 | +} |
0 commit comments