Skip to content

Commit 78cff03

Browse files
authored
Migrate .to(torch_xla.device()) to .to('xla') (#9324)
1 parent ae4f06b commit 78cff03

33 files changed

+192
-193
lines changed

API_GUIDE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Or used with neural network modules:
4747

4848
```python
4949
l_in = torch.randn(10, device='xla')
50-
linear = torch.nn.Linear(10, 20).to(torch_xla.device())
50+
linear = torch.nn.Linear(10, 20).to('xla')
5151
l_out = linear(l_in)
5252
print(l_out)
5353
```

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,12 @@ To update your existing training loop, make the following changes:
158158
...
159159

160160
+ # Move the model paramters to your XLA device
161-
+ model.to(torch_xla.device())
161+
+ model.to('xla')
162162

163163
for inputs, labels in train_loader:
164164
+ with torch_xla.step():
165165
+ # Transfer data to the XLA device. This happens asynchronously.
166-
+ inputs, labels = inputs.to(torch_xla.device()), labels.to(torch_xla.device())
166+
+ inputs, labels = inputs.to('xla'), labels.to('xla')
167167
optimizer.zero_grad()
168168
outputs = model(inputs)
169169
loss = loss_fn(outputs, labels)
@@ -196,15 +196,15 @@ If you're using `DistributedDataParallel`, make the following changes:
196196
+ # Rank and world size are inferred from the XLA device runtime
197197
+ dist.init_process_group("xla", init_method='xla://')
198198
+
199-
+ model.to(torch_xla.device())
199+
+ model.to('xla')
200200
+ ddp_model = DDP(model, gradient_as_bucket_view=True)
201201

202202
- model = model.to(rank)
203203
- ddp_model = DDP(model, device_ids=[rank])
204204

205205
for inputs, labels in train_loader:
206206
+ with torch_xla.step():
207-
+ inputs, labels = inputs.to(torch_xla.device()), labels.to(torch_xla.device())
207+
+ inputs, labels = inputs.to('xla'), labels.to('xla')
208208
optimizer.zero_grad()
209209
outputs = ddp_model(inputs)
210210
loss = loss_fn(outputs, labels)

contrib/kaggle/pytorch-xla-2-0-on-kaggle.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@
172172
"\n",
173173
"pipeline = DiffusionPipeline.from_pretrained(\"runwayml/stable-diffusion-v1-5\")\n",
174174
"# Move the model to the first TPU core\n",
175-
"pipeline = pipeline.to(torch_xla.device())"
175+
"pipeline = pipeline.to('xla')"
176176
]
177177
},
178178
{

docs/source/learn/pytorch-on-xla-devices.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Or used with neural network modules:
4747

4848
``` python
4949
l_in = torch.randn(10, device='xla')
50-
linear = torch.nn.Linear(10, 20).to(torch_xla.device())
50+
linear = torch.nn.Linear(10, 20).to('xla')
5151
l_out = linear(l_in)
5252
print(l_out)
5353
```

docs/source/perf/amp.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ from torch_xla.amp import syncfree
1919
import torch_xla.core.xla_model as xm
2020

2121
# Creates model and optimizer in default precision
22-
model = Net().to(torch_xla.device())
22+
model = Net().to('xla')
2323
# Pytorch/XLA provides sync-free optimizers for improved performance
2424
optimizer = syncfree.SGD(model.parameters(), ...)
2525

@@ -106,7 +106,7 @@ from torch_xla.amp import syncfree
106106
import torch_xla.core.xla_model as xm
107107

108108
# Creates model and optimizer in default precision
109-
model = Net().to(torch_xla.device())
109+
model = Net().to('xla')
110110
# Pytorch/XLA provides sync-free optimizers for improved performance
111111
optimizer = syncfree.SGD(model.parameters(), ...)
112112
scaler = GradScaler()

docs/source/perf/dynamo.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ import torch
2323
import torch_xla.core.xla_model as xm
2424

2525
def add(a, b):
26-
a_xla = a.to(torch_xla.device())
27-
b_xla = b.to(torch_xla.device())
26+
a_xla = a.to('xla')
27+
b_xla = b.to('xla')
2828
return a_xla + b_xla
2929

3030
compiled_code = torch.compile(add, backend='openxla')

docs/source/perf/spmd_basic.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ mesh_shape = (num_devices, 1)
4141
device_ids = np.array(range(num_devices))
4242
mesh = Mesh(device_ids, mesh_shape, ('data', 'model'))
4343

44-
t = torch.randn(8, 4).to(torch_xla.device())
44+
t = torch.randn(8, 4).to('xla')
4545

4646
# Mesh partitioning, each device holds 1/8-th of the input
4747
partition_spec = ('data', 'model')

test/dynamo/test_dynamo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_random_op_different_result_each_run(self, backend):
6666
met.clear_all()
6767
dynamo_random_op = torch.compile(
6868
self.random_op, backend=backend, fullgraph=True)
69-
t = torch.randn(5, 5).to(torch_xla.device())
69+
t = torch.randn(5, 5).to('xla')
7070
dynamo_res_1 = dynamo_random_op(t)
7171
dynamo_res_2 = dynamo_random_op(t)
7272
dynamo_res_3 = dynamo_random_op(t)
@@ -783,7 +783,7 @@ def foo(x):
783783
optfoo = torch.compile(backend=backend)(foo)
784784

785785
t = torch.arange(9)
786-
Xt = t.to(torch_xla.device())
786+
Xt = t.to('xla')
787787

788788
expected = foo(t)
789789
actual = optfoo(Xt).cpu()
@@ -803,7 +803,7 @@ def foo(x):
803803
optfoo = torch.compile(backend=backend)(foo)
804804

805805
t = torch.arange(10)
806-
Xt = t.to(torch_xla.device())
806+
Xt = t.to('xla')
807807

808808
expected = foo(t)
809809
actual = optfoo(Xt)

test/pjrt/test_dtypes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class TestDtypes(parameterized.TestCase):
1111
torch.bfloat16, torch.complex64)
1212
def test_float_round_trip(self, dtype: torch.dtype):
1313
t = torch.randn((3, 3), dtype=dtype)
14-
xt = t.to(torch_xla.device())
14+
xt = t.to('xla')
1515
torch.testing.assert_close(xt.cpu(), t)
1616

1717
@parameterized.parameters(
@@ -23,12 +23,12 @@ def test_float_round_trip(self, dtype: torch.dtype):
2323
)
2424
def test_int_round_trip(self, dtype: torch.dtype):
2525
t = torch.randint(0, 128, (3, 3), dtype=dtype)
26-
xt = t.to(torch_xla.device())
26+
xt = t.to('xla')
2727
torch.testing.assert_close(xt.cpu(), t)
2828

2929
def test_bool_round_trip(self):
3030
t = torch.randint(0, 2, (3, 3), dtype=torch.bool)
31-
xt = t.to(torch_xla.device())
31+
xt = t.to('xla')
3232
torch.testing.assert_close(xt.cpu(), t)
3333

3434

test/scan/test_scan_layers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,14 +262,14 @@ def forward_scan(
262262
assert checks > 0
263263

264264
def test_heterogenous_layers(self):
265-
layer1 = nn.Linear(128, 128).to(torch_xla.device())
266-
layer2 = nn.Sequential(nn.Linear(128, 128).to(torch_xla.device()))
265+
layer1 = nn.Linear(128, 128).to('xla')
266+
layer2 = nn.Sequential(nn.Linear(128, 128).to('xla'))
267267
with self.assertRaisesRegex(ValueError, "mismatched keys"):
268268
scan_layers([layer1, layer2], torch.zeros((128,), device='xla'))
269269

270270
def test_mismatched_shapes(self):
271-
layer1 = nn.Linear(128, 128).to(torch_xla.device())
272-
layer2 = nn.Linear(128, 129).to(torch_xla.device())
271+
layer1 = nn.Linear(128, 128).to('xla')
272+
layer2 = nn.Linear(128, 129).to('xla')
273273
with self.assertRaisesRegex(ValueError, "Shape mismatch"):
274274
scan_layers([layer1, layer2], torch.zeros((128,), device='xla'))
275275

0 commit comments

Comments
 (0)