Skip to content

Commit 07c440e

Browse files
[Tests] Add tests to check type inference
1 parent db1c304 commit 07c440e

File tree

3 files changed

+136
-58
lines changed

3 files changed

+136
-58
lines changed

python/tests/test_cast.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

python/tests/test_type.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import triton
2+
import triton.language as tl
3+
4+
5+
# TODO: function with no arguments don't work
6+
@triton.jit
7+
def binop_type_check(X):
8+
# 0d-tensor is not allowed.
9+
# zero_0d = tl.zeros([], dtype=tl.float32)
10+
zero_1d = tl.zeros([2], dtype=tl.float32)
11+
zero_2d_21 = tl.zeros([2, 1], dtype=tl.float32)
12+
zero_2d_22 = tl.zeros([2, 2], dtype=tl.float32)
13+
14+
# scalar + scalar -> scalar
15+
a0 = 0.0 + 0.0
16+
# # scalar + 0D -> 0D
17+
# a1 = 0.0 + zero_0d
18+
# a2 = zero_0d + 0.0
19+
# scalar + 1D -> 1D
20+
a3 = 0.0 + zero_1d
21+
a4 = zero_1d + 0.0
22+
# scalar + 2D -> 2D
23+
a5 = 0.0 + zero_2d_22
24+
a6 = zero_2d_22 + 0.0
25+
26+
# # 0D + 0D -> 0D
27+
# b1 = zero_0d + zero_0d
28+
# # 0D + 1D -> 1D
29+
# b2 = zero_0d + zero_1d
30+
# b3 = zero_1d + zero_0d
31+
# # 0D + 2D -> 2D
32+
# b4 = zero_0d + zero_2d_22
33+
# b5 = zero_2d_22 + zero_0d
34+
35+
# 1D + 1D -> 1D
36+
c1 = zero_1d + zero_1d
37+
# 1D + 2D -> 2D
38+
c2 = zero_1d + zero_2d_21
39+
c3 = zero_1d + zero_2d_22
40+
c4 = zero_2d_21 + zero_1d
41+
c5 = zero_2d_22 + zero_1d
42+
43+
# 2D + 2D -> 2D
44+
d1 = zero_2d_21 + zero_2d_21
45+
d2 = zero_2d_22 + zero_2d_22
46+
d3 = zero_2d_21 + zero_2d_22
47+
d4 = zero_2d_22 + zero_2d_21
48+
49+
# return a0, a1, a2, a3, a4, a5, a6, b1, b2, b3, b4, b5, c1, c2, c3, c4, c5, d1, d2, d3, d4
50+
return a0, a3, a4, a5, a6, c1, c2, c3, c4, c5, d1, d2, d3, d4
51+
52+
53+
def test_binop_type_check():
54+
kernel = triton.compiler._compile(binop_type_check,
55+
signature="*fp32",
56+
device=0,
57+
output="ttgir")
58+
assert (kernel)
59+
# TODO: Check types of the results
60+
61+
62+
@triton.jit
63+
def reduce_type_check(ptr):
64+
v_32 = tl.load(ptr + tl.arange(0, 32))
65+
v_scalar = tl.min(v_32, axis=0)
66+
tl.store(ptr, v_scalar)
67+
v_64x128 = tl.load(ptr + tl.arange(0, 64)[:, None] + tl.arange(0, 128)[None, :])
68+
v_64 = tl.max(v_64x128, axis=1)
69+
tl.store(ptr + tl.arange(0, 64), v_64)
70+
v_128 = tl.max(v_64x128, axis=0)
71+
tl.store(ptr + tl.arange(0, 128), v_128)
72+
73+
74+
def test_reduce_type_check():
75+
kernel = triton.compiler._compile(reduce_type_check,
76+
signature="*fp32",
77+
device=0,
78+
output="ttgir")
79+
assert (kernel)
80+
# TODO: Check types of the results

test/Conversion/triton_ops.mlir

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
5555
}
5656

5757
func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %mask : i1) {
58-
// Test if Load/Store ops can handle scalar values (see #XXX)
58+
// Test if Load/Store ops can handle scalar values
5959
%other = arith.constant 0.0e+0 : f32
6060

6161
// load scalar
@@ -75,3 +75,58 @@ func @load_store_ops_scalar(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %ma
7575
tt.store %ptr, %c, %mask : f32
7676
return
7777
}
78+
79+
func @reduce_ops_infer(%ptr: !tt.ptr<f32>, %v : tensor<1x2x4xf32>) {
80+
// Test if reduce ops infer types correctly
81+
82+
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<2x4xf32>
83+
%a = tt.reduce %v {redOp = 1 : i32, axis = 0 : i32} : tensor<1x2x4xf32> -> tensor<2x4xf32>
84+
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x4xf32>
85+
%b = tt.reduce %v {redOp = 1 : i32, axis = 1 : i32} : tensor<1x2x4xf32> -> tensor<1x4xf32>
86+
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1x2xf32>
87+
%c = tt.reduce %v {redOp = 1 : i32, axis = 2 : i32} : tensor<1x2x4xf32> -> tensor<1x2xf32>
88+
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<1xf32>
89+
%e = tt.reduce %b {redOp = 1 : i32, axis = 1 : i32} : tensor<1x4xf32> -> tensor<1xf32>
90+
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<4xf32>
91+
%f = tt.reduce %a {redOp = 1 : i32, axis = 0 : i32} : tensor<2x4xf32> -> tensor<4xf32>
92+
// CHECK: %{{.*}} = tt.reduce %{{.*}} -> f32
93+
%g = tt.reduce %f {redOp = 1 : i32, axis = 0 : i32} : tensor<4xf32> -> f32
94+
95+
// Avoid optimizations for c, e, and g
96+
%ptr1x2 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<1x2x!tt.ptr<f32>>
97+
%ptr1 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<1x!tt.ptr<f32>>
98+
tt.store %ptr1x2, %c : tensor<1x2xf32>
99+
tt.store %ptr1, %e : tensor<1xf32>
100+
tt.store %ptr, %g : f32
101+
return
102+
}
103+
104+
func @dot_ops_infer(%ptr: !tt.ptr<f32>, %v : f32) {
105+
// Test if reduce ops infer types correctly
106+
%v128x32 = tt.splat %v : (f32) -> tensor<128x32xf32>
107+
%v32x128 = tt.splat %v : (f32) -> tensor<32x128xf32>
108+
%v128x1 = tt.splat %v : (f32) -> tensor<128x1xf32>
109+
%v1x128 = tt.splat %v : (f32) -> tensor<1x128xf32>
110+
111+
%zero128x128 = arith.constant dense<0.00e+00> : tensor<128x128xf32>
112+
%zero32x32 = arith.constant dense<0.00e+00> : tensor<32x32xf32>
113+
%zero1x1 = arith.constant dense<0.00e+00> : tensor<1x1xf32>
114+
115+
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32>
116+
%r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32>
117+
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<32x32xf32>
118+
%r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32>
119+
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32>
120+
%r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32>
121+
// CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<1x1xf32>
122+
%r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32>
123+
124+
%ptr128x128 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x128x!tt.ptr<f32>>
125+
%ptr32x32 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<32x32x!tt.ptr<f32>>
126+
%ptr1x1 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<1x1x!tt.ptr<f32>>
127+
tt.store %ptr128x128, %r1 : tensor<128x128xf32>
128+
tt.store %ptr32x32, %r2 : tensor<32x32xf32>
129+
tt.store %ptr128x128, %r3 : tensor<128x128xf32>
130+
tt.store %ptr1x1, %r4 : tensor<1x1xf32>
131+
return
132+
}

0 commit comments

Comments
 (0)