Skip to content

Commit 5cd2b34

Browse files
dccimalfet
authored andcommitted
[inductor] Adjust test_log_fp64 to only run when float64 is supported. (pytorch#145686)
Pull Request resolved: pytorch#145686 Approved by: https://github.com/malfet, https://github.com/jansel Co-authored-by: Nikita Shulga <[email protected]>
1 parent ed01514 commit 5cd2b34

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

test/inductor/test_mps_basic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class MPSBasicTests(TestCase):
5858
test_inf = CommonTemplate.test_inf
5959
test_isinf = CommonTemplate.test_isinf
6060
test_isinf2 = CommonTemplate.test_isinf2
61+
test_log_fp64 = CommonTemplate.test_log_fp64
6162
test_low_memory_max_pool = CommonTemplate.test_low_memory_max_pool
6263
test_max_min = CommonTemplate.test_max_min
6364
test_max_pool2d2 = CommonTemplate.test_max_pool2d2

test/inductor/test_torchinductor.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6471,10 +6471,17 @@ def test_log_fp64(self):
64716471
def fn(x):
64726472
return torch.log(x), torch.log2(x)
64736473

6474-
self.common(
6475-
fn,
6476-
(torch.randn([1024], dtype=torch.float64) + 10,),
6474+
_dtype = torch.float64
6475+
ctx = (
6476+
contextlib.nullcontext()
6477+
if self.is_dtype_supported(_dtype)
6478+
else self.assertRaises(TypeError)
64776479
)
6480+
with ctx:
6481+
self.common(
6482+
fn,
6483+
(torch.randn([1024], dtype=_dtype) + 10,),
6484+
)
64786485

64796486
def test_bitwise(self):
64806487
def fn(x, y):

0 commit comments

Comments
 (0)