Skip to content

Commit 343052b

Browse files
authored
[language][standard] added tl.squeeze / tl.unsqueeze (#8924)
1 parent 2b9c24e commit 343052b

File tree

3 files changed

+55
-0
lines changed

3 files changed

+55
-0
lines changed

python/test/unit/language/test_standard.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,43 @@ def swizzle2d_kernel(output, size_i, size_j, size_g):
143143
expected_order = torch.tensor([[0, 3, 6, 9, 12, 15, 18], [1, 4, 7, 10, 13, 16, 19], [2, 5, 8, 11, 14, 17, 20],
144144
[21, 23, 25, 27, 29, 31, 33], [22, 24, 26, 28, 30, 32, 34]]).to(device)
145145
assert (output == expected_order).all(), (output, expected_order)
146+
147+
148+
@pytest.mark.interpreter
149+
@pytest.mark.parametrize("shape, dim", [((1, 2, 4), 0), ((2, 1, 4), 1), ((2, 4, 1), 2)])
150+
def test_squeeze(shape, dim, device):
151+
152+
@triton.jit
153+
def triton_squeeze(out_ptr, dim: tl.constexpr, s0: tl.constexpr, s1: tl.constexpr, s2: tl.constexpr):
154+
a = tl.arange(0, 8)
155+
a = tl.reshape(a, (s0, s1, s2))
156+
a = tl.squeeze(a, dim)
157+
a = tl.ravel(a)
158+
tl.store(out_ptr + tl.arange(0, 8), a)
159+
160+
out = torch.empty((8, ), device=device, dtype=torch.int32)
161+
triton_squeeze[(1, )](out, dim, shape[0], shape[1], shape[2])
162+
163+
expected = torch.arange(0, 8, device=device, dtype=torch.int32)
164+
expected = expected.reshape(shape).squeeze(dim).reshape(-1)
165+
assert (out == expected).all()
166+
167+
168+
@pytest.mark.interpreter
169+
@pytest.mark.parametrize("dim", [0, 1, 2])
170+
def test_unsqueeze(dim, device):
171+
172+
@triton.jit
173+
def triton_unsqueeze(out_ptr, dim: tl.constexpr):
174+
a = tl.arange(0, 8)
175+
a = tl.reshape(a, (2, 4))
176+
a = tl.unsqueeze(a, dim)
177+
a = tl.ravel(a)
178+
tl.store(out_ptr + tl.arange(0, 8), a)
179+
180+
out = torch.empty((8, ), device=device, dtype=torch.int32)
181+
triton_unsqueeze[(1, )](out, dim)
182+
183+
expected = torch.arange(0, 8, device=device, dtype=torch.int32)
184+
expected = expected.reshape(2, 4).unsqueeze(dim).reshape(-1)
185+
assert (out == expected).all()

python/triton/language/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
sigmoid,
2020
softmax,
2121
sort,
22+
squeeze,
2223
sum,
2324
swizzle2d,
2425
topk,
26+
unsqueeze,
2527
xor_sum,
2628
zeros,
2729
zeros_like,
@@ -253,6 +255,7 @@
253255
"split",
254256
"sqrt",
255257
"sqrt_rn",
258+
"squeeze",
256259
"static_assert",
257260
"static_print",
258261
"static_range",
@@ -272,6 +275,7 @@
272275
"uint8",
273276
"uint_to_uniform_float",
274277
"umulhi",
278+
"unsqueeze",
275279
"view",
276280
"void",
277281
"where",

python/triton/language/standard.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,3 +534,14 @@ def interleave(a, b):
534534
# understand that if we take the `if` above we definitely don't run this
535535
# `else`.
536536
return core.reshape(c, c.shape[:-2] + [2 * c.shape[-2]])
537+
538+
539+
@jit
540+
def squeeze(x, dim: core.constexpr):
541+
core.static_assert(x.shape[dim] == 1)
542+
return x.reshape(x.shape[:dim] + x.shape[dim + 1:])
543+
544+
545+
@jit
546+
def unsqueeze(x, dim: core.constexpr):
547+
return x.reshape(x.shape[:dim] + (1, ) + x.shape[dim:])

0 commit comments

Comments
 (0)