@@ -143,3 +143,43 @@ def swizzle2d_kernel(output, size_i, size_j, size_g):
143143 expected_order = torch .tensor ([[0 , 3 , 6 , 9 , 12 , 15 , 18 ], [1 , 4 , 7 , 10 , 13 , 16 , 19 ], [2 , 5 , 8 , 11 , 14 , 17 , 20 ],
144144 [21 , 23 , 25 , 27 , 29 , 31 , 33 ], [22 , 24 , 26 , 28 , 30 , 32 , 34 ]]).to (device )
145145 assert (output == expected_order ).all (), (output , expected_order )
146+
147+
148+ @pytest .mark .interpreter
149+ @pytest .mark .parametrize ("shape, dim" , [((1 , 2 , 4 ), 0 ), ((2 , 1 , 4 ), 1 ), ((2 , 4 , 1 ), 2 )])
150+ def test_squeeze (shape , dim , device ):
151+
152+ @triton .jit
153+ def triton_squeeze (out_ptr , dim : tl .constexpr , s0 : tl .constexpr , s1 : tl .constexpr , s2 : tl .constexpr ):
154+ a = tl .arange (0 , 8 )
155+ a = tl .reshape (a , (s0 , s1 , s2 ))
156+ a = tl .squeeze (a , dim )
157+ a = tl .ravel (a )
158+ tl .store (out_ptr + tl .arange (0 , 8 ), a )
159+
160+ out = torch .empty ((8 , ), device = device , dtype = torch .int32 )
161+ triton_squeeze [(1 , )](out , dim , shape [0 ], shape [1 ], shape [2 ])
162+
163+ expected = torch .arange (0 , 8 , device = device , dtype = torch .int32 )
164+ expected = expected .reshape (shape ).squeeze (dim ).reshape (- 1 )
165+ assert (out == expected ).all ()
166+
167+
168+ @pytest .mark .interpreter
169+ @pytest .mark .parametrize ("dim" , [0 , 1 , 2 ])
170+ def test_unsqueeze (dim , device ):
171+
172+ @triton .jit
173+ def triton_unsqueeze (out_ptr , dim : tl .constexpr ):
174+ a = tl .arange (0 , 8 )
175+ a = tl .reshape (a , (2 , 4 ))
176+ a = tl .unsqueeze (a , dim )
177+ a = tl .ravel (a )
178+ tl .store (out_ptr + tl .arange (0 , 8 ), a )
179+
180+ out = torch .empty ((8 , ), device = device , dtype = torch .int32 )
181+ triton_unsqueeze [(1 , )](out , dim )
182+
183+ expected = torch .arange (0 , 8 , device = device , dtype = torch .int32 )
184+ expected = expected .reshape (2 , 4 ).unsqueeze (dim ).reshape (- 1 )
185+ assert (out == expected ).all ()
0 commit comments