Skip to content

Commit 3bc9ece

Browse files
authored
doc: update pytorch-on-xla-devices and troubleshoot doc for tensor synchronization issue (#9258)
1 parent 1bd3042 commit 3bc9ece

File tree

2 files changed

+80
-8
lines changed

2 files changed

+80
-8
lines changed

docs/source/learn/pytorch-on-xla-devices.md

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,72 @@ device is unavailable the load will fail. PyTorch/XLA, like all of
346346
PyTorch, is under active development and this behavior may change in the
347347
future.
348348

349+
### Unexpected Tensor Materialization During AOT (ahead of time) Tracing
350+
351+
While tensor materialization is normal for JIT workflow, it is not expected during traced inference (i.e. [AOT model tracing in AWS Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/programming-guide/inference/trace-vs-xla-lazytensor.html)).
352+
When working with traced inference, developers may encounter tensor materialization, which leads to graphs being compiled based on example input tensor value and unexpected program behavior.
353+
Therefore we need to take advantage of PyTorch/XLA's debugging flags to identify when unexpected tensor materialization happens and make appropriate code changes to avoid tensor materialization.
354+
355+
356+
A common issue occurs when tensor values are evaluated during model compilation (traced inference). Consider this example:
357+
```python
358+
def forward(self, tensor):
359+
if tensor[0] == 1:
360+
return tensor
361+
else:
362+
return tensor * 2
363+
```
364+
365+
While this code can compile and run, it may lead to unexpected behavior because:
366+
367+
* The tensor value is being accessed during tracing (``tensor[0]``).
368+
* The resulting graph becomes fixed based on the tensor value available during tracing
369+
* Developers might incorrectly assume the condition will be evaluated dynamically during inference
370+
* The solution for the code above is to utilize the debugging flags below to catch the issue and modify the code. One example is to feed the flag through model configuration
371+
372+
See the updated code without tensor materialization:
373+
```python
374+
class TestModel(torch.nn.Module):
375+
def __init__(self, flag=1):
376+
super().__init__()
377+
# the flag should be pre-determined based on the model configuration
378+
# it should not be an input of the model during runtime
379+
self.flag = flag
380+
381+
def forward(self, tensor):
382+
if self.flag:
383+
return tensor
384+
else:
385+
return tensor * 2
386+
```
387+
388+
389+
#### Debugging Flags
390+
To help catch tensor materialization issues, PyTorch/XLA provides two useful approaches:
391+
392+
1. Enable warning messages for tensor materialization:
393+
```
394+
import os
395+
os.environ['PT_XLA_DEBUG_LEVEL'] = '2'
396+
```
397+
398+
2. Disable graph execution to catch issues during development:
399+
```
400+
import torch_xla
401+
torch_xla._XLAC._set_allow_execution(False)
402+
```
403+
404+
#### Recommendations
405+
406+
Using these flags during development can help identify potential issues early in the development cycle. The recommended approach is to:
407+
408+
* Use ``PT_XLA_DEBUG_LEVEL=2`` during initial development to identify potential materialization points
409+
* Apply ``_set_allow_execution(False)`` when you want to ensure no tensor materialization occurs during tracing
410+
* When you see warnings or errors related the tensor materialization, look into the code path and make appropriate changes. The example above moved the flag to the `__init__` function which does not depend on the model input during runtime.
411+
412+
For more detailed debugging information, refer to the [XLA troubleshoot](https://github.com/pytorch/xla/blob/master/docs/source/learn/troubleshoot.md#pytorchxla-debugging-tool).
413+
414+
349415
## Compilation Caching
350416

351417
The XLA compiler converts the traced HLO into an executable which runs

docs/source/learn/troubleshoot.md

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,19 +137,25 @@ Execution Analysis: ------------------------------------------------------------
137137
Execution Analysis: ================================================================================
138138
```
139139

140-
Some common causes of Compilation/Executation are 1. User manually call
141-
`torch_xla.sync()`. 2. [Parallel
140+
Some common causes of compilation/executation are
141+
1. User manually calls
142+
`torch_xla.sync()`.
143+
2. [Parallel
142144
loader](https://github.com/pytorch/xla/blob/fe4af0080af07f78ca2b614dd91b71885a3bbbb8/torch_xla/distributed/parallel_loader.py#L49-L51)
143-
call `torch_xla.sync()` for every x (configurable) batch. 3. Exiting a
145+
cals `torch_xla.sync()` for every x (configurable) batch.
146+
3. Exit a
144147
[profiler StepTrace
145148
region](https://github.com/pytorch/xla/blob/fe4af0080af07f78ca2b614dd91b71885a3bbbb8/torch_xla/debug/profiler.py#L165-L171).
146-
4. Dynamo decide to compile/execute the graph. 5. User trying to
147-
access(often due to logging) the value of a tensor before the
149+
4. Dynamo decides to compile/execute the graph.
150+
5. User tries to
151+
access (often due to logging) the value of a tensor before the
148152
`torch_xla.sync()`.
153+
6. User tries to access a tensor value before calling `mark_step`. See [PyTorch on XLA Devices](https://github.com/pytorch/xla/blob/master/docs/source/learn/pytorch-on-xla-devices.md) for more details.
154+
155+
The op executions caused by items 1-4 are expected, and we want to avoid item 5 by
156+
either reducing the frequency of accessing tensor values or manually adding a call to
157+
`torch_xla.sync()` before accessing them.
149158
150-
The execution caused by 1-4 are expected, and we want to avoid 5 by
151-
either reduce the frequency of accessing tensor values or manually add a
152-
`torch_xla.sync()` before accessing.
153159
154160
Users should expect to see this `Compilation Cause` +
155161
`Executation Cause` pairs for first couple steps. After the model

0 commit comments

Comments
 (0)