@@ -816,6 +816,83 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
816816
817817// -----
818818
819+ #blocked = #ttg.blocked <{sizePerThread = [8 , 1 ], threadsPerWarp = [8 , 4 ], warpsPerCTA = [1 , 4 ], order = [0 , 1 ]}>
820+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [4 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
821+ #mma = #ttg.nvidia_mma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 16 , 16 ]}>
822+ #shared = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = true , elementBitWidth = 16 }>
823+ #smem = #ttg.shared_memory
824+ module attributes {" ttg.target" = " cuda:90" , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 32 : i32 } {
825+ // CHECK-LABEL: dot_lhs_in_reg_with_epilogue
826+ tt.func @dot_lhs_in_reg_with_epilogue (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg2: i1 ) -> tensor <128 x16 xf32 , #mma > {
827+ %cst = arith.constant dense <0 > : tensor <128 x64 xi32 , #blocked1 >
828+ %cst1 = arith.constant dense <0 > : tensor <64 x16 xi32 , #blocked >
829+ %c0_i32 = arith.constant 0 : i32
830+ %cst_0 = arith.constant dense <0 > : tensor <1 x16 xi32 , #blocked >
831+ %cst_1 = arith.constant dense <0 > : tensor <128 x1 xi32 , #blocked1 >
832+ %c0_i64 = arith.constant 0 : i64
833+ %cst_2 = arith.constant dense <0.000000e+00 > : tensor <128 x16 xf32 , #mma >
834+ %cst_3 = arith.constant dense <0 > : tensor <128 x64 xi32 , #blocked1 >
835+ %cst_4 = arith.constant dense <2.0 > : tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>
836+ %c1_i32 = arith.constant 1 : i32
837+ %c8_i32 = arith.constant 8 : i32
838+ %0 = tt.addptr %arg0 , %c0_i64 : !tt.ptr <f16 >, i64
839+ %1 = tt.addptr %arg1 , %c0_i64 : !tt.ptr <f16 >, i64
840+ %2 = tt.splat %1 : !tt.ptr <f16 > -> tensor <128 x1 x!tt.ptr <f16 >, #blocked1 >
841+ %3 = tt.addptr %2 , %cst_1 : tensor <128 x1 x!tt.ptr <f16 >, #blocked1 >, tensor <128 x1 xi32 , #blocked1 >
842+ %4 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>
843+ %5 = tt.expand_dims %4 {axis = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>> -> tensor <1 x64 xi32 , #blocked1 >
844+ %6 = tt.broadcast %3 : tensor <128 x1 x!tt.ptr <f16 >, #blocked1 > -> tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >
845+ %7 = tt.broadcast %5 : tensor <1 x64 xi32 , #blocked1 > -> tensor <128 x64 xi32 , #blocked1 >
846+ %8 = tt.addptr %6 , %7 : tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <128 x64 xi32 , #blocked1 >
847+ %10 = tt.splat %0 : !tt.ptr <f16 > -> tensor <1 x16 x!tt.ptr <f16 >, #blocked >
848+ %11 = tt.addptr %10 , %cst_0 : tensor <1 x16 x!tt.ptr <f16 >, #blocked >, tensor <1 x16 xi32 , #blocked >
849+ %12 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>>
850+ %13 = tt.expand_dims %12 {axis = 1 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>> -> tensor <64 x1 xi32 , #blocked >
851+ %14 = tt.broadcast %11 : tensor <1 x16 x!tt.ptr <f16 >, #blocked > -> tensor <64 x16 x!tt.ptr <f16 >, #blocked >
852+ %15 = tt.broadcast %13 : tensor <64 x1 xi32 , #blocked > -> tensor <64 x16 xi32 , #blocked >
853+ %16 = tt.addptr %14 , %15 : tensor <64 x16 x!tt.ptr <f16 >, #blocked >, tensor <64 x16 xi32 , #blocked >
854+ // CHECK: scf.for
855+ // CHECK: ttg.async_wait {{.*}} {num = 2 : i32}
856+ // CHECK: ttng.warp_group_dot
857+ // CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 3 : i32}
858+ // CHECK: ttng.warp_group_dot
859+ // CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 3 : i32}
860+ // CHECK: ttng.warp_group_dot
861+ // CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 3 : i32}
862+ // CHECK: ttng.warp_group_dot
863+ // CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 3 : i32}
864+ // CHECK: ttg.async_copy_global_to_local
865+ // CHECK: ttg.async_copy_global_to_local
866+ // CHECK: ttg.async_commit_group
867+ // CHECK: scf.if
868+ // CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32}
869+ // CHECK: } else {
870+ // CHECK-NOT: ttng.warp_group_dot_wait
871+ // CHECK: scf.yield
872+ %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args (%arg4 = %cst_2 , %arg5 = %8 , %arg6 = %16 ) -> (tensor <128 x16 xf32 , #mma >, tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >,
873+ tensor <64 x16 x!tt.ptr <f16 >, #blocked >) : i32 {
874+ %a_block = tt.load %arg5 : tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >
875+ %b_block = tt.load %arg6 : tensor <64 x16 x!tt.ptr <f16 >, #blocked >
876+ %a_dotop = ttg.convert_layout %a_block : tensor <128 x64 xf16 , #blocked1 > -> tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>
877+ %a_dotop_mul = arith.mulf %a_dotop , %cst_4 : tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>
878+ %b_smem = ttg.local_alloc %b_block : (tensor <64 x16 xf16 , #blocked >) -> !ttg.memdesc <64 x16 xf16 , #shared , #smem >
879+ %25 = ttng.warp_group_dot %a_dotop_mul , %b_smem , %arg4 : tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>> * !ttg.memdesc <64 x16 xf16 , #shared , #smem > -> tensor <128 x16 xf32 , #mma >
880+ %26 = tt.addptr %arg5 , %cst : tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <128 x64 xi32 , #blocked1 >
881+ %27 = tt.addptr %arg6 , %cst1 : tensor <64 x16 x!tt.ptr <f16 >, #blocked >, tensor <64 x16 xi32 , #blocked >
882+ %28 = scf.if %arg2 -> tensor <128 x16 xf32 , #mma > {
883+ %29 = arith.addf %25 , %25 : tensor <128 x16 xf32 , #mma >
884+ scf.yield %29: tensor <128 x16 xf32 , #mma >
885+ } else {
886+ scf.yield %25: tensor <128 x16 xf32 , #mma >
887+ }
888+ scf.yield %28 , %26 , %27 : tensor <128 x16 xf32 , #mma >, tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <64 x16 x!tt.ptr <f16 >, #blocked >
889+ }
890+ tt.return %17#0 : tensor <128 x16 xf32 , #mma >
891+ }
892+ }
893+
894+ // -----
895+
819896#blocked = #ttg.blocked <{sizePerThread = [16 , 1 ], threadsPerWarp = [16 , 2 ], warpsPerCTA = [1 , 8 ], order = [0 , 1 ]}>
820897#blocked1 = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
821898#blocked2 = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
0 commit comments