diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 37a41bf6de71..f0cfad7fa3a5 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -137,6 +137,9 @@ class ParallelConfig: disable_custom_all_reduce: bool = False """Disable the custom all-reduce kernel and fall back to NCCL.""" + enable_elastic_ep: bool = False + """Enable elastic expert parallelism with stateless NCCL groups for DP/EP.""" + enable_dbo: bool = False """Enable microbatching for the model executor.""" @@ -188,6 +191,21 @@ class is dynamically inherited by the worker class. This is used to inject Set to be private as it's not intended to be configured by users. """ + _stateless_world_group_port_list: list[int] = field(default_factory=list) + """List of open ports for stateless world group when enable_elastic_ep is True. + Set to be private as it's not intended to be configured by users. + """ + + _stateless_dp_group_port_list: list[int] = field(default_factory=list) + """List of open ports for stateless DP groups when enable_elastic_ep is True. + Set to be private as it's not intended to be configured by users. + """ + + _stateless_ep_group_port_list: list[int] = field(default_factory=list) + """List of open ports for stateless EP groups when enable_elastic_ep is True. + Set to be private as it's not intended to be configured by users. + """ + decode_context_parallel_size: int = 1 """Number of decode context parallel groups, because the world size does not change by dcp, it simply reuse the GPUs of TP group, and tp_size @@ -235,7 +253,16 @@ def get_next_dp_init_port(self) -> int: return answer - def stateless_init_dp_group(self) -> ProcessGroup: + def get_next_stateless_world_group_port(self) -> list[int]: + return self._stateless_world_group_port_list.pop(0) + + def get_next_stateless_dp_group_port(self) -> list[int]: + return self._stateless_dp_group_port_list.pop(0) + + def get_next_stateless_ep_group_port(self) -> list[int]: + return self._stateless_ep_group_port_list.pop(0) + + def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup: # NOTE: In high-concurrency scenarios multiple processes # can pick the same (currently free) port through a race # condition when calling `get_open_port()`. When the first @@ -258,7 +285,8 @@ def stateless_init_dp_group(self) -> ProcessGroup: self.get_next_dp_init_port(), self.data_parallel_rank, self.data_parallel_size, - backend="gloo") + backend="gloo", + return_store=return_store) except DistNetworkError as e: # We only want to retry when the root cause is EADDRINUSE. if "EADDRINUSE" in str(e): @@ -351,6 +379,24 @@ def __post_init__(self) -> None: self.world_size = self.pipeline_parallel_size * \ self.tensor_parallel_size + # Initialize stateless group ports for elastic EP + if self.enable_elastic_ep: + num_world_groups = 1 + num_dp_groups = max(1, self.world_size_across_dp // self.data_parallel_size) + num_ep_groups = max(1, self.world_size_across_dp // (self.data_parallel_size * self.tensor_parallel_size)) + + total_ports_needed = (num_world_groups + num_dp_groups + num_ep_groups) * 3 + + if not self._stateless_world_group_port_list: + all_ports = get_open_ports_list(total_ports_needed + 5) + self._data_parallel_master_port_list = all_ports[-5:] + all_ports = all_ports[:-5] + self._stateless_world_group_port_list = [all_ports[i:i+3] for i in range(0, num_world_groups * 3, 3)] + start_idx = num_world_groups * 3 + self._stateless_dp_group_port_list = [all_ports[i:i+3] for i in range(start_idx, start_idx + num_dp_groups * 3, 3)] + start_idx += num_dp_groups * 3 + self._stateless_ep_group_port_list = [all_ports[i:i+3] for i in range(start_idx, start_idx + num_ep_groups * 3, 3)] + if self.data_parallel_size_local > self.data_parallel_size: raise ValueError( f"data_parallel_size_local ({self.data_parallel_size_local}) " diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 149df73d8667..f6fd3bb89d0d 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -23,8 +23,8 @@ class NaiveAll2AllManager(All2AllManagerBase): debugging. """ - def __init__(self, cpu_group): - super().__init__(cpu_group) + def __init__(self, cpu_group, tcp_store_group=None): + super().__init__(cpu_group, tcp_store_group) def naive_multicast(self, x: torch.Tensor, cu_tokens_across_dp_cpu: torch.Tensor): @@ -76,8 +76,8 @@ class AgRsAll2AllManager(All2AllManagerBase): all-gather (dispatch) and reduce-scatter (combine). """ - def __init__(self, cpu_group): - super().__init__(cpu_group) + def __init__(self, cpu_group, tcp_store_group=None): + super().__init__(cpu_group, tcp_store_group) def dispatch(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): @@ -113,14 +113,16 @@ class PPLXAll2AllManager(All2AllManagerBase): All2All communication based on PPLX kernels. """ - def __init__(self, cpu_group): + def __init__(self, cpu_group, tcp_store_group=None): assert has_pplx( ), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa - super().__init__(cpu_group) + super().__init__(cpu_group, tcp_store_group) - if self.internode: - # inter-node communication needs nvshmem, - # intra-node communication uses p2p mapping directly + self.nvshmem_initialized = False + self.handle_cache = Cache() + + def get_handle(self, kwargs): + if self.internode and not self.nvshmem_initialized: from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, nvshmem_get_unique_id, nvshmem_init) @@ -129,15 +131,18 @@ def __init__(self, cpu_group): "rank=%d, world size=%d", self.rank, self.world_size) uid = nvshmem_get_unique_id( ) if self.rank == 0 else nvshmem_alloc_empty_unique_id() - dist.broadcast(uid, - src=dist.get_process_group_ranks(self.cpu_group)[0], - group=self.cpu_group) + + if self.tcp_store_group is not None: + uid = self.tcp_store_group.broadcast_obj(uid, src=0) + else: + dist.broadcast(uid, + src=dist.get_process_group_ranks(self.cpu_group)[0], + group=self.cpu_group) + logger.debug("PPLX NVSHMEM UID = %s", uid) nvshmem_init(uid, self.rank, self.world_size) + self.nvshmem_initialized = True - self.handle_cache = Cache() - - def get_handle(self, kwargs): import pplx_kernels as pplx return self.handle_cache.get_or_create( kwargs, pplx.AllToAll.internode @@ -166,10 +171,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): All2All communication based on DeepEP High-Throughput kernels. """ - def __init__(self, cpu_group): + def __init__(self, cpu_group, tcp_store_group=None): assert has_deep_ep( ), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa - super().__init__(cpu_group) + super().__init__(cpu_group, tcp_store_group) self.handle_cache = Cache() # This is the DeepEP default. Stick to it till we can establish @@ -195,8 +200,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): All2All communication based on DeepEP High-Throughput kernels. """ - def __init__(self, cpu_group): - super().__init__(cpu_group) + def __init__(self, cpu_group, tcp_store_group=None): + super().__init__(cpu_group, tcp_store_group) def _make_all2all_kwargs(self) -> dict[Any, Any]: # Defaults for internode and intranode are taken from DeepEP tests. @@ -243,8 +248,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): All2All communication based on DeepEP Low-Latency kernels. """ - def __init__(self, cpu_group): - super().__init__(cpu_group) + def __init__(self, cpu_group, tcp_store_group=None): + super().__init__(cpu_group, tcp_store_group) def _make_all2all_kwargs( self, @@ -265,7 +270,8 @@ def _make_all2all_kwargs( import deep_ep # Defaults for internode and intranode are taken from DeepEP tests. - num_nvl_bytes = 1024 * 1024 * 1024 + # num_nvl_bytes = 1024 * 1024 * 1024 + num_nvl_bytes = 0 num_qps_per_rank = num_local_experts num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank, @@ -278,7 +284,8 @@ def _make_all2all_kwargs( num_nvl_bytes=num_nvl_bytes, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, - num_qps_per_rank=num_qps_per_rank) + num_qps_per_rank=num_qps_per_rank, + allow_mnnvl=True) def get_handle(self, kwargs): """ diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 01f59b44a0e6..b539835279bd 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -29,8 +29,9 @@ def get_or_create(self, kwargs, func): class All2AllManagerBase: - def __init__(self, cpu_group): + def __init__(self, cpu_group, tcp_store_group=None): self.cpu_group = cpu_group + self.tcp_store_group = tcp_store_group # compute some common properties from vllm.distributed.parallel_state import (get_dp_group, @@ -44,12 +45,15 @@ def __init__(self, cpu_group): # when we create this object self.dp_rank = self.dp_group.rank_in_group self.dp_world_size = self.dp_group.world_size - self.rank = dist.get_rank(cpu_group) - self.world_size = dist.get_world_size(cpu_group) + self.rank = cpu_group.rank() + self.world_size = cpu_group.size() # all2all communication often has separate implementations for # intra-node and inter-node communication - self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0)) + if tcp_store_group is None: + self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0)) + else: + self.internode = not all(in_the_same_node_as(tcp_store_group, source_rank=0)) def get_handle(self, kwargs): # get a handle for the all2all communication, @@ -83,18 +87,34 @@ def __init__(self, cpu_group: ProcessGroup, device: Optional[torch.device] = None, device_group: Optional[ProcessGroup] = None, - unique_name: str = ""): + unique_name: str = "", + global_ranks: Optional[list[int]] = None, + global_world_size: Optional[int] = None): self.device = device or torch.device("cpu") self.cpu_group = cpu_group self.device_group = device_group self.unique_name = unique_name - self.rank = dist.get_rank(cpu_group) - self.world_size = dist.get_world_size(cpu_group) - self.ranks = dist.get_process_group_ranks(cpu_group) - self.global_rank = dist.get_rank() - self.global_world_size = dist.get_world_size() - self.rank_in_group = dist.get_group_rank(self.cpu_group, - self.global_rank) + + # Check if this is a stateless process group + from torch.distributed.distributed_c10d import _world + is_stateless = _world.pg_map.get(cpu_group, None) is None + + if is_stateless: + # For stateless groups, we can't use torch.distributed methods + self.rank = cpu_group.rank() + self.world_size = cpu_group.size() + self.ranks = global_ranks + self.global_rank = self.ranks[self.rank] + self.global_world_size = global_world_size + self.rank_in_group = self.rank + else: + self.rank = dist.get_rank(cpu_group) + self.world_size = dist.get_world_size(cpu_group) + self.ranks = dist.get_process_group_ranks(cpu_group) + self.global_rank = dist.get_rank() + self.global_world_size = dist.get_world_size() + self.rank_in_group = dist.get_group_rank(self.cpu_group, + self.global_rank) use_ep = False from vllm.config import get_current_vllm_config diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index b2bf3bc3cc2e..664813f9a143 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -11,6 +11,7 @@ from vllm.platforms import current_platform from .base_device_communicator import DeviceCommunicatorBase +from ..utils import StatelessProcessGroup logger = init_logger(__name__) @@ -21,8 +22,12 @@ def __init__(self, cpu_group: ProcessGroup, device: Optional[torch.device] = None, device_group: Optional[ProcessGroup] = None, - unique_name: str = ""): - super().__init__(cpu_group, device, device_group, unique_name) + unique_name: str = "", + global_ranks: Optional[list[int]] = None, + global_world_size: Optional[int] = None, + tcp_store_group: Optional[StatelessProcessGroup] = None): + super().__init__(cpu_group, device, device_group, unique_name, + global_ranks, global_world_size) if "tp" not in unique_name: # only tp uses custom allreduce use_custom_allreduce = False @@ -32,7 +37,7 @@ def __init__(self, use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE # ep does not use pynccl - use_pynccl = "ep" not in unique_name + use_pynccl = ("ep" not in unique_name) or (tcp_store_group is not None) self.use_pynccl = use_pynccl self.use_custom_allreduce = use_custom_allreduce @@ -50,7 +55,7 @@ def __init__(self, self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: self.pynccl_comm = PyNcclCommunicator( - group=self.cpu_group, + group=self.cpu_group if tcp_store_group is None else tcp_store_group, device=self.device, ) @@ -85,23 +90,23 @@ def __init__(self, all2all_backend = envs.VLLM_ALL2ALL_BACKEND if all2all_backend == "naive": from .all2all import NaiveAll2AllManager - self.all2all_manager = NaiveAll2AllManager(self.cpu_group) + self.all2all_manager = NaiveAll2AllManager(self.cpu_group, tcp_store_group=tcp_store_group) logger.info("Using naive all2all manager.") elif all2all_backend == "allgather_reducescatter": from .all2all import AgRsAll2AllManager - self.all2all_manager = AgRsAll2AllManager(self.cpu_group) + self.all2all_manager = AgRsAll2AllManager(self.cpu_group, tcp_store_group=tcp_store_group) logger.info("Using AllGather-ReduceScatter all2all manager.") elif all2all_backend == "pplx": from .all2all import PPLXAll2AllManager - self.all2all_manager = PPLXAll2AllManager(self.cpu_group) + self.all2all_manager = PPLXAll2AllManager(self.cpu_group, tcp_store_group=tcp_store_group) logger.info("Using PPLX all2all manager.") elif all2all_backend == "deepep_high_throughput": from .all2all import DeepEPHTAll2AllManager - self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group) + self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group, tcp_store_group=tcp_store_group) logger.info("Using DeepEP High-Throughput all2all manager.") elif all2all_backend == "deepep_low_latency": from .all2all import DeepEPLLAll2AllManager - self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group) + self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group, tcp_store_group=tcp_store_group) logger.info("Using DeepEP Low-Latency all2all manager.") else: raise ValueError(f"Unknown all2all backend: {all2all_backend}") @@ -229,6 +234,18 @@ def recv(self, torch.distributed.recv(tensor, self.ranks[src], self.device_group) return tensor + def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor: + """Broadcast a tensor from source rank to all ranks.""" + if self.world_size == 1: + return tensor + + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.broadcast(tensor, src) + return tensor + else: + raise ValueError("No PyNCCL communicator found") + def destroy(self): if self.pynccl_comm is not None: self.pynccl_comm = None @@ -296,3 +313,10 @@ def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: assert self.all2all_manager is not None hidden_states = self.all2all_manager.combine(hidden_states) return hidden_states + + def batch_isend_irecv(self, p2p_ops: list): + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.batch_isend_irecv(p2p_ops) + else: + raise ValueError("No PyNCCL communicator found") diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 3e4d0d250af9..2e3740cec250 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -6,7 +6,7 @@ # ===================== import region ===================== import torch import torch.distributed as dist -from torch.distributed import ProcessGroup, ReduceOp +from torch.distributed import ProcessGroup, ReduceOp, P2POp from vllm.distributed.device_communicators.pynccl_wrapper import ( NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, @@ -248,8 +248,12 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None): f"but the input tensor is on {tensor.device}") if stream is None: stream = current_stream() + if tensor.dtype in [torch.float8_e5m2, torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2fnuz]: + nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8) + else: + nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype) self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), dst, + nccl_dtype, dst, self.comm, cudaStream_t(stream.cuda_stream)) def recv(self, tensor: torch.Tensor, src: int, stream=None): @@ -260,8 +264,12 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None): f"but the input tensor is on {tensor.device}") if stream is None: stream = current_stream() + if tensor.dtype in [torch.float8_e5m2, torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2fnuz]: + nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8) + else: + nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype) self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), src, + nccl_dtype, src, self.comm, cudaStream_t(stream.cuda_stream)) def broadcast(self, tensor: torch.Tensor, src: int, stream=None): @@ -288,3 +296,16 @@ def group_start(self): def group_end(self): self.nccl.ncclGroupEnd() + + def batch_isend_irecv(self, p2p_ops: list, stream=None): + if self.disabled: + return + if stream is None: + stream = current_stream() + self.group_start() + for op in p2p_ops: + if op.op.__name__ == "isend": + self.send(op.tensor, op.group_peer, stream) + elif op.op.__name__ == "irecv": + self.recv(op.tensor, op.group_peer, stream) + self.group_end() diff --git a/vllm/distributed/elastic_ep/__init__.py b/vllm/distributed/elastic_ep/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/distributed/elastic_ep/elastic_execute.py b/vllm/distributed/elastic_ep/elastic_execute.py new file mode 100644 index 000000000000..515f9b5817ae --- /dev/null +++ b/vllm/distributed/elastic_ep/elastic_execute.py @@ -0,0 +1,440 @@ +import gc +import weakref +from typing import Optional, Iterable, Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed import P2POp + +from vllm.compilation.counter import compilation_counter +from vllm.compilation.cuda_graph import CUDAGraphWrapper +from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.config import CompilationLevel, set_current_vllm_config, get_current_vllm_config +from vllm.distributed import ( + get_dp_group, + get_ep_group, + get_tp_group, + get_standby_dp_group, + get_standby_ep_group, +) +from vllm.distributed.parallel_state import ( + create_standby_groups, + switch_to_standby_groups, + prepare_communication_buffer_for_model, +) +from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEParallelConfig +from vllm.v1.engine import ReconfigureRankType, ReconfigureDistributedRequest + + +logger = init_logger(__name__) + + +def batch_transfer_weights( + model: nn.Module, + is_sender: bool, + peer_rank: int, + dp_group: StatelessGroupCoordinator, + expert_weights: Sequence[Iterable[torch.Tensor]], +) -> None: + device_comm = dp_group.device_communicator + if device_comm is None: + raise ValueError("No device communicator found") + + expert_weights_set = set() + for weight_group in expert_weights: + for weight in weight_group: + expert_weights_set.add(weight.data_ptr()) + + state_dict = model.state_dict() + all_params = [] + + for _, param in state_dict.items(): + if param.data_ptr() not in expert_weights_set: + all_params.append(param.data) + + if all_params: + p2p_ops = [] + for param in all_params: + op = object.__new__(P2POp) + if is_sender: + op.op = torch.distributed.isend + op.tensor = param + else: + op.op = torch.distributed.irecv + op.tensor = param + op.group_peer = peer_rank + p2p_ops.append(op) + + device_comm.batch_isend_irecv(p2p_ops) + + +def broadcast_expert_mapping( + physical_to_logical: Optional[torch.Tensor], + num_local_physical_experts: Optional[int], + num_logical_experts: Optional[int], + dp_group: StatelessGroupCoordinator, + device: torch.device, + src_rank: int = 0, +) -> tuple[torch.Tensor, int, int]: + if dp_group.rank_in_group == src_rank: + assert physical_to_logical is not None + assert num_local_physical_experts is not None + assert num_logical_experts is not None + assert physical_to_logical.dtype == torch.int64 + shape_tensor = torch.tensor(list(physical_to_logical.shape), dtype=torch.int64, device='cpu') + metadata_tensor = torch.tensor([num_local_physical_experts, num_logical_experts], dtype=torch.int64, device='cpu') + else: + shape_tensor = torch.empty(2, dtype=torch.int64, device='cpu') + metadata_tensor = torch.empty(2, dtype=torch.int64, device='cpu') + + shape_tensor = dp_group.tcp_store_group.broadcast(shape_tensor, src_rank) + metadata_tensor = dp_group.tcp_store_group.broadcast(metadata_tensor, src_rank) + + if dp_group.rank_in_group != src_rank: + assert device is not None + physical_to_logical = torch.empty( + tuple(shape_tensor.tolist()), + dtype=torch.int64, + device=device, + ) + + assert physical_to_logical is not None + physical_to_logical = dp_group.broadcast(physical_to_logical, src_rank) + num_local_physical_experts = int(metadata_tensor[0].item()) + num_logical_experts = int(metadata_tensor[1].item()) + + return physical_to_logical, num_local_physical_experts, num_logical_experts + + +def clear_compile_and_cache(model: nn.Module) -> None: + if not isinstance(model, TorchCompileWrapperWithCustomDispatcher): + if hasattr(model, 'model'): + model = model.model + if not isinstance(model, TorchCompileWrapperWithCustomDispatcher): + return + if model.do_not_compile: + return + # reset the compilation counter + compilation_counter.num_models_seen = 0 + compilation_counter.num_graphs_seen = 0 + compilation_counter.num_piecewise_graphs_seen = 0 + compilation_counter.num_piecewise_capturable_graphs_seen = 0 + compilation_counter.num_backend_compilations = 0 + compilation_counter.num_gpu_runner_capture_triggers = 0 + compilation_counter.num_cudagraph_captured = 0 + compilation_counter.num_inductor_compiles = 0 + compilation_counter.num_eager_compiles = 0 + compilation_counter.num_cache_entries_updated = 0 + compilation_counter.num_compiled_artifacts_saved = 0 + compilation_counter.dynamo_as_is_count = 0 + compilation_level = get_current_vllm_config().compilation_config.level + TorchCompileWrapperWithCustomDispatcher.__init__( + model, + compilation_level=compilation_level + ) + + +class ElasticScalingExecutor: + def __init__(self, worker): + self.worker_ref = weakref.ref(worker) + self.reconfig_request = None + + @property + def worker(self): + worker = self.worker_ref() + if worker is None: + raise RuntimeError("Worker has been garbage collected") + return worker + + def execute(self, execute_method: str, *args, **kwargs): + method = getattr(self, execute_method, None) + if method is None: + raise ValueError(f"Unknown execute method: {execute_method}") + return method(*args, **kwargs) + + def create_standby_groups(self, reconfig_request: ReconfigureDistributedRequest) -> None: + self.reconfig_request = reconfig_request + new_dp_size = reconfig_request.new_data_parallel_size + world_size = self.worker.vllm_config.parallel_config.world_size + new_world_size_across_dp = world_size * new_dp_size + # TODO(yongji): check whether we need to use updated vllm_config here + with set_current_vllm_config(self.worker.vllm_config): + create_standby_groups( + new_dp_size=new_dp_size, + new_world_size_across_dp=new_world_size_across_dp, + master_ip=reconfig_request.new_data_parallel_master_ip, + world_group_ports=reconfig_request.new_stateless_world_group_port_list, + dp_group_ports=reconfig_request.new_stateless_dp_group_port_list, + ep_group_ports=reconfig_request.new_stateless_ep_group_port_list, + ) + self.worker.model_runner.eplb_disabled = True + if get_standby_ep_group().rank == 0: + logger.info("[Elastic EP] EPLB disabled during elastic scaling transition") + + def transfer_weights(self, old_dp_size: int, new_dp_size: int) -> None: + standby_dp_group = get_standby_dp_group() + + # Broadcast old_dp_size to all workers in standby group + if standby_dp_group.rank_in_group < old_dp_size: + old_dp_size_tensor = torch.tensor([old_dp_size], dtype=torch.int64, device='cpu') + else: + old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device='cpu') + old_dp_size_tensor = standby_dp_group.tcp_store_group.broadcast(old_dp_size_tensor, 0) + + num_new_workers = new_dp_size - old_dp_size + dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank + + ranks_to_send = [] + num_dst_per_sender = num_new_workers // old_dp_size + sender_pos = dp_rank + recv_begin = sender_pos * num_dst_per_sender + recv_end = recv_begin + num_dst_per_sender + ranks_to_send = list(range(old_dp_size + recv_begin, old_dp_size + recv_end)) + + remainder_start = old_dp_size * num_dst_per_sender + recver_pos = remainder_start + sender_pos + if recver_pos < num_new_workers: + ranks_to_send.append(old_dp_size + recver_pos) + + model = self.worker.model_runner.get_model() + for new_worker_rank in sorted(ranks_to_send): + batch_transfer_weights( + model=model, + is_sender=True, + peer_rank=new_worker_rank, + dp_group=standby_dp_group, + expert_weights=model.expert_weights, + ) + torch.cuda.synchronize() + + def broadcast_expert_mapping(self) -> None: + standby_dp_group = get_standby_dp_group() + physical_to_logical = self.worker.model_runner.eplb_state.physical_to_logical_map + num_physical_experts = physical_to_logical.shape[1] + num_local_physical_experts = num_physical_experts // get_ep_group().world_size + num_logical_experts = self.worker.model_runner.eplb_state.logical_replica_count.shape[1] + broadcast_expert_mapping( + physical_to_logical=physical_to_logical, + num_local_physical_experts=num_local_physical_experts, + num_logical_experts=num_logical_experts, + dp_group=standby_dp_group, + src_rank=0, + device=self.worker.device, + ) + + def switch_and_prepare(self) -> None: + from vllm.platforms import current_platform + + old_dp_size = get_dp_group().world_size + old_ep_size = get_ep_group().world_size + + switch_to_standby_groups() + + parallel_config = self.worker.vllm_config.parallel_config + reconfig_request = self.reconfig_request + assert reconfig_request is not None + new_dp_size = reconfig_request.new_data_parallel_size + new_ep_size = get_ep_group().world_size + + parallel_config.data_parallel_size = new_dp_size + if reconfig_request.new_data_parallel_rank != ReconfigureRankType.KEEP_CURRENT_RANK: + parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank + if reconfig_request.new_data_parallel_rank_local != ReconfigureRankType.KEEP_CURRENT_RANK: + parallel_config.data_parallel_rank_local = reconfig_request.new_data_parallel_rank_local + parallel_config.data_parallel_master_ip = reconfig_request.new_data_parallel_master_ip + parallel_config.data_parallel_master_port = reconfig_request.new_data_parallel_master_port + + # Reconfigure MoE modules with new EP size + moe_modules = [ + module for module in self.worker.model_runner.model.modules() + if (module.__class__.__name__ == "FusedMoE" + or module.__class__.__name__ == "SharedFusedMoE") + ] + num_local_experts = moe_modules[0].moe_config.num_local_experts + assert all(module.moe_config.num_local_experts == num_local_experts + for module in moe_modules), ( + "All MoE modules must have the same number of experts") + for module in moe_modules: + module.moe_config.num_experts = num_local_experts * new_ep_size + module.global_num_experts = module.moe_config.num_experts + module.moe_parallel_config = FusedMoEParallelConfig.make( + tp_size_=get_tp_group().world_size, + dp_size_=get_dp_group().world_size, + vllm_parallel_config=parallel_config, + ) + module.moe_config.moe_parallel_config = module.moe_parallel_config + if hasattr(module.quant_method, 'topk_indices_dtype'): + module.quant_method.topk_indices_dtype = None + module.quant_method.fused_experts = None + + # Update EPLB state + eplb_state = self.worker.model_runner.eplb_state + assert eplb_state is not None + num_physical_experts = num_local_experts * new_ep_size + num_logical_experts = eplb_state.logical_replica_count.shape[1] + parallel_config.eplb_config.num_redundant_experts = num_physical_experts - num_logical_experts + old_physical_to_logical = eplb_state.physical_to_logical_map + num_moe_layers = old_physical_to_logical.shape[0] + num_local_experts = old_physical_to_logical.shape[1] // old_ep_size + if new_dp_size > old_dp_size: + expanded_physical_to_logical = torch.full( + (num_moe_layers, num_local_experts * new_ep_size), + -1, + dtype=old_physical_to_logical.dtype, + device=old_physical_to_logical.device + ) + expanded_physical_to_logical[:, :num_local_experts * old_ep_size] = old_physical_to_logical + eplb_state.physical_to_logical_map = expanded_physical_to_logical + + old_num_physical_experts = eplb_state.expert_load_pass.shape[1] + pad_size = num_physical_experts - old_num_physical_experts + expanded_expert_load_pass = F.pad( + eplb_state.expert_load_pass, + (0, pad_size), + value=0 + ) + expanded_expert_load_window = F.pad( + eplb_state.expert_load_window, + (0, pad_size), + value=0 + ) + eplb_state.expert_load_pass = expanded_expert_load_pass + eplb_state.expert_load_window = expanded_expert_load_window + eplb_state.num_valid_physical_experts = old_num_physical_experts + + model = self.worker.model_runner.get_model() + model.expert_weights = [] + model.set_eplb_state( + expanded_expert_load_pass, + eplb_state.logical_to_physical_map, + eplb_state.logical_replica_count, + ) + model.update_physical_experts_metadata( + num_physical_experts=num_physical_experts, + num_local_physical_experts=num_local_experts, + ) + + prepare_communication_buffer_for_model(self.worker.model_runner.model) + if (self.worker.vllm_config.compilation_config.level == CompilationLevel.DYNAMO_AS_IS and + current_platform.is_cuda_alike()): + backend = self.worker.vllm_config.compilation_config.init_backend(self.worker.vllm_config) + compilation_counter.dynamo_as_is_count += 1 + self.worker.model_runner.model.compile(fullgraph=True, backend=backend) + + # release all previously captured CUDA graphs + if isinstance(self.worker.model_runner.model, CUDAGraphWrapper): + # TODO(yongji): do we need to reset graph pool here? + wrapper = self.worker.model_runner.model + wrapper.concrete_cudagraph_entries = {} + elif isinstance(self.worker.model_runner.model, UBatchWrapper): + raise RuntimeError("DBO is not yet supported in elastic EP") + + # clear all torch.compile + with set_current_vllm_config(self.worker.vllm_config): + clear_compile_and_cache(self.worker.model_runner.get_model()) + + gc.collect() + torch.cuda.empty_cache() + self.worker.compile_or_warm_up_model() + + def perform_eplb_reshuffle(self, new_dp_size: Optional[int] = None) -> None: + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Starting expert resharding...") + + model = self.worker.model_runner.get_model() + assert self.worker.model_runner.eplb_state is not None + + if new_dp_size is None: + self.worker.model_runner.eplb_state.rearrange(model) + else: + # scale down + parallel_config = self.worker.vllm_config.parallel_config + tp_size = parallel_config.tensor_parallel_size + old_ep_size = parallel_config.data_parallel_size * tp_size + new_ep_size = new_dp_size * tp_size + + rank_mapping = { + old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1 + for old_ep_rank in range(old_ep_size) + } + + self.worker.model_runner.eplb_state.rearrange( + model, + rank_mapping=rank_mapping + ) + # NOTE(yongji): check whether we need to synchronize here + torch.cuda.synchronize() + # reset expert_rearrangement_step to ensure all ranks are synchronized + self.worker.model_runner.eplb_state.expert_rearrangement_step = 0 + self.worker.model_runner.eplb_disabled = False + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Expert resharding completed") + + def receive_weights(self) -> None: + dp_group = get_dp_group() + new_dp_size = dp_group.world_size + dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank + + # Receive old_dp_size broadcasted during transfer_weights + old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device='cpu') + old_dp_size_tensor = dp_group.tcp_store_group.broadcast(old_dp_size_tensor, 0) + old_dp_size = int(old_dp_size_tensor[0].item()) + + # Calculate which existing worker will send to this new worker + num_new_workers = new_dp_size - old_dp_size + new_worker_idx = dp_rank - old_dp_size + num_dst_per_sender = num_new_workers // old_dp_size + remainder = num_new_workers % old_dp_size + + if new_worker_idx < remainder * (num_dst_per_sender + 1): + sender_rank = new_worker_idx // (num_dst_per_sender + 1) + else: + sender_rank = remainder + (new_worker_idx - remainder * (num_dst_per_sender + 1)) // num_dst_per_sender + + model = self.worker.model_runner.get_model() + batch_transfer_weights( + model=model, + is_sender=False, + peer_rank=sender_rank, + dp_group=dp_group, + expert_weights=model.expert_weights, + ) + torch.cuda.synchronize() + + def receive_expert_mapping(self) -> tuple[torch.Tensor, int, int]: + physical_to_logical, num_local_physical_experts, num_logical_experts = broadcast_expert_mapping( + physical_to_logical=None, + num_local_physical_experts=None, + num_logical_experts=None, + dp_group=get_dp_group(), + src_rank=0, + device=self.worker.device, + ) + num_moe_layers = physical_to_logical.shape[0] + new_dp_size = get_dp_group().world_size + tp_size = self.worker.vllm_config.parallel_config.tensor_parallel_size + new_ep_size = new_dp_size * tp_size + expanded_physical_to_logical = torch.full( + (num_moe_layers, num_local_physical_experts * new_ep_size), + -1, + dtype=physical_to_logical.dtype, + device=physical_to_logical.device + ) + old_num_physical_experts = physical_to_logical.shape[1] + expanded_physical_to_logical[:, :old_num_physical_experts] = physical_to_logical + return expanded_physical_to_logical, num_logical_experts, old_num_physical_experts + + def prepare_new_worker(self) -> None: + prepare_communication_buffer_for_model(self.worker.model_runner.get_model()) + + from vllm.platforms import current_platform + if (self.worker.vllm_config.compilation_config.level == CompilationLevel.DYNAMO_AS_IS and + current_platform.is_cuda_alike()): + backend = self.worker.vllm_config.compilation_config.init_backend(self.worker.vllm_config) + compilation_counter.dynamo_as_is_count += 1 + self.worker.model_runner.get_model().compile(fullgraph=True, backend=backend) + diff --git a/vllm/distributed/elastic_ep/elastic_state.py b/vllm/distributed/elastic_ep/elastic_state.py new file mode 100644 index 000000000000..8f047c03928f --- /dev/null +++ b/vllm/distributed/elastic_ep/elastic_state.py @@ -0,0 +1,306 @@ +import enum +import weakref +from typing import TYPE_CHECKING, Literal, Optional + +from vllm.logger import init_logger +from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType +from vllm.v1.engine.core import DPEngineCoreProc +from vllm.distributed import stateless_destroy_torch_distributed_process_group +from vllm.config import ParallelConfig + + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.v1.executor.abstract import Executor + +logger = init_logger(__name__) + +WorkerType = Literal["existing", "new", "shutdown"] + + +class ScaleUpExistingWorkerState(enum.IntEnum): + WAIT_NEW_WORKERS_INIT = 0 + CREATE_STANDBY_GROUPS = 1 + TRANSFER_EXPERT_MAPPING = 2 + WAIT_NEW_WORKERS_WEIGHTS_INIT = 3 + TRANSFER_WEIGHTS = 4 + SYNC_KV_CACHE_MEMORY = 5 + SWITCH_AND_PREPARE = 6 + EPLB_RESHUFFLE = 7 + COMPLETE = 8 + + +class ScaleUpNewWorkerState(enum.IntEnum): + FINISH_BOOTUP = 0 + EPLB_RESHUFFLE = 1 + COMPLETE = 2 + + +class ScaleDownRemainingWorkerState(enum.IntEnum): + CREATE_STANDBY_GROUPS = 0 + EPLB_RESHUFFLE = 1 + SWITCH_AND_PREPARE = 2 + COMPLETE = 3 + + +class ScaleDownShutdownWorkerState(enum.IntEnum): + PREPARE = 0 + EPLB_RESHUFFLE = 1 + COMPLETE = 2 + + +class ElasticScalingState: + + def __init__( + self, + model_executor: "Executor", + engine_core: "DPEngineCoreProc", + vllm_config: "VllmConfig", + new_parallel_config: ParallelConfig, + worker_type: WorkerType, + scale_type: Literal["scale_up", "scale_down"], + reconfig_request: Optional[ReconfigureDistributedRequest] = None, + ): + self.model_executor_ref = weakref.ref(model_executor) + self.engine_core_ref = weakref.ref(engine_core) + self.vllm_config = vllm_config + self.old_dp_group = self.engine_core.dp_group if worker_type != "new" else None + self.old_dp_store = self.engine_core.dp_store if worker_type != "new" else None + self.new_dp_group = self.engine_core.dp_group if worker_type == "new" else new_parallel_config + self.new_dp_store = self.engine_core.dp_store if worker_type == "new" else None + self.worker_type = worker_type + self.scale_type = scale_type + self.reconfig_request = reconfig_request + self.waiting_for_notification = False + + if scale_type == "scale_up": + self.state = (ScaleUpNewWorkerState.EPLB_RESHUFFLE if worker_type == "new" + else ScaleUpExistingWorkerState.WAIT_NEW_WORKERS_INIT) + else: + self.state = (ScaleDownShutdownWorkerState.EPLB_RESHUFFLE if worker_type == "shutdown" + else ScaleDownRemainingWorkerState.CREATE_STANDBY_GROUPS) + + @property + def model_executor(self) -> "Executor": + model_executor = self.model_executor_ref() + if model_executor is None: + raise RuntimeError("Model executor has been garbage collected") + return model_executor + + @property + def engine_core(self) -> "DPEngineCoreProc": + engine_core = self.engine_core_ref() + if engine_core is None: + raise RuntimeError("Engine core has been garbage collected") + return engine_core + + def progress(self) -> bool: + if self.waiting_for_notification: + return False + + if self.scale_type == "scale_up": + return (self._progress_new_worker() if self.worker_type == "new" + else self._progress_existing_worker()) + return (self._progress_shutdown_worker() if self.worker_type == "shutdown" + else self._progress_remaining_worker()) + + def _progress_existing_worker(self) -> bool: + state = self.state + + if state == ScaleUpExistingWorkerState.WAIT_NEW_WORKERS_INIT: + self.waiting_for_notification = True + return False + + elif state == ScaleUpExistingWorkerState.CREATE_STANDBY_GROUPS: + # NOTE(yongji): wait for all exisiting workers to receive the request + if int(self.old_dp_store.get("elastic_ep_request_first_barrier")) < self.old_dp_group.size(): + return False + self._create_standby_groups() + self.state = ScaleUpExistingWorkerState.TRANSFER_EXPERT_MAPPING + return True + + elif state == ScaleUpExistingWorkerState.TRANSFER_EXPERT_MAPPING: + self._transfer_expert_mapping() + self.state = ScaleUpExistingWorkerState.WAIT_NEW_WORKERS_WEIGHTS_INIT + self.old_dp_store.add("elastic_ep_request_second_barrier", 1) + return True + + elif state == ScaleUpExistingWorkerState.WAIT_NEW_WORKERS_WEIGHTS_INIT: + self.waiting_for_notification = True + return False + + elif state == ScaleUpExistingWorkerState.TRANSFER_WEIGHTS: + if int(self.old_dp_store.get("elastic_ep_request_second_barrier")) < self.old_dp_group.size(): + return False + self._transfer_weights() + self.state = ScaleUpExistingWorkerState.SYNC_KV_CACHE_MEMORY + return True + + elif state == ScaleUpExistingWorkerState.SYNC_KV_CACHE_MEMORY: + self._sync_kv_cache_memory() + self.state = ScaleUpExistingWorkerState.SWITCH_AND_PREPARE + return True + + elif state == ScaleUpExistingWorkerState.SWITCH_AND_PREPARE: + self._switch_and_prepare() + self.new_dp_store.add("elastic_ep_request_third_barrier", 1) + self.state = ScaleUpExistingWorkerState.EPLB_RESHUFFLE + return True + + elif state == ScaleUpExistingWorkerState.EPLB_RESHUFFLE: + if int(self.new_dp_store.get("elastic_ep_request_third_barrier")) < self.new_dp_group.size(): + return False + self._eplb_reshuffle() + self.state = ScaleUpExistingWorkerState.COMPLETE + self._update_parallel_config() + return True + + return False + + def _progress_new_worker(self) -> bool: + state = self.state + + if state == ScaleUpNewWorkerState.FINISH_BOOTUP: + self.new_dp_store.add("elastic_ep_request_third_barrier", 1) + self.state = ScaleUpNewWorkerState.EPLB_RESHUFFLE + return True + + elif state == ScaleUpNewWorkerState.EPLB_RESHUFFLE: + if int(self.new_dp_store.get("elastic_ep_request_third_barrier")) < self.new_dp_group.size(): + return False + assert self.new_dp_group.rank() > 0 + self._eplb_reshuffle() + self.state = ScaleUpNewWorkerState.COMPLETE + return True + + return False + + def _progress_remaining_worker(self) -> bool: + state = self.state + + if state == ScaleDownRemainingWorkerState.CREATE_STANDBY_GROUPS: + if int(self.old_dp_store.get("elastic_ep_request_first_barrier")) < self.old_dp_group.size(): + return False + self._create_standby_groups() + self.state = ScaleDownRemainingWorkerState.EPLB_RESHUFFLE + return True + + elif state == ScaleDownRemainingWorkerState.EPLB_RESHUFFLE: + self._eplb_reshuffle_before_scale_down() + self.state = ScaleDownRemainingWorkerState.SWITCH_AND_PREPARE + return True + + elif state == ScaleDownRemainingWorkerState.SWITCH_AND_PREPARE: + self._switch_and_prepare() + self.state = ScaleDownRemainingWorkerState.COMPLETE + return True + + return False + + def _progress_shutdown_worker(self) -> bool: + state = self.state + + if state == ScaleDownShutdownWorkerState.PREPARE: + if int(self.old_dp_store.get("elastic_ep_request_first_barrier")) < self.old_dp_group.size(): + return False + assert self.old_dp_group.rank() > 0 + self.state = ScaleDownShutdownWorkerState.EPLB_RESHUFFLE + return True + + if state == ScaleDownShutdownWorkerState.EPLB_RESHUFFLE: + self._eplb_reshuffle_before_scale_down() + self.state = ScaleDownShutdownWorkerState.COMPLETE + self.engine_core.shutdown() + return True + + return False + + def handle_notification(self, notification_type: str): + assert self.worker_type != 'new' + if (notification_type == "NEW_WORKERS_INIT_READY" and + self.state == ScaleUpExistingWorkerState.WAIT_NEW_WORKERS_INIT): + self.waiting_for_notification = False + self.state = ScaleUpExistingWorkerState.CREATE_STANDBY_GROUPS + elif (notification_type == "NEW_WORKERS_WEIGHTS_INIT_READY" and + self.state == ScaleUpExistingWorkerState.WAIT_NEW_WORKERS_WEIGHTS_INIT): + self.waiting_for_notification = False + self.state = ScaleUpExistingWorkerState.TRANSFER_WEIGHTS + + def is_complete(self) -> bool: + if self.scale_type == "scale_up": + return (self.state == ScaleUpNewWorkerState.COMPLETE if self.worker_type == "new" + else self.state == ScaleUpExistingWorkerState.COMPLETE) + return (self.state == ScaleDownShutdownWorkerState.COMPLETE if self.worker_type == "shutdown" + else self.state == ScaleDownRemainingWorkerState.COMPLETE) + + def _create_standby_groups(self): + assert isinstance(self.new_dp_group, ParallelConfig) + self.new_dp_group, self.new_dp_store = self.new_dp_group.stateless_init_dp_group(return_store=True) + self.model_executor.collective_rpc( + "elastic_ep_execute", + args=("create_standby_groups", self.reconfig_request) + ) + logger.info("[Elastic EP] Created standby communication groups") + + def _transfer_weights(self): + old_dp_size = self.old_dp_group.size() + new_dp_size = self.reconfig_request.new_data_parallel_size + + self.model_executor.collective_rpc( + "elastic_ep_execute", + args=("transfer_weights", old_dp_size, new_dp_size) + ) + logger.info("[Elastic EP] Transferred weights to new workers") + + def _transfer_expert_mapping(self): + self.model_executor.collective_rpc( + "elastic_ep_execute", + args=("broadcast_expert_mapping",) + ) + logger.info("[Elastic EP] Broadcasted expert mapping to new workers") + + def _sync_kv_cache_memory(self): + assert self.engine_core.available_gpu_memory_for_kv_cache > 0 + ParallelConfig.sync_kv_cache_memory_size( + self.new_dp_group, self.engine_core.available_gpu_memory_for_kv_cache) + logger.info("[Elastic EP] Synced KV cache memory size to new workers") + + def _switch_and_prepare(self): + self.model_executor.collective_rpc( + "elastic_ep_execute", + args=("switch_and_prepare",) + ) + old_dp_group = self.old_dp_group + stateless_destroy_torch_distributed_process_group(old_dp_group) + new_dp_group = self.new_dp_group + self.engine_core.dp_group = new_dp_group + self.engine_core.dp_rank = new_dp_group.rank() + self.engine_core.dp_store = self.new_dp_store + logger.info("[Elastic EP] Switched to new comm group and prepare model for new setup") + + def _eplb_reshuffle(self): + self.model_executor.collective_rpc("elastic_ep_execute", args=("perform_eplb_reshuffle",)) + logger.info("[Elastic EP] EPLB reshuffle completed") + + def _eplb_reshuffle_before_scale_down(self): + assert self.reconfig_request is not None + self.model_executor.collective_rpc( + "elastic_ep_execute", + args=("perform_eplb_reshuffle", self.reconfig_request.new_data_parallel_size) + ) + logger.info("[Elastic EP] EPLB reshuffle completed") + + def _update_parallel_config(self): + reconfig_request = self.reconfig_request + parallel_config = self.vllm_config.parallel_config + parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size + if reconfig_request.new_data_parallel_rank != ReconfigureRankType.KEEP_CURRENT_RANK: + parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank + if reconfig_request.new_data_parallel_rank_local != ReconfigureRankType.KEEP_CURRENT_RANK: + parallel_config.data_parallel_rank_local = reconfig_request.new_data_parallel_rank_local + parallel_config.data_parallel_master_ip = reconfig_request.new_data_parallel_master_ip + parallel_config.data_parallel_master_port = reconfig_request.new_data_parallel_master_port + parallel_config._data_parallel_master_port_list = reconfig_request.new_data_parallel_master_port_list + parallel_config._stateless_world_group_port_list = reconfig_request.new_stateless_world_group_port_list + parallel_config._stateless_dp_group_port_list = reconfig_request.new_stateless_dp_group_port_list + parallel_config._stateless_ep_group_port_list = reconfig_request.new_stateless_ep_group_port_list diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 3e318d784832..c0f01c3ac6c4 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -37,6 +37,7 @@ from vllm.config import ParallelConfig from vllm.distributed.parallel_state import (get_ep_group, get_node_count, in_the_same_node_as) +from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.model_executor.models.interfaces import MixtureOfExperts @@ -156,6 +157,15 @@ class EplbState: Interval for expert rearrangement steps. This is a constant and is taken from the config. """ + + num_valid_physical_experts: int = 0 + """ + Number of valid physical experts. + This is the number of physical experts that are + actually mapped to logical experts. In elastic EP, + newly started EP ranks may not have physical experts + mapped yet. + """ @staticmethod def build_initial_global_physical_to_logical_map( @@ -183,9 +193,6 @@ def build( model: MixtureOfExperts, device: torch.device, parallel_config: ParallelConfig, - global_expert_load: Optional[torch.Tensor] = None, - old_global_expert_indices: Optional[torch.Tensor] = None, - rank_mapping: Optional[dict[int, int]] = None, ) -> "EplbState": """ Build the initial EPLB state. @@ -257,63 +264,11 @@ def build( expert_rearrangement_step = max( 0, eplb_step_interval - eplb_step_interval // 4) - if global_expert_load is not None: - ep_group = get_ep_group().device_group - assert global_expert_load.shape == (model.num_moe_layers, - model.num_logical_experts) - assert global_expert_load.dtype == torch.int64 - - num_replicas = model.num_physical_experts - num_groups = model.num_expert_groups - num_nodes = get_node_count() - num_gpus = ep_group.size() - - if num_gpus % num_nodes != 0: - num_nodes = 1 - logger.warning_once( - f"num_gpus % num_nodes != 0, " - "not using hierarchical rearrangement algorithm.\n" - f"{num_gpus=}, {num_nodes=}") - - # Get new expert mappings - ( - new_physical_to_logical_map, - new_logical_to_physical_map, - new_logical_replica_count, - ) = (rebalance_experts( - global_expert_load, - num_replicas, - num_groups, - num_nodes, - num_gpus, - )) - - max_physical_slots = new_logical_to_physical_map.shape[-1] - assert max_physical_slots <= logical_to_physical_map.shape[-1] - new_logical_to_physical_map = torch.nn.functional.pad( - new_logical_to_physical_map, - (0, logical_to_physical_map.shape[-1] - max_physical_slots), - value=-1, - ) - physical_to_logical_map = new_physical_to_logical_map.to(device) - logical_to_physical_map.copy_(new_logical_to_physical_map) - logical_replica_count.copy_(new_logical_replica_count) - model.set_eplb_state( expert_load_pass, logical_to_physical_map, logical_replica_count, ) - if global_expert_load is not None: - rearrange_expert_weights_inplace( - old_global_expert_indices, - new_physical_to_logical_map, - model.expert_weights, - ep_group, - False, - rank_mapping, - ) - expert_rearrangement_step = 0 return cls( physical_to_logical_map, @@ -324,6 +279,7 @@ def build( expert_load_window_size=expert_load_window_size, expert_rearrangement_step=expert_rearrangement_step, expert_rearrangement_step_interval=eplb_step_interval, + num_valid_physical_experts=model.num_physical_experts ) def step(self, @@ -414,8 +370,6 @@ def rearrange( self, model: MixtureOfExperts, is_profile: bool = False, - execute_shuffle: bool = True, - global_expert_load: Optional[torch.Tensor] = None, rank_mapping: Optional[dict[int, int]] = None) -> Optional[torch.Tensor]: """ @@ -433,49 +387,26 @@ def rearrange( logger.info("Rearranging experts %s...", "(profile)" if is_profile else "") - if global_expert_load is None: - # Map the physical expert load to global logical experts - logical_expert_load_window = torch.zeros( - self.expert_load_window_size, - model.num_moe_layers, - model.num_logical_experts, - dtype=self.expert_load_window.dtype, - device=self.expert_load_window.device, - ) - logical_expert_load_window.scatter_add_( - dim=-1, - index=self.physical_to_logical_map.unsqueeze(0).expand_as( - self.expert_load_window).long(), - src=self.expert_load_window, - ) + # Map the physical expert load to global logical experts + expert_load_window = self.expert_load_window[:, :, :self.num_valid_physical_experts] - if not execute_shuffle: - metadata = torch.tensor( - [ - model.num_moe_layers, model.num_logical_experts, - self.physical_to_logical_map.shape[1] - ], - dtype=torch.int32, - device="cpu", - ) - torch.distributed.broadcast(metadata, - group=get_ep_group().cpu_group, - group_src=0) - - # Perform all-reduce to get the expert load across all ranks - global_expert_load_window = logical_expert_load_window.sum(dim=0) - all_reduce(global_expert_load_window, group=ep_group) - - if not execute_shuffle: - # (num_moe_layers, old_num_physical_experts) - old_global_expert_indices = self.physical_to_logical_map - torch.distributed.broadcast(old_global_expert_indices, - group=ep_group, - group_src=0) - return global_expert_load_window - else: - assert execute_shuffle - global_expert_load_window = global_expert_load + logical_expert_load_window = torch.zeros( + self.expert_load_window_size, + model.num_moe_layers, + model.num_logical_experts, + dtype=self.expert_load_window.dtype, + device=self.expert_load_window.device, + ) + logical_expert_load_window.scatter_add_( + dim=-1, + index=self.physical_to_logical_map[:, :self.num_valid_physical_experts].unsqueeze(0).expand_as( + expert_load_window).long(), + src=expert_load_window, + ) + + # Perform all-reduce to get the expert load across all ranks + global_expert_load_window = logical_expert_load_window.sum(dim=0) + all_reduce(global_expert_load_window, group=ep_group) # TODO(bowen): Treat differently for prefill and decode nodes num_replicas = model.num_physical_experts @@ -553,34 +484,49 @@ def rearrange( ) return None - @staticmethod - def recv_state() -> tuple[torch.Tensor, torch.Tensor]: - """ - Receive the expert load and old placement from the master rank. - """ - ep_group = get_ep_group() - metadata = torch.empty(3, dtype=torch.int32, device="cpu") - torch.distributed.broadcast(metadata, - group=ep_group.cpu_group, - group_src=0) - num_moe_layers, num_logical_experts, num_old_physical_experts = ( - metadata.tolist()) - global_expert_load = torch.zeros( - (num_moe_layers, num_logical_experts), + @classmethod + def from_mapping( + cls, + model: MixtureOfExperts, + device: torch.device, + parallel_config: ParallelConfig, + expanded_physical_to_logical: torch.Tensor, + num_valid_physical_experts: int, + ) -> "EplbState": + eplb_state = cls.build( + model=model, + device=device, + parallel_config=parallel_config, + ) + eplb_state.num_valid_physical_experts = num_valid_physical_experts + num_moe_layers = expanded_physical_to_logical.shape[0] + num_physical_experts = expanded_physical_to_logical.shape[1] + eplb_state.physical_to_logical_map.copy_(expanded_physical_to_logical) + + logical_to_physical_map = torch.full( + (num_moe_layers, model.num_logical_experts, + eplb_state.logical_to_physical_map.shape[2]), + -1, dtype=torch.int64, - device=ep_group.device, ) - all_reduce(global_expert_load, group=ep_group.device_group) - old_global_expert_indices = torch.empty( - (num_moe_layers, num_old_physical_experts), + logical_replica_count = torch.zeros( + (num_moe_layers, model.num_logical_experts), dtype=torch.int64, - device=ep_group.device, ) - torch.distributed.broadcast(old_global_expert_indices, - group=ep_group.device_group, - group_src=0) - - return global_expert_load, old_global_expert_indices + expanded_physical_to_logical_numpy = expanded_physical_to_logical.cpu().numpy() + for layer_idx in range(num_moe_layers): + for phys_idx in range(num_physical_experts): + logical_idx = expanded_physical_to_logical_numpy[layer_idx, phys_idx] + if logical_idx >= 0: + replica_idx = logical_replica_count[layer_idx, logical_idx] + logical_to_physical_map[layer_idx, logical_idx, replica_idx] = phys_idx + logical_replica_count[layer_idx, logical_idx] += 1 + + logical_to_physical_map = logical_to_physical_map.to(device) + logical_replica_count = logical_replica_count.to(device) + eplb_state.logical_to_physical_map.copy_(logical_to_physical_map) + eplb_state.logical_replica_count.copy_(logical_replica_count) + return eplb_state def _node_count_with_rank_mapping( diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index f8a7d1170bb0..d6d434aa3cf2 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -13,6 +13,9 @@ import torch from torch.distributed import (P2POp, ProcessGroup, all_gather, batch_isend_irecv, get_global_rank) +from torch.distributed.distributed_c10d import _world + +from vllm.distributed.parallel_state import get_ep_group def idx_local_to_global( @@ -137,6 +140,11 @@ def shuffle_layer( buffer[dst].copy_(weight[src]) p2p_ops: list[P2POp] = [] + if ep_group not in _world.pg_map: + ep_group = get_ep_group() + is_stateless = True + else: + is_stateless = False # 2. Initiate sending of weights. experts_send_loc: dict[int, int] = {} @@ -171,14 +179,22 @@ def shuffle_layer( recv_ranks.append(ranks_to_recv[recver_pos]) for dst in recv_ranks: - dst_global = get_global_rank(ep_group, dst) - p2p_ops += [ - P2POp( - torch.distributed.isend, - weight[src], - dst_global, - ) for weight in expert_weights - ] + if is_stateless: + for weight in expert_weights: + op = object.__new__(P2POp) + op.op = torch.distributed.isend + op.tensor = weight[src] + op.group_peer = dst + p2p_ops.append(op) + else: + dst_global = get_global_rank(ep_group, dst) + p2p_ops += [ + P2POp( + torch.distributed.isend, + weight[src], + dst_global, + ) for weight in expert_weights + ] # 3. Initiate receiving of weights. experts_recv_loc: dict[int, int] = {} @@ -210,20 +226,31 @@ def shuffle_layer( else: src = ranks_to_send[recver_pos - remainder_start] - src_global = get_global_rank(ep_group, src) - p2p_ops += [ - P2POp( - torch.distributed.irecv, - weight[dst], - src_global, - ) for weight in expert_weights_buffer - ] + if is_stateless: + for weight in expert_weights_buffer: + op = object.__new__(P2POp) + op.op = torch.distributed.irecv + op.tensor = weight[dst] + op.group_peer = src + p2p_ops.append(op) + else: + src_global = get_global_rank(ep_group, src) + p2p_ops += [ + P2POp( + torch.distributed.irecv, + weight[dst], + src_global, + ) for weight in expert_weights_buffer + ] # 4. Execute the P2P operations. The real communication happens here. if p2p_ops: - reqs = batch_isend_irecv(p2p_ops) - for req in reqs: - req.wait() + if is_stateless: + ep_group.device_communicator.batch_isend_irecv(p2p_ops) + else: + reqs = batch_isend_irecv(p2p_ops) + for req in reqs: + req.wait() # 5. Copy the weights from the buffer back to the original weights. for dst in range(num_local_experts): diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 12571afaa4c1..bacd6cd83ddd 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -982,6 +982,24 @@ def get_pp_group() -> GroupCoordinator: return _PP +_STANDBY_DP: Optional[GroupCoordinator] = None +_STANDBY_EP: Optional[GroupCoordinator] = None +_STANDBY_WORLD: Optional[GroupCoordinator] = None +_STANDBY_WORLD_NODE_COUNT: Optional[int] = None + + +def get_standby_dp_group() -> Optional[GroupCoordinator]: + return _STANDBY_DP + + +def get_standby_ep_group() -> Optional[GroupCoordinator]: + return _STANDBY_EP + + +def get_standby_world_group() -> Optional[GroupCoordinator]: + return _STANDBY_WORLD + + @deprecated("`get_pipeline_model_parallel_group` has been replaced with " "`get_pp_group` and may be removed in v0.12. Please use " "`get_pp_group` instead.") @@ -1026,43 +1044,90 @@ def init_distributed_environment(world_size: int = -1, local_rank: int = -1, backend: str = "nccl", timeout: Optional[timedelta] = None): - logger.debug( + logger.info( "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", world_size, rank, local_rank, distributed_init_method, backend) from vllm.config import get_current_vllm_config config = get_current_vllm_config() + enable_elastic_ep = config.parallel_config.enable_elastic_ep if config is not None and config.parallel_config.data_parallel_size > 1: - parallel_config = config.parallel_config - # adjust to take into account data parallelism - # offset the rank by the data parallel rank - rank = parallel_config.data_parallel_rank * world_size + rank - # adjust the world size to take into account data parallelism - world_size = parallel_config.world_size_across_dp - ip = parallel_config.data_parallel_master_ip - port = parallel_config.get_next_dp_init_port() - distributed_init_method = get_distributed_init_method(ip, port) - logger.info( - "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP", - world_size, rank, distributed_init_method) - if not torch.distributed.is_initialized(): - assert distributed_init_method is not None, ( - "distributed_init_method must be provided when initializing " - "distributed environment") - if not torch.distributed.is_backend_available(backend): - logger.warning( - "Distributed backend %s is not available; " - "falling back to gloo.", backend) - assert torch.distributed.is_gloo_available(), ( - "Fallback Gloo backend is not available.") - backend = "gloo" - # this backend is used for WORLD - torch.distributed.init_process_group( - backend=backend, - init_method=distributed_init_method, - world_size=world_size, - rank=rank, - timeout=timeout) + if enable_elastic_ep: + # NOTE(yongji): In elastic EP, only init PyTorch distributed for TP * PP ranks + # to avoid destorying and re-initializing all communication groups during scaling up/down + if not torch.distributed.is_initialized(): + assert distributed_init_method is not None, ( + "distributed_init_method must be provided when initializing " + "distributed environment") + if not torch.distributed.is_backend_available(backend): + logger.warning( + "Distributed backend %s is not available; " + "falling back to gloo.", backend) + assert torch.distributed.is_gloo_available(), ( + "Fallback Gloo backend is not available.") + backend = "gloo" + + # Initialize PyTorch distributed with only TP * PP ranks + torch.distributed.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, # TP * PP size only + rank=rank % world_size, # rank within TP * PP + timeout=timeout) + cpu_group = torch.distributed.new_group(backend="gloo", timeout=timeout) + if _node_count(cpu_group) > 1: + # NOTE(yongji): StatelessGroupCoordinator uses data_parallel_master_ip + # to initialize all DP/EP groups, hence all ranks within TP/PP group + # must reside on the same node + raise RuntimeError("Elastic EP is not supported with multi-node TP/PP") + else: + parallel_config = config.parallel_config + rank = parallel_config.data_parallel_rank * world_size + rank + world_size = parallel_config.world_size_across_dp + ip = parallel_config.data_parallel_master_ip + port = parallel_config.get_next_dp_init_port() + distributed_init_method = get_distributed_init_method(ip, port) + logger.info( + "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP", + world_size, rank, distributed_init_method) + + if not torch.distributed.is_initialized(): + assert distributed_init_method is not None, ( + "distributed_init_method must be provided when initializing " + "distributed environment") + if not torch.distributed.is_backend_available(backend): + logger.warning( + "Distributed backend %s is not available; " + "falling back to gloo.", backend) + assert torch.distributed.is_gloo_available(), ( + "Fallback Gloo backend is not available.") + backend = "gloo" + torch.distributed.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + timeout=timeout) + else: + # No data parallelism + if not torch.distributed.is_initialized(): + assert distributed_init_method is not None, ( + "distributed_init_method must be provided when initializing " + "distributed environment") + if not torch.distributed.is_backend_available(backend): + logger.warning( + "Distributed backend %s is not available; " + "falling back to gloo.", backend) + assert torch.distributed.is_gloo_available(), ( + "Fallback Gloo backend is not available.") + backend = "gloo" + torch.distributed.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + timeout=timeout) + # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 @@ -1074,15 +1139,40 @@ def init_distributed_environment(world_size: int = -1, else: local_rank = rank global _WORLD, _NODE_COUNT - if _WORLD is None: - ranks = list(range(torch.distributed.get_world_size())) - _WORLD = init_world_group(ranks, local_rank, backend) - _NODE_COUNT = _node_count(_WORLD.cpu_group) - logger.debug("Detected %d nodes in the distributed environment", - _NODE_COUNT) + if enable_elastic_ep: + from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator + # Create stateless world group with all ranks + assert _WORLD is None, "world group already initialized" + parallel_config = config.parallel_config + global_rank = parallel_config.data_parallel_rank * world_size + rank + global_world_size = parallel_config.world_size_across_dp + all_ranks = list(range(global_world_size)) + group_ranks = [all_ranks[i:i+1] for i in range(global_world_size)] + if global_rank in all_ranks: + group_ranks = [all_ranks] + group_ports = [parallel_config.get_next_stateless_world_group_port()] + _WORLD = StatelessGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_device_communicator=False, + group_name="world", + host=parallel_config.data_parallel_master_ip, + group_ports=group_ports, + global_rank=global_rank, + global_world_size=global_world_size + ) + _NODE_COUNT = _node_count(_WORLD.tcp_store_group) else: - assert _WORLD.world_size == torch.distributed.get_world_size(), ( - "world group already initialized with a different world size") + if _WORLD is None: + ranks = list(range(torch.distributed.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend) + _NODE_COUNT = _node_count(_WORLD.cpu_group) + logger.debug("Detected %d nodes in the distributed environment", + _NODE_COUNT) + else: + assert _WORLD.world_size == torch.distributed.get_world_size(), ( + "world group already initialized with a different world size") def initialize_model_parallel( @@ -1115,11 +1205,21 @@ def initialize_model_parallel( ranks 8 to 15 belong to the second box. """ # Get world size and rank. Ensure some consistencies. - assert torch.distributed.is_initialized() - world_size: int = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - backend = backend or torch.distributed.get_backend( - get_world_group().device_group) + from vllm.config import get_current_vllm_config + config = get_current_vllm_config() + enable_elastic_ep = config is not None and config.parallel_config.enable_elastic_ep + + if enable_elastic_ep: + # Use stateless world group for global information + world_size: int = get_world_group().world_size + rank = get_world_group().rank + backend = backend or "nccl" + else: + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + backend = backend or torch.distributed.get_backend( + get_world_group().device_group) data_parallel_size = 1 from vllm.config import get_current_vllm_config @@ -1145,14 +1245,18 @@ def initialize_model_parallel( assert _TP is None, ("tensor model parallel group is already initialized") group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] - + if enable_elastic_ep: + tp_pp_size = tensor_model_parallel_size * pipeline_model_parallel_size + local_all_ranks = torch.arange(tp_pp_size).reshape( + pipeline_model_parallel_size, tensor_model_parallel_size) + group_ranks = local_all_ranks.view(-1, tensor_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] # message queue broadcaster is only used in tensor model parallel group _TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_message_queue_broadcaster=True, group_name="tp") - # Build the DCP model-parallel groups. global _DCP assert _DCP is None, ( @@ -1164,6 +1268,13 @@ def initialize_model_parallel( group_ranks = all_ranks.reshape( -1, decode_context_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] + if enable_elastic_ep: + tp_pp_size = tensor_model_parallel_size * pipeline_model_parallel_size + local_all_ranks = torch.arange(tp_pp_size).reshape( + pipeline_model_parallel_size, tensor_model_parallel_size) + group_ranks = local_all_ranks.reshape( + -1, decode_context_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] _DCP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, @@ -1177,6 +1288,13 @@ def initialize_model_parallel( group_ranks = all_ranks.transpose(2, 3).reshape( -1, pipeline_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] + if enable_elastic_ep: + tp_pp_size = tensor_model_parallel_size * pipeline_model_parallel_size + local_all_ranks = torch.arange(tp_pp_size).reshape( + pipeline_model_parallel_size, tensor_model_parallel_size) + group_ranks = local_all_ranks.transpose(0, 1).reshape( + -1, pipeline_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, @@ -1188,20 +1306,54 @@ def initialize_model_parallel( 3).reshape(-1, data_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] - _DP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="dp") + + if enable_elastic_ep: + from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator + parallel_config = config.parallel_config + group_ports = [parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks] + _DP = StatelessGroupCoordinator( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + torch_distributed_backend=backend, + use_device_communicator=True, + group_name="dp", + host=parallel_config.data_parallel_master_ip, + group_ports=group_ports, + global_rank=get_world_group().rank, + global_world_size=get_world_group().world_size + ) + else: + _DP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="dp") global _EP assert _EP is None, ("expert parallel group is already initialized") group_ranks = all_ranks.transpose(1, 2).reshape( -1, data_parallel_size * tensor_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] - _EP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="ep") + + if enable_elastic_ep: + from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator + parallel_config = config.parallel_config + group_ports = [parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks] + _EP = StatelessGroupCoordinator( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + torch_distributed_backend=backend, + use_device_communicator=True, + group_name="ep", + host=parallel_config.data_parallel_master_ip, + group_ports=group_ports, + global_rank=get_world_group().rank, + global_world_size=get_world_group().world_size + ) + else: + _EP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="ep") logger.info( "rank %s in world size %s is assigned as " @@ -1220,7 +1372,11 @@ def ensure_model_parallel_initialized( or ensure tensor-parallel and pipeline-parallel sizes are equal to expected values if the model parallel groups are initialized. """ - backend = backend or torch.distributed.get_backend( + from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator + if isinstance(get_world_group(), StatelessGroupCoordinator): + backend = backend or get_world_group().backend + else: + backend = backend or torch.distributed.get_backend( get_world_group().device_group) if not model_parallel_is_initialized(): initialize_model_parallel(tensor_model_parallel_size, @@ -1240,6 +1396,89 @@ def ensure_model_parallel_initialized( f"wanted: {pipeline_model_parallel_size=}") +def create_standby_groups( + new_dp_size: int, + new_world_size_across_dp: int, + master_ip: str, + world_group_ports: list[list[int]], + dp_group_ports: list[list[int]], + ep_group_ports: list[list[int]], + backend: Optional[str] = None, +) -> None: + from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator + + global _STANDBY_WORLD, _STANDBY_WORLD_NODE_COUNT, _STANDBY_DP, _STANDBY_EP + + assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size + backend = backend or get_world_group().backend + local_rank = get_world_group().local_rank + global_rank = get_world_group().rank + + standby_world_ranks = [list(range(new_world_size_across_dp))] + _STANDBY_WORLD = StatelessGroupCoordinator( + group_ranks=standby_world_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_device_communicator=False, + group_name="world", + host=master_ip, + group_ports=world_group_ports, + global_rank=global_rank, + global_world_size=new_world_size_across_dp, + ) + _STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group) + + tp_size = get_tp_group().world_size + pp_size = get_pp_group().world_size + + all_ranks = torch.arange(new_world_size_across_dp).reshape( + -1, new_dp_size, pp_size, tp_size) + standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0) + standby_dp_ranks = [x.tolist() for x in standby_dp_ranks] + _STANDBY_DP = StatelessGroupCoordinator( + group_ranks=standby_dp_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_device_communicator=True, + group_name="dp", + host=master_ip, + group_ports=dp_group_ports, + global_rank=global_rank, + global_world_size=new_world_size_across_dp, + ) + + standby_ep_ranks = all_ranks.transpose(1, 2).reshape( + -1, new_dp_size * tp_size).unbind(0) + standby_ep_ranks = [x.tolist() for x in standby_ep_ranks] + _STANDBY_EP = StatelessGroupCoordinator( + group_ranks=standby_ep_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_device_communicator=True, + group_name="ep", + host=master_ip, + group_ports=ep_group_ports, + global_rank=global_rank, + global_world_size=new_world_size_across_dp, + ) + + +def switch_to_standby_groups() -> None: + global _WORLD, _STANDBY_WORLD, _NODE_COUNT, _STANDBY_WORLD_NODE_COUNT + global _DP, _EP, _STANDBY_DP, _STANDBY_EP + _DP.destroy() + _EP.destroy() + _WORLD.destroy() + _DP = _STANDBY_DP + _EP = _STANDBY_EP + _WORLD = _STANDBY_WORLD + _NODE_COUNT = _STANDBY_WORLD_NODE_COUNT + _STANDBY_DP = None + _STANDBY_EP = None + _STANDBY_WORLD = None + _STANDBY_WORLD_NODE_COUNT = None + + def prepare_communication_buffer_for_model(model: torch.nn.Module): """Prepare the communication buffer for the model. Traditional communication libraries like NCCL are almost diff --git a/vllm/distributed/stateless_coordinator.py b/vllm/distributed/stateless_coordinator.py new file mode 100644 index 000000000000..f8cef6da2f71 --- /dev/null +++ b/vllm/distributed/stateless_coordinator.py @@ -0,0 +1,296 @@ +from typing import Optional, Union, Any +from torch.distributed import Backend, ProcessGroup +import torch + +from vllm.distributed.parallel_state import GroupCoordinator, TensorMetadata +from vllm.distributed.parallel_state import _get_unique_name, _register_group, _split_tensor_dict +from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator +from vllm.distributed.utils import ( + StatelessProcessGroup, stateless_init_torch_distributed_process_group, + stateless_destroy_torch_distributed_process_group) +from vllm.logger import init_logger +from vllm.utils import resolve_obj_by_qualname + +logger = init_logger(__name__) + + +class StatelessGroupCoordinator(GroupCoordinator): + + def __init__( + self, + group_ranks: list[list[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + use_device_communicator: bool, + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, + host: str = "127.0.0.1", + group_ports: list[list[int]] = None, + global_rank: int = 0, + global_world_size: int = 1, + ): + group_name = group_name or "anonymous" + self.unique_name = _get_unique_name(group_name) + _register_group(self) + + self.rank = global_rank + self.local_rank = local_rank + + self_device_group = None + self_cpu_group = None + self_tcp_store_group = None + + from vllm.platforms import current_platform + + backend = str(torch_distributed_backend) + self.backend = backend + + for idx, ranks in enumerate(group_ranks): + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + + ports = group_ports[idx] + device_port = ports[0] + cpu_port = ports[1] + tcp_store_port = ports[2] + + device_group = stateless_init_torch_distributed_process_group( + host=host, + port=device_port, + rank=self.rank_in_group, + world_size=self.world_size, + backend=backend, + group_name=f"{self.unique_name}_device" + ) + cpu_group = stateless_init_torch_distributed_process_group( + host=host, + port=cpu_port, + rank=self.rank_in_group, + world_size=self.world_size, + backend="gloo", + group_name=f"{self.unique_name}_cpu" + ) + tcp_store_group = StatelessProcessGroup.create( + host=host, + port=tcp_store_port, + rank=self.rank_in_group, + world_size=self.world_size, + ) + + self_device_group = device_group + self_cpu_group = cpu_group + self_tcp_store_group = tcp_store_group + + assert self_cpu_group is not None + assert self_device_group is not None + assert self_tcp_store_group is not None + + self.cpu_group = self_cpu_group + self.device_group = self_device_group + self.tcp_store_group = self_tcp_store_group + + if current_platform.is_cuda_alike(): + self.device = torch.device(f"cuda:{local_rank}") + elif current_platform.is_xpu(): + self.device = torch.device(f"xpu:{local_rank}") + elif current_platform.is_out_of_tree(): + self.device = torch.device( + f"{current_platform.device_name}:{local_rank}") + else: + self.device = torch.device("cpu") + + self.use_device_communicator = use_device_communicator + self.device_communicator = None + if use_device_communicator and self.world_size > 1: + device_comm_cls = resolve_obj_by_qualname( + current_platform.get_device_communicator_cls()) + assert device_comm_cls == CudaCommunicator + self.device_communicator = CudaCommunicator( + cpu_group=self.cpu_group, + device=self.device, + device_group=self.device_group, + unique_name=self.unique_name, + global_ranks=self.ranks, + global_world_size=global_world_size, + tcp_store_group=self.tcp_store_group + ) + + self.mq_broadcaster = None + + self.use_custom_op_call = (current_platform.is_cuda_alike() + or current_platform.is_tpu()) + self.use_cpu_custom_send_recv = False + + def destroy(self): + if self.device_communicator: + self.device_communicator.destroy() + if self.device_group: + stateless_destroy_torch_distributed_process_group(self.device_group) + if self.cpu_group: + stateless_destroy_torch_distributed_process_group(self.cpu_group) + self.tcp_store_group = None + + def broadcast(self, input_: torch.Tensor, src: int = 0): + if self.world_size == 1: + return input_ + + if self.device_communicator and input_.is_cuda: + return self.device_communicator.broadcast(input_, src) + else: + return self.tcp_store_group.broadcast(input_, src) + + def broadcast_object(self, obj=None, src: int = 0): + if self.world_size == 1: + return obj + return self.tcp_store_group.broadcast_obj(obj, src) + + def broadcast_object_list(self, + obj_list: list[Any], + src: int = 0, + group: Optional[ProcessGroup] = None): + assert src < self.world_size + + if self.world_size == 1: + return obj_list + + if self.rank_in_group == src: + for obj in obj_list: + self.tcp_store_group.broadcast_obj(obj, src) + else: + for i in range(len(obj_list)): + obj_list[i] = self.tcp_store_group.broadcast_obj(None, src) + + return obj_list + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[dict[str, Union[torch.Tensor, any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None + ) -> Optional[dict[str, Union[torch.Tensor, any]]]: + if self.world_size == 1: + return tensor_dict + + if self.rank_in_group == src: + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + else: + metadata_list = None + tensor_list = [] + + metadata_list = self.tcp_store_group.broadcast_obj(metadata_list, src) + + if self.rank_in_group != src: + tensor_dict = {} + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + tensor_list.append(tensor) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + + for tensor in tensor_list: + if tensor.numel() == 0: + continue + if self.device_communicator and tensor.is_cuda: + self.device_communicator.broadcast(tensor, src) + else: + self.tcp_store_group.broadcast(tensor, src) + + return tensor_dict + + def send_object(self, obj, dst: int) -> None: + assert dst < self.world_size + assert dst != self.rank_in_group + self.tcp_store_group.send_obj(obj, dst) + + def recv_object(self, src: int): + assert src < self.world_size + assert src != self.rank_in_group + return self.tcp_store_group.recv_obj(src) + + def send_tensor_dict( + self, + tensor_dict: dict[str, Union[torch.Tensor, any]], + dst: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, + all_gather_tensors: Optional[dict[str, bool]] = None, + ) -> Optional[dict[str, Union[torch.Tensor, any]]]: + if self.world_size == 1: + return tensor_dict + + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + assert dst < self.world_size + + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + self.tcp_store_group.send_obj(metadata_list, dst) + + for tensor in tensor_list: + if tensor.numel() == 0: + continue + if self.device_communicator and tensor.is_cuda: + self.device_communicator.send(tensor, dst) + else: + self.tcp_store_group.send(tensor, dst) + + return None + + def recv_tensor_dict( + self, + src: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, + all_gather_tensors: Optional[dict[str, bool]] = None, + ) -> Optional[dict[str, Union[torch.Tensor, any]]]: + if self.world_size == 1: + return None + + if src is None: + src = (self.rank_in_group - 1) % self.world_size + assert src < self.world_size + + recv_metadata_list = self.tcp_store_group.recv_obj(src) + tensor_dict = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() > 0: + if self.device_communicator and tensor.is_cuda: + tensor = self.device_communicator.recv(tensor.size(), tensor.dtype, src) + else: + tensor = self.tcp_store_group.recv(tensor, src) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + return tensor_dict + + def barrier(self): + self.tcp_store_group.barrier() + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + if self.world_size == 1: + return input_ + + if self.device_communicator is None: + raise ValueError("No device communicator found") + + if self.rank_in_group == dst: + gathered_list = [torch.empty_like(input_) for _ in range(self.world_size)] + gathered_list[self.rank_in_group] = input_ + for src_rank in range(self.world_size): + if src_rank != self.rank_in_group: + gathered_list[src_rank] = self.device_communicator.recv(input_.size(), input_.dtype, src_rank) + return torch.cat(gathered_list, dim=dim) + else: + self.device_communicator.send(input_, dst) + return None \ No newline at end of file diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 67f71643d039..ec6e59518ae2 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -18,7 +18,7 @@ from typing import Any, Optional import torch -from torch.distributed import ProcessGroup, TCPStore +from torch.distributed import ProcessGroup, TCPStore, Store from torch.distributed.distributed_c10d import (Backend, PrefixStore, _get_default_timeout, _unregister_process_group) @@ -170,6 +170,10 @@ def __post_init__(self): for i in range(self.world_size) } + def size(self) -> int: + """Return the world size of the process group.""" + return self.world_size + def send_obj(self, obj: Any, dst: int): """Send an object to a destination rank.""" self.expire_data() @@ -229,6 +233,71 @@ def all_gather_obj(self, obj: Any) -> list[Any]: gathered_objs.append(recv_obj) return gathered_objs + def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: + """Broadcast a tensor from source rank to all other ranks.""" + if self.rank == src: + tensor_bytes = pickle.dumps(tensor) + self.expire_data() + key = f"broadcast_tensor/{src}/{self.broadcast_send_counter}" + self.store.set(key, tensor_bytes) + self.broadcast_send_counter += 1 + self.entries.append((key, time.time())) + return tensor + else: + key = f"broadcast_tensor/{src}/{self.broadcast_recv_src_counter[src]}" + tensor = pickle.loads(self.store.get(key)) + self.broadcast_recv_src_counter[src] += 1 + return tensor + + def send(self, tensor: torch.Tensor, dst: int) -> None: + """Send a tensor to a destination rank.""" + self.expire_data() + key = f"send_tensor/{dst}/{self.send_dst_counter[dst]}" + self.store.set(key, pickle.dumps(tensor)) + self.send_dst_counter[dst] += 1 + self.entries.append((key, time.time())) + + def recv(self, tensor: torch.Tensor, src: int) -> torch.Tensor: + """Receive a tensor from a source rank.""" + key = f"send_tensor/{self.rank}/{self.recv_src_counter[src]}" + received = pickle.loads(self.store.get(key)) + self.recv_src_counter[src] += 1 + if tensor is not None: + tensor.copy_(received) + return tensor + return received + + def all_reduce(self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM) -> torch.Tensor: + """All-reduce a tensor across all ranks.""" + tensors = self.all_gather_obj(tensor) + result = tensors[0].clone() + for t in tensors[1:]: + if op == torch.distributed.ReduceOp.SUM: + result.add_(t) + elif op == torch.distributed.ReduceOp.PRODUCT: + result.mul_(t) + elif op == torch.distributed.ReduceOp.MAX: + result = torch.maximum(result, t) + elif op == torch.distributed.ReduceOp.MIN: + result = torch.minimum(result, t) + return result + + def gather(self, tensor: torch.Tensor, gather_list: Optional[list] = None, dst: int = 0) -> Optional[list[torch.Tensor]]: + """Gather tensors from all ranks to the destination rank.""" + if self.rank == dst: + if gather_list is None: + gather_list = [] + for i in range(self.world_size): + if i == self.rank: + gather_list.append(tensor) + else: + recv_tensor = self.recv(None, src=i) + gather_list.append(recv_tensor) + return gather_list + else: + self.send(tensor, dst=dst) + return None + def barrier(self, timeout: float = 30.0): """A robust barrier to synchronize all ranks. @@ -458,7 +527,8 @@ def init_gloo_process_group(backend: Backend, prefix_store: PrefixStore, def stateless_init_torch_distributed_process_group( host: str, port: int, rank: int, world_size: int, - backend: str) -> ProcessGroup: + backend: str, group_name: Optional[str] = None, + return_store: bool = False) -> ProcessGroup | tuple[ProcessGroup, Store]: """ A replacement for `torch.distributed.init_process_group` that does not pollute the global state. The created ProcessGroup object can be used for @@ -506,18 +576,29 @@ def stateless_init_torch_distributed_process_group( prefix_store = PrefixStore(init_method, store) if backend == "gloo": - return init_gloo_process_group(backend=backend, + pg = init_gloo_process_group(backend=backend, prefix_store=prefix_store, group_rank=group_rank, group_size=group_size, timeout=timeout) - from vllm.platforms import current_platform - return current_platform.stateless_init_device_torch_dist_pg( - backend=backend, - prefix_store=prefix_store, - group_rank=group_rank, - group_size=group_size, - timeout=timeout) + else: + from vllm.platforms import current_platform + pg = current_platform.stateless_init_device_torch_dist_pg( + backend=backend, + prefix_store=prefix_store, + group_rank=group_rank, + group_size=group_size, + timeout=timeout) + + if isinstance(group_name, str): + from torch._C._distributed_c10d import _register_process_group + pg._set_group_name(group_name) + _register_process_group(group_name, pg) + + if return_store: + return pg, store + else: + return pg def stateless_destroy_torch_distributed_process_group( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d4d801b155e1..2d61b827af97 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -327,6 +327,7 @@ class EngineArgs: data_parallel_hybrid_lb: bool = False data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel + enable_elastic_ep: bool = ParallelConfig.enable_elastic_ep enable_dbo: bool = ParallelConfig.enable_dbo dbo_decode_token_threshold: int = \ ParallelConfig.dbo_decode_token_threshold @@ -693,6 +694,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]) + parallel_group.add_argument( + "--enable-elastic-ep", + **parallel_kwargs["enable_elastic_ep"]) parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"]) parallel_group.add_argument( @@ -1319,6 +1323,7 @@ def create_engine_config( data_parallel_backend=self.data_parallel_backend, data_parallel_hybrid_lb=self.data_parallel_hybrid_lb, enable_expert_parallel=self.enable_expert_parallel, + enable_elastic_ep=self.enable_elastic_ep, enable_dbo=self.enable_dbo, dbo_decode_token_threshold=self.dbo_decode_token_threshold, enable_eplb=self.enable_eplb, diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 7a753d608a43..9221ded193de 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -50,9 +50,13 @@ def _init_executor(self) -> None: self.async_output_thread = ThreadPoolExecutor( max_workers=1, thread_name_prefix="WorkerAsyncOutput") + is_new_worker = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1" + self.collective_rpc("init_worker", args=([kwargs], )) - self.collective_rpc("init_device") - self.collective_rpc("load_model") + + if not is_new_worker: + self.collective_rpc("init_device") + self.collective_rpc("load_model") def _distributed_args(self) -> tuple[str, int, int]: """Return (distributed_init_method, rank, local_rank).""" diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 345f5a464c2c..037e55352979 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -193,6 +193,10 @@ class ReconfigureDistributedRequest(msgspec.Struct): new_data_parallel_rank_local: int new_data_parallel_master_ip: str new_data_parallel_master_port: int + new_data_parallel_master_port_list: Optional[list[int]] = None + new_stateless_world_group_port_list: Optional[list[list[int]]] = None + new_stateless_dp_group_port_list: Optional[list[list[int]]] = None + new_stateless_ep_group_port_list: Optional[list[list[int]]] = None class ReconfigureRankType(enum.IntEnum): diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 757baecea9ce..7ffa73dc5fec 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -430,7 +430,6 @@ def _run_output_handler(self): engine_core = self.engine_core output_processor = self.output_processor log_stats = self.log_stats - logger_manager = self.logger_manager async def output_handler(): try: @@ -470,8 +469,10 @@ async def output_handler(): # 4) Logging. # TODO(rob): make into a coroutine and launch it in # background thread once Prometheus overhead is non-trivial. - if logger_manager: - logger_manager.record( + if self.logger_manager: + # NOTE(yongji): we need to use self.logger_manager here + # since it can be reinstantiated during scaling up + self.logger_manager.record( engine_idx=outputs.engine_index, scheduler_stats=outputs.scheduler_stats, iteration_stats=iteration_stats, @@ -701,16 +702,6 @@ async def scale_elastic_ep(self, logger.info("Data parallel size is already %s, skipping scale", new_data_parallel_size) return - logger.info( - "Waiting for requests to drain before " - "scaling up to %s engines...", new_data_parallel_size) - await self.wait_for_requests_to_drain(drain_timeout) - logger.info( - "Requests have been drained, proceeding with scale " - "to %s engines", new_data_parallel_size) - await self.engine_core.scale_elastic_ep(new_data_parallel_size) - self.vllm_config.parallel_config.data_parallel_size = \ - new_data_parallel_size # recreate stat loggers if new_data_parallel_size > old_data_parallel_size and self.log_stats: @@ -723,6 +714,11 @@ async def scale_elastic_ep(self, engine_idxs=list(range(new_data_parallel_size)), custom_stat_loggers=None, ) + self.logger_manager.log_engine_initialized() + + await self.engine_core.scale_elastic_ep(new_data_parallel_size) + self.vllm_config.parallel_config.data_parallel_size = \ + new_data_parallel_size @property def is_running(self) -> bool: diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index 596edfdbe24f..a8e5bf35b309 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -71,6 +71,9 @@ def __init__(self, parallel_config: ParallelConfig): local_only=local_only, host=host) local_only_eng = dp_size == parallel_config.data_parallel_size_local + # NOTE(yongji): handling scaling from intra-node to inter-node + if parallel_config.enable_elastic_ep: + local_only_eng = False back_publish_address = get_engine_client_zmq_addr(local_only_eng, host) back_output_address = get_engine_client_zmq_addr(local_only_eng, host) @@ -188,6 +191,7 @@ def process_input_socket(self, front_publish_address: str, poller = zmq.Poller() poller.register(publish_front, zmq.POLLIN) + poller.register(publish_back, zmq.POLLIN) poller.register(output_back, zmq.POLLIN) last_publish_time = 0 while True: @@ -221,6 +225,20 @@ def process_input_socket(self, front_publish_address: str, events = dict(events) wave_state_changed = False + if publish_back in events: + buffer = publish_back.recv() + if buffer == b'\x01': + # NOTE(yongji): newly started engine subscribed + # We need to send READY message here instead of receiving + # SCALE_ELASTIC_EP notification from engine core client + # as SCALE_ELASTIC_EP is only sent when + # new engines finished initialization. + # Subscription message, on the other hand, is sent + # by each engine during initialization + publish_back.send(b"READY") + else: + logger.error("DP Coordinator receives unexpected message from engines") + if publish_front in events: buffer = publish_front.recv() if buffer in (b'\x01', b'\x00'): diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index a43042a5510a..9b67f6840ad2 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -87,6 +87,9 @@ def __init__(self, self.available_gpu_memory_for_kv_cache = -1 + if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1": + self._elastic_scale_up_post_init() + # Setup KV Caches and update CacheConfig after profiling. num_gpu_blocks, num_cpu_blocks, kv_cache_config = \ self._initialize_kv_caches(vllm_config) @@ -176,10 +179,8 @@ def _initialize_kv_caches( has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs) if has_kv_cache: if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1": - dp_group = getattr(self, "dp_group", None) - assert dp_group is not None - self.available_gpu_memory_for_kv_cache = \ - ParallelConfig.sync_kv_cache_memory_size(dp_group, -1) + # NOTE(yongji): should already be set during _elastic_scale_up_post_init + assert self.available_gpu_memory_for_kv_cache > 0 available_gpu_memory = [ self.available_gpu_memory_for_kv_cache ] * len(kv_cache_specs) @@ -493,6 +494,10 @@ def __init__( self.has_coordinator and not vllm_config.parallel_config.data_parallel_external_lb) + self.addresses = addresses + self.process_input_queue_non_block = False + if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1": + self._send_worker_notification("NEW_WORKERS_INIT_READY", vllm_config=vllm_config) self._init_data_parallel(vllm_config) super().__init__(vllm_config, executor_class, log_stats, @@ -736,8 +741,14 @@ def _process_input_queue(self): if logger.isEnabledFor(DEBUG) and self.input_queue.empty(): logger.debug("EngineCore waiting for work.") waited = True - req = self.input_queue.get() - self._handle_client_request(*req) + block = not self.process_input_queue_non_block + try: + req = self.input_queue.get(block=block) + self._handle_client_request(*req) + except queue.Empty: + break + if not block: + break if waited: logger.debug("EngineCore loop active.") @@ -866,6 +877,13 @@ def process_input_sockets(self, input_addresses: list[str], # (RequestType, RequestData) type_frame, *data_frames = input_socket.recv_multipart( copy=False) + + # NOTE(yongji): ignore READY message sent by DP coordinator + # that is used to notify newly started engines + if type_frame.buffer == b"READY": + assert input_socket == coord_socket + continue + request_type = EngineCoreRequestType( bytes(type_frame.buffer)) @@ -959,6 +977,7 @@ def __init__( self.step_counter = 0 self.current_wave = 0 self.last_counts = (0, 0) + self.elastic_scaling_state = None # Initialize the engine. dp_rank = vllm_config.parallel_config.data_parallel_rank @@ -986,7 +1005,7 @@ def _init_data_parallel(self, vllm_config: VllmConfig): vllm_config.kv_transfer_config.engine_id) self.dp_rank = dp_rank - self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() + self.dp_group, self.dp_store = vllm_config.parallel_config.stateless_init_dp_group(return_store=True) def shutdown(self): super().shutdown() @@ -1041,7 +1060,14 @@ def run_busy_loop(self): # 1) Poll the input queue until there is work to do. self._process_input_queue() - # 2) Step the engine core. + if self.elastic_scaling_state is not None: + has_progress = self.elastic_scaling_state.progress() + if self.elastic_scaling_state.is_complete(): + self.process_input_queue_non_block = False + self.elastic_scaling_state = None + elif has_progress: + continue + executed = self._process_engine_step() self._maybe_publish_request_counts() @@ -1064,9 +1090,6 @@ def run_busy_loop(self): # Notify client that we are pausing the loop. logger.debug("Wave %d finished, pausing engine loop.", self.current_wave) - # In the coordinator case, dp rank 0 sends updates to the - # coordinator. Otherwise (offline spmd case), each rank - # sends the update to its colocated front-end process. client_index = -1 if self.has_coordinator else 0 self.output_queue.put_nowait( (client_index, @@ -1087,47 +1110,80 @@ def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: def reinitialize_distributed( self, reconfig_request: ReconfigureDistributedRequest) -> None: - stateless_destroy_torch_distributed_process_group(self.dp_group) - self.shutdown() - - parallel_config = self.vllm_config.parallel_config - old_dp_size = parallel_config.data_parallel_size - parallel_config.data_parallel_size = \ - reconfig_request.new_data_parallel_size - if reconfig_request.new_data_parallel_rank != -1: - parallel_config.data_parallel_rank = \ - reconfig_request.new_data_parallel_rank - # local rank specifies device visibility, it should not be changed - assert reconfig_request.new_data_parallel_rank_local == \ - ReconfigureRankType.KEEP_CURRENT_RANK - parallel_config.data_parallel_master_ip = \ - reconfig_request.new_data_parallel_master_ip - parallel_config.data_parallel_master_port = \ - reconfig_request.new_data_parallel_master_port - if reconfig_request.new_data_parallel_rank != -2: - self.dp_rank = parallel_config.data_parallel_rank - self.dp_group = parallel_config.stateless_init_dp_group() - reconfig_request.new_data_parallel_master_port = \ - parallel_config.data_parallel_master_port - - self.model_executor.reinitialize_distributed(reconfig_request) - if reconfig_request.new_data_parallel_size > old_dp_size: - assert self.available_gpu_memory_for_kv_cache > 0 - # pass available_gpu_memory_for_kv_cache from existing - # engine-cores to new engine-cores so they can directly - # use it in _initialize_kv_caches() rather than profiling. - ParallelConfig.sync_kv_cache_memory_size( - self.dp_group, self.available_gpu_memory_for_kv_cache) - # NOTE(yongji): newly joined workers require dummy_run even - # CUDA graph is not used - self.model_executor.collective_rpc("compile_or_warm_up_model") - if reconfig_request.new_data_parallel_rank == \ - ReconfigureRankType.SHUTDOWN_CURRENT_RANK: - self.shutdown() - logger.info("DPEngineCoreProc %s shutdown", self.dp_rank) + from vllm.distributed.elastic_ep.elastic_state import ElasticScalingState + from copy import deepcopy + + new_parallel_config = deepcopy(self.vllm_config.parallel_config) + old_dp_size = new_parallel_config.data_parallel_size + new_parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size + if reconfig_request.new_data_parallel_rank != ReconfigureRankType.KEEP_CURRENT_RANK: + new_parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank + new_parallel_config.data_parallel_master_ip = reconfig_request.new_data_parallel_master_ip + new_parallel_config.data_parallel_master_port = reconfig_request.new_data_parallel_master_port + new_parallel_config._data_parallel_master_port_list = reconfig_request.new_data_parallel_master_port_list + + is_scale_down = reconfig_request.new_data_parallel_size < old_dp_size + is_shutdown = reconfig_request.new_data_parallel_rank == ReconfigureRankType.SHUTDOWN_CURRENT_RANK + worker_type = "shutdown" if is_shutdown else "existing" + scale_type = "scale_down" if is_scale_down else "scale_up" + + self.elastic_scaling_state = ElasticScalingState( + model_executor=self.model_executor, + engine_core=self, + vllm_config=self.vllm_config, + new_parallel_config=new_parallel_config if not is_shutdown else None, + worker_type=worker_type, + scale_type=scale_type, + reconfig_request=reconfig_request, + ) + self.process_input_queue_non_block = True + self.dp_store.add("elastic_ep_request_first_barrier", 1) + logger.info("[Elastic EP] Received reconfiguration request and starting scaling up/down") + + def _send_worker_notification(self, notification_type: str, vllm_config: Optional[VllmConfig] = None): + if vllm_config is None: + dp_rank = self.vllm_config.parallel_config.data_parallel_rank + else: + dp_rank = vllm_config.parallel_config.data_parallel_rank + notification_data = (notification_type, dp_rank) + outputs = EngineCoreOutputs(utility_output=UtilityOutput( + call_id=-1, + result=UtilityResult(notification_data) + )) + outputs.engine_index = self.engine_index + + if hasattr(self, 'output_thread') and self.output_thread.is_alive(): + self.output_queue.put_nowait((0, outputs)) else: - logger.info("Distributed environment reinitialized for DP rank %s", - self.dp_rank) + encoder = MsgpackEncoder() + with zmq.Context() as ctx, \ + make_zmq_socket(ctx, self.addresses.outputs[0], zmq.PUSH, linger=4000) as socket: + socket.send_multipart(encoder.encode(outputs)) + + def handle_worker_notification(self, notification_type: str): + assert self.elastic_scaling_state is not None + self.elastic_scaling_state.handle_notification(notification_type) + + def _elastic_scale_up_post_init(self): + from vllm.distributed.elastic_ep.elastic_state import ElasticScalingState + + self.elastic_scaling_state = ElasticScalingState( + model_executor=self.model_executor, + engine_core=self, + vllm_config=self.vllm_config, + new_parallel_config=None, + worker_type="new", + scale_type="scale_up", + reconfig_request=None, + ) + self.model_executor.collective_rpc("init_device") + self.model_executor.collective_rpc("load_model") + self._send_worker_notification("NEW_WORKERS_WEIGHTS_INIT_READY") + self.model_executor.collective_rpc("elastic_ep_execute", args=("receive_weights",)) + self.available_gpu_memory_for_kv_cache = \ + ParallelConfig.sync_kv_cache_memory_size(self.dp_group, -1) + self.model_executor.collective_rpc("elastic_ep_execute", args=("prepare_new_worker",)) + self.process_input_queue_non_block = True class DPEngineCoreActor(DPEngineCoreProc): diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index a84b0e55105b..7357ec68ca24 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -23,7 +23,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.tasks import SupportedTask -from vllm.utils import (close_sockets, get_open_port, get_open_zmq_inproc_path, +from vllm.utils import (close_sockets, get_open_zmq_inproc_path, in_loop, make_zmq_socket) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, @@ -396,6 +396,41 @@ def validate_alive(self, frames: Sequence[zmq.Frame]): raise EngineDeadError() +@dataclass +class ElasticScalingCache: + existing_workers: list[EngineIdentity] + num_new_workers: int + pending_notifications: dict[str, set[int]] + + +def allocate_stateless_group_ports(parallel_config, new_data_parallel_size: int): + """ + Allocate stateless group ports for elastic EP. + """ + from vllm.utils import get_open_ports_list + assert parallel_config.enable_elastic_ep, "Elastic EP must be enabled" + world_size = parallel_config.world_size + new_world_size_across_dp = world_size * new_data_parallel_size + num_world_groups = 1 + num_dp_groups = max(1, new_world_size_across_dp // new_data_parallel_size) + num_ep_groups = max(1, new_world_size_across_dp // (new_data_parallel_size * parallel_config.tensor_parallel_size)) + total_ports_needed = (num_world_groups + num_dp_groups + num_ep_groups) * 3 + 5 + all_ports = get_open_ports_list(total_ports_needed) + new_data_parallel_master_port_list = all_ports[-5:] + all_ports = all_ports[:-5] + new_stateless_world_group_port_list = [all_ports[i:i+3] for i in range(0, num_world_groups * 3, 3)] + start_idx = num_world_groups * 3 + new_stateless_dp_group_port_list = [all_ports[i:i+3] for i in range(start_idx, start_idx + num_dp_groups * 3, 3)] + start_idx += num_dp_groups * 3 + new_stateless_ep_group_port_list = [all_ports[i:i+3] for i in range(start_idx, start_idx + num_ep_groups * 3, 3)] + + parallel_config._stateless_world_group_port_list = new_stateless_world_group_port_list + parallel_config._stateless_dp_group_port_list = new_stateless_dp_group_port_list + parallel_config._stateless_ep_group_port_list = new_stateless_ep_group_port_list + parallel_config.data_parallel_master_port = new_data_parallel_master_port_list.pop() + parallel_config._data_parallel_master_port_list = new_data_parallel_master_port_list + + class MPClient(EngineCoreClient): """ MPClient: base client for multi-proc EngineCore. @@ -806,6 +841,9 @@ def _ensure_output_queue_task(self): output_socket = resources.output_socket assert output_socket is not None + notification_callback_handler: Optional[Callable[[tuple[str, int]], None]] = getattr( + self.__class__, "process_worker_notification", None) + async def process_outputs_socket(): try: while True: @@ -813,8 +851,15 @@ async def process_outputs_socket(): resources.validate_alive(frames) outputs: EngineCoreOutputs = decoder.decode(frames) if outputs.utility_output: - _process_utility_output(outputs.utility_output, - utility_results) + if outputs.utility_output.call_id == -1: + if notification_callback_handler and outputs.utility_output.result: + _self = _self_ref() if _self_ref else None + if _self: + notification_data = outputs.utility_output.result.result + asyncio.create_task(notification_callback_handler(_self, notification_data)) + else: + _process_utility_output(outputs.utility_output, + utility_results) continue if output_handler is not None: @@ -979,6 +1024,8 @@ def __init__(self, # Used only by DPLBAsyncMPClient subclass. self.lb_engines: list[list[int]] = [[0, 0] for _ in self.core_engines] + self.elastic_scaling_cache: Optional[ElasticScalingCache] = None + self.first_req_sock_addr = get_open_zmq_inproc_path() self.first_req_send_socket = self.resources.first_req_send_socket = ( make_zmq_socket(self.ctx, @@ -1027,6 +1074,7 @@ async def run_engine_stats_update_task(): poller.register(socket, zmq.POLLIN) poller.register(first_req_rcv_socket, zmq.POLLIN) + nonlocal count_slice while True: events = await poller.poll() if not self.engines_running and len(events) == 2 or ( @@ -1043,6 +1091,24 @@ async def run_engine_stats_update_task(): 0] == "SCALE_ELASTIC_EP": # Extract new engine count from the decoded message new_engine_count = decoded[1] + # Update engine_ranks_managed and count_slice + parallel_config = self.vllm_config.parallel_config + dp_size = parallel_config.data_parallel_size + dp_rank = parallel_config.data_parallel_rank + assert dp_rank == 0 + assert dp_size == new_engine_count + assert not (parallel_config.data_parallel_hybrid_lb + or parallel_config.data_parallel_external_lb) + num_ranks = dp_size + self.engine_ranks_managed = list( + range(dp_rank, dp_rank + num_ranks)) + count_slice = slice(self.engine_ranks_managed[0], + self.engine_ranks_managed[-1] + 1) + if len(self.lb_engines) < new_engine_count: + self.lb_engines = self.lb_engines + \ + [[0, 0] for _ in range(new_engine_count - len(self.lb_engines))] + else: + self.lb_engines = self.lb_engines[:new_engine_count] # Send scale up notification to coordinator scale_msg = msgspec.msgpack.encode( ("SCALE_ELASTIC_EP", new_engine_count)) @@ -1171,6 +1237,30 @@ async def process_engine_outputs(self: "DPLBAsyncMPClient", for req_id in outputs.finished_requests: self.reqs_in_flight.pop(req_id, None) + @staticmethod + async def process_worker_notification(self: "DPLBAsyncMPClient", + notification_data: tuple[str, int]): + cache = self.elastic_scaling_cache + assert cache is not None + notification_type, dp_rank = notification_data + if notification_type not in cache.pending_notifications: + cache.pending_notifications[notification_type] = set() + if dp_rank in cache.pending_notifications[notification_type]: + raise ValueError(f"Duplicate notification {notification_type} from dp_rank {dp_rank}") + cache.pending_notifications[notification_type].add(dp_rank) + if len(cache.pending_notifications[notification_type]) >= cache.num_new_workers: + logger.info("Received %d/%d notifications of type %s, forwarding to existing workers", + len(cache.pending_notifications[notification_type]), + cache.num_new_workers, + notification_type) + await asyncio.gather(*[ + self._call_utility_async("handle_worker_notification", + notification_type, + engine=engine) + for engine in cache.existing_workers + ]) + cache.pending_notifications[notification_type] = set() + async def abort_requests_async(self, request_ids: list[str]) -> None: if not request_ids or self.resources.engine_dead: return @@ -1219,41 +1309,53 @@ async def _scale_up_elastic_ep(self, cur_data_parallel_size: int, and reconfiguring existing ones.""" cur_data_parallel_size = len(self.core_engines) - # Phase 1: Send reconfigure messages to all existing engines and wait - # for them to be sent + self.elastic_scaling_cache = ElasticScalingCache( + existing_workers=list(self.core_engines), + num_new_workers=new_data_parallel_size - cur_data_parallel_size, + pending_notifications=dict(), + ) + + parallel_config = self.vllm_config.parallel_config + allocate_stateless_group_ports(parallel_config, new_data_parallel_size) + + # Phase 1: Send reconfig messages to existing engines reconfig_futures = [] - self.vllm_config.parallel_config.data_parallel_master_port = \ - get_open_port() for engine in self.core_engines: reconfig_request = ReconfigureDistributedRequest( new_data_parallel_size=new_data_parallel_size, new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK, new_data_parallel_rank_local=\ ReconfigureRankType.KEEP_CURRENT_RANK, - new_data_parallel_master_ip=self.vllm_config.parallel_config. - data_parallel_master_ip, - new_data_parallel_master_port=self.vllm_config.parallel_config. - data_parallel_master_port) + new_data_parallel_master_ip=parallel_config.data_parallel_master_ip, + new_data_parallel_master_port=parallel_config.data_parallel_master_port, + new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list, + new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list, + new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list, + new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list) coro = self._call_utility_async("reinitialize_distributed", reconfig_request, engine=engine) reconfig_futures.append(asyncio.create_task(coro)) - logger.info("All reconfigure messages sent, starting engine creation") - - # Phase 2: Create new engines now that reconfig messages have been sent - # self.resources.engine_manager is guaranteed to be - # CoreEngineActorManager for RayDPClient + # Phase 2: Create new engines assert isinstance(self.resources.engine_manager, CoreEngineActorManager) - self.resources.engine_manager.scale_up_elastic_ep( + parallel_config.eplb_config.num_redundant_experts = 0 + start_new_worker_future = asyncio.to_thread( + self.resources.engine_manager.scale_up_elastic_ep, self.vllm_config, new_data_parallel_size) + + # Phase 3: Wait for new engines to be created and reconfig messages to be received + await asyncio.gather(start_new_worker_future, *reconfig_futures) + logger.info("[Elastic EP] Successfully started new engines") # Create new CoreEngine objects for the new engines new_engine_identities = set() for i in range(cur_data_parallel_size, new_data_parallel_size): new_engine = i.to_bytes(2, "little") self.core_engines.append(new_engine) + # NOTE(yongji): we don't update lb_engines here, + # we let run_engine_stats_update_task to update it. new_engine_identities.add(new_engine) # Wait for ready messages from new engines on the input socket @@ -1266,10 +1368,6 @@ async def _scale_up_elastic_ep(self, cur_data_parallel_size: int, identity, _ = sync_input_socket.recv_multipart() new_engine_identities.discard(identity) - # Phase 3: Wait for all existing engines to complete reconfiguration - logger.info("Waiting for existing engines to complete reconfiguration") - await asyncio.gather(*reconfig_futures) - # Notify coordinator about scale up through existing # stats_update_task connection self._ensure_stats_update_task() @@ -1280,6 +1378,10 @@ async def _scale_up_elastic_ep(self, cur_data_parallel_size: int, # Update the parallel config self.vllm_config.parallel_config.data_parallel_size = \ new_data_parallel_size + self.elastic_scaling_cache = None + + # NOTE(yongji): at this point, reconfiguration may not be fully completed. + # But we can already start sending requests to the new engines. logger.info( "[Elastic EP] Scale up completed, new data parallel size: %s", new_data_parallel_size) @@ -1290,8 +1392,8 @@ async def _scale_down_elastic_ep(self, cur_data_parallel_size: int, reconfiguring existing engine cores.""" cur_data_parallel_size = len(self.core_engines) - self.vllm_config.parallel_config.data_parallel_master_port = \ - get_open_port() + parallel_config = self.vllm_config.parallel_config + allocate_stateless_group_ports(parallel_config, new_data_parallel_size) reconfig_futures = [] for cur_dp_rank, engine in enumerate(self.core_engines): @@ -1300,10 +1402,12 @@ async def _scale_down_elastic_ep(self, cur_data_parallel_size: int, new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK, new_data_parallel_rank_local=\ ReconfigureRankType.KEEP_CURRENT_RANK, - new_data_parallel_master_ip=self.vllm_config.parallel_config. - data_parallel_master_ip, - new_data_parallel_master_port=self.vllm_config.parallel_config. - data_parallel_master_port) + new_data_parallel_master_ip=parallel_config.data_parallel_master_ip, + new_data_parallel_master_port=parallel_config.data_parallel_master_port, + new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list, + new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list, + new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list, + new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list) if cur_dp_rank >= new_data_parallel_size: reconfig_request.new_data_parallel_rank = \ ReconfigureRankType.SHUTDOWN_CURRENT_RANK @@ -1312,8 +1416,8 @@ async def _scale_down_elastic_ep(self, cur_data_parallel_size: int, engine=engine) reconfig_futures.append(asyncio.create_task(coro)) - for _ in range(new_data_parallel_size, cur_data_parallel_size): - self.core_engines.pop() + self.core_engines = self.core_engines[:new_data_parallel_size] + self.lb_engines = self.lb_engines[:new_data_parallel_size] await asyncio.gather(*reconfig_futures) @@ -1327,6 +1431,8 @@ async def _scale_down_elastic_ep(self, cur_data_parallel_size: int, ("SCALE_ELASTIC_EP", new_data_parallel_size)) await self.first_req_send_socket.send(scale_down_marker) + # NOTE(yongji): at this point, reconfiguration may not be fully completed. + # But we will no longer send requests to the shutdown engines. self.vllm_config.parallel_config.data_parallel_size = \ new_data_parallel_size logger.info( diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 18ef25ceb6f5..a02cbaebe8d9 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -424,6 +424,8 @@ def add_dp_placement_groups( node_ip = node.node_ip node_id = node.node_id + if device_str not in available_resources[node_id]: + continue available_gpus = int(available_resources[node_id][device_str]) # Get total GPUs on this node from the node's resources @@ -623,6 +625,9 @@ def launch_core_engines( # sends requests only to colocated engines. client_local_only = (offline_mode or local_engines_only or (local_engine_count == dp_size)) + # NOTE(yongji): handling scaling from intra-node to inter-node + if parallel_config.enable_elastic_ep: + client_local_only = False # Set up input and output addresses. addresses = EngineZmqAddresses( @@ -694,6 +699,10 @@ def launch_core_engines( # will be False. handshake_local_only = offline_mode or local_engine_count == dp_size + # NOTE(yongji): handling scaling from intra-node to inter-node + if parallel_config.enable_elastic_ep: + handshake_local_only = False + handshake_address = get_engine_client_zmq_addr( handshake_local_only, host, parallel_config.data_parallel_rpc_port) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 2aa732f34bcc..849d800cbf75 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -426,15 +426,15 @@ def __init__( self.mm_receiver_cache = worker_receiver_cache_from_config( vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock) - # Initialize device - self.worker.init_device() + import os + is_new_worker = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1" - # Set process title and log prefix self.setup_proc_title_and_log_prefix( enable_ep=vllm_config.parallel_config.enable_expert_parallel) - # Load model - self.worker.load_model() + if not is_new_worker: + self.worker.init_device() + self.worker.load_model() @staticmethod def make_worker_process( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b0cd0f413307..5953e32388be 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -245,6 +245,7 @@ def __init__( self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) self.eplb_state: Optional[EplbState] = None + self.eplb_disabled = False """ State of the expert parallelism load balancer. @@ -1719,7 +1720,7 @@ def eplb_step(self, """ Step for the EPLB (Expert Parallelism Load Balancing) state. """ - if not self.parallel_config.enable_eplb: + if not self.parallel_config.enable_eplb or self.eplb_disabled: return assert self.eplb_state is not None @@ -1732,6 +1733,22 @@ def eplb_step(self, log_stats=self.parallel_config.eplb_config.log_balancedness, ) + def setup_eplb_from_mapping( + self, + expanded_physical_to_logical: torch.Tensor, + old_num_physical_experts: int, + ) -> None: + model = self.get_model() + assert is_mixture_of_experts(model) + + self.eplb_state = EplbState.from_mapping( + model=model, + device=self.device, + parallel_config=self.parallel_config, + expanded_physical_to_logical=expanded_physical_to_logical, + num_valid_physical_experts=old_num_physical_experts, + ) + def get_dp_padding(self, num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: """ @@ -2521,42 +2538,17 @@ def update_config(self, overrides: dict[str, Any]) -> None: new_config = update_config(config, config_overrides) setattr(self, config_name, new_config) - def load_model(self, eep_scale_up: bool = False) -> None: + def load_model(self, dummy_weights: bool = False) -> None: """ Args: - eep_scale_up: the model loading is for elastic EP scale up. + dummy_weights: load dummy weights instead of real weights. """ logger.info("Starting to load model %s...", self.model_config.model) - if eep_scale_up: - from vllm.distributed.parallel_state import get_ep_group - num_local_physical_experts = torch.empty(1, - dtype=torch.int32, - device="cpu") - torch.distributed.broadcast(num_local_physical_experts, - group=get_ep_group().cpu_group, - group_src=0) - num_local_physical_experts = int(num_local_physical_experts.item()) - new_ep_size = get_ep_group().world_size - global_expert_load, old_global_expert_indices = ( - EplbState.recv_state()) - num_logical_experts = global_expert_load.shape[1] - self.parallel_config.eplb_config.num_redundant_experts = ( - num_local_physical_experts * new_ep_size - num_logical_experts) - assert old_global_expert_indices.shape[ - 1] % num_local_physical_experts == 0 - old_ep_size = old_global_expert_indices.shape[ - 1] // num_local_physical_experts - rank_mapping = { - old_ep_rank: old_ep_rank - for old_ep_rank in range(old_ep_size) - } - else: - global_expert_load = None - old_global_expert_indices = None - rank_mapping = None with DeviceMemoryProfiler() as m: time_before_load = time.perf_counter() + if dummy_weights: + self.load_config.load_format = "dummy" model_loader = get_model_loader(self.load_config) logger.info("Loading model from scratch...") self.model = model_loader.load_model( @@ -2580,22 +2572,22 @@ def load_model(self, eep_scale_up: bool = False) -> None: logger.info("Model loading took %.4f GiB and %.6f seconds", self.model_memory_usage / GiB_bytes, time_after_load - time_before_load) - prepare_communication_buffer_for_model(self.model) + + if not dummy_weights: + prepare_communication_buffer_for_model(self.model) if is_mixture_of_experts( - self.model) and self.parallel_config.enable_eplb: + self.model) and self.parallel_config.enable_eplb and not dummy_weights: logger.info("EPLB is enabled for model %s.", self.model_config.model) self.eplb_state = EplbState.build( self.model, self.device, self.parallel_config, - global_expert_load, - old_global_expert_indices, - rank_mapping, ) if ( + not dummy_weights and self.vllm_config.compilation_config.level == \ CompilationLevel.DYNAMO_AS_IS and supports_dynamo() ): @@ -2896,6 +2888,8 @@ def _dummy_run( # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. + if not (num_tokens <= self.scheduler_config.max_num_batched_tokens): + logger.info(f"num_tokens: {num_tokens}, max_num_batched_tokens: {self.scheduler_config.max_num_batched_tokens}") assert num_tokens <= self.scheduler_config.max_num_batched_tokens max_num_reqs = self.scheduler_config.max_num_seqs if create_mixed_batch: diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 8b1e1bb8f45c..fa1151d2f1f2 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Any, Optional, Union import torch -import torch.distributed import torch.nn as nn import vllm.envs as envs @@ -26,7 +25,6 @@ from vllm.sequence import IntermediateTensors from vllm.tasks import SupportedTask from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling -from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput) @@ -59,6 +57,9 @@ def __init__( distributed_init_method=distributed_init_method, is_driver_worker=is_driver_worker) + from vllm.distributed.elastic_ep.elastic_execute import ElasticScalingExecutor + self.elastic_scaling_executor = ElasticScalingExecutor(self) + if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules @@ -209,9 +210,18 @@ def init_device(self): # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool # to hijack tensor allocation. def load_model(self) -> None: - eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1" + dummy_weights = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1" + if dummy_weights: + expanded_physical_to_logical, num_logical_experts, old_num_physical_experts = self.elastic_scaling_executor.receive_expert_mapping() + num_physical_experts = expanded_physical_to_logical.shape[1] + self.parallel_config.eplb_config.num_redundant_experts = num_physical_experts - num_logical_experts + with self._maybe_get_memory_pool_context(tag="weights"): - self.model_runner.load_model(eep_scale_up=eep_scale_up) + self.model_runner.load_model(dummy_weights=dummy_weights) + + if dummy_weights: + self.model_runner.setup_eplb_from_mapping(expanded_physical_to_logical, old_num_physical_experts) + self.model_runner.eplb_disabled = True def update_config(self, overrides: dict[str, Any]) -> None: self.model_runner.update_config(overrides) @@ -505,162 +515,6 @@ def check_health(self) -> None: # worker will always be healthy as long as it's running. return - def _eplb_before_scale_down(self, old_ep_size: int, - new_ep_size: int) -> None: - from vllm.distributed.parallel_state import get_ep_group - if get_ep_group().rank == 0: - logger.info("[Elastic EP] Starting expert resharding " - "before scaling down...") - rank_mapping = { - old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1 - for old_ep_rank in range(old_ep_size) - } - assert self.model_runner.eplb_state is not None - self.model_runner.eplb_state.rearrange(self.model_runner.model, - execute_shuffle=True, - global_expert_load=None, - rank_mapping=rank_mapping) - torch.cuda.synchronize() - if get_ep_group().rank == 0: - logger.info("[Elastic EP] Expert resharding completed!") - - def _eplb_after_scale_up( - self, old_ep_size: int, new_ep_size: int, - global_expert_load: Optional[torch.Tensor]) -> None: - from vllm.distributed.parallel_state import get_ep_group - if get_ep_group().rank == 0: - logger.info("[Elastic EP] Starting expert resharding " - "after scaling up...") - rank_mapping = { - old_ep_rank: old_ep_rank - for old_ep_rank in range(old_ep_size) - } - assert self.model_runner.eplb_state is not None - self.model_runner.eplb_state.rearrange( - self.model_runner.model, - execute_shuffle=True, - global_expert_load=global_expert_load, - rank_mapping=rank_mapping) - if get_ep_group().rank == 0: - logger.info("[Elastic EP] Expert resharding completed!") - - def _reconfigure_parallel_config( - self, reconfig_request: ReconfigureDistributedRequest) -> None: - """ - Update parallel config with provided reconfig_request - """ - parallel_config = self.vllm_config.parallel_config - parallel_config.data_parallel_size = \ - reconfig_request.new_data_parallel_size - if reconfig_request.new_data_parallel_rank != \ - ReconfigureRankType.KEEP_CURRENT_RANK: - parallel_config.data_parallel_rank = \ - reconfig_request.new_data_parallel_rank - if reconfig_request.new_data_parallel_rank_local != \ - ReconfigureRankType.KEEP_CURRENT_RANK: - parallel_config.data_parallel_rank_local = \ - reconfig_request.new_data_parallel_rank_local - parallel_config.data_parallel_master_ip = \ - reconfig_request.new_data_parallel_master_ip - parallel_config.data_parallel_master_port = \ - reconfig_request.new_data_parallel_master_port - - def _reconfigure_moe(self, old_ep_size: int, - new_ep_size: int) -> Optional[torch.Tensor]: - """ - Reconfigure MoE modules with provided reconfig_request - - Return the global expert load if new_ep_size > old_ep_size, - otherwise None - """ - from vllm.distributed.parallel_state import ( - get_dp_group, get_ep_group, prepare_communication_buffer_for_model) - from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoEParallelConfig) - - parallel_config = self.vllm_config.parallel_config - moe_modules = [ - module for module in self.model_runner.model.modules() - if (module.__class__.__name__ == "FusedMoE" - or module.__class__.__name__ == "SharedFusedMoE") - ] - num_local_experts = moe_modules[0].moe_config.num_local_experts - assert all(module.moe_config.num_local_experts == num_local_experts - for module in moe_modules), ( - "All MoE modules must have the same number of experts") - for module in moe_modules: - module.moe_config.num_experts = num_local_experts * new_ep_size - module.global_num_experts = module.moe_config.num_experts - module.moe_parallel_config = FusedMoEParallelConfig.make( - tp_size_=get_tp_group().world_size, - dp_size_=get_dp_group().world_size, - vllm_parallel_config=parallel_config, - ) - module.moe_config.moe_parallel_config = module.moe_parallel_config - if new_ep_size < old_ep_size: - num_local_physical_experts = num_local_experts - assert self.model_runner.eplb_state is not None - new_physical_experts = \ - self.model_runner.eplb_state.physical_to_logical_map.shape[1] - parallel_config.eplb_config.num_redundant_experts = ( - new_physical_experts - - self.model_runner.eplb_state.logical_replica_count.shape[1]) - global_expert_load = None - else: - num_local_physical_experts = torch.tensor([num_local_experts], - dtype=torch.int32, - device="cpu") - torch.distributed.broadcast(num_local_physical_experts, - group=get_ep_group().cpu_group, - group_src=0) - num_local_physical_experts = num_local_physical_experts.item() - new_physical_experts = num_local_physical_experts * new_ep_size - assert self.model_runner.eplb_state is not None - global_expert_load = self.model_runner.eplb_state.rearrange( - self.model_runner.model, execute_shuffle=False) - parallel_config.eplb_config.num_redundant_experts = ( - new_physical_experts - global_expert_load.shape[1]) - prepare_communication_buffer_for_model(self.model_runner.model) - self.model_runner.model.update_physical_experts_metadata( - num_physical_experts=new_physical_experts, - num_local_physical_experts=num_local_physical_experts) - return global_expert_load - - def reinitialize_distributed( - self, reconfig_request: ReconfigureDistributedRequest) -> None: - from vllm.config import set_current_vllm_config - from vllm.distributed.parallel_state import ( - cleanup_dist_env_and_memory, get_ep_group) - - old_ep_size = get_ep_group().world_size - old_ep_rank = get_ep_group().rank - new_ep_size = reconfig_request.new_data_parallel_size * get_tp_group( - ).world_size * get_pp_group().world_size - if new_ep_size < old_ep_size: - self._eplb_before_scale_down(old_ep_size, new_ep_size) - - cleanup_dist_env_and_memory() - - if reconfig_request.new_data_parallel_rank == \ - ReconfigureRankType.SHUTDOWN_CURRENT_RANK: - assert old_ep_rank >= new_ep_size - # shutdown - return - - self._reconfigure_parallel_config(reconfig_request) - - with set_current_vllm_config(self.vllm_config): - init_worker_distributed_environment(self.vllm_config, self.rank, - self.distributed_init_method, - self.local_rank) - - global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size) - - if new_ep_size > old_ep_size: - assert global_expert_load is not None - self._eplb_after_scale_up(old_ep_size, new_ep_size, - global_expert_load) - def save_sharded_state( self, path: str, @@ -686,6 +540,9 @@ def shutdown(self) -> None: if runner := getattr(self, "model_runner", None): runner.ensure_kv_transfer_shutdown() + def elastic_ep_execute(self, execute_method: str, *args, **kwargs): + return self.elastic_scaling_executor.execute(execute_method, *args, **kwargs) + def init_worker_distributed_environment( vllm_config: VllmConfig,