Skip to content

Commit 551f104

Browse files
dccipytorchmergebot
authored andcommitted
[mps/inductor] Add support for sign(). (pytorch#144298)
Drive-by fix of a test name while I was at it. Pull Request resolved: pytorch#144298 Approved by: https://github.com/malfet
1 parent a3ab27b commit 551f104

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

test/inductor/test_mps_basic.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,12 @@ def test_acos(self):
7474
def test_atanh(self):
7575
self.common(lambda x: x.atanh(), (torch.rand(1024),))
7676

77-
def floor(self):
77+
def test_floor(self):
7878
self.common(lambda x: x.floor(), (torch.rand(1024),))
7979

80+
def test_sign(self):
81+
self.common(lambda x: x.sign(), (torch.rand(1024),))
82+
8083
def test_sliced_input(self):
8184
self.common(
8285
lambda x: x[:, ::2].sin() + x[:, 1::2].cos(), (torch.rand(32, 1024),)

torch/_inductor/codegen/mps.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,10 @@ def floordiv(a: CSEVariable, b: CSEVariable) -> str:
194194
def floor(x: CSEVariable) -> str:
195195
return f"metal::floor({x})"
196196

197+
@staticmethod
198+
def sign(x: CSEVariable) -> str:
199+
return f"metal::sign({x})"
200+
197201

198202
class MetalKernel(SIMDKernel):
199203
overrides = MetalOverrides # type: ignore[assignment]

0 commit comments

Comments
 (0)