44import torch
55from torch import nn
66from torch .utils .data import TensorDataset , DataLoader
7- import torch_xla as xla
7+ import torch_xla
88import torch_xla .core .xla_model as xm
99import torch_xla .runtime as xr
1010import torch_xla .debug .metrics as met
@@ -24,35 +24,35 @@ def setUp(self):
2424 (0 , torch .device ('xla:0' )),
2525 (3 , torch .device ('xla:3' )))
2626 def test_device (self , index , expected ):
27- device = xla .device (index )
27+ device = torch_xla .device (index )
2828 self .assertEqual (device , expected )
2929
3030 def test_devices (self ):
31- self .assertEqual (xla .devices (),
31+ self .assertEqual (torch_xla .devices (),
3232 [torch .device (f'xla:{ i } ' ) for i in range (4 )])
3333
3434 def test_real_devices (self ):
35- self .assertEqual (xla .real_devices (), [f'CPU:{ i } ' for i in range (4 )])
35+ self .assertEqual (torch_xla .real_devices (), [f'CPU:{ i } ' for i in range (4 )])
3636
3737 def test_device_count (self ):
38- self .assertEqual (xla .device_count (), 4 )
38+ self .assertEqual (torch_xla .device_count (), 4 )
3939
4040 def test_sync (self ):
41- torch .ones ((3 , 3 ), device = xla .device ())
42- xla .sync ()
41+ torch .ones ((3 , 3 ), device = torch_xla .device ())
42+ torch_xla .sync ()
4343
4444 self .assertEqual (met .counter_value ('MarkStep' ), 1 )
4545
4646 def test_step (self ):
47- with xla .step ():
48- torch .ones ((3 , 3 ), device = xla .device ())
47+ with torch_xla .step ():
48+ torch .ones ((3 , 3 ), device = torch_xla .device ())
4949
5050 self .assertEqual (met .counter_value ('MarkStep' ), 2 )
5151
5252 def test_step_exception (self ):
5353 with self .assertRaisesRegex (RuntimeError , 'Expected error' ):
54- with xla .step ():
55- torch .ones ((3 , 3 ), device = xla .device ())
54+ with torch_xla .step ():
55+ torch .ones ((3 , 3 ), device = torch_xla .device ())
5656 raise RuntimeError ('Expected error' )
5757
5858 self .assertEqual (met .counter_value ('MarkStep' ), 2 )
@@ -69,7 +69,7 @@ def __init__(self):
6969 def forward (self , x ):
7070 return self .linear (x )
7171
72- model = TrivialModel ().to (xla .device ())
72+ model = TrivialModel ().to (torch_xla .device ())
7373
7474 batch_size = 16
7575 num_samples = 100
@@ -85,8 +85,9 @@ def forward(self, x):
8585 optimizer = torch .optim .SGD (model .parameters (), lr = 0.01 )
8686
8787 for inputs , labels in loader :
88- with xla .step ():
89- inputs , labels = inputs .to (xla .device ()), labels .to (xla .device ())
88+ with torch_xla .step ():
89+ inputs , labels = inputs .to (torch_xla .device ()), labels .to (
90+ torch_xla .device ())
9091 optimizer .zero_grad ()
9192 outputs = model (inputs )
9293 loss = loss_fn (outputs , labels )
0 commit comments