@@ -73,6 +73,33 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 :
7373
7474// -----
7575
76+ // CHECK: #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 32, 16]}>
77+ // CHECK: #mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 64, 16]}>
78+ #blocked = #triton_gpu.blocked <{sizePerThread = [4 , 4 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
79+ #blocked1 = #triton_gpu.blocked <{sizePerThread = [4 , 4 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
80+ #blocked2 = #triton_gpu.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
81+ module attributes {" triton_gpu.target" = " cuda:90" , " triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 8 : i32 , " triton_gpu.threads-per-warp" = 32 : i32 } {
82+ // CHECK-LABEL: chained_dot
83+ tt.func public @chained_dot_wgmma (
84+ %arg0: tensor <64 x128 xf16 , #triton_gpu.dot_op <{opIdx = 0 , parent = #blocked }>>,
85+ %arg1: tensor <128 x64 xf16 , #triton_gpu.dot_op <{opIdx = 1 , parent = #blocked }>>,
86+ %arg2: tensor <64 x128 xf16 , #triton_gpu.dot_op <{opIdx = 1 , parent = #blocked1 }>>) -> tensor <64 x128 xf32 , #blocked1 > {
87+ %cst_0 = arith.constant dense <0.000000e+00 > : tensor <64 x64 xf32 , #blocked >
88+ %cst_1 = arith.constant dense <0.000000e+00 > : tensor <64 x128 xf32 , #blocked1 >
89+ // CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<64x64xf32, #mma>
90+ %d = tt.dot %arg0 , %arg1 , %cst_0 :
91+ tensor <64 x128 xf16 , #triton_gpu.dot_op <{opIdx = 0 , parent = #blocked }>> * tensor <128 x64 xf16 , #triton_gpu.dot_op <{opIdx = 1 , parent = #blocked }>> -> tensor <64 x64 xf32 , #blocked >
92+ %t = arith.truncf %d : tensor <64 x64 xf32 , #blocked > to tensor <64 x64 xf16 , #blocked >
93+ %c = triton_gpu.convert_layout %t : tensor <64 x64 xf16 , #blocked > -> tensor <64 x64 xf16 , #triton_gpu.dot_op <{opIdx = 0 , parent = #blocked1 }>>
94+ // CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<64x128xf32, #mma1>
95+ %r = tt.dot %c , %arg2 , %cst_1 :
96+ tensor <64 x64 xf16 , #triton_gpu.dot_op <{opIdx = 0 , parent = #blocked1 }>> * tensor <64 x128 xf16 , #triton_gpu.dot_op <{opIdx = 1 , parent = #blocked1 }>> -> tensor <64 x128 xf32 , #blocked1 >
97+ tt.return %r : tensor <64 x128 xf32 , #blocked1 >
98+ }
99+ }
100+
101+ // -----
102+
76103// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 8]}>
77104#blocked = #triton_gpu.blocked <{sizePerThread = [4 , 4 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
78105#blocked1 = #triton_gpu.blocked <{sizePerThread = [4 , 4 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
0 commit comments