@@ -55,7 +55,7 @@ func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
5555}
5656
5757func @load_store_ops_scalar (%ptr: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %mask : i1 ) {
58- // Test if Load/Store ops can handle scalar values (see #XXX)
58+ // Test if Load/Store ops can handle scalar values
5959 %other = arith.constant 0.0e+0 : f32
6060
6161 // load scalar
@@ -75,3 +75,58 @@ func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ma
7575 tt.store %ptr , %c , %mask : f32
7676 return
7777}
78+
79+ func @reduce_ops_infer (%ptr: !tt.ptr <f32 >, %v : tensor <1 x2 x4 xf32 >) {
80+ // Test if reduce ops infer types correctly
81+
82+ // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<2x4xf32>
83+ %a = tt.reduce %v {redOp = 1 : i32 , axis = 0 : i32 } : tensor <1 x2 x4 xf32 > -> tensor <2 x4 xf32 >
84+ // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x4xf32>
85+ %b = tt.reduce %v {redOp = 1 : i32 , axis = 1 : i32 } : tensor <1 x2 x4 xf32 > -> tensor <1 x4 xf32 >
86+ // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x2xf32>
87+ %c = tt.reduce %v {redOp = 1 : i32 , axis = 2 : i32 } : tensor <1 x2 x4 xf32 > -> tensor <1 x2 xf32 >
88+ // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1xf32>
89+ %e = tt.reduce %b {redOp = 1 : i32 , axis = 1 : i32 } : tensor <1 x4 xf32 > -> tensor <1 xf32 >
90+ // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<4xf32>
91+ %f = tt.reduce %a {redOp = 1 : i32 , axis = 0 : i32 } : tensor <2 x4 xf32 > -> tensor <4 xf32 >
92+ // CHECK: %{{.*}} = tt.reduce %{{.*}} -> f32
93+ %g = tt.reduce %f {redOp = 1 : i32 , axis = 0 : i32 } : tensor <4 xf32 > -> f32
94+
95+ // Avoid optimizations for c, e, and g
96+ %ptr1x2 = tt.splat %ptr : (!tt.ptr <f32 >) -> tensor <1 x2 x!tt.ptr <f32 >>
97+ %ptr1 = tt.splat %ptr : (!tt.ptr <f32 >) -> tensor <1 x!tt.ptr <f32 >>
98+ tt.store %ptr1x2 , %c : tensor <1 x2 xf32 >
99+ tt.store %ptr1 , %e : tensor <1 xf32 >
100+ tt.store %ptr , %g : f32
101+ return
102+ }
103+
104+ func @dot_ops_infer (%ptr: !tt.ptr <f32 >, %v : f32 ) {
105+ // Test if reduce ops infer types correctly
106+ %v128x32 = tt.splat %v : (f32 ) -> tensor <128 x32 xf32 >
107+ %v32x128 = tt.splat %v : (f32 ) -> tensor <32 x128 xf32 >
108+ %v128x1 = tt.splat %v : (f32 ) -> tensor <128 x1 xf32 >
109+ %v1x128 = tt.splat %v : (f32 ) -> tensor <1 x128 xf32 >
110+
111+ %zero128x128 = arith.constant dense <0.00e+00 > : tensor <128 x128 xf32 >
112+ %zero32x32 = arith.constant dense <0.00e+00 > : tensor <32 x32 xf32 >
113+ %zero1x1 = arith.constant dense <0.00e+00 > : tensor <1 x1 xf32 >
114+
115+ // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32>
116+ %r1 = tt.dot %v128x32 , %v32x128 , %zero128x128 {allowTF32 = true } : tensor <128 x32 xf32 > * tensor <32 x128 xf32 > -> tensor <128 x128 xf32 >
117+ // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<32x32xf32>
118+ %r2 = tt.dot %v32x128 , %v128x32 , %zero32x32 {allowTF32 = true } : tensor <32 x128 xf32 > * tensor <128 x32 xf32 > -> tensor <32 x32 xf32 >
119+ // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32>
120+ %r3 = tt.dot %v128x1 , %v1x128 , %zero128x128 {allowTF32 = true } : tensor <128 x1 xf32 > * tensor <1 x128 xf32 > -> tensor <128 x128 xf32 >
121+ // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<1x1xf32>
122+ %r4 = tt.dot %v1x128 , %v128x1 , %zero1x1 {allowTF32 = true } : tensor <1 x128 xf32 > * tensor <128 x1 xf32 > -> tensor <1 x1 xf32 >
123+
124+ %ptr128x128 = tt.splat %ptr : (!tt.ptr <f32 >) -> tensor <128 x128 x!tt.ptr <f32 >>
125+ %ptr32x32 = tt.splat %ptr : (!tt.ptr <f32 >) -> tensor <32 x32 x!tt.ptr <f32 >>
126+ %ptr1x1 = tt.splat %ptr : (!tt.ptr <f32 >) -> tensor <1 x1 x!tt.ptr <f32 >>
127+ tt.store %ptr128x128 , %r1 : tensor <128 x128 xf32 >
128+ tt.store %ptr32x32 , %r2 : tensor <32 x32 xf32 >
129+ tt.store %ptr128x128 , %r3 : tensor <128 x128 xf32 >
130+ tt.store %ptr1x1 , %r4 : tensor <1 x1 xf32 >
131+ return
132+ }
0 commit comments