@@ -96,6 +96,38 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32}
9696
9797// -----
9898
99+ // CHECK-LABEL: sink_convert_idx_1_negative
100+ // CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #{{.*}}, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>
101+ // CHECK: ttng.arrive_barrier
102+ // CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #{{.*}}, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
103+ // CHECK: tt.dot
104+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [1 , 4 ], order = [0 , 1 ]}>
105+ #mma = #ttg.nvidia_mma <{versionMajor = 2 , versionMinor = 0 , warpsPerCTA = [2 , 2 ], instrShape = [16 , 8 ]}>
106+ #shared = #ttg.swizzled_shared <{vec = 8 , perPhase = 1 , maxPhase = 4 , order = [0 , 1 ]}>
107+ #shared1 = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [0 ]}>
108+ #smem = #ttg.shared_memory
109+ module attributes {" ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 32 : i32 } {
110+ tt.func public @sink_convert_idx_1_negative (%arg0: tensor <32 x32 x!tt.ptr <f32 >, #blocked >) {
111+ %bar = ttg.local_alloc : () -> !ttg.memdesc <1 xi64 , #shared1 , #smem , mutable >
112+ %true = arith.constant true
113+ %cst = arith.constant dense <0.000000e+00 > : tensor <32 x32 xf32 , #mma >
114+ %B = tt.load %arg0 : tensor <32 x32 x!tt.ptr <f32 >, #blocked >
115+ %BS = ttg.local_alloc %B : (tensor <32 x32 xf32 , #blocked >) -> !ttg.memdesc <32 x32 xf32 , #shared , #smem >
116+ %BD = ttg.local_load %BS : !ttg.memdesc <32 x32 xf32 , #shared , #smem > -> tensor <32 x32 xf32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 1 }>>
117+ %cst_0 = arith.constant dense <1.230000e+02 > : tensor <32 x32 xf32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 1 }>>
118+ %A = tt.load %arg0 : tensor <32 x32 x!tt.ptr <f32 >, #blocked >
119+ %AS = ttg.local_alloc %A : (tensor <32 x32 xf32 , #blocked >) -> !ttg.memdesc <32 x32 xf32 , #shared , #smem >
120+ ttng.arrive_barrier %bar , 2 , %true : !ttg.memdesc <1 xi64 , #shared1 , #smem , mutable >
121+ %AD = ttg.local_load %AS : !ttg.memdesc <32 x32 xf32 , #shared , #smem > -> tensor <32 x32 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>
122+ %12 = tt.dot %AD , %BD , %cst : tensor <32 x32 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>> * tensor <32 x32 xf32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 1 }>> -> tensor <32 x32 xf32 , #mma >
123+ %13 = ttg.convert_layout %12 : tensor <32 x32 xf32 , #mma > -> tensor <32 x32 xf32 , #blocked >
124+ tt.store %arg0 , %13 : tensor <32 x32 x!tt.ptr <f32 >, #blocked >
125+ tt.return
126+ }
127+ }
128+
129+ // -----
130+
99131// check that we don't sink convert_layout if it has multi users
100132// CHECK-LABEL: convert_cannot_sink
101133// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>
0 commit comments