diff --git a/test/quantized_ops/test_dot_general.py b/test/quantized_ops/test_dot_general.py index 68418945a06..fa4b57d3640 100644 --- a/test/quantized_ops/test_dot_general.py +++ b/test/quantized_ops/test_dot_general.py @@ -1,17 +1,11 @@ import os import re - import torch import torch_xla import unittest - device = torch_xla.device() - torch.manual_seed(12345) - - class DotGeneralTest(unittest.TestCase): - def test_dot_general(self): x = torch.rand(10, 3, 4).to(torch.bfloat16) w = torch.rand(10, 4, 5).to(torch.bfloat16) @@ -21,7 +15,6 @@ def test_dot_general(self): xla_out = torch_xla._XLAC._xla_dot_general(xla_x, xla_w, (([2], [1]), ([0], [0]))) self.assertTrue(torch.allclose(xla_out.cpu(), expected_out)) - def test_dot_general_negative_dim(self): x = torch.rand(10, 3, 4).to(torch.bfloat16) w = torch.rand(10, 4, 5).to(torch.bfloat16) @@ -31,7 +24,6 @@ def test_dot_general_negative_dim(self): xla_out = torch_xla._XLAC._xla_dot_general(xla_x, xla_w, (([-1], [-2]), ([0], [0]))) self.assertTrue(torch.allclose(xla_out.cpu(), expected_out)) - def test_dot_general_linear(self): x = torch.rand(10, 3, 4).to(torch.bfloat16) w = torch.rand(5, 4).to(torch.bfloat16) @@ -40,7 +32,6 @@ def test_dot_general_linear(self): xla_w = w.to(device) xla_out = torch_xla._XLAC._xla_dot_general(xla_x, xla_w, (([2], [1]), ())) self.assertTrue(torch.allclose(xla_out.cpu(), expected_out)) - def test_dot_general_int32_dtype(self): x = torch.randint(-15, 15, (10, 3, 4)).to(torch.int32) w = torch.randint(-15, 15, (10, 4, 5)).to(torch.int32) @@ -55,11 +46,9 @@ def test_dot_general_int32_dtype(self): xla_w, (([2], [1]), ([0], [0])), preferred_element_type=torch.int32) self.assertTrue(torch.allclose(xla_out.cpu(), expected_out)) - def test_raises_error_on_non_xla_tensor(self): lhs = torch.rand(10, 3, 4, dtype=torch.bfloat16) rhs = torch.rand(10, 4, 5, dtype=torch.bfloat16) - def test(args, non_xla_tensor_arg): arg_number_to_str = ["first", "second"] position = arg_number_to_str[non_xla_tensor_arg] @@ -70,12 +59,10 @@ def test(args, non_xla_tensor_arg): f"Expected input tensor ({position} argument) to be an actual XLA tensor. " f"Got: CPUBFloat16Type. Consider moving it ({position} argument) to XLA." ) - self.assertEqual(str(err), error_message) - + msg = str(err) + self.assertIn(error_message, msg) test((lhs, rhs.to(device)), non_xla_tensor_arg=0) test((lhs.to(device), rhs), non_xla_tensor_arg=1) - - if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1)