Skip to content

Commit ae4f06b

Browse files
authored
Migrate uses of import torch_xla as xla to import torch_xla (#9325)
1 parent f921e5e commit ae4f06b

File tree

5 files changed

+28
-27
lines changed

5 files changed

+28
-27
lines changed

CONTRIBUTING.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ commands on your Linux machine directly, outside of the container.
166166
installed correctly:
167167

168168
```bash
169-
python -c 'import torch_xla as xla; print(xla.device())'
169+
python -c 'import torch_xla; print(torch_xla.device())'
170170
# Output: xla:0
171171
```
172172

@@ -375,11 +375,11 @@ First, for the `pytorch` repo:
375375
cd $WORKSPACE_DIR/pytorch
376376
# Fetch the latest changes from upstream.
377377
git fetch upstream
378-
git checkout main
378+
git checkout main
379379
# Merge the changes from upstream/main into your local branch.
380380
git merge upstream/main
381381
# Update submodules to match the latest changes.
382-
git submodule update --recursive
382+
git submodule update --recursive
383383
# Push the updated branch to your fork on GitHub.
384384
git push origin main
385385
```
@@ -389,7 +389,7 @@ Next, for the `vision` repo:
389389
```bash
390390
cd $WORKSPACE_DIR/vision
391391
git fetch upstream
392-
git checkout main
392+
git checkout main
393393
git merge upstream/main
394394
git push origin main
395395
```

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ If you're using `DistributedDataParallel`, make the following changes:
184184
```diff
185185
import torch.distributed as dist
186186
-import torch.multiprocessing as mp
187-
+import torch_xla as xla
187+
+import torch_xla
188188
+import torch_xla.distributed.xla_backend
189189

190190
def _mp_fn(rank):
@@ -203,8 +203,8 @@ If you're using `DistributedDataParallel`, make the following changes:
203203
- ddp_model = DDP(model, device_ids=[rank])
204204

205205
for inputs, labels in train_loader:
206-
+ with xla.step():
207-
+ inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
206+
+ with torch_xla.step():
207+
+ inputs, labels = inputs.to(torch_xla.device()), labels.to(torch_xla.device())
208208
optimizer.zero_grad()
209209
outputs = ddp_model(inputs)
210210
loss = loss_fn(outputs, labels)
@@ -213,7 +213,7 @@ If you're using `DistributedDataParallel`, make the following changes:
213213

214214
if __name__ == '__main__':
215215
- mp.spawn(_mp_fn, args=(), nprocs=world_size)
216-
+ xla.launch(_mp_fn, args=())
216+
+ torch_xla.launch(_mp_fn, args=())
217217
```
218218

219219
Additional information on PyTorch/XLA, including a description of its semantics

docs/source/contribute/configure-environment.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,12 @@ pip install numpy torch torch_xla[tpu] \
9595
Create a file `test.py`:
9696

9797
``` python
98-
import torch_xla as xla
98+
import torch_xla
9999

100100
# Optional
101-
xla.runtime.set_device_type("TPU")
101+
torch_xla.runtime.set_device_type("TPU")
102102

103-
print("XLA devices:", xla.real_devices())
103+
print("XLA devices:", torch_xla.real_devices())
104104
```
105105

106106
Run the test script from your terminal:

test/test_devices.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
from torch import nn
66
from torch.utils.data import TensorDataset, DataLoader
7-
import torch_xla as xla
7+
import torch_xla
88
import torch_xla.core.xla_model as xm
99
import torch_xla.runtime as xr
1010
import torch_xla.debug.metrics as met
@@ -24,35 +24,35 @@ def setUp(self):
2424
(0, torch.device('xla:0')),
2525
(3, torch.device('xla:3')))
2626
def test_device(self, index, expected):
27-
device = xla.device(index)
27+
device = torch_xla.device(index)
2828
self.assertEqual(device, expected)
2929

3030
def test_devices(self):
31-
self.assertEqual(xla.devices(),
31+
self.assertEqual(torch_xla.devices(),
3232
[torch.device(f'xla:{i}') for i in range(4)])
3333

3434
def test_real_devices(self):
35-
self.assertEqual(xla.real_devices(), [f'CPU:{i}' for i in range(4)])
35+
self.assertEqual(torch_xla.real_devices(), [f'CPU:{i}' for i in range(4)])
3636

3737
def test_device_count(self):
38-
self.assertEqual(xla.device_count(), 4)
38+
self.assertEqual(torch_xla.device_count(), 4)
3939

4040
def test_sync(self):
41-
torch.ones((3, 3), device=xla.device())
42-
xla.sync()
41+
torch.ones((3, 3), device=torch_xla.device())
42+
torch_xla.sync()
4343

4444
self.assertEqual(met.counter_value('MarkStep'), 1)
4545

4646
def test_step(self):
47-
with xla.step():
48-
torch.ones((3, 3), device=xla.device())
47+
with torch_xla.step():
48+
torch.ones((3, 3), device=torch_xla.device())
4949

5050
self.assertEqual(met.counter_value('MarkStep'), 2)
5151

5252
def test_step_exception(self):
5353
with self.assertRaisesRegex(RuntimeError, 'Expected error'):
54-
with xla.step():
55-
torch.ones((3, 3), device=xla.device())
54+
with torch_xla.step():
55+
torch.ones((3, 3), device=torch_xla.device())
5656
raise RuntimeError('Expected error')
5757

5858
self.assertEqual(met.counter_value('MarkStep'), 2)
@@ -69,7 +69,7 @@ def __init__(self):
6969
def forward(self, x):
7070
return self.linear(x)
7171

72-
model = TrivialModel().to(xla.device())
72+
model = TrivialModel().to(torch_xla.device())
7373

7474
batch_size = 16
7575
num_samples = 100
@@ -85,8 +85,9 @@ def forward(self, x):
8585
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
8686

8787
for inputs, labels in loader:
88-
with xla.step():
89-
inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
88+
with torch_xla.step():
89+
inputs, labels = inputs.to(torch_xla.device()), labels.to(
90+
torch_xla.device())
9091
optimizer.zero_grad()
9192
outputs = model(inputs)
9293
loss = loss_fn(outputs, labels)

test/tpu/tpu_info/test_cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import multiprocessing
55
import os
66
from typing import Dict, Optional
7-
import torch_xla as xla
7+
import torch_xla
88
import torch_xla.runtime as xr
99
import torch_xla.distributed.xla_multiprocessing as xmp
1010
from tpu_info import cli, device, metrics
@@ -30,7 +30,7 @@ def _init_tpu_and_wait(
3030
):
3131
if env:
3232
os.environ.update(**env)
33-
xla.device()
33+
torch_xla.device()
3434
q.put(os.getpid())
3535
done.wait()
3636

0 commit comments

Comments
 (0)