Skip to content

Commit 2329746

Browse files
authored
Fix test_ops_error_message.py and run it on CI. (#9640)
This PR fixes #9622, which extracted error messages checks out of `test_operations.py` into `test_ops_error_message.py`. There were a few problems with that PR, namely: - `unittest.main()` wasn't being run: although calling `python -m pytest` runs it automatically, running it without the `pytest` module did nothing - `onlyOnCPU()` would error if `PJRT_DEVICE` environment variable wasn't set - `test_ops_error_message.py` wasn't being run on CI This PR fixes all the aforementioned PRs.
1 parent c77852e commit 2329746

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ function run_xla_op_tests1 {
150150
run_dynamic "$_TEST_DIR/ds/test_dynamic_shape_models.py" "$@" --verbosity=$VERBOSITY
151151
run_eager_debug "$_TEST_DIR/test_operations.py" "$@" --verbosity=$VERBOSITY
152152
run_test "$_TEST_DIR/test_operations.py" "$@" --verbosity=$VERBOSITY
153+
run_test "$_TEST_DIR/test_ops_error_message.py"
153154
run_test "$_TEST_DIR/test_xla_graph_execution.py" "$@" --verbosity=$VERBOSITY
154155
run_pt_xla_debug_level2 "$_TEST_DIR/test_xla_graph_execution.py" "$@" --verbosity=$VERBOSITY
155156
run_test_without_functionalization "$_TEST_DIR/test_operations.py" "$@" --verbosity=$VERBOSITY

test/test_ops_error_message.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
def onlyOnCPU(fn):
9-
accelerator = os.environ.get("PJRT_DEVICE").lower()
9+
accelerator = os.environ.get("PJRT_DEVICE", "").lower()
1010
return unittest.skipIf(accelerator != "cpu", "PJRT_DEVICE=CPU required")(fn)
1111

1212

@@ -158,7 +158,7 @@ def test_mm_raises_error_on_non_matrix_input(self):
158158
b = torch.rand(2, 2, device=device)
159159

160160
def test():
161-
torch.mm(a, b)
161+
return torch.mm(a, b)
162162

163163
self.assertExpectedRaisesInline(
164164
exc_type=RuntimeError,
@@ -172,10 +172,14 @@ def test_mm_raises_error_on_incompatible_shapes(self):
172172
b = torch.rand(8, 2, device=device)
173173

174174
def test():
175-
torch.mm(a, b)
175+
return torch.mm(a, b)
176176

177177
self.assertExpectedRaisesInline(
178178
exc_type=RuntimeError,
179179
callable=test,
180180
expect="""mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (5) to be equal the size of dimension 0 of the second input tensor (8)."""
181181
)
182+
183+
184+
if __name__ == "__main__":
185+
unittest.main()

0 commit comments

Comments
 (0)