Skip to content

Commit 99dac09

Browse files
njhillYard1
andauthored
[Core][Doc] Default to multiprocessing for single-node distributed case (#5230)
Co-authored-by: Antoni Baum <[email protected]>
1 parent c4bd03c commit 99dac09

File tree

5 files changed

+31
-14
lines changed

5 files changed

+31
-14
lines changed

docs/source/serving/distributed_serving.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@
33
Distributed Inference and Serving
44
=================================
55

6-
vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm <https://arxiv.org/pdf/1909.08053.pdf>`_. We manage the distributed runtime with `Ray <https://github.com/ray-project/ray>`_. To run distributed inference, install Ray with:
6+
vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm <https://arxiv.org/pdf/1909.08053.pdf>`_. We manage the distributed runtime with either `Ray <https://github.com/ray-project/ray>`_ or python native multiprocessing. Multiprocessing can be used when deploying on a single node, multi-node inferencing currently requires Ray.
77

8-
.. code-block:: console
9-
10-
$ pip install ray
8+
Multiprocessing will be used by default when not running in a Ray placement group and if there are sufficient GPUs available on the same node for the configured :code:`tensor_parallel_size`, otherwise Ray will be used. This default can be overridden via the :code:`LLM` class :code:`distributed-executor-backend` argument or :code:`--distributed-executor-backend` API server argument. Set it to :code:`mp` for multiprocessing or :code:`ray` for Ray. It's not required for Ray to be installed for the multiprocessing case.
119

1210
To run multi-GPU inference with the :code:`LLM` class, set the :code:`tensor_parallel_size` argument to the number of GPUs you want to use. For example, to run inference on 4 GPUs:
1311

@@ -25,10 +23,12 @@ To run multi-GPU serving, pass in the :code:`--tensor-parallel-size` argument wh
2523
$ --model facebook/opt-13b \
2624
$ --tensor-parallel-size 4
2725
28-
To scale vLLM beyond a single machine, start a `Ray runtime <https://docs.ray.io/en/latest/ray-core/starting-ray.html>`_ via CLI before running vLLM:
26+
To scale vLLM beyond a single machine, install and start a `Ray runtime <https://docs.ray.io/en/latest/ray-core/starting-ray.html>`_ via CLI before running vLLM:
2927

3028
.. code-block:: console
3129
30+
$ pip install ray
31+
3232
$ # On head node
3333
$ ray start --head
3434

tests/spec_decode/e2e/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,11 @@ def __init__(
7777
swap_space=swap_space,
7878
enforce_eager=enforce_eager,
7979
max_seq_len_to_capture=max_seq_len_to_capture,
80+
# For now use ray for the distributed back-end, since
81+
# we rely on the use of engine_use_ray=True to avoid
82+
# reinitializing CUDA in the same process (driver worker)
8083
engine_use_ray=True,
84+
distributed_executor_backend="ray",
8185
disable_custom_all_reduce=disable_custom_all_reduce,
8286
**kwargs,
8387
)

vllm/config.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,9 +603,25 @@ def __init__(
603603
f"'{self.distributed_executor_backend}'.")
604604

605605
if self.distributed_executor_backend is None and self.world_size > 1:
606+
# We use multiprocessing by default if world_size fits on the
607+
# current node and we aren't in a ray placement group.
608+
from torch.cuda import device_count
609+
606610
from vllm.executor import ray_utils
611+
backend = "mp"
607612
ray_found = ray_utils.ray is not None
608-
self.distributed_executor_backend = "ray" if ray_found else "mp"
613+
if device_count() < self.world_size:
614+
if not ray_found:
615+
raise ValueError("Unable to load Ray which is "
616+
"required for multi-node inference")
617+
backend = "ray"
618+
elif ray_found:
619+
from ray.util import get_current_placement_group
620+
if self.placement_group or get_current_placement_group():
621+
backend = "ray"
622+
self.distributed_executor_backend = backend
623+
logger.info("Defaulting to use %s for distributed inference",
624+
backend)
609625

610626
self._verify_args()
611627

vllm/executor/multiproc_gpu_executor.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
1919
"""Python multiprocessing-based multi-GPU executor"""
2020

2121
def _init_executor(self) -> None:
22-
assert (
23-
not self.speculative_config
24-
), "Speculative decoding not yet supported for MultiProcGPU backend."
25-
2622
# Create the parallel GPU workers.
2723
world_size = self.parallel_config.tensor_parallel_size
2824

vllm/executor/multiproc_worker_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,11 @@ def _set_future_result(future: Union[ResultFuture, asyncio.Future],
6565
future.set_result(result)
6666
return
6767
loop = future.get_loop()
68-
if result.exception is not None:
69-
loop.call_soon_threadsafe(future.set_exception, result.exception)
70-
else:
71-
loop.call_soon_threadsafe(future.set_result, result.value)
68+
if not loop.is_closed():
69+
if result.exception is not None:
70+
loop.call_soon_threadsafe(future.set_exception, result.exception)
71+
else:
72+
loop.call_soon_threadsafe(future.set_result, result.value)
7273

7374

7475
class ResultHandler(threading.Thread):

0 commit comments

Comments
 (0)