@@ -110,7 +110,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
110110#smem = #ttg.shared_memory
111111#tmem = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 128 , unpacked = true >
112112module attributes {" ttg.num-warps" = 4 : i32 , " ttg.num-ctas" = 1 : i32 , ttg.target = " cuda:100" , " ttg.threads-per-warp" = 32 : i32 } {
113- tt.func @matmul_loop_cast_load (%lb : index , %ub : index , %step : index ,
113+ tt.func @matmul_loop_cast_load (%lb : i32 , %ub : i32 , %step : i32 ,
114114 %A : !tt.ptr <f8E4M3FN > {tt.divisibility = 16 : i32 },
115115 %B : !tt.ptr <f8E4M3FN > {tt.divisibility = 16 : i32 }) -> tensor <128 x128 xf32 , #C > {
116116// CHECK-LABEL: tt.func @matmul_loop_cast_load
@@ -137,7 +137,7 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, ttg.targ
137137 %a_off = arith.constant dense <4 > : tensor <128 x32 xi32 , #AL >
138138 %b_off = arith.constant dense <4 > : tensor <32 x128 xi32 , #BL >
139139
140- %loop:3 = scf.for %iv = %lb to %ub step %step iter_args (%a_ptr = %a_ptr_init , %b_ptr = %b_ptr_init , %prev_c = %c_init ) -> (tensor <128 x32 x!tt.ptr <f8E4M3FN >, #AL >, tensor <32 x128 x!tt.ptr <f8E4M3FN >, #BL >, tensor <128 x128 xf32 , #C >) {
140+ %loop:3 = scf.for %iv = %lb to %ub step %step iter_args (%a_ptr = %a_ptr_init , %b_ptr = %b_ptr_init , %prev_c = %c_init ) -> (tensor <128 x32 x!tt.ptr <f8E4M3FN >, #AL >, tensor <32 x128 x!tt.ptr <f8E4M3FN >, #BL >, tensor <128 x128 xf32 , #C >) : i32 {
141141 %a___ = tt.load %a_ptr : tensor <128 x32 x!tt.ptr <f8E4M3FN >, #AL >
142142 %a__ = tt.fp_to_fp %a___ : tensor <128 x32 xf8 E4 M3 FN, #AL > -> tensor <128 x32 xf16 , #AL >
143143 %a_ = ttg.convert_layout %a__ : tensor <128 x32 xf16 , #AL > -> tensor <128 x32 xf16 , #A >
@@ -250,7 +250,7 @@ tt.func private @pipelined_gather(
250250#smem = #ttg.shared_memory
251251
252252module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " cuda:100" , " ttg.threads-per-warp" = 32 : i32 } {
253- tt.func public @block_scale_mxfp_matmul (%lb : index , %ub : index , %step : index , %arg0: !tt.ptr <f8E5M2 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f8E5M2 > {tt.divisibility = 16 : i32 }, %arg3: !tt.ptr <i8 > {tt.divisibility = 16 : i32 }, %arg4: !tt.ptr <i8 > {tt.divisibility = 16 : i32 }) -> tensor <128 x128 xf32 , #blocked4 > {
253+ tt.func public @block_scale_mxfp_matmul (%lb : i32 , %ub : i32 , %step : i32 , %arg0: !tt.ptr <f8E5M2 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f8E5M2 > {tt.divisibility = 16 : i32 }, %arg3: !tt.ptr <i8 > {tt.divisibility = 16 : i32 }, %arg4: !tt.ptr <i8 > {tt.divisibility = 16 : i32 }) -> tensor <128 x128 xf32 , #blocked4 > {
254254 // CHECK: ttg.local_alloc : () -> !ttg.memdesc<2x128x256xf8E5M2
255255 // CHECK: ttg.local_alloc : () -> !ttg.memdesc<2x256x128xf8E5M2
256256 // Do not multibuffer the scale loads, as we cannot pipeline the mma due to tmem.cp not being used
@@ -288,7 +288,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
288288 %arg3_init = tt.addptr %arg3_splat , %57 : tensor <1 x2 x32 x4 x4 x!tt.ptr <i8 >, #blocked2 >, tensor <1 x2 x32 x4 x4 xi32 , #blocked2 >
289289 %arg4_init = tt.addptr %arg4_splat , %57 : tensor <1 x2 x32 x4 x4 x!tt.ptr <i8 >, #blocked2 >, tensor <1 x2 x32 x4 x4 xi32 , #blocked2 >
290290
291- %99:5 = scf.for %iv = %lb to %ub step %step iter_args (%arg15 = %cst_1 , %arg16 = %arg0_init , %arg17 = %arg1_init , %arg18 = %arg3_init , %arg19 = %arg4_init ) -> (tensor <128 x128 xf32 , #blocked4 >, tensor <128 x256 x!tt.ptr <f8E5M2 >, #blocked >, tensor <256 x128 x!tt.ptr <f8E5M2 >, #blocked1 >, tensor <1 x2 x32 x4 x4 x!tt.ptr <i8 >, #blocked2 >, tensor <1 x2 x32 x4 x4 x!tt.ptr <i8 >, #blocked2 >) {
291+ %99:5 = scf.for %iv = %lb to %ub step %step iter_args (%arg15 = %cst_1 , %arg16 = %arg0_init , %arg17 = %arg1_init , %arg18 = %arg3_init , %arg19 = %arg4_init ) -> (tensor <128 x128 xf32 , #blocked4 >, tensor <128 x256 x!tt.ptr <f8E5M2 >, #blocked >, tensor <256 x128 x!tt.ptr <f8E5M2 >, #blocked1 >, tensor <1 x2 x32 x4 x4 x!tt.ptr <i8 >, #blocked2 >, tensor <1 x2 x32 x4 x4 x!tt.ptr <i8 >, #blocked2 >) : i32 {
292292 %117 = tt.load %arg16 : tensor <128 x256 x!tt.ptr <f8E5M2 >, #blocked >
293293 %118 = ttg.local_alloc %117 : (tensor <128 x256 xf8 E5 M2 , #blocked >) -> !ttg.memdesc <128 x256 xf8 E5 M2 , #shared , #ttg.shared_memory >
294294 %119 = tt.load %arg17 : tensor <256 x128 x!tt.ptr <f8E5M2 >, #blocked1 >
@@ -338,7 +338,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
338338#smem = #ttg.shared_memory
339339
340340module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " cuda:100" , " ttg.threads-per-warp" = 32 : i32 } {
341- tt.func public @block_scale_mxfp_matmul_tmem_copy (%lb : index , %ub : index , %step : index , %arg0: !tt.ptr <f8E5M2 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f8E5M2 > {tt.divisibility = 16 : i32 }, %arg3: !tt.ptr <i8 > {tt.divisibility = 16 : i32 }, %arg4: !tt.ptr <i8 > {tt.divisibility = 16 : i32 }) -> tensor <128 x128 xf32 , #blocked4 > {
341+ tt.func public @block_scale_mxfp_matmul_tmem_copy (%lb : i32 , %ub : i32 , %step : i32 , %arg0: !tt.ptr <f8E5M2 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f8E5M2 > {tt.divisibility = 16 : i32 }, %arg3: !tt.ptr <i8 > {tt.divisibility = 16 : i32 }, %arg4: !tt.ptr <i8 > {tt.divisibility = 16 : i32 }) -> tensor <128 x128 xf32 , #blocked4 > {
342342 // CHECK: ttg.local_alloc : () -> !ttg.memdesc<3x128x256xf8E5M2
343343 // CHECK: ttg.local_alloc : () -> !ttg.memdesc<3x256x128xf8E5M2
344344 // CHECK: ttg.local_alloc : () -> !ttg.memdesc<3x1x2x32x4x4xi8
@@ -375,7 +375,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
375375 %arg3_init = tt.addptr %arg3_splat , %57 : tensor <1 x2 x32 x4 x4 x!tt.ptr <i8 >, #blocked2 >, tensor <1 x2 x32 x4 x4 xi32 , #blocked2 >
376376 %arg4_init = tt.addptr %arg4_splat , %57 : tensor <1 x2 x32 x4 x4 x!tt.ptr <i8 >, #blocked2 >, tensor <1 x2 x32 x4 x4 xi32 , #blocked2 >
377377
378- %99:6 = scf.for %iv = %lb to %ub step %step iter_args (%arg15 = %cst_1 , %arg16 = %arg0_init , %arg17 = %arg1_init , %arg18 = %arg3_init , %arg19 = %arg4_init , %init_flag =%false ) -> (tensor <128 x128 xf32 , #blocked4 >, tensor <128 x256 x!tt.ptr <f8E5M2 >, #blocked >, tensor <256 x128 x!tt.ptr <f8E5M2 >, #blocked1 >, tensor <1 x2 x32 x4 x4 x!tt.ptr <i8 >, #blocked2 >, tensor <1 x2 x32 x4 x4 x!tt.ptr <i8 >, #blocked2 >, i1 ) {
378+ %99:6 = scf.for %iv = %lb to %ub step %step iter_args (%arg15 = %cst_1 , %arg16 = %arg0_init , %arg17 = %arg1_init , %arg18 = %arg3_init , %arg19 = %arg4_init , %init_flag =%false ) -> (tensor <128 x128 xf32 , #blocked4 >, tensor <128 x256 x!tt.ptr <f8E5M2 >, #blocked >, tensor <256 x128 x!tt.ptr <f8E5M2 >, #blocked1 >, tensor <1 x2 x32 x4 x4 x!tt.ptr <i8 >, #blocked2 >, tensor <1 x2 x32 x4 x4 x!tt.ptr <i8 >, #blocked2 >, i1 ) : i32 {
379379 %117 = tt.load %arg16 : tensor <128 x256 x!tt.ptr <f8E5M2 >, #blocked >
380380 %118 = ttg.local_alloc %117 : (tensor <128 x256 xf8 E5 M2 , #blocked >) -> !ttg.memdesc <128 x256 xf8 E5 M2 , #shared , #ttg.shared_memory >
381381 %119 = tt.load %arg17 : tensor <256 x128 x!tt.ptr <f8E5M2 >, #blocked1 >
0 commit comments