You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/learn/pytorch-on-xla-devices.md
+66Lines changed: 66 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -346,6 +346,72 @@ device is unavailable the load will fail. PyTorch/XLA, like all of
346
346
PyTorch, is under active development and this behavior may change in the
347
347
future.
348
348
349
+
### Unexpected Tensor Materialization During AOT (ahead of time) Tracing
350
+
351
+
While tensor materialization is normal for JIT workflow, it is not expected during traced inference (i.e. [AOT model tracing in AWS Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/programming-guide/inference/trace-vs-xla-lazytensor.html)).
352
+
When working with traced inference, developers may encounter tensor materialization, which leads to graphs being compiled based on example input tensor value and unexpected program behavior.
353
+
Therefore we need to take advantage of PyTorch/XLA's debugging flags to identify when unexpected tensor materialization happens and make appropriate code changes to avoid tensor materialization.
354
+
355
+
356
+
A common issue occurs when tensor values are evaluated during model compilation (traced inference). Consider this example:
357
+
```python
358
+
defforward(self, tensor):
359
+
if tensor[0] ==1:
360
+
return tensor
361
+
else:
362
+
return tensor *2
363
+
```
364
+
365
+
While this code can compile and run, it may lead to unexpected behavior because:
366
+
367
+
* The tensor value is being accessed during tracing (``tensor[0]``).
368
+
* The resulting graph becomes fixed based on the tensor value available during tracing
369
+
* Developers might incorrectly assume the condition will be evaluated dynamically during inference
370
+
* The solution for the code above is to utilize the debugging flags below to catch the issue and modify the code. One example is to feed the flag through model configuration
371
+
372
+
See the updated code without tensor materialization:
373
+
```python
374
+
classTestModel(torch.nn.Module):
375
+
def__init__(self, flag=1):
376
+
super().__init__()
377
+
# the flag should be pre-determined based on the model configuration
378
+
# it should not be an input of the model during runtime
379
+
self.flag = flag
380
+
381
+
defforward(self, tensor):
382
+
ifself.flag:
383
+
return tensor
384
+
else:
385
+
return tensor *2
386
+
```
387
+
388
+
389
+
#### Debugging Flags
390
+
To help catch tensor materialization issues, PyTorch/XLA provides two useful approaches:
391
+
392
+
1. Enable warning messages for tensor materialization:
393
+
```
394
+
import os
395
+
os.environ['PT_XLA_DEBUG_LEVEL'] = '2'
396
+
```
397
+
398
+
2. Disable graph execution to catch issues during development:
399
+
```
400
+
import torch_xla
401
+
torch_xla._XLAC._set_allow_execution(False)
402
+
```
403
+
404
+
#### Recommendations
405
+
406
+
Using these flags during development can help identify potential issues early in the development cycle. The recommended approach is to:
407
+
408
+
* Use ``PT_XLA_DEBUG_LEVEL=2`` during initial development to identify potential materialization points
409
+
* Apply ``_set_allow_execution(False)`` when you want to ensure no tensor materialization occurs during tracing
410
+
* When you see warnings or errors related the tensor materialization, look into the code path and make appropriate changes. The example above moved the flag to the `__init__` function which does not depend on the model input during runtime.
411
+
412
+
For more detailed debugging information, refer to the [XLA troubleshoot](https://github.com/pytorch/xla/blob/master/docs/source/learn/troubleshoot.md#pytorchxla-debugging-tool).
413
+
414
+
349
415
## Compilation Caching
350
416
351
417
The XLA compiler converts the traced HLO into an executable which runs
4. Dynamo decide to compile/execute the graph. 5. User trying to
147
-
access(often due to logging) the value of a tensor before the
149
+
4. Dynamo decides to compile/execute the graph.
150
+
5. User tries to
151
+
access (often due to logging) the value of a tensor before the
148
152
`torch_xla.sync()`.
153
+
6. User tries to access a tensor value before calling `mark_step`. See [PyTorch on XLA Devices](https://github.com/pytorch/xla/blob/master/docs/source/learn/pytorch-on-xla-devices.md) for more details.
154
+
155
+
The op executions caused by items 1-4 are expected, and we want to avoid item 5 by
156
+
either reducing the frequency of accessing tensor values or manually adding a call to
157
+
`torch_xla.sync()` before accessing them.
149
158
150
-
The execution caused by 1-4 are expected, and we want to avoid 5 by
151
-
either reduce the frequency of accessing tensor values or manually add a
152
-
`torch_xla.sync()` before accessing.
153
159
154
160
Users should expect to see this `Compilation Cause` +
155
161
`Executation Cause` pairs for first couple steps. After the model
0 commit comments