@@ -88,11 +88,6 @@ def skipIfFunctionalizationDisabled(reason):
88
88
return _skipIfFunctionalization (value = True , reason = reason )
89
89
90
90
91
- def onlyOnCPU (fn ):
92
- accelerator = os .environ .get ("PJRT_DEVICE" ).lower ()
93
- return unittest .skipIf (accelerator != "cpu" , "PJRT_DEVICE=CPU required" )(fn )
94
-
95
-
96
91
def onlyIfXLAExperimentalContains (feat ):
97
92
experimental = os .environ .get ("XLA_EXPERIMENTAL" , "" ).split (":" )
98
93
return unittest .skipIf (feat not in experimental ,
@@ -2372,165 +2367,6 @@ def test_isneginf_no_fallback(self):
2372
2367
t = t .to (torch .float16 )
2373
2368
self ._test_no_fallback (torch .isneginf , (t ,))
2374
2369
2375
- def test_add_broadcast_error (self ):
2376
- a = torch .rand (2 , 2 , 4 , 4 , device = "xla" )
2377
- b = torch .rand (2 , 2 , device = "xla" )
2378
-
2379
- expected_regex = (
2380
- r"Shapes are not compatible for broadcasting: f32\[2,2,4,4\] vs. f32\[2,2\]. "
2381
- r"Expected dimension 2 of shape f32\[2,2,4,4\] \(4\) to match dimension "
2382
- r"0 of shape f32\[2,2\] \(2\). .*" )
2383
-
2384
- with self .assertRaisesRegex (RuntimeError , expected_regex ):
2385
- torch .add (a , b )
2386
- torch_xla .sync ()
2387
-
2388
- @onlyOnCPU
2389
- def test_construct_large_tensor_raises_error (self ):
2390
- with self .assertRaisesRegex (RuntimeError ,
2391
- r"Out of memory allocating \d+ bytes" ):
2392
- # When eager-mode is enabled, OOM is triggered here.
2393
- a = torch .rand (1024 , 1024 , 1024 , 1024 , 1024 , device = torch_xla .device ())
2394
- b = a .sum ()
2395
- # OOM is raised when we try to bring data from the device.
2396
- b .cpu ()
2397
-
2398
- def test_cat_raises_error_on_incompatible_shapes (self ):
2399
- a = torch .rand (2 , 2 , device = torch_xla .device ())
2400
- b = torch .rand (5 , 1 , device = torch_xla .device ())
2401
-
2402
- try :
2403
- torch .cat ([a , b ])
2404
- except RuntimeError as e :
2405
- expected_error = (
2406
- "cat(): cannot concatenate tensors of shape f32[2,2] with f32[5,1] "
2407
- "at dimension 0. Expected shapes to be equal (except at dimension 0) "
2408
- "or that either of them was a 1D empty tensor of size (0,)." )
2409
- self .assertEqual (str (e ), expected_error )
2410
-
2411
- def test_div_raises_error_on_invalid_rounding_mode (self ):
2412
- a = torch .rand (2 , 2 , device = torch_xla .device ())
2413
-
2414
- try :
2415
- torch .div (a , 2 , rounding_mode = "bad" )
2416
- except RuntimeError as e :
2417
- expected_error = (
2418
- "div(): invalid rounding mode `bad`. Expected it to be either "
2419
- "'trunc', 'floor', or be left unspecified." )
2420
- self .assertEqual (str (e ), expected_error )
2421
-
2422
- def test_flip_raises_error_on_duplicated_dims (self ):
2423
- a = torch .rand (2 , 2 , 2 , 2 , device = torch_xla .device ())
2424
- dims = [0 , 0 , 0 , 1 , 2 , 3 , - 1 ]
2425
- dims_suggestion = [0 , 1 , 2 , 3 ]
2426
-
2427
- try :
2428
- torch .flip (a , dims = dims )
2429
- except RuntimeError as e :
2430
- expected_error = (
2431
- "flip(): expected each dimension to appear at most once. Found "
2432
- "dimensions: 0 (3 times), 3 (2 times). Consider changing dims "
2433
- f"from { dims } to { dims_suggestion } ." )
2434
- self .assertEqual (str (e ), expected_error )
2435
-
2436
- def test_full_raises_error_on_negative_size (self ):
2437
- shape = [2 , - 2 , 2 ]
2438
- try :
2439
- torch .full (shape , 1.5 , device = "xla" )
2440
- except RuntimeError as e :
2441
- expected_error = (
2442
- "full(): expected concrete sizes (i.e. non-symbolic) to be "
2443
- f"positive values. However found negative ones: { shape } ." )
2444
- self .assertEqual (str (e ), expected_error )
2445
-
2446
- def test_gather_raises_error_on_rank_mismatch (self ):
2447
- S = 2
2448
-
2449
- input = torch .arange (4 , device = torch_xla .device ()).view (S , S )
2450
- index = torch .randint (0 , S , (S , S , S ), device = torch_xla .device ())
2451
- dim = 1
2452
-
2453
- try :
2454
- torch .gather (input , dim , index )
2455
- except RuntimeError as e :
2456
- expected_error = (
2457
- "gather(): expected rank of input (2) and index (3) tensors "
2458
- "to be the same." )
2459
- self .assertEqual (str (e ), expected_error )
2460
-
2461
- def test_gather_raises_error_on_invalid_index_size (self ):
2462
- S = 2
2463
- X = S + 2
2464
-
2465
- input = torch .arange (16 , device = torch_xla .device ()).view (S , S , S , S )
2466
- index = torch .randint (0 , S , (X , S , X , S ), device = torch_xla .device ())
2467
- dim = 1
2468
-
2469
- try :
2470
- torch .gather (input , dim , index )
2471
- except RuntimeError as e :
2472
- expected_error = (
2473
- f"gather(): expected sizes of index [{ X } , { S } , { X } , { S } ] to be "
2474
- f"smaller or equal those of input [{ S } , { S } , { S } , { S } ] on all "
2475
- f"dimensions, except on dimension { dim } . "
2476
- "However, that's not true on dimensions [0, 2]." )
2477
- self .assertEqual (str (e ), expected_error )
2478
-
2479
- def test_random__raises_error_on_empty_interval (self ):
2480
- a = torch .empty (10 , device = torch_xla .device ())
2481
- from_ = 3
2482
- to_ = 1
2483
-
2484
- try :
2485
- a .random_ (from_ , to_ )
2486
- except RuntimeError as e :
2487
- expected_error = (
2488
- f"random_(): expected `from` ({ from_ } ) to be smaller than "
2489
- f"`to` ({ to_ } )." )
2490
- self .assertEqual (str (e ), expected_error )
2491
-
2492
- def test_random__raises_error_on_value_out_of_type_value_range (self ):
2493
- a = torch .empty (10 , device = torch_xla .device (), dtype = torch .float16 )
2494
- from_ = 3
2495
- to_ = 65504 + 1
2496
-
2497
- try :
2498
- a .random_ (from_ , to_ )
2499
- except RuntimeError as e :
2500
- expected_error = (
2501
- f"random_(): expected `to` to be within the range "
2502
- f"[-65504, 65504]. However got value { to_ } , which is greater "
2503
- "than the upper bound." )
2504
- self .assertEqual (str (e ), expected_error )
2505
-
2506
- def test_mm_raises_error_on_non_matrix_input (self ):
2507
- device = torch_xla .device ()
2508
- a = torch .rand (2 , 2 , 2 , device = device )
2509
- b = torch .rand (2 , 2 , device = device )
2510
-
2511
- try :
2512
- torch .mm (a , b )
2513
- except RuntimeError as e :
2514
- expected_error = (
2515
- "mm(): expected the first input tensor f32[2,2,2] to be a "
2516
- "matrix (i.e. a 2D tensor)." )
2517
- self .assertEqual (str (e ), expected_error )
2518
-
2519
- def test_mm_raises_error_on_incompatible_shapes (self ):
2520
- device = torch_xla .device ()
2521
- a = torch .rand (2 , 5 , device = device )
2522
- b = torch .rand (8 , 2 , device = device )
2523
-
2524
- try :
2525
- torch .mm (a , b )
2526
- except RuntimeError as e :
2527
- expected_error = (
2528
- "mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. "
2529
- "Expected the size of dimension 1 of the first input tensor (5) "
2530
- "to be equal the size of dimension 0 of the second input "
2531
- "tensor (8)." )
2532
- self .assertEqual (str (e ), expected_error )
2533
-
2534
2370
2535
2371
class MNISTComparator (nn .Module ):
2536
2372
0 commit comments