@@ -194,3 +194,136 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
194194 tt.return
195195 }
196196}
197+
198+ // -----
199+
200+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
201+ #blocked2 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [8 , 4 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
202+ #blocked4 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ]}>
203+ // CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], lane = [[0, 32], [0, 64], [1, 0], [2, 0], [4, 0]], warp = [[8, 0], [16, 0]], block = []}>
204+ // CHECK-LABEL: wmma_dot_scaled_mxfp8_bf16
205+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " hip:gfx1250" , " ttg.threads-per-warp" = 32 : i32 } {
206+ tt.func public @wmma_dot_scaled_mxfp8_bf16 (
207+ %arg0: tensor <32 x128 x!tt.ptr <f8E4M3FN >, #blocked4 >,
208+ %arg1: tensor <32 x4 x!tt.ptr <i8 >, #blocked2 >,
209+ %arg2: tensor <128 x32 x!tt.ptr <bf16 >, #blocked >,
210+ %output: tensor <32 x32 x!tt.ptr <f32 >, #blocked >
211+ ) {
212+ // CHECK: tt.load %arg1 {amdg.decomposed_dot_scaled_source = true} : tensor<32x4x!tt.ptr<i8>, #blocked1>
213+ // CHECK: %[[SCALE:.*]] = tt.reshape {{.*}} : tensor<32x4x32xi8, #blocked3> -> tensor<32x128xi8, #linear>
214+ // CHECK: %[[CVT0:.*]] = ttg.convert_layout %[[SCALE]] : tensor<32x128xi8, #linear> -> tensor<32x128xi8, #blocked>
215+ // CHECK: %[[UPCASTED:.*]] = amdg.scaled_upcast_fp8 {{.*}} scale %[[CVT0]] : tensor<32x128xf8E4M3FN, #blocked>, tensor<32x128xi8, #blocked> -> tensor<32x128xbf16, #blocked>
216+ // CHECK: %[[SEL:.*]] = arith.select {{.*}}, {{.*}}, %[[UPCASTED]]
217+ // CHECK: %[[CVT1:.*]] = ttg.convert_layout %[[SEL]] : tensor<32x128xbf16, #blocked> -> tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
218+ // CHECK: %[[OPND0:.*]] = ttg.convert_layout %[[CVT1]] : tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
219+ // CHECK: tt.dot %[[OPND0]]
220+ %a = tt.load %arg0 : tensor <32 x128 x!tt.ptr <f8E4M3FN >, #blocked4 >
221+ %scale = tt.load %arg1 : tensor <32 x4 x!tt.ptr <i8 >, #blocked2 >
222+ %b = tt.load %arg2 : tensor <128 x32 x!tt.ptr <bf16 >, #blocked >
223+ %c = arith.constant dense <0.000000e+00 > : tensor <32 x32 xf32 , #blocked >
224+ %res = tt.dot_scaled %a scale %scale , %b , %c lhs = e4m3 rhs = bf16 {fastMath = false } : tensor <32 x128 xf8 E4 M3 FN, #blocked4 >, tensor <32 x4 xi8 , #blocked2 > * tensor <128 x32 xbf16 , #blocked > -> tensor <32 x32 xf32 , #blocked >
225+
226+ tt.store %output , %res : tensor <32 x32 x!tt.ptr <f32 >, #blocked >
227+ tt.return
228+ }
229+ }
230+
231+ // -----
232+
233+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
234+ #blocked2 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [8 , 4 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
235+ #blocked4 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ]}>
236+ // CHECK{LITERAL}: #linear = #ttg.linear<{register = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], warp = [[32, 0], [64, 0]], block = []}>
237+ // CHECK-LABEL: wmma_dot_scaled_f16_mxfp8
238+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " hip:gfx1250" , " ttg.threads-per-warp" = 32 : i32 } {
239+ tt.func public @wmma_dot_scaled_f16_mxfp8 (
240+ %arg0: tensor <32 x128 x!tt.ptr <f16 >, #blocked4 >,
241+ %arg1: tensor <32 x4 x!tt.ptr <i8 >, #blocked2 >,
242+ %arg2: tensor <128 x32 x!tt.ptr <f8E5M2 >, #blocked >,
243+ %output: tensor <32 x32 x!tt.ptr <f32 >, #blocked >
244+ ) {
245+ // CHECK: %[[TRANS:.*]] = tt.trans {{.*}} {order = array<i32: 0, 2, 1>} : tensor<4x32x32xi8, #blocked4> -> tensor<4x32x32xi8, #blocked5>
246+ // CHECK: %[[SCALE:.*]] = tt.reshape %[[TRANS]] : tensor<4x32x32xi8, #blocked5> -> tensor<128x32xi8, #linear>
247+ // CHECK: %[[CVT0:.*]] = ttg.convert_layout %[[SCALE]] : tensor<128x32xi8, #linear> -> tensor<128x32xi8, #blocked2>
248+ // CHECK: %[[UPCASTED:.*]] = amdg.scaled_upcast_fp8 {{.*}} scale %[[CVT0]] : tensor<128x32xf8E5M2, #blocked2>, tensor<128x32xi8, #blocked2> -> tensor<128x32xf16, #blocked2>
249+ // CHECK: %[[SEL:.*]] = arith.select {{.*}}, %cst, %[[UPCASTED]] : tensor<128x32xi1, #blocked2>, tensor<128x32xf16, #blocked2>
250+ // CHECK: %[[CVT1:.*]] = ttg.convert_layout %[[SEL]] : tensor<128x32xf16, #blocked2> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked2}>>
251+ // CHECK: %[[OPND1:.*]] = ttg.convert_layout %[[CVT1]] : tensor<128x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
252+ // CHECK: = tt.dot {{.*}}, %[[OPND1]]
253+ %a = tt.load %arg0 : tensor <32 x128 x!tt.ptr <f16 >, #blocked4 >
254+ %scale = tt.load %arg1 : tensor <32 x4 x!tt.ptr <i8 >, #blocked2 >
255+ %b = tt.load %arg2 : tensor <128 x32 x!tt.ptr <f8E5M2 >, #blocked >
256+ %c = arith.constant dense <0.000000e+00 > : tensor <32 x32 xf32 , #blocked >
257+ %res = tt.dot_scaled %a , %b scale %scale , %c lhs = fp16 rhs = e5m2 {fastMath = false } : tensor <32 x128 xf16 , #blocked4 > * tensor <128 x32 xf8 E5 M2 , #blocked >, tensor <32 x4 xi8 , #blocked2 > -> tensor <32 x32 xf32 , #blocked >
258+
259+ tt.store %output , %res : tensor <32 x32 x!tt.ptr <f32 >, #blocked >
260+ tt.return
261+ }
262+ }
263+
264+ // -----
265+
266+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
267+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 2 ], threadsPerWarp = [4 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
268+ #blocked2 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [16 , 2 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
269+ #blocked5 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
270+ // CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], lane = [[0, 32], [1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [0, 0]], block = []}>
271+ // CHECK-LABEL: wmma_dot_scaled_mxfp4_bf16
272+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " hip:gfx1250" , " ttg.threads-per-warp" = 32 : i32 } {
273+ tt.func public @wmma_dot_scaled_mxfp4_bf16 (
274+ %arg0: tensor <16 x32 x!tt.ptr <i8 >, #blocked5 >,
275+ %arg1: tensor <16 x2 x!tt.ptr <i8 >, #blocked2 >,
276+ %arg2: tensor <64 x16 x!tt.ptr <bf16 >, #blocked >,
277+ %output: tensor <16 x16 x!tt.ptr <f32 >, #blocked >
278+ ) {
279+ // CHECK: tt.load %arg1 {amdg.decomposed_dot_scaled_source = true} : tensor<16x2x!tt.ptr<i8>, #blocked1>
280+ // CHECK: %[[SCALE:.*]] = tt.reshape {{.*}} : tensor<16x2x32xi8, #blocked3> -> tensor<16x64xi8, #linear>
281+ // CHECK: %[[CVT0:.*]] = ttg.convert_layout %[[SCALE]] : tensor<16x64xi8, #linear> -> tensor<16x64xi8, #blocked>
282+ // CHECK: %[[UPCASTED:.*]] = amdg.scaled_upcast_fp4 {{.*}} scale %[[CVT0]] {axis = 1 : i32} : tensor<16x32xi8, #blocked>, tensor<16x64xi8, #blocked> -> tensor<16x64xbf16, #blocked>
283+ // CHECK: %[[SEL:.*]] = arith.select {{.*}}, %{{.*}}, %[[UPCASTED]] : tensor<16x64xi1, #blocked>, tensor<16x64xbf16, #blocked>
284+ // CHECK: %[[CVT1:.*]] = ttg.convert_layout %[[SEL]] : tensor<16x64xbf16, #blocked> -> tensor<16x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
285+ // CHECK: %[[OPND0:.*]] = ttg.convert_layout %[[CVT1]] : tensor<16x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<16x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
286+ // CHECK: tt.dot %[[OPND0]]
287+ %a = tt.load %arg0 : tensor <16 x32 x!tt.ptr <i8 >, #blocked5 >
288+ %scale = tt.load %arg1 : tensor <16 x2 x!tt.ptr <i8 >, #blocked2 >
289+ %b = tt.load %arg2 : tensor <64 x16 x!tt.ptr <bf16 >, #blocked >
290+ %c = arith.constant dense <0.000000e+00 > : tensor <16 x16 xf32 , #blocked >
291+ %res = tt.dot_scaled %a scale %scale , %b , %c lhs = e2m1 rhs = bf16 {fastMath = false } : tensor <16 x32 xi8 , #blocked5 >, tensor <16 x2 xi8 , #blocked2 > * tensor <64 x16 xbf16 , #blocked > -> tensor <16 x16 xf32 , #blocked >
292+
293+ tt.store %output , %res : tensor <16 x16 x!tt.ptr <f32 >, #blocked >
294+ tt.return
295+ }
296+ }
297+
298+ // -----
299+
300+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
301+ #blocked2 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [16 , 2 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
302+ #blocked5 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
303+ // CHECK{LITERAL}: #linear = #ttg.linear<{register = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [32, 0]], warp = [[0, 0], [0, 0]], block = []}>
304+ // CHECK-LABEL: wmma_dot_scaled_fp16_mxfp4
305+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " hip:gfx1250" , " ttg.threads-per-warp" = 32 : i32 } {
306+ tt.func public @wmma_dot_scaled_fp16_mxfp4 (
307+ %arg0: tensor <16 x64 x!tt.ptr <f16 >, #blocked5 >,
308+ %arg1: tensor <16 x2 x!tt.ptr <i8 >, #blocked2 >,
309+ %arg2: tensor <32 x16 x!tt.ptr <i8 >, #blocked >,
310+ %output: tensor <16 x16 x!tt.ptr <f32 >, #blocked >
311+ ) {
312+ // CHECK: tt.load %arg1 {amdg.decomposed_dot_scaled_source = true} : tensor<16x2x!tt.ptr<i8>, #blocked1>
313+ // CHECK: %[[SCALE:.*]] = tt.reshape {{.*}} : tensor<2x32x16xi8, #blocked5> -> tensor<64x16xi8, #linear>
314+ // CHECK: %[[CVT0:.*]] = ttg.convert_layout %[[SCALE]] : tensor<64x16xi8, #linear> -> tensor<64x16xi8, #blocked2>
315+ // CHECK: %[[UPCASTED:.*]] = amdg.scaled_upcast_fp4 {{.*}} scale %[[CVT0]] {axis = 0 : i32} : tensor<32x16xi8, #blocked2>, tensor<64x16xi8, #blocked2> -> tensor<64x16xf16, #blocked2>
316+ // CHECK: %[[SEL:.*]] = arith.select {{.*}}, %cst, %[[UPCASTED]] : tensor<64x16xi1, #blocked2>, tensor<64x16xf16, #blocked2>
317+ // CHECK: %[[CVT1:.*]] = ttg.convert_layout %[[SEL]] : tensor<64x16xf16, #blocked2> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked2}>>
318+ // CHECK: %[[OPND1:.*]] = ttg.convert_layout %[[CVT1]] : tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
319+ // CHECK: tt.dot {{.*}}, %[[OPND1]]
320+ %a = tt.load %arg0 : tensor <16 x64 x!tt.ptr <f16 >, #blocked5 >
321+ %scale = tt.load %arg1 : tensor <16 x2 x!tt.ptr <i8 >, #blocked2 >
322+ %b = tt.load %arg2 : tensor <32 x16 x!tt.ptr <i8 >, #blocked >
323+ %c = arith.constant dense <0.000000e+00 > : tensor <16 x16 xf32 , #blocked >
324+ %res = tt.dot_scaled %a , %b scale %scale , %c lhs = fp16 rhs = e2m1 {fastMath = false } : tensor <16 x64 xf16 , #blocked5 > * tensor <32 x16 xi8 , #blocked >, tensor <16 x2 xi8 , #blocked2 > -> tensor <16 x16 xf32 , #blocked >
325+
326+ tt.store %output , %res : tensor <16 x16 x!tt.ptr <f32 >, #blocked >
327+ tt.return
328+ }
329+ }
0 commit comments