@@ -158,12 +158,12 @@ To update your existing training loop, make the following changes:
158158 ...
159159
160160+ # Move the model paramters to your XLA device
161- + model.to(torch_xla.device() )
161+ + model.to('xla' )
162162
163163 for inputs, labels in train_loader:
164164+ with torch_xla.step():
165165+ # Transfer data to the XLA device. This happens asynchronously.
166- + inputs, labels = inputs.to(torch_xla.device()) , labels.to(torch_xla.device() )
166+ + inputs, labels = inputs.to('xla') , labels.to('xla' )
167167 optimizer.zero_grad()
168168 outputs = model(inputs)
169169 loss = loss_fn(outputs, labels)
@@ -196,15 +196,15 @@ If you're using `DistributedDataParallel`, make the following changes:
196196+ # Rank and world size are inferred from the XLA device runtime
197197+ dist.init_process_group("xla", init_method='xla://')
198198+
199- + model.to(torch_xla.device() )
199+ + model.to('xla' )
200200+ ddp_model = DDP(model, gradient_as_bucket_view=True)
201201
202202- model = model.to(rank)
203203- ddp_model = DDP(model, device_ids=[rank])
204204
205205 for inputs, labels in train_loader:
206206+ with torch_xla.step():
207- + inputs, labels = inputs.to(torch_xla.device()) , labels.to(torch_xla.device() )
207+ + inputs, labels = inputs.to('xla') , labels.to('xla' )
208208 optimizer.zero_grad()
209209 outputs = ddp_model(inputs)
210210 loss = loss_fn(outputs, labels)
0 commit comments