@@ -270,30 +270,30 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
270270
271271// CHECK-LABEL: tt.func @matmul_loop_mb
272272// CHECK: %{{.*}}:8 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}})
273- // Stage 0
274- // CHECK: %[[ADDPTR_28:.*]] = tt.addptr %[[ARG6]], %{{.*}}
275- // CHECK: %[[MULI_29:.*]] = arith.muli %{{.*}}, %{{.*}}
276- // CHECK: %[[SUBI_30:.*]] = arith.subi %{{.*}}, %[[MULI_29]]
277- // CHECK: %[[CMPI_31:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_30]]
278- // CHECK: %[[SPLAT_32:.*]] = tt.splat %[[CMPI_31]]
279- // CHECK: %[[LOAD_33:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_32]]
280- // CHECK: %[[ADDPTR_34:.*]] = tt.addptr %[[ARG7]], %{{.*}}
281- // CHECK: %[[SPLAT_35:.*]] = tt.splat %[[CMPI_31]]
282- // CHECK: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_34]], %[[SPLAT_35]]
283273// Stage 1
284- // CHECK: %[[ADDI_37:.*]] = arith.addi %[[ARG9]], %{{.*}}
285- // CHECK: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}}
286- // CHECK: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}}
287- // CHECK: %[[MEMDESC_SUBVIEW_40:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}]
288- // CHECK: ttg.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_40]]
289- // CHECK: %[[MEMDESC_SUBVIEW_41:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}]
290- // CHECK: ttg.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_41]]
274+ // CHECK: %[[ADDI_28:.*]] = arith.addi %[[ARG9]], %{{.*}}
275+ // CHECK: %[[CMPI_29:.*]] = arith.cmpi slt, %[[ADDI_28]], %{{.*}}
276+ // CHECK: %[[SELECT_30:.*]] = arith.select %[[CMPI_29]], %[[ADDI_28]], %{{.*}}
277+ // CHECK: %[[MEMDESC_SUBVIEW_31:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_30]], %{{.*}}, %{{.*}}]
278+ // CHECK: ttg.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_31]]
279+ // CHECK: %[[MEMDESC_SUBVIEW_32:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_30]], %{{.*}}, %{{.*}}]
280+ // CHECK: ttg.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_32]]
281+ // Stage 1
282+ // CHECK: %[[ADDPTR_33:.*]] = tt.addptr %[[ARG6]], %{{.*}}
283+ // CHECK: %[[MULI_34:.*]] = arith.muli %{{.*}}, %{{.*}}
284+ // CHECK: %[[SUBI_35:.*]] = arith.subi %{{.*}}, %[[MULI_34]]
285+ // CHECK: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_35]]
286+ // CHECK: %[[SPLAT_37:.*]] = tt.splat %[[CMPI_36]]
287+ // CHECK: %[[LOAD_38:.*]] = tt.load %[[ADDPTR_33]], %[[SPLAT_37]]
288+ // CHECK: %[[ADDPTR_39:.*]] = tt.addptr %[[ARG7]], %{{.*}}
289+ // CHECK: %[[SPLAT_40:.*]] = tt.splat %[[CMPI_36]]
290+ // CHECK: %[[LOAD_41:.*]] = tt.load %[[ADDPTR_39]], %[[SPLAT_40]]
291291// Stage 2
292292// CHECK: %[[LOCAL_LOAD_42:.*]] = ttg.local_load %[[ARG10]]
293293// CHECK: %[[LOCAL_LOAD_43:.*]] = ttg.local_load %[[ARG11]]
294294// CHECK: %[[MULF_44:.*]] = arith.mulf %[[LOCAL_LOAD_43]], %{{.*}}
295295// CHECK: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_42]], %[[MULF_44]], %[[ARG8]]
296- // CHECK: scf.yield %[[ADDPTR_28 ]], %[[ADDPTR_34 ]], %[[DOT_45]], %[[SELECT_39 ]], %[[MEMDESC_SUBVIEW_40 ]], %[[MEMDESC_SUBVIEW_41 ]], %[[LOAD_33 ]], %[[LOAD_36 ]]
296+ // CHECK: scf.yield %[[ADDPTR_33 ]], %[[ADDPTR_39 ]], %[[DOT_45]], %[[SELECT_30 ]], %[[MEMDESC_SUBVIEW_31 ]], %[[MEMDESC_SUBVIEW_32 ]], %[[LOAD_38 ]], %[[LOAD_41 ]]
297297// CHECK: }
298298
299299 tt.func @matmul_loop_mb (%arg0: index , %arg1: index , %arg2: index , %arg3: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg4: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }) -> tensor <128 x128 xf32 , #mma > {
0 commit comments