Skip to content

Commit c77852e

Browse files
authored
Move torch ops error message tests into a new file. (#9622)
In summary, this PR: - Moves tests that checked error message of PyTorch operations into `test_ops_error_message.py` - Introduces `expecttest` as a dependency in `requirements.in` file - Introduces `expecttest` to those tests, so as to avoid copy-and-pasting error messages The introduction of `expecttest` Python package was, in fact, implicit to our tests because of the following PyTorch testing library import: https://github.com/pytorch/xla/blob/f6ff30d3c2cd837e940aaa70b61faf948aa805f7/test/test_operations.py#L33
1 parent 8efa568 commit c77852e

9 files changed

+206
-164
lines changed

requirements.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
expecttest
12
filelock
23
fsspec
34
jinja2

requirements_lock_3_10.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
#
55
# bazel run //:requirements.update
66
#
7+
expecttest==0.3.0 \
8+
--hash=sha256:60f88103086e1754240b42175f622be83b6ffeac419434691ee5a5be819d0544 \
9+
--hash=sha256:6e8512fb86523ada1f94fd1b14e280f924e379064bb8a29ee399950e513eeccd
10+
# via -r requirements.in
711
filelock==3.14.0 \
812
--hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \
913
--hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a

requirements_lock_3_11.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
#
55
# bazel run //:requirements.update
66
#
7+
expecttest==0.3.0 \
8+
--hash=sha256:60f88103086e1754240b42175f622be83b6ffeac419434691ee5a5be819d0544 \
9+
--hash=sha256:6e8512fb86523ada1f94fd1b14e280f924e379064bb8a29ee399950e513eeccd
10+
# via -r requirements.in
711
filelock==3.14.0 \
812
--hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \
913
--hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a

requirements_lock_3_12.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
#
55
# bazel run //:requirements.update
66
#
7+
expecttest==0.3.0 \
8+
--hash=sha256:60f88103086e1754240b42175f622be83b6ffeac419434691ee5a5be819d0544 \
9+
--hash=sha256:6e8512fb86523ada1f94fd1b14e280f924e379064bb8a29ee399950e513eeccd
10+
# via -r requirements.in
711
filelock==3.18.0 \
812
--hash=sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2 \
913
--hash=sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de

requirements_lock_3_13.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
#
55
# bazel run //:requirements.update
66
#
7+
expecttest==0.3.0 \
8+
--hash=sha256:60f88103086e1754240b42175f622be83b6ffeac419434691ee5a5be819d0544 \
9+
--hash=sha256:6e8512fb86523ada1f94fd1b14e280f924e379064bb8a29ee399950e513eeccd
10+
# via -r requirements.in
711
filelock==3.18.0 \
812
--hash=sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2 \
913
--hash=sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de

requirements_lock_3_8.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
#
55
# bazel run //:requirements.update
66
#
7+
expecttest==0.3.0 \
8+
--hash=sha256:60f88103086e1754240b42175f622be83b6ffeac419434691ee5a5be819d0544 \
9+
--hash=sha256:6e8512fb86523ada1f94fd1b14e280f924e379064bb8a29ee399950e513eeccd
10+
# via -r requirements.in
711
filelock==3.14.0 \
812
--hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \
913
--hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a

requirements_lock_3_9.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
#
55
# bazel run //:requirements.update
66
#
7+
expecttest==0.3.0 \
8+
--hash=sha256:60f88103086e1754240b42175f622be83b6ffeac419434691ee5a5be819d0544 \
9+
--hash=sha256:6e8512fb86523ada1f94fd1b14e280f924e379064bb8a29ee399950e513eeccd
10+
# via -r requirements.in
711
filelock==3.14.0 \
812
--hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \
913
--hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a

test/test_operations.py

Lines changed: 0 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,6 @@ def skipIfFunctionalizationDisabled(reason):
8888
return _skipIfFunctionalization(value=True, reason=reason)
8989

9090

91-
def onlyOnCPU(fn):
92-
accelerator = os.environ.get("PJRT_DEVICE").lower()
93-
return unittest.skipIf(accelerator != "cpu", "PJRT_DEVICE=CPU required")(fn)
94-
95-
9691
def onlyIfXLAExperimentalContains(feat):
9792
experimental = os.environ.get("XLA_EXPERIMENTAL", "").split(":")
9893
return unittest.skipIf(feat not in experimental,
@@ -2372,165 +2367,6 @@ def test_isneginf_no_fallback(self):
23722367
t = t.to(torch.float16)
23732368
self._test_no_fallback(torch.isneginf, (t,))
23742369

2375-
def test_add_broadcast_error(self):
2376-
a = torch.rand(2, 2, 4, 4, device="xla")
2377-
b = torch.rand(2, 2, device="xla")
2378-
2379-
expected_regex = (
2380-
r"Shapes are not compatible for broadcasting: f32\[2,2,4,4\] vs. f32\[2,2\]. "
2381-
r"Expected dimension 2 of shape f32\[2,2,4,4\] \(4\) to match dimension "
2382-
r"0 of shape f32\[2,2\] \(2\). .*")
2383-
2384-
with self.assertRaisesRegex(RuntimeError, expected_regex):
2385-
torch.add(a, b)
2386-
torch_xla.sync()
2387-
2388-
@onlyOnCPU
2389-
def test_construct_large_tensor_raises_error(self):
2390-
with self.assertRaisesRegex(RuntimeError,
2391-
r"Out of memory allocating \d+ bytes"):
2392-
# When eager-mode is enabled, OOM is triggered here.
2393-
a = torch.rand(1024, 1024, 1024, 1024, 1024, device=torch_xla.device())
2394-
b = a.sum()
2395-
# OOM is raised when we try to bring data from the device.
2396-
b.cpu()
2397-
2398-
def test_cat_raises_error_on_incompatible_shapes(self):
2399-
a = torch.rand(2, 2, device=torch_xla.device())
2400-
b = torch.rand(5, 1, device=torch_xla.device())
2401-
2402-
try:
2403-
torch.cat([a, b])
2404-
except RuntimeError as e:
2405-
expected_error = (
2406-
"cat(): cannot concatenate tensors of shape f32[2,2] with f32[5,1] "
2407-
"at dimension 0. Expected shapes to be equal (except at dimension 0) "
2408-
"or that either of them was a 1D empty tensor of size (0,).")
2409-
self.assertEqual(str(e), expected_error)
2410-
2411-
def test_div_raises_error_on_invalid_rounding_mode(self):
2412-
a = torch.rand(2, 2, device=torch_xla.device())
2413-
2414-
try:
2415-
torch.div(a, 2, rounding_mode="bad")
2416-
except RuntimeError as e:
2417-
expected_error = (
2418-
"div(): invalid rounding mode `bad`. Expected it to be either "
2419-
"'trunc', 'floor', or be left unspecified.")
2420-
self.assertEqual(str(e), expected_error)
2421-
2422-
def test_flip_raises_error_on_duplicated_dims(self):
2423-
a = torch.rand(2, 2, 2, 2, device=torch_xla.device())
2424-
dims = [0, 0, 0, 1, 2, 3, -1]
2425-
dims_suggestion = [0, 1, 2, 3]
2426-
2427-
try:
2428-
torch.flip(a, dims=dims)
2429-
except RuntimeError as e:
2430-
expected_error = (
2431-
"flip(): expected each dimension to appear at most once. Found "
2432-
"dimensions: 0 (3 times), 3 (2 times). Consider changing dims "
2433-
f"from {dims} to {dims_suggestion}.")
2434-
self.assertEqual(str(e), expected_error)
2435-
2436-
def test_full_raises_error_on_negative_size(self):
2437-
shape = [2, -2, 2]
2438-
try:
2439-
torch.full(shape, 1.5, device="xla")
2440-
except RuntimeError as e:
2441-
expected_error = (
2442-
"full(): expected concrete sizes (i.e. non-symbolic) to be "
2443-
f"positive values. However found negative ones: {shape}.")
2444-
self.assertEqual(str(e), expected_error)
2445-
2446-
def test_gather_raises_error_on_rank_mismatch(self):
2447-
S = 2
2448-
2449-
input = torch.arange(4, device=torch_xla.device()).view(S, S)
2450-
index = torch.randint(0, S, (S, S, S), device=torch_xla.device())
2451-
dim = 1
2452-
2453-
try:
2454-
torch.gather(input, dim, index)
2455-
except RuntimeError as e:
2456-
expected_error = (
2457-
"gather(): expected rank of input (2) and index (3) tensors "
2458-
"to be the same.")
2459-
self.assertEqual(str(e), expected_error)
2460-
2461-
def test_gather_raises_error_on_invalid_index_size(self):
2462-
S = 2
2463-
X = S + 2
2464-
2465-
input = torch.arange(16, device=torch_xla.device()).view(S, S, S, S)
2466-
index = torch.randint(0, S, (X, S, X, S), device=torch_xla.device())
2467-
dim = 1
2468-
2469-
try:
2470-
torch.gather(input, dim, index)
2471-
except RuntimeError as e:
2472-
expected_error = (
2473-
f"gather(): expected sizes of index [{X}, {S}, {X}, {S}] to be "
2474-
f"smaller or equal those of input [{S}, {S}, {S}, {S}] on all "
2475-
f"dimensions, except on dimension {dim}. "
2476-
"However, that's not true on dimensions [0, 2].")
2477-
self.assertEqual(str(e), expected_error)
2478-
2479-
def test_random__raises_error_on_empty_interval(self):
2480-
a = torch.empty(10, device=torch_xla.device())
2481-
from_ = 3
2482-
to_ = 1
2483-
2484-
try:
2485-
a.random_(from_, to_)
2486-
except RuntimeError as e:
2487-
expected_error = (
2488-
f"random_(): expected `from` ({from_}) to be smaller than "
2489-
f"`to` ({to_}).")
2490-
self.assertEqual(str(e), expected_error)
2491-
2492-
def test_random__raises_error_on_value_out_of_type_value_range(self):
2493-
a = torch.empty(10, device=torch_xla.device(), dtype=torch.float16)
2494-
from_ = 3
2495-
to_ = 65504 + 1
2496-
2497-
try:
2498-
a.random_(from_, to_)
2499-
except RuntimeError as e:
2500-
expected_error = (
2501-
f"random_(): expected `to` to be within the range "
2502-
f"[-65504, 65504]. However got value {to_}, which is greater "
2503-
"than the upper bound.")
2504-
self.assertEqual(str(e), expected_error)
2505-
2506-
def test_mm_raises_error_on_non_matrix_input(self):
2507-
device = torch_xla.device()
2508-
a = torch.rand(2, 2, 2, device=device)
2509-
b = torch.rand(2, 2, device=device)
2510-
2511-
try:
2512-
torch.mm(a, b)
2513-
except RuntimeError as e:
2514-
expected_error = (
2515-
"mm(): expected the first input tensor f32[2,2,2] to be a "
2516-
"matrix (i.e. a 2D tensor).")
2517-
self.assertEqual(str(e), expected_error)
2518-
2519-
def test_mm_raises_error_on_incompatible_shapes(self):
2520-
device = torch_xla.device()
2521-
a = torch.rand(2, 5, device=device)
2522-
b = torch.rand(8, 2, device=device)
2523-
2524-
try:
2525-
torch.mm(a, b)
2526-
except RuntimeError as e:
2527-
expected_error = (
2528-
"mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. "
2529-
"Expected the size of dimension 1 of the first input tensor (5) "
2530-
"to be equal the size of dimension 0 of the second input "
2531-
"tensor (8).")
2532-
self.assertEqual(str(e), expected_error)
2533-
25342370

25352371
class MNISTComparator(nn.Module):
25362372

0 commit comments

Comments
 (0)