@@ -816,83 +816,6 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
816816
817817// -----
818818
819- #blocked = #ttg.blocked <{sizePerThread = [8 , 1 ], threadsPerWarp = [8 , 4 ], warpsPerCTA = [1 , 4 ], order = [0 , 1 ]}>
820- #blocked1 = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [4 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
821- #mma = #ttg.nvidia_mma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 16 , 16 ]}>
822- #shared = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = true , elementBitWidth = 16 }>
823- #smem = #ttg.shared_memory
824- module attributes {" ttg.target" = " cuda:90" , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 32 : i32 } {
825- // CHECK-LABEL: dot_lhs_in_reg_with_epilogue
826- tt.func @dot_lhs_in_reg_with_epilogue (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg2: i1 ) -> tensor <128 x16 xf32 , #mma > {
827- %cst = arith.constant dense <0 > : tensor <128 x64 xi32 , #blocked1 >
828- %cst1 = arith.constant dense <0 > : tensor <64 x16 xi32 , #blocked >
829- %c0_i32 = arith.constant 0 : i32
830- %cst_0 = arith.constant dense <0 > : tensor <1 x16 xi32 , #blocked >
831- %cst_1 = arith.constant dense <0 > : tensor <128 x1 xi32 , #blocked1 >
832- %c0_i64 = arith.constant 0 : i64
833- %cst_2 = arith.constant dense <0.000000e+00 > : tensor <128 x16 xf32 , #mma >
834- %cst_3 = arith.constant dense <0 > : tensor <128 x64 xi32 , #blocked1 >
835- %cst_4 = arith.constant dense <2.0 > : tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>
836- %c1_i32 = arith.constant 1 : i32
837- %c8_i32 = arith.constant 8 : i32
838- %0 = tt.addptr %arg0 , %c0_i64 : !tt.ptr <f16 >, i64
839- %1 = tt.addptr %arg1 , %c0_i64 : !tt.ptr <f16 >, i64
840- %2 = tt.splat %1 : !tt.ptr <f16 > -> tensor <128 x1 x!tt.ptr <f16 >, #blocked1 >
841- %3 = tt.addptr %2 , %cst_1 : tensor <128 x1 x!tt.ptr <f16 >, #blocked1 >, tensor <128 x1 xi32 , #blocked1 >
842- %4 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>
843- %5 = tt.expand_dims %4 {axis = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>> -> tensor <1 x64 xi32 , #blocked1 >
844- %6 = tt.broadcast %3 : tensor <128 x1 x!tt.ptr <f16 >, #blocked1 > -> tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >
845- %7 = tt.broadcast %5 : tensor <1 x64 xi32 , #blocked1 > -> tensor <128 x64 xi32 , #blocked1 >
846- %8 = tt.addptr %6 , %7 : tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <128 x64 xi32 , #blocked1 >
847- %10 = tt.splat %0 : !tt.ptr <f16 > -> tensor <1 x16 x!tt.ptr <f16 >, #blocked >
848- %11 = tt.addptr %10 , %cst_0 : tensor <1 x16 x!tt.ptr <f16 >, #blocked >, tensor <1 x16 xi32 , #blocked >
849- %12 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>>
850- %13 = tt.expand_dims %12 {axis = 1 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>> -> tensor <64 x1 xi32 , #blocked >
851- %14 = tt.broadcast %11 : tensor <1 x16 x!tt.ptr <f16 >, #blocked > -> tensor <64 x16 x!tt.ptr <f16 >, #blocked >
852- %15 = tt.broadcast %13 : tensor <64 x1 xi32 , #blocked > -> tensor <64 x16 xi32 , #blocked >
853- %16 = tt.addptr %14 , %15 : tensor <64 x16 x!tt.ptr <f16 >, #blocked >, tensor <64 x16 xi32 , #blocked >
854- // CHECK: scf.for
855- // CHECK: ttg.async_wait {{.*}} {num = 2 : i32}
856- // CHECK: ttng.warp_group_dot
857- // CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 3 : i32}
858- // CHECK: ttng.warp_group_dot
859- // CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 3 : i32}
860- // CHECK: ttng.warp_group_dot
861- // CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 3 : i32}
862- // CHECK: ttng.warp_group_dot
863- // CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 3 : i32}
864- // CHECK: ttg.async_copy_global_to_local
865- // CHECK: ttg.async_copy_global_to_local
866- // CHECK: ttg.async_commit_group
867- // CHECK: scf.if
868- // CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32}
869- // CHECK: } else {
870- // CHECK-NOT: ttng.warp_group_dot_wait
871- // CHECK: scf.yield
872- %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args (%arg4 = %cst_2 , %arg5 = %8 , %arg6 = %16 ) -> (tensor <128 x16 xf32 , #mma >, tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >,
873- tensor <64 x16 x!tt.ptr <f16 >, #blocked >) : i32 {
874- %a_block = tt.load %arg5 : tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >
875- %b_block = tt.load %arg6 : tensor <64 x16 x!tt.ptr <f16 >, #blocked >
876- %a_dotop = ttg.convert_layout %a_block : tensor <128 x64 xf16 , #blocked1 > -> tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>
877- %a_dotop_mul = arith.mulf %a_dotop , %cst_4 : tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>
878- %b_smem = ttg.local_alloc %b_block : (tensor <64 x16 xf16 , #blocked >) -> !ttg.memdesc <64 x16 xf16 , #shared , #smem >
879- %25 = ttng.warp_group_dot %a_dotop_mul , %b_smem , %arg4 : tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>> * !ttg.memdesc <64 x16 xf16 , #shared , #smem > -> tensor <128 x16 xf32 , #mma >
880- %26 = tt.addptr %arg5 , %cst : tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <128 x64 xi32 , #blocked1 >
881- %27 = tt.addptr %arg6 , %cst1 : tensor <64 x16 x!tt.ptr <f16 >, #blocked >, tensor <64 x16 xi32 , #blocked >
882- %28 = scf.if %arg2 -> tensor <128 x16 xf32 , #mma > {
883- %29 = arith.addf %25 , %25 : tensor <128 x16 xf32 , #mma >
884- scf.yield %29: tensor <128 x16 xf32 , #mma >
885- } else {
886- scf.yield %25: tensor <128 x16 xf32 , #mma >
887- }
888- scf.yield %28 , %26 , %27 : tensor <128 x16 xf32 , #mma >, tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <64 x16 x!tt.ptr <f16 >, #blocked >
889- }
890- tt.return %17#0 : tensor <128 x16 xf32 , #mma >
891- }
892- }
893-
894- // -----
895-
896819#blocked = #ttg.blocked <{sizePerThread = [16 , 1 ], threadsPerWarp = [16 , 2 ], warpsPerCTA = [1 , 8 ], order = [0 , 1 ]}>
897820#blocked1 = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
898821#blocked2 = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
0 commit comments