Skip to content

Commit d8f31f2

Browse files
authored
[Doc] add debugging tips (#5409)
1 parent 640052b commit d8f31f2

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
.. _debugging:
2+
3+
Debugging Tips
4+
===============
5+
6+
Debugging hang/crash issues
7+
---------------------------
8+
9+
When an vLLM instance hangs or crashes, it is very difficult to debug the issue. Here are some tips to help debug the issue:
10+
11+
- Set the environment variable ``export VLLM_LOGGING_LEVEL=DEBUG`` to turn on more logging.
12+
- Set the environment variable ``export CUDA_LAUNCH_BLOCKING=1`` to know exactly which CUDA kernel is causing the trouble.
13+
- Set the environment variable ``export NCCL_DEBUG=TRACE`` to turn on more logging for NCCL.
14+
- Set the environment variable ``export VLLM_TRACE_FUNCTION=1`` . All the function calls in vLLM will be recorded. Inspect these log files, and tell which function crashes or hangs. **Note: it will generate a lot of logs and slow down the system. Only use it for debugging purposes.**
15+
16+
With more logging, hopefully you can find the root cause of the issue.
17+
18+
Here are some common issues that can cause hangs:
19+
20+
- The network setup is incorrect. The vLLM instance cannot get the correct IP address. You can find the log such as ``DEBUG 06-10 21:32:17 parallel_state.py:88] world_size=8 rank=0 local_rank=0 distributed_init_method=tcp://xxx.xxx.xxx.xxx:54641 backend=nccl``. The IP address should be the correct one. If not, override the IP address by setting the environment variable ``export VLLM_HOST_IP=your_ip_address``.
21+
- Hardware/driver setup is incorrect. GPU communication cannot be established. You can run a sanity check script below to see if the GPU communication is working correctly.
22+
23+
.. code-block:: python
24+
25+
# save it as `test.py`` , and run it with `NCCL_DEBUG=TRACE torchrun --nproc-per-node=8 test.py`
26+
# adjust `--nproc-per-node` to the number of GPUs you want to use.
27+
import torch
28+
import torch.distributed as dist
29+
dist.init_process_group(backend="nccl")
30+
data = torch.FloatTensor([1,] * 128).to(f"cuda:{dist.get_rank()}")
31+
dist.all_reduce(data, op=dist.ReduceOp.SUM)
32+
torch.cuda.synchronize()
33+
value = data.mean().item()
34+
assert value == dist.get_world_size()
35+
36+
If the problem persists, feel free to open an `issue <https://github.com/vllm-project/vllm/issues/new/choose>`_ on GitHub, with a detailed description of the issue, your environment, and the logs.

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ Documentation
6666
getting_started/neuron-installation
6767
getting_started/cpu-installation
6868
getting_started/quickstart
69+
getting_started/debugging
6970
getting_started/examples/examples_index
7071

7172
.. toctree::

0 commit comments

Comments
 (0)