Skip to content

Commit 1353f3b

Browse files
dccipytorchmergebot
authored andcommitted
[mps/inductor] Add support for fmod(). (pytorch#144449)
397 -> 395 tests failing. `static_cast<>` is because there are several overloads of `fmod()` that's otherwise ambiguous. I wonder if we should take in account NaN propagation (maybe it's not tested). Pull Request resolved: pytorch#144449 Approved by: https://github.com/malfet
1 parent 9631d1a commit 1353f3b

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

test/inductor/test_mps_basic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class MPSBasicTests(TestCase):
4343
test_addmm = CommonTemplate.test_addmm
4444
test_cat_empty = CommonTemplate.test_cat_empty
4545
test_floordiv = CommonTemplate.test_floordiv
46+
test_fmod = CommonTemplate.test_fmod
47+
test_fmod_zero_dim = CommonTemplate.test_fmod_zero_dim
4648
test_inf = CommonTemplate.test_inf
4749
test_isinf = CommonTemplate.test_isinf
4850
test_isinf2 = CommonTemplate.test_isinf2

torch/_inductor/codegen/mps.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,12 @@ def floor(x: CSEVariable) -> str:
206206
def sign(x: CSEVariable) -> str:
207207
return f"metal::sign({x})"
208208

209+
@staticmethod
210+
def fmod(a: CSEVariable, b: CSEVariable) -> str:
211+
typecast_a = f"static_cast<decltype({a}+{b})>({a})"
212+
typecast_b = f"static_cast<decltype({a}+{b})>({b})"
213+
return f"metal::fmod({typecast_a}, {typecast_b})"
214+
209215

210216
class MetalKernel(SIMDKernel):
211217
overrides = MetalOverrides # type: ignore[assignment]

0 commit comments

Comments
 (0)