@@ -148,17 +148,6 @@ def fn_simple(self, x, y):
148
148
b = torch .sin (y )
149
149
return a + b
150
150
151
- def _choose_proper_device (self , initialize_on_cuda ):
152
- if not initialize_on_cuda :
153
- return torch_xla .device ()
154
-
155
- assert initialize_on_cuda
156
- if xr .device_type () != "CUDA" or not torch .cuda .is_available ():
157
- self .skipTest (
158
- "Skip this test because it requires xr.device_type()=='CUDA' and torch.cuda.is_available()."
159
- )
160
- return "cuda:0"
161
-
162
151
@skipOnNeuron
163
152
def test_simple_model (self ):
164
153
device = torch_xla .device ()
@@ -193,71 +182,23 @@ def test_simple_model(self):
193
182
# Dynamo has to sync the input since they are intermedate IR(xla_xy and xla_y3)
194
183
self .assertEqual (met .counter_value ('DynamoSyncInputExecuteTime' ), 1 )
195
184
196
- # Tests that the dynamo bridge automatically moves tensors to XLA device,
197
- # then back to the original device.
198
- @unittest .skipIf (xr .device_type () != "CUDA" or not torch .cuda .is_available (),
199
- f"GPU tests should only run on GPU devices." )
200
- @parameterized .parameters (
201
- "0" ,
202
- "1" ,
203
- )
204
- def test_simple_model_automoves_tensors (self , zero_copy_enabled ):
205
- x = torch .tensor (100.0 , requires_grad = True , device = "cuda:0" )
206
- y = torch .tensor (200.0 , requires_grad = True , device = "cuda:0" )
207
- original_device = x .device
208
- eager_result = self .fn_simple (x , y )
209
-
210
- # Since all tests run in the same process, have to reset the metrics report.
211
- met .clear_all ()
212
- torch ._dynamo .reset ()
213
-
214
- fn_simple_dynamo = torch .compile (self .fn_simple , backend = "openxla" )
215
- res_xla_dynamo = fn_simple_dynamo (x , y )
216
- self .assertIn ('xla::add' , met .counter_names ())
217
- self .assertTrue (res_xla_dynamo .device == original_device )
218
- self .assertTrue (torch .allclose (eager_result , res_xla_dynamo ))
219
-
220
- # verify that tracing is skipped in following runs
221
- met .clear_counters ()
222
- res_xla_dynamo_reused = fn_simple_dynamo (x , y )
223
- self .assertNotIn ('xla::add' , met .counter_names ())
224
- self .assertTrue (res_xla_dynamo_reused .device == original_device )
225
- self .assertTrue (torch .allclose (eager_result , res_xla_dynamo_reused ))
226
-
227
- # verify that dynamo can handle different inputs
228
- res_xla_dynamo_different = fn_simple_dynamo (x + y , y * 3 )
229
- res_cpu_3 = self .fn_simple (x + y , y * 3 )
230
- self .assertTrue (res_xla_dynamo_different .device == original_device )
231
- self .assertTrue (torch .allclose (res_cpu_3 , res_xla_dynamo_different ))
232
-
233
- # There should not be any fallbacks.
234
- self .assertEqual (torch_xla ._XLAC ._get_executed_fallback_ops (), [])
235
-
236
- @parameterized .parameters (
237
- True ,
238
- False ,
239
- )
240
- def test_fn_without_input (self , initialize_on_cuda ):
185
+ def test_fn_without_input (self ):
241
186
242
187
def fn_without_input (device ):
243
188
constant = 0.835
244
189
expanded = torch .full ((4 , 4 ), constant , device = device )
245
190
arange = torch .arange (16 , device = device ).reshape (4 , 4 )
246
191
return expanded + arange
247
192
248
- device = self . _choose_proper_device ( initialize_on_cuda )
193
+ device = torch_xla . device ( )
249
194
250
195
compiled_fn = torch .compile (fn_without_input , backend = 'openxla' )
251
196
res_cpu = fn_without_input ('cpu' )
252
197
res_xla_dynamo = compiled_fn (device )
253
198
self .assertTrue (torch .allclose (res_cpu , res_xla_dynamo .cpu ()))
254
199
255
- @parameterized .parameters (
256
- (True , 'openxla' ),
257
- (False , dynamo_backend2 .dynamo_backend ),
258
- (False , 'openxla' ),
259
- )
260
- def test_simple_model_with_in_place_ops (self , initialize_on_cuda , backend ):
200
+ @parameterized .parameters ('openxla' , dynamo_backend2 .dynamo_backend )
201
+ def test_simple_model_with_in_place_ops (self , backend ):
261
202
262
203
class TestModel (nn .Module ):
263
204
@@ -279,7 +220,7 @@ def forward(self, index, copy_tensor, input_tensor, op_name):
279
220
output = input_tensor + self .self_tensor
280
221
return output
281
222
282
- device = self . _choose_proper_device ( initialize_on_cuda )
223
+ device = torch_xla . device ( )
283
224
284
225
torch ._dynamo .reset ()
285
226
met .clear_all ()
@@ -306,18 +247,14 @@ def forward(self, index, copy_tensor, input_tensor, op_name):
306
247
op_name = in_place_op )
307
248
self .assertTrue (torch .allclose (res_cpu , res_device_dynamo .cpu ()))
308
249
309
- @parameterized .parameters (
310
- (True , 'openxla' ),
311
- (False , dynamo_backend2 .dynamo_backend ),
312
- (False , 'openxla' ),
313
- )
314
- def test_einsum (self , initialize_on_cuda , backend ):
250
+ @parameterized .parameters ('openxla' , dynamo_backend2 .dynamo_backend )
251
+ def test_einsum (self , backend ):
315
252
# einsum currently does not have meta function to compute the shape hence
316
253
# will fallback to XLA with FakeTensor as input to infer the output shape.
317
254
def einsum_mm (a , b ):
318
255
return torch .einsum ('ijkl,ijlm->ijkm' , a , b )
319
256
320
- device = self . _choose_proper_device ( initialize_on_cuda )
257
+ device = torch_xla . device ( )
321
258
a = torch .randn (4 , 4 , 4 , 4 ).to (device )
322
259
b = torch .randn (4 , 4 , 4 , 4 ).to (device )
323
260
torch_xla .sync ()
@@ -328,16 +265,10 @@ def einsum_mm(a, b):
328
265
self .assertTrue (
329
266
torch .allclose (res_device_non_dynamo .cpu (), res_device_dynamo .cpu ()))
330
267
331
- @parameterized .parameters (
332
- True ,
333
- False ,
334
- )
335
- def test_simple_model_with_different_input_shape (self , initialize_on_cuda ):
268
+ def test_simple_model_with_different_input_shape (self ):
336
269
met .clear_all ()
337
- device = self ._choose_proper_device (initialize_on_cuda )
338
- # We need to make `dim` depend on `initialize_on_cuda` because the XLA compilation cache
339
- # does not clean itself between the parameterized tests.
340
- dim = 5 + int (initialize_on_cuda )
270
+ device = torch_xla .device ()
271
+ dim = 5
341
272
device_x = torch .randn (dim , dim ).to (device )
342
273
device_y = torch .randn (dim , dim ).to (device )
343
274
new_dim = 2 * dim
@@ -369,13 +300,9 @@ def get_loader(self, device, sample_count, batch_size=4):
369
300
370
301
@skipOnTpu
371
302
@skipOnNeuron
372
- @parameterized .parameters (
373
- (True , 'openxla' ),
374
- (False , dynamo_backend2 .dynamo_backend ),
375
- (False , 'openxla' ),
376
- )
377
- def test_resnet18 (self , initialize_on_cuda , backend ):
378
- device = self ._choose_proper_device (initialize_on_cuda )
303
+ @parameterized .parameters ('openxla' , dynamo_backend2 .dynamo_backend )
304
+ def test_resnet18 (self , backend ):
305
+ device = torch_xla .device ()
379
306
sample_count = xu .getenv_as ('SAMPLE_COUNT' , int , defval = 10 )
380
307
loader = self .get_loader (device , sample_count , batch_size = 4 )
381
308
resnet18 = torchvision .models .resnet18 ()
0 commit comments