diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 737b2eede9c6..42506730e868 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -66,10 +66,10 @@ function cpu_tests() { tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token" # Run AWQ test - # docker exec cpu-test-"$NUMA_NODE" bash -c " - # set -e - # VLLM_USE_V1=0 pytest -s -v \ - # tests/quantization/test_ipex_quant.py" + docker exec cpu-test-"$NUMA_NODE" bash -c " + set -e + VLLM_USE_V1=0 pytest -s -v \ + tests/quantization/test_ipex_quant.py" # Run chunked-prefill and prefix-cache test docker exec cpu-test-"$NUMA_NODE" bash -c " diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index cf3aaab8493b..827649bfcf54 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -26,5 +26,7 @@ docker run \ --name "${container_name}" \ "${image_name}" \ sh -c ' + VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m + VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2 VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager ' diff --git a/examples/online_serving/chart-helm/values.yaml b/examples/online_serving/chart-helm/values.yaml index 815f02a4bfd5..28dba9a6f688 100644 --- a/examples/online_serving/chart-helm/values.yaml +++ b/examples/online_serving/chart-helm/values.yaml @@ -8,7 +8,7 @@ image: # -- Image tag tag: "latest" # -- Container launch command - command: ["vllm", "serve", "/data/", "--served-model-name", "opt-125m", "--enforce-eager", "--dtype", "bfloat16", "--block-size", "16", "--host", "0.0.0.0", "--port", "8000"] + command: ["vllm", "serve", "/data/", "--served-model-name", "opt-125m", "--dtype", "float32", "--block-size", "16", "--host", "0.0.0.0", "--port", "8000"] # -- Container port containerPort: 8000 diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 93bf20da4adb..3722e0eb537f 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -36,8 +36,7 @@ def clear_cache(): DEVICE_MLA_BLOCK_SIZES = { "cuda": [16, 64], # CUDA supports both standard and extended block sizes "hip": [16, 1], # HIP requires special handling for block_size=1 - # "cpu": [16] # CPU uses fixed block size from test cases - "cpu": [] # FIXME(woosuk): Temporarily disable CPU tests + "cpu": [16] # CPU uses fixed block size from test cases } @@ -82,14 +81,14 @@ def test_env( m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") if device == "cpu": - if not use_v1: - pytest.skip("CPU backend only supports V1") - with patch("vllm.attention.selector.current_platform", CpuPlatform()): backend = get_attn_backend(16, torch.float16, torch.float16, block_size, False) - assert backend.get_name() == "TORCH_SDPA_VLLM_V1" + if use_v1: + assert backend.get_name() == "TORCH_SDPA_VLLM_V1" + else: + assert backend.get_name() == "TORCH_SDPA" elif device == "hip": with patch("vllm.attention.selector.current_platform", @@ -205,14 +204,12 @@ def test_fp32_fallback( m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") if device == "cpu": - if not use_v1: - pytest.skip("CPU backend only supports V1") - with patch("vllm.attention.selector.current_platform", CpuPlatform()): backend = get_attn_backend(16, torch.float32, torch.float32, 16, False) - assert backend.get_name() == "TORCH_SDPA_VLLM_V1" + assert (backend.get_name() == "TORCH_SDPA_VLLM_V1" + if use_v1 else "TORCH_SDPA") elif device == "cuda": with patch("vllm.attention.selector.current_platform", diff --git a/vllm/attention/backends/cpu_mla.py b/vllm/attention/backends/cpu_mla.py new file mode 100644 index 000000000000..793cb87b7434 --- /dev/null +++ b/vllm/attention/backends/cpu_mla.py @@ -0,0 +1,307 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +import vllm._custom_ops as ops +from vllm._ipex_ops import ipex_ops +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadataBuilder, + AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonState +from vllm.attention.backends.torch_sdpa import TorchSDPAMetadata +from vllm.utils import make_tensor_with_pad +from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder + + +class CPUMLABackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "CPU_MLA" + + @staticmethod + def get_metadata_cls() -> Type["CPUMLAMetadata"]: + return CPUMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]: + return CPUMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["MLACommonState"]: + return MLACommonState + + @staticmethod + def get_impl_cls() -> Type["CPUMLAImpl"]: + return CPUMLAImpl + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + ops.copy_blocks_mla(kv_caches, src_to_dists) + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [576] + + +@dataclass +class CPUMLAMetadata(TorchSDPAMetadata): + # New for MLA + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor = None + + # required by MLACommonImpl + is_profile_run: bool = False + + +class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]): + + def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: + self.chunked_prefill = input_builder.chunked_prefill + self.input_builder = input_builder + assert not self.chunked_prefill, \ + "chunked prefill is currently not supported" + + def prepare(self): + self.input_data = self.input_builder.input_data + + def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size): + input_data = self.input_data + prefill_seq_lens = seq_lens[0:input_data.num_prefills] + prefill_query_lens = query_lens[0:input_data.num_prefills] + slot_mapping = torch.tensor(input_data.slot_mapping, + dtype=torch.long, + device="cpu") + + # metadata for prefill + if input_data.num_prefills > 0: + query_lens_tensor = torch.tensor(prefill_query_lens, + dtype=torch.int32, + device="cpu") + kv_lens_tensor = torch.tensor(prefill_seq_lens, + dtype=torch.int32, + device="cpu") + query_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + kv_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + torch.cumsum(query_lens_tensor, + dim=0, + dtype=torch.int32, + out=query_start_loc[1:]) + torch.cumsum(kv_lens_tensor, + dim=0, + dtype=torch.int32, + out=kv_start_loc[1:]) + max_query_len = max(prefill_query_lens) + max_kv_len = max(prefill_seq_lens) + + # for chunked-prefill + if self.chunked_prefill: + prefill_block_tables = make_tensor_with_pad( + self.input_data.prefill_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + else: + prefill_block_tables = None + + else: + query_start_loc = None + kv_start_loc = None + max_query_len = None + max_kv_len = None + prefill_block_tables = None + + # metadata for decode + if input_data.num_decode_tokens != 0: + seq_lens_tensor = torch.tensor( + input_data.seq_lens[input_data.num_prefills:], + dtype=torch.int32, + device="cpu", + ) + block_tables = make_tensor_with_pad( + self.input_data.decode_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + else: + block_tables = torch.tensor([]) + seq_lens_tensor = torch.tensor( + input_data.seq_lens[:input_data.num_prefills], + dtype=torch.int32, + device="cpu", + ) + + # For multi-modal models + placeholder_index_maps = None + if len(input_data.multi_modal_inputs_list) != 0: + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + input_data.multi_modal_placeholder_maps.items() + } + + return CPUMLAMetadata( + chunked_prefill=self.chunked_prefill, + seq_lens=prefill_seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_kv_len=max_kv_len, + prefill_query_start_loc=query_start_loc, + kv_start_loc=kv_start_loc, + max_decode_seq_len=input_data.max_decode_seq_len, + num_prefills=input_data.num_prefills, + num_prefill_tokens=input_data.num_prefill_tokens, + num_decode_tokens=input_data.num_decode_tokens, + block_tables=block_tables, + prefill_block_tables=prefill_block_tables, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, + input_positions=torch.tensor([self.input_data.input_positions])) + + +class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "CPUMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "CPUMLAImpl") + + # states is implemented. + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "CPUMLAImpl with FP8 KV cache not yet supported") + + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: CPUMLAMetadata, # type: ignore[override] + ) -> torch.Tensor: + + prefill_metadata = attn_metadata.prefill_metadata + assert prefill_metadata is not None + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) + + output = torch.empty_like(q) + ipex_ops.varlen_attention( + query=q, + key=k, + value=v_padded, + out=output, + seqlen_q=prefill_metadata.prefill_query_start_loc, + seqlen_k=prefill_metadata.prefill_query_start_loc, + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata.max_query_len, + pdropout=0.0, + softmax_scale=self.scale, + zero_tensors=False, + is_causal=True, + return_softmax=False, + gen_=None, + logits_soft_cap=0.0, + window_size_left=-1, + window_size_right=-1, + alibi_slopes=None, + ) + + # remove padding + output = output.view(-1, self.num_heads, + q.shape[-1])[..., :v.shape[-1]] + return output.reshape(-1, self.num_heads * v.shape[-1]) + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: CPUMLAMetadata, # type: ignore[override] + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + + q = torch.cat([q_nope, q_pe], dim=-1) + o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank) + + # Run MQA + ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale, + decode_meta.block_tables, + decode_meta.seq_lens_tensor) + return self._v_up_proj(o) diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py new file mode 100644 index 000000000000..410ada3b0828 --- /dev/null +++ b/vllm/attention/backends/ipex_attn.py @@ -0,0 +1,403 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" Attention layer with torch scaled_dot_product_attention + and PagedAttention.""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm._ipex_ops import ipex_ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) +from vllm.logger import init_logger + +logger = init_logger(__name__) + +_PARTITION_SIZE = 512 + + +class IpexAttnBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "IPEX" + + @staticmethod + def get_impl_cls() -> Type["IpexAttnBackendImpl"]: + return IpexAttnBackendImpl + + @staticmethod + def get_metadata_cls() -> Type["IpexAttnMetadata"]: + return IpexAttnMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + from vllm._ipex_ops import ipex_ops as ops + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + from vllm._ipex_ops import ipex_ops as ops + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +@dataclass +class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for IpexAttnBackend. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + slot_mapping: torch.Tensor + seq_lens: Optional[List[int]] + seqlen_q: Optional[torch.Tensor] + max_seqlen: Optional[int] + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[torch.Tensor]] = None + + @property + def prefill_metadata(self) -> Optional["IpexAttnMetadata"]: + # Currently chunked prefill is not supported + if self.num_decode_tokens == 0: + assert self.num_prefills > 0 + return self + + return None + + @property + def decode_metadata(self) -> Optional["IpexAttnMetadata"]: + # Currently chunked prefill is not supported + if self.num_prefills > 0: + assert self.num_decode_tokens == 0 + return None + + return self + + +class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if use_irope: + logger.warning_once( + "Using irope in Ipex is not supported yet, it will fall" + " back to global attention for long context.") + if blocksparse_params is not None: + raise ValueError( + "IPEX backend does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.sliding_window is not None) + if logits_soft_cap is None: + logits_soft_cap = -1 + self.logits_soft_cap = logits_soft_cap + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + if is_quantized_kv_cache(kv_cache_dtype): + raise NotImplementedError( + "IPEX backend does not support FP8 KV cache. " + "Please use xFormers backend instead.") + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "IpexAttnBackendImpl") + + def split_kv_cache( + self, + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 1 + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: IpexAttnMetadata, # type: ignore + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with IPEX varlen_attention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for IpexAttentionImpl") + + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache.numel() > 0: + key_cache, value_cache = self.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + ipex_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + layer._k_scale_float, + layer._v_scale_float, + ) + + if attn_metadata.is_prompt: + assert attn_metadata.seq_lens is not None + if (kv_cache.numel() == 0 + or attn_metadata.block_tables.numel() == 0): + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=1) + + if attn_metadata.attn_bias is None: + if self.sliding_window is not None: + att_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, self.sliding_window, + query.dtype) # type: ignore + else: + att_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, None, dtype=query.dtype) + attn_metadata.attn_bias = att_masks + + output = torch.empty( + (num_tokens, self.num_heads, self.head_size), + dtype=query.dtype, + device=query.device) + ipex_ops.varlen_attention( + query, + key, + value, + output, + attn_metadata.seqlen_q, + attn_metadata.seqlen_q, + self.alibi_slopes, + attn_metadata.max_seqlen, + attn_metadata.max_seqlen, + pdropout=0.0, + softmax_scale=self.scale, + zero_tensors=False, + is_causal=True, + return_softmax=False, + gen_=None, + window_size_left=-1, + window_size_right=-1, + logits_soft_cap=self.logits_soft_cap, + ) + else: + # prefix-enabled attention + raise RuntimeError( + "IPEX backend doesn't support prefix decoding.") + + else: + # Decoding run. + max_seq_len = attn_metadata.max_decode_seq_len + output = torch.empty_like(query) + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + # TODO(woosuk): Tune this heuristic. + # For context len > 8192, use V2 kernel to avoid shared memory + # shortage. + use_v1 = (max_seq_len <= 8192 and + (max_num_partitions == 1 or num_seqs * num_heads > 512)) + if use_v1: + # Run PagedAttention V1. + ipex_ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + attn_metadata.block_tables, + attn_metadata.seq_lens_tensor, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + layer._k_scale_float, + layer._v_scale_float, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ipex_ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + attn_metadata.block_tables, + attn_metadata.seq_lens_tensor, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + layer._k_scale_float, + layer._v_scale_float, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: List[int], +) -> List[torch.Tensor]: + attn_biases = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat((num_heads, 1, 1)) + bias.mul_(alibi_slopes[:, None, None]) + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype, + device=alibi_slopes.device).fill_(-torch.inf).triu_(diagonal=1) + attn_biases.append((bias + inf_mask).to(dtype)) + + return attn_biases + + +def _make_sliding_window_bias( + seq_lens: List[int], + window_size: Optional[int], + dtype: torch.dtype, +) -> List[torch.Tensor]: + attn_biases = [] + for seq_len in seq_lens: + tensor = torch.full( + (1, seq_len, seq_len), + dtype=dtype, + fill_value=1, + ) + shift = 0 + mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore + if window_size is not None: + mask = torch.triu(mask, diagonal=shift - window_size + 1) + mask = torch.log(mask) + attn_biases.append(mask.to(dtype)) + + return attn_biases diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py new file mode 100644 index 000000000000..c900666955a3 --- /dev/null +++ b/vllm/attention/backends/pallas.py @@ -0,0 +1,356 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import torch_xla.experimental.custom_kernel # Required to register custom ops. + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class PallasAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "PALLAS" + + @staticmethod + def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: + return PallasAttentionBackendImpl + + @staticmethod + def get_metadata_cls() -> Type["PallasMetadata"]: + return PallasMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_kv_heads, num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + raise RuntimeError("swap_blocks is not used for the TPU backend.") + + @torch.compile(backend="openxla") + @staticmethod + def copy_blocks( + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + src_to_dists: Tuple[torch.Tensor, torch.Tensor], + ) -> None: + src_indices, dst_indices = src_to_dists + for k_cache, v_cache in kv_caches: + torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True) + k_cache[:, dst_indices] = k_cache[:, src_indices] + torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True) + v_cache[:, dst_indices] = v_cache[:, src_indices] + + +@dataclass +class PallasMetadata(AttentionMetadata): + + # Currently, input sequences can only contain all prefills + # or all decoding. + block_tables: Optional[torch.Tensor] = None + context_lens: Optional[torch.Tensor] = None + effective_query_lens: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self) -> Optional["PallasMetadata"]: + if self.num_prefills == 0: + return None + + assert self.num_decode_tokens == 0 + return self + + @property + def decode_metadata(self) -> Optional["PallasMetadata"]: + if self.num_decode_tokens == 0: + return None + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.block_tables is not None + assert self.context_lens is not None + return self + + +class PallasAttentionBackendImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if use_irope: + logger.warning_once( + "Using irope in Pallas is not supported yet, it will fall back " + "to global attention for long context.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.logits_soft_cap = logits_soft_cap + if head_size % 128 != 0: + raise NotImplementedError( + f"Head size must be a multiple of 128, found {head_size}.") + if alibi_slopes is not None: + raise NotImplementedError("Alibi slopes is not supported.") + if sliding_window is not None: + raise NotImplementedError("Sliding window is not supported.") + if is_quantized_kv_cache(kv_cache_dtype): + raise NotImplementedError("FP8 KV cache dtype is not supported.") + if blocksparse_params is not None: + raise NotImplementedError("Blocksparse is not supported.") + + if torch_xla.tpu.version() < 4: + raise NotImplementedError("TPU version must be 4 or higher.") + + self.megacore_mode = None + tpu_env = torch_xla.tpu.get_tpu_env() + tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None) + or tpu_env.get("TYPE", None) + or tpu_env.get("TPU_ACCELERATOR_TYPE", None)) + assert tpu_type is not None + tpu_type = tpu_type.lower() + + if (("lite" not in tpu_type) and ("v6" not in tpu_type)): + if self.num_kv_heads % 2 == 0: + self.megacore_mode = "kv_head" + else: + # NOTE(woosuk): If the batch size is not a multiple of 2, the + # megacore mode will be None. + self.megacore_mode = "batch" + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + attn_metadata: PallasMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with Pallas attention. + + Args: + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] + kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size] + kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size] + NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor + with shape [0] for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [batch_size, seq_len, num_heads * head_size] + """ + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for PallasAttentionImpl") + + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + batch_size, seq_len, hidden_size = query.shape + query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) + value = value.view(batch_size, seq_len, self.num_kv_heads, + self.head_size) + + if kv_cache[0].numel() > 0: + slot_mapping = attn_metadata.slot_mapping + key_cache, value_cache = kv_cache + write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) + + query = query * self.scale + if attn_metadata.num_prefills > 0: + if attn_metadata.block_tables is None: + # Prefill without paged KV cache. + assert seq_len % 16 == 0, ( + "Pallas FlashAttention kernel requires seq_len to be a " + f"multiple of 16 but got {seq_len}") + + # Handle GQA/MQA. + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, + dim=-2) + key = key.view(batch_size, seq_len, self.num_heads, + self.head_size) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=-2) + value = value.view(batch_size, seq_len, self.num_heads, + self.head_size) + # FlashAttention kernel requires the input shape to be + # [batch_size, num_heads, seq_len, d_model] + # while the input is [batch_size, seq_len, num_heads, d_model]. + # Permute the input to match the required format. + output = torch.ops.xla.flash_attention( + query.permute(0, 2, 1, 3), + key.permute(0, 2, 1, 3), + value.permute(0, 2, 1, 3), + True, + ) + output = output.permute(0, 2, 1, 3) + else: + # Prefill with paged KV cache. + # TODO(woosuk): Tune the below knobs. + num_kv_pages_per_compute_block = 16 + num_queries_per_compute_block = 16 + assert seq_len % num_queries_per_compute_block == 0 + output = torch.ops.xla.multi_queries_paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + attn_metadata.effective_query_lens, + num_kv_pages_per_compute_block, + num_queries_per_compute_block, + use_kernel=True, + attn_logits_soft_cap=self.logits_soft_cap, + ) + else: + # Decoding run. + assert kv_cache[0].numel() > 0 + query = query.squeeze(dim=1) + pages_per_compute_block = 16 # TODO(woosuk): Tune this value. + + assert attn_metadata.block_tables is not None + assert attn_metadata.context_lens is not None + # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire + # block table in SMEM. Therefore, if the block table is too large, + # the kernel compilation will fail. To avoid this, we split the + # batch dimension into smaller chunks and run the kernel multiple + # times. + MAX_SMEM_USAGE = 512 * 1024 + size_per_seq = 4 * attn_metadata.block_tables.shape[1] + max_num_seq = MAX_SMEM_USAGE // size_per_seq + + if batch_size <= max_num_seq: + output = paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + pages_per_compute_block, + self.megacore_mode, + attn_logits_soft_cap=self.logits_soft_cap, + ) + else: + chunk_size = max_num_seq + # Make sure the chunk size is a multiple of 2. + chunk_size = chunk_size // 2 * 2 + num_chunks = (batch_size + chunk_size - 1) // chunk_size + + output = torch.empty_like(query) + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = chunk_start + chunk_size + # NOTE(woosuk): We skip this line because it causes Dynamo + # compilation error. Instead, we rely on the slice operation + # to handle the out-of-bound case. + # chunk_end = min(chunk_end, batch_size) + chunk_output = paged_attention( + query[chunk_start:chunk_end], + key_cache, + value_cache, + attn_metadata.context_lens[chunk_start:chunk_end], + attn_metadata.block_tables[chunk_start:chunk_end], + pages_per_compute_block, + self.megacore_mode, + attn_logits_soft_cap=self.logits_soft_cap, + ) + output[chunk_start:chunk_end] = chunk_output + + # Reshape the output tensor. + return output.reshape(batch_size, seq_len, hidden_size) + + +def write_to_kv_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, +) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) + torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) + + key = key.flatten(0, 2) + value = value.flatten(0, 2) + key_cache = key_cache.flatten(0, 2) + value_cache = value_cache.flatten(0, 2) + key_cache.index_copy_(0, slot_mapping, key) + value_cache.index_copy_(0, slot_mapping, value) + + +def paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + pages_per_compute_block: int, + megacore_mode: Optional[str], + *, + attn_logits_soft_cap: Optional[float], +) -> torch.Tensor: + batch_size = query.shape[0] + if megacore_mode == "batch" and batch_size % 2 != 0: + megacore_mode = None + else: + megacore_mode = megacore_mode + + return torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + megacore_mode=megacore_mode, + attn_logits_soft_cap=attn_logits_soft_cap, + ) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index a490aa397991..af5fe81dc883 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -3,24 +3,78 @@ """ Attention layer with torch scaled_dot_product_attention and PagedAttention.""" from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple, Type import torch from torch.nn.functional import scaled_dot_product_attention # yapf conflicts with isort for this block # yapf: disable -from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer, - AttentionMetadata, AttentionType, +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType, is_quantized_kv_cache) # yapf: enable +from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.ipex_attn import PagedAttention, _use_ipex from vllm.attention.ops.paged_attn import PagedAttentionMetadata from vllm.logger import init_logger +from vllm.utils import make_tensor_with_pad +from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder logger = init_logger(__name__) +class TorchSDPABackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "TORCH_SDPA" + + @staticmethod + def get_impl_cls() -> Type["TorchSDPABackendImpl"]: + return TorchSDPABackendImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return TorchSDPAMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]: + return TorchSDPAMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + raise NotImplementedError("Swap is not supported in TorchSDPABackend.") + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + @dataclass class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): """Metadata for TorchSDPABackend. @@ -233,6 +287,113 @@ def get_seq_len_block_table_args( raise AttributeError(f"Invalid attention type {str(attn_type)}") +class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]): + + def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: + self.chunked_prefill = input_builder.chunked_prefill + self.input_builder = input_builder + + def prepare(self): + self.input_data = self.input_builder.input_data + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata: + input_data = self.input_data + prefill_seq_lens = seq_lens[0:input_data.num_prefills] + prefill_query_lens = query_lens[0:input_data.num_prefills] + slot_mapping = torch.tensor(input_data.slot_mapping, + dtype=torch.long, + device="cpu") + + # For chunked-prefill + if self.chunked_prefill and input_data.num_prefill_tokens != 0: + prefill_block_tables = make_tensor_with_pad( + self.input_data.prefill_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + query_lens_tensor = torch.tensor(prefill_query_lens, + dtype=torch.int32, + device="cpu") + kv_lens_tensor = torch.tensor(prefill_seq_lens, + dtype=torch.int32, + device="cpu") + query_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + kv_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + torch.cumsum(query_lens_tensor, + dim=0, + dtype=torch.int32, + out=query_start_loc[1:]) + torch.cumsum(kv_lens_tensor, + dim=0, + dtype=torch.int32, + out=kv_start_loc[1:]) + max_query_len = max(prefill_query_lens) + max_kv_len = max(prefill_seq_lens) + else: + prefill_block_tables = None + query_start_loc = None + kv_start_loc = None + max_query_len = None + max_kv_len = None + + # For paged attention + if input_data.num_decode_tokens != 0: + seq_lens_tensor = torch.tensor( + input_data.seq_lens[input_data.num_prefills:], + dtype=torch.int32, + device="cpu", + ) + block_tables = make_tensor_with_pad( + self.input_data.decode_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + else: + block_tables = torch.tensor([]) + seq_lens_tensor = torch.tensor( + input_data.seq_lens[:input_data.num_prefills], + dtype=torch.int32, + device="cpu", + ) + + # For multi-modal models + placeholder_index_maps = None + if len(input_data.multi_modal_inputs_list) != 0: + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + input_data.multi_modal_placeholder_maps.items() + } + + attn_metadata = TorchSDPAMetadata( + chunked_prefill=self.chunked_prefill, + seq_lens=prefill_seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_kv_len=max_kv_len, + prefill_query_start_loc=query_start_loc, + kv_start_loc=kv_start_loc, + max_decode_seq_len=input_data.max_decode_seq_len, + num_prefills=input_data.num_prefills, + num_prefill_tokens=input_data.num_prefill_tokens, + num_decode_tokens=input_data.num_decode_tokens, + block_tables=block_tables, + prefill_block_tables=prefill_block_tables, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, + ) + + return attn_metadata + + class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): def __init__( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 56ae1acf8571..a5157d6671b1 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -205,13 +205,13 @@ def get_num_new_matched_tokens( """ For remote prefill, pull all prompt blocks from remote asynchronously relative to engine execution. - + Args: request (Request): the request object. num_computed_tokens (int): the number of locally computed tokens for this request Returns: - * the number of tokens that can be loaded from the + * the number of tokens that can be loaded from the external KV cache beyond what is already computed. * true if the external KV cache tokens will be loaded asynchronously (between scheduler steps). @@ -687,14 +687,14 @@ def add_remote_agent(self, blocks from remote. In particular, handle both homogeneous and heterogeneous TP. The former - requires local rank_i to read from remote rank_i. - The latter, assuming D.world_size > P.world_size, requires that two or + requires local rank_i to read from remote rank_i. + The latter, assuming D.world_size > P.world_size, requires that two or more local TP worker share the xfer from a single TP worker. Here's an example: rank_offset p_remote_tp_rank - (kv split no) + (kv split no) -------------------------------- 0 0 Worker0 ---- 1st half of KV ----> Worker0 [ KV Cache ] / @@ -707,14 +707,14 @@ def add_remote_agent(self, Decoder TP workers Prefix TP workers (world_size=4) (world_size=2) - tp_ratio = 4 // 2 = 2 - - Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, kv_heads, block_size, head_dim] + tp_ratio = 4 // 2 = 2 + + Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, kv_heads, block_size, head_dim] then D-Worker_j has [2, num_blocksD, kv_heads//tp_ratio, block_size, head_dim]. Mind the "HND" layout format. - Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio + Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio first heads from all the slots of all the blocks. D-Worker1 will do the same, but reading the second split - along the kv_heads dimension, and so forth until "tp_ratio" D TP workers have pulled from P-Worker0. - + along the kv_heads dimension, and so forth until "tp_ratio" D TP workers have pulled from P-Worker0. + Note that the above will also hold true for the homogeneous TP case, where tp_ratio evaluates to 1. Regarding MLA case, the cache is replicated across TP workers so the rank_offset will just always be 0 @@ -904,8 +904,14 @@ def _pop_done_transfers( for handle, _xfer_stime in handles: xfer_state = self.nixl_wrapper.check_xfer_state(handle) if xfer_state == "DONE": + logger.info( + "Transfer KVCache for req %s has completed with handle %s", + req_id, handle) self.nixl_wrapper.release_xfer_handle(handle) elif xfer_state == "PROC": + logger.info( + "Transfer KVCache for req %s still in progress with handle %s", + req_id, handle) in_progress = True continue else: @@ -923,7 +929,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): """ for req_id, meta in metadata.requests.items(): remote_engine_id = meta.remote_engine_id - logger.debug( + logger.info( "start_load_kv for request %s from remote engine %s. " "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, remote_engine_id, len(meta.local_block_ids), diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 1050d3c59344..dccd60f4463a 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -64,11 +64,13 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, if selected_backend and selected_backend != _Backend.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) if use_mla: - raise NotImplementedError("MLA is not supported on CPU.") + logger.info("Using CPU MLA backend.") + return "vllm.attention.backends.cpu_mla.CPUMLABackend" logger.info("Using Torch SDPA backend.") - if not use_v1: - raise ValueError("CPU backend only supports V1.") - return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend" + if use_v1: + return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend" + else: + return "vllm.attention.backends.torch_sdpa.TorchSDPABackend" @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: @@ -145,14 +147,26 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config.distributed_executor_backend) parallel_config.distributed_executor_backend = "mp" if parallel_config.worker_cls == "auto": - parallel_config.worker_cls = "vllm.v1.worker.cpu_worker.CPUWorker" + if vllm_config.speculative_config: + parallel_config.worker_cls = \ + "vllm.spec_decode.spec_decode_worker.create_spec_worker" + parallel_config.sd_worker_cls = \ + "vllm.worker.cpu_worker.CPUWorker" + else: + if envs.VLLM_USE_V1: + parallel_config.worker_cls = \ + "vllm.v1.worker.cpu_worker.CPUWorker" + else: + parallel_config.worker_cls = \ + "vllm.worker.cpu_worker.CPUWorker" # Note: workaround for v1 gpu_model_runner from vllm.config import CompilationLevel vllm_config.compilation_config.cudagraph_capture_sizes = [] compilation_config = vllm_config.compilation_config - if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE: + if (envs.VLLM_USE_V1 and vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE): # Note: vLLM V1 is using PIECEWISE level compilation, which will # take time to compile kernels just-in-time with the inductor diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index a8c8cb46de2c..0387e348965d 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -6,6 +6,7 @@ import torch from tpu_info import device +import vllm.envs as envs from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger from vllm.sampling_params import SamplingParams, SamplingType @@ -49,10 +50,12 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, and selected_backend != _Backend.PALLAS_VLLM_V1): logger.info("Cannot use %s backend on TPU.", selected_backend) - if not use_v1: - raise ValueError("TPU backend only supports V1.") - logger.info("Using Pallas V1 backend.") - return "vllm.v1.attention.backends.pallas.PallasAttentionBackend" + if use_v1: + logger.info("Using Pallas V1 backend.") + return "vllm.v1.attention.backends.pallas.PallasAttentionBackend" + else: + logger.info("Using Pallas backend.") + return "vllm.attention.backends.pallas.PallasAttentionBackend" @classmethod def get_device_name(cls, device_id: int = 0) -> str: @@ -65,7 +68,7 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: @classmethod def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return False + return not envs.VLLM_USE_V1 @classmethod def get_punica_wrapper(cls) -> str: @@ -114,19 +117,31 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "Using bfloat16 instead.", vllm_config.model_config.dtype) vllm_config.model_config.dtype = torch.bfloat16 - from vllm.v1.attention.backends.pallas import PallasAttentionBackend - cache_config.block_size = PallasAttentionBackend.get_page_size( - vllm_config) # type: ignore[assignment] + if envs.VLLM_USE_V1: + from vllm.v1.attention.backends.pallas import ( + PallasAttentionBackend) + cache_config.block_size = PallasAttentionBackend.get_page_size( + vllm_config) # type: ignore[assignment] parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config if parallel_config.worker_cls == "auto": if scheduler_config.is_multi_step: - raise NotImplementedError( - "Multi-step scheduling is not supported (and not " - "needed) on vLLM V1. Please launch without " - "--num-scheduler-steps.") - parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker" + if envs.VLLM_USE_V1: + raise NotImplementedError( + "Multi-step scheduling is not supported (and not " + "needed) on vLLM V1. Please launch without " + "--num-scheduler-steps.") + else: + parallel_config.worker_cls = \ + "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" + else: + if envs.VLLM_USE_V1: + parallel_config.worker_cls = \ + "vllm.v1.worker.tpu_worker.TPUWorker" + else: + parallel_config.worker_cls = \ + "vllm.worker.tpu_worker.TPUWorker" assert not vllm_config.speculative_config, ( "Speculative decoding is not yet supported for TPU backend") @@ -174,9 +189,13 @@ def validate_request( processed_inputs: ProcessorInputs, ) -> None: """Raises if this request is unsupported on this platform""" - if (isinstance(params, SamplingParams) - and params.sampling_type == SamplingType.RANDOM_SEED): - raise ValueError("Torch XLA does not support per-request seed.") + if isinstance(params, SamplingParams): + if params.guided_decoding is not None and not envs.VLLM_USE_V1: + raise ValueError("Structured output is not supported on " + f"{cls.device_name} V0.") + if params.sampling_type == SamplingType.RANDOM_SEED: + raise ValueError( + "Torch XLA does not support per-request seed.") try: diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 5bd34033233a..61a0453dcbc8 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -39,10 +39,12 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, if selected_backend != _Backend.IPEX: logger.info("Cannot use %s backend on XPU.", selected_backend) use_v1 = envs.VLLM_USE_V1 - if not use_v1: - raise ValueError("XPU backend only supports V1.") - logger.info("Using Flash Attention backend on V1 engine.") - return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + if use_v1: + logger.info("Using Flash Attention backend on V1 engine.") + return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + else: + logger.info("Using IPEX attention backend.") + return "vllm.attention.backends.ipex_attn.IpexAttnBackend" @classmethod def get_device_capability( @@ -75,7 +77,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config = vllm_config.cache_config # in V1(or with ipex chunked prefill) block_size is 64 if cache_config and cache_config.block_size is None: - cache_config.block_size = 64 + if envs.VLLM_USE_V1: + cache_config.block_size = 64 + else: + cache_config.block_size = 16 # Instances created using VllmConfig() typically have model_config as # None by default. The modification involves adding a check to prevent @@ -101,7 +106,11 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # check and update parallel config parallel_config = vllm_config.parallel_config - parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker" + if envs.VLLM_USE_V1: + parallel_config.worker_cls =\ + "vllm.v1.worker.xpu_worker.XPUWorker" + else: + parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker" if parallel_config.distributed_executor_backend is None: if parallel_config.world_size > 1: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index fe552db74e2f..e720f8ef30b6 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -343,9 +343,11 @@ def schedule(self) -> SchedulerOutput: is_ready = self._update_waiting_for_remote_kv(request) if is_ready: request.status = RequestStatus.WAITING + logger.info("[wxl debug] %s has already received kvcache and enter into waiting status.", + request.request_id) else: - logger.debug( - "%s is still in WAITING_FOR_REMOTE_KVS state.", + logger.info( + "[wxl debug] %s is still in WAITING_FOR_REMOTE_KVS state.", request.request_id) self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 3754570dfaaa..792e46bd762c 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -224,7 +224,7 @@ async def add_request( data_parallel_rank: Optional[int] = None, ) -> RequestOutputCollector: """Add new request to the AsyncLLM.""" - + logger.info("[wxl debug] Request %s start add_request.", request_id) if self.errored: raise EngineDeadError() @@ -301,6 +301,7 @@ async def generate( """ try: + logger.info("[wxl debug] Request %s start generate.", request_id) # We start the output_handler on the first call to generate() so # we can call __init__ before the event loop, which enables us # to handle startup failure gracefully in the OpenAI server. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e2fdf6f8a11c..00180c30c0e1 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -195,7 +195,6 @@ def add_request(self, request: EngineCoreRequest): not self.scheduler.get_kv_connector()): logger.warning("Got kv_transfer_params, but no KVConnector found. " "Disabling KVTransfer for this request.") - self.scheduler.add_request(req) def abort_requests(self, request_ids: list[str]): @@ -638,7 +637,7 @@ def _process_engine_step(self) -> bool: def _handle_client_request(self, request_type: EngineCoreRequestType, request: Any) -> None: """Dispatch request from client.""" - + logger.info("[wxl debug] EngineCore handling client request %s.", request.request_id) if request_type == EngineCoreRequestType.ADD: self.add_request(request) elif request_type == EngineCoreRequestType.ABORT: diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 2bcd61d1f0aa..6f8939756ea4 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -19,7 +19,8 @@ from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates, RequestStateStats) - +from vllm.logger import init_logger +logger = init_logger(__name__) class RequestOutputCollector: """ @@ -323,6 +324,7 @@ def add_request( request_index: int = 0, queue: Optional[RequestOutputCollector] = None, ) -> None: + logger.info("[wxl debug] Request %s start outputprocessor add_request.", request.request_id) request_id = request.request_id if request_id in self.request_states: raise ValueError(f"Request id {request_id} already running.") @@ -351,17 +353,17 @@ def process_outputs( 1) Compute stats for logging 2) Detokenize 3) Create and handle RequestOutput objects: - * If there is a queue (for usage with AsyncLLM), + * If there is a queue (for usage with AsyncLLM), put the RequestOutput objects into the queue for handling by the per-request generate() tasks. - * If there is no queue (for usage with LLMEngine), + * If there is no queue (for usage with LLMEngine), return a list of RequestOutput objects. NOTE FOR DEVELOPERS vLLM V1 minimizes the number of python loops over the full - batch to ensure system overheads are minimized. This is the + batch to ensure system overheads are minimized. This is the only function that should loop over EngineCoreOutputs. If you need to touch every element of the batch, do it from @@ -433,6 +435,7 @@ def process_outputs( self._update_stats_from_finished(req_state, finish_reason, iteration_stats) + logger.info("[wxl debug] Request %s outputprocessor processed output and will return to client.", req_id) self.lora_states.update_iteration_stats(iteration_stats) return OutputProcessorOutput( diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 9fc52543efde..669cddee38a6 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -25,6 +25,8 @@ validate_guidance_grammar) from vllm.v1.structured_output.backend_xgrammar import ( validate_xgrammar_grammar) +from vllm.logger import init_logger +logger = init_logger(__name__) class Processor: @@ -225,7 +227,7 @@ def process_inputs( priority: int = 0, data_parallel_rank: Optional[int] = None, ) -> tuple[Optional[str], EngineCoreRequest]: - + logger.info("[wxl debug] Request %s start process_inputs.", request_id) # TODO(woosuk): Support pooling models. # TODO(woosuk): Support encoder-decoder models. self._validate_lora(lora_request) diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py new file mode 100644 index 000000000000..c99e2652a397 --- /dev/null +++ b/vllm/worker/cpu_enc_dec_model_runner.py @@ -0,0 +1,326 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast + +import torch + +from vllm.attention import AttentionMetadata +from vllm.forward_context import set_forward_context +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.multimodal import MultiModalKwargs +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.utils import make_tensor_with_pad +from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, + ModelInputForCPUBuilder, + ModelInputForCPUWithSamplingMetadata) +from vllm.worker.model_runner_base import ( + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + + +@dataclasses.dataclass(frozen=True) +class EncoderDecoderModelInputForCPU(ModelInputForCPUWithSamplingMetadata): + """ + Used by the EncoderDecoderModelRunner. + """ + encoder_input_tokens: Optional[torch.Tensor] = None + encoder_input_positions: Optional[torch.Tensor] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "encoder_input_tokens": self.encoder_input_tokens, + "encoder_input_positions": self.encoder_input_positions, + "multi_modal_kwargs": self.multi_modal_kwargs, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "EncoderDecoderModelInputForCPU": + return cast( + EncoderDecoderModelInputForCPU, + super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) + + +class CPUEncoderDecoderModelRunner( + CPUModelRunnerBase[EncoderDecoderModelInputForCPU]): + _model_input_cls: Type[EncoderDecoderModelInputForCPU] = ( + EncoderDecoderModelInputForCPU) + _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder + + def _list_to_int32_tensor( + self, + _list: List[int], + ) -> torch.Tensor: + return torch.tensor(_list, dtype=torch.int32, device=self.device) + + def _list_to_long_tensor( + self, + _list: List[int], + ) -> torch.Tensor: + return torch.tensor(_list, dtype=torch.long, device=self.device) + + def _empty_int32_tensor(self) -> torch.Tensor: + return self._list_to_int32_tensor([]) + + def _empty_long_tensor(self) -> torch.Tensor: + return self._list_to_long_tensor([]) + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, + Any]) -> EncoderDecoderModelInputForCPU: + return EncoderDecoderModelInputForCPU.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> EncoderDecoderModelInputForCPU: + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + ( + attn_metadata, + encoder_input_tokens_tensor, + encoder_input_positions_tensor, + ) = self._prepare_encoder_model_input_tensors(seq_group_metadata_list, + model_input) + # Sampling metadata is only required for the final pp group + generators = self.get_generators(finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + pin_memory=False, + generators=generators) + return dataclasses.replace( + model_input, + sampling_metadata=sampling_metadata, + attn_metadata=attn_metadata, + encoder_input_tokens=encoder_input_tokens_tensor, + encoder_input_positions=encoder_input_positions_tensor, + virtual_engine=virtual_engine, + ) + + def _prepare_encoder_model_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + model_input: EncoderDecoderModelInputForCPU, + ) -> Tuple[AttentionMetadata, Optional[torch.Tensor], + Optional[torch.Tensor]]: + """Helper method to prepare the encoder- and cross-attn-related + model inputs based on a given sequence group. These additional inputs + are used to augment an already-computed `EncoderDecoderModelInput` + data structure which already has decoder-related model inputs + populated. + + Sets the following attn_metadata fields: + * `num_encoder_tokens` + * `encoder_seq_lens` + * `encoder_seq_lens_tensor` + * `max_encoder_seq_len` + * `cross_slot_mapping` + * `cross_block_tables` + + Constructs a new model inputs data structure, based on + (1) the existing fields in the `model_inputs` argument, + and (2) the following additional fields which are + computed (or in the case of `attn_metadata`, updated) + by this function: + * attn_metadata + * encoder_input_tokens + * encoder_input_positions + + Arguments: + + * seq_group_metadata_list: list of sequence groups for which to + compute inputs + * model_inputs: model inputs data structure with decoder-oriented + fields already computed. + + Return: + + * Updated model inputs data structure + """ + + if len(seq_group_metadata_list) == 0: + return (model_input.attn_metadata, None, None) + + # Since we are not supporting chunked prefill either the entire + # batch is prefill or it is decode + is_prompt = seq_group_metadata_list[0].is_prompt + + # Build encoder inputs + encoder_seq_lens: List[int] = [] + if is_prompt: + # Prefill phase. + cross_block_tables = self._empty_int32_tensor().view( + len(seq_group_metadata_list), -1) + + # Extract input tokens/positions, cross-attention slot-mapping, + # & seq len from each sequence group metadata + ( + encoder_input_tokens, + encoder_input_positions, + cross_slot_mapping, + ) = ( + [], + [], + [], + ) + for seq_group_metadata in seq_group_metadata_list: + # Build seq lens + seq_len = seq_group_metadata.encoder_seq_data.get_len() + token_ids = seq_group_metadata.encoder_seq_data.get_token_ids() + encoder_seq_lens.append(seq_len) + + # Build slot mapping + for i in range(0, seq_len): + block_number = seq_group_metadata.cross_block_table[ + i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + cross_slot_mapping.append(slot) + + # Build encoder input tokens + encoder_input_tokens.extend(token_ids) + encoder_input_positions.extend(list(range(0, seq_len))) + + # Convert tokens/positions & cross-attention + # slot-mapping to encoder input tensors + encoder_input_tokens_tensor = self._list_to_long_tensor( + encoder_input_tokens) + encoder_input_positions_tensor = self._list_to_long_tensor( + encoder_input_positions) + cross_slot_mapping_tensor = self._list_to_long_tensor( + cross_slot_mapping) + + else: + # Decode phase. + encoder_input_tokens_tensor = self._empty_long_tensor() + encoder_input_positions_tensor = self._empty_long_tensor() + cross_slot_mapping_tensor = self._empty_long_tensor() + # Extract cross-attention block tables & + # seq len from each sequence group metadata. + # Cross-attention block tables are empty + # during vLLM memory profiling. + cross_block_tables = [] + for seq_group_metadata in seq_group_metadata_list: + for _ in range(len(seq_group_metadata.seq_data)): + encoder_seq_lens.append( + seq_group_metadata.encoder_seq_data.get_len()) + cross_block_table = seq_group_metadata.cross_block_table + cross_block_tables.append([] if ( + cross_block_table is None) else cross_block_table) + + max_len_of_block_table = max( + len(block_table) for block_table in cross_block_tables) + + cross_block_tables = make_tensor_with_pad( + cross_block_tables, + max_len=max_len_of_block_table, + pad=0, + dtype=torch.int32, + device=self.device, + ) + + # Compute encoder sequence lengths & encoder + # sequence starting offset tensors + max_encoder_seq_len = max(encoder_seq_lens, default=0) + encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens) + encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + + 1, + dtype=torch.int32, + device=self.device) + torch.cumsum(encoder_seq_lens_tensor, + dim=0, + dtype=encoder_seq_start_loc.dtype, + out=encoder_seq_start_loc[1:]) + + # Update attention metadata with encoder-oriented attributes + attn_metadata = model_input.attn_metadata + assert attn_metadata is not None + ( + attn_metadata.num_encoder_tokens, + attn_metadata.encoder_seq_lens, + attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, + attn_metadata.cross_slot_mapping, + attn_metadata.cross_block_tables, + ) = ( + sum(encoder_seq_lens), + encoder_seq_lens, + encoder_seq_lens_tensor, + max_encoder_seq_len, + cross_slot_mapping_tensor, + cross_block_tables, + ) + + return (attn_metadata, encoder_input_tokens_tensor, + encoder_input_positions_tensor) + + @torch.no_grad() + def execute_model( + self, + model_input: EncoderDecoderModelInputForCPU, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + if num_steps > 1: + raise ValueError( + "CPU worker does not support multi-step execution.") + + model_executable = self.model + execute_model_kwargs = { + "input_ids": + model_input.input_tokens, + "positions": + model_input.input_positions, + "encoder_input_ids": + model_input.encoder_input_tokens, + "encoder_positions": + model_input.encoder_input_positions, + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + device=self.device, + ), + "intermediate_tensors": + intermediate_tensors, + } + + with set_forward_context(model_input.attn_metadata, self.vllm_config, + model_input.virtual_engine): + hidden_states = model_executable(**execute_model_kwargs) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) + + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return [] + + # Sample the next token. + output = self.sampler( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + return [output] diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py new file mode 100644 index 000000000000..68cdf65cafa7 --- /dev/null +++ b/vllm/worker/cpu_model_runner.py @@ -0,0 +1,671 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +import weakref +from collections import defaultdict +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, + TypeVar, Union) + +import torch +from torch import nn + +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models import supports_lora, supports_multimodal +from vllm.multimodal import (BatchedTensorInputs, MultiModalKwargs, + MultiModalPlaceholderMap) +from vllm.sequence import (IntermediateTensors, SequenceData, + SequenceGroupMetadata) +from vllm.worker.model_runner_base import ( + ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + +TModelInputForCPU = TypeVar('TModelInputForCPU', bound="ModelInputForCPU") +_PAD_SLOT_ID = -1 + + +@dataclass(frozen=True) +class ModelInputForCPU(ModelRunnerInputBase): + """ + Base class contains metadata needed for the base model forward pass on CPU + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + token_type_ids: Optional[torch.Tensor] = None + attn_metadata: Optional["AttentionMetadata"] = None + multi_modal_kwargs: Optional[BatchedTensorInputs] = None + virtual_engine: Optional[int] = None + seq_lens: Optional[List[int]] = None + query_lens: Optional[List[int]] = None + lora_mapping: Optional["LoRAMapping"] = None + lora_requests: Optional[Set[LoRARequest]] = None + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "token_type_ids": self.token_type_ids, + "multi_modal_kwargs": self.multi_modal_kwargs, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type[TModelInputForCPU], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None + ) -> TModelInputForCPU: + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +@dataclass(frozen=True) +class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU): + """ + Used by the ModelRunner. + """ + sampling_metadata: Optional["SamplingMetadata"] = None + is_prompt: Optional[bool] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "token_type_ids": self.token_type_ids, + "multi_modal_kwargs": self.multi_modal_kwargs, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForCPUWithSamplingMetadata": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): + + class ModelInputData: + + def __init__(self, use_mrope: bool): + self.use_mrope = use_mrope + self.input_tokens: List[int] = [] + self.input_positions: List[int] = [] + self.token_type_ids: Optional[List[int]] = [] + self.seq_lens: List[int] = [] + self.query_lens: List[int] = [] + self.prefill_block_tables: List[List[int]] = [] + self.decode_block_tables: List[List[int]] = [] + self.max_decode_seq_len: int = 0 + self.num_prefills: int = 0 + self.num_prefill_tokens: int = 0 + self.num_decode_tokens: int = 0 + self.slot_mapping: List[int] = [] + self.multi_modal_inputs_list: List[MultiModalKwargs] = [] + self.multi_modal_placeholder_maps: Dict[ + str, MultiModalPlaceholderMap] = defaultdict( + MultiModalPlaceholderMap) + self.input_mrope_positions: List[List[int]] = [[] + for _ in range(3)] + + def __init__(self, + runner: "CPUModelRunner", + finished_requests_ids: Optional[List[str]] = None) -> None: + super().__init__() + self.runner = runner + self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled + or runner.cache_config.enable_prefix_caching) + self.model_input_cls = self.runner._model_input_cls + self.attn_backend = self.runner.attn_backend + self.sliding_window = self.runner.sliding_window + self.block_size = self.runner.block_size + self.device = self.runner.device + self.enable_lora = self.runner.lora_config is not None + if self.runner.attn_backend is not None: + # spec decode (e.g. Medusa) does not have atten backend + attn_backend = self.runner.attn_backend + self.att_metadata_builder = attn_backend.get_builder_cls()(self) + + def prepare(self, + finished_requests_ids: Optional[List[str]] = None) -> None: + self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] + self.input_data = ModelInputForCPUBuilder.ModelInputData( + self.runner.model_config.uses_mrope) + self.att_metadata_builder.prepare() + + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): + self.seq_group_metadata_list.append(seq_group_metadata) + + def set_seq_group_list( + self, seq_group_metadata_list: List[SequenceGroupMetadata]): + self.seq_group_metadata_list = seq_group_metadata_list + + def build(self) -> ModelInputForCPU: + self._build_input_data() + + input_data = self.input_data + input_tokens = torch.tensor(input_data.input_tokens, + dtype=torch.long, + device="cpu") + input_positions = torch.tensor( + input_data.input_positions + if not any(input_data.input_mrope_positions) else + input_data.input_mrope_positions, + dtype=torch.long, + device="cpu") + token_type_ids = torch.tensor(input_data.token_type_ids, + dtype=torch.long, + device="cpu") \ + if input_data.token_type_ids else None + + # For multi-modal models + multi_modal_kwargs = None + if len(input_data.multi_modal_inputs_list) != 0: + multi_modal_kwargs = MultiModalKwargs.batch( + input_data.multi_modal_inputs_list) + + attn_metadata = self.att_metadata_builder.build( + input_data.seq_lens, input_data.query_lens, -1, -1) + + is_prompt = (self.seq_group_metadata_list[0].is_prompt + if self.seq_group_metadata_list else None) + # LoRA data. + lora_requests = set() + lora_mapping = None + if self.enable_lora: + lora_requests = set(seq.lora_request + for seq in self.seq_group_metadata_list + if seq.lora_request is not None) + + lora_mapping = self._prepare_lora_input( + self.seq_group_metadata_list, is_prompt) + + return self.model_input_cls(input_tokens=input_tokens, + input_positions=input_positions, + token_type_ids=token_type_ids, + seq_lens=input_data.seq_lens, + query_lens=input_data.query_lens, + attn_metadata=attn_metadata, + multi_modal_kwargs=multi_modal_kwargs, + lora_mapping=lora_mapping, + lora_requests=lora_requests) + + def _build_input_data(self): + for seq_group_metadata in self.seq_group_metadata_list: + for seq_id, seq_data in seq_group_metadata.seq_data.items(): + if seq_group_metadata.is_prompt: + self._compute_prompt_input_tokens(self.input_data, + seq_group_metadata, + seq_data, seq_id) + if seq_group_metadata.multi_modal_data: + self._compute_multi_modal_input( + seq_group_metadata, seq_data) + else: + self._compute_decode_input_tokens(self.input_data, + seq_group_metadata, + seq_data, seq_id) + + def _compute_decode_input_tokens(self, data: ModelInputData, + seq_group_metadata: SequenceGroupMetadata, + seq_data: SequenceData, seq_id: int): + """ + Compute decode input tokens, positions, block table and slot mapping. + """ + block_size = self.runner.block_size + + block_table = seq_group_metadata.block_tables[seq_id] + seq_len = seq_data.get_len() + context_len = seq_data.get_num_computed_tokens() + + tokens = seq_data.get_last_token_id() + token_positions = seq_len - 1 + block_number = block_table[token_positions // block_size] + block_offset = token_positions % block_size + slot = block_number * block_size + block_offset + + # For paged_attention kernel + if self.runner.sliding_window: + start_idx = max(0, seq_len - self.runner.sliding_window) + start_block = start_idx // block_size + start_idx = start_block * block_size + seq_len = seq_len - start_idx + block_table = block_table[start_block:] + + # For MRotaryEmbedding + if seq_data.mrope_position_delta is not None: + next_pos = MRotaryEmbedding.get_next_input_positions( + seq_data.mrope_position_delta, + context_len, + seq_len, + ) + for idx in range(3): + data.input_mrope_positions[idx].extend( # type: ignore + next_pos[idx]) + else: + data.input_positions.append(token_positions) # type: ignore + + # Update fields + data.input_tokens.append(tokens) + data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len) + data.num_decode_tokens += 1 + data.slot_mapping.append(slot) + data.decode_block_tables.append(block_table) + data.query_lens.append(1) + data.seq_lens.append(seq_len) + + def _compute_prompt_input_tokens(self, data: ModelInputData, + seq_group_metadata: SequenceGroupMetadata, + seq_data: SequenceData, seq_id: int): + """ + Compute prompt input tokens, positions, block table and slot mapping. + """ + token_chunk_size = seq_group_metadata.token_chunk_size + block_size = self.runner.block_size + + block_table = seq_group_metadata.block_tables[seq_id] + seq_len = seq_data.get_len() + context_len = seq_data.get_num_computed_tokens() + seq_len = min(seq_len, context_len + token_chunk_size) + + # For prefix caching + prefix_cache_block_num = len(seq_group_metadata.computed_block_nums) + if prefix_cache_block_num > 0: + prefix_cache_len = (prefix_cache_block_num * + self.runner.block_size) + if prefix_cache_len <= context_len: + # We already passed the cache hit region, + # so do normal computation. + pass + elif context_len < prefix_cache_len < seq_len: + # Partial hit. Compute the missing part. + context_len = prefix_cache_len + token_chunk_size = seq_len - context_len + elif seq_len <= prefix_cache_len: + # Full hit. Only compute the last token to avoid + # erroneous behavior. FIXME: Ideally we should directly + # mark all tokens as computed in the scheduler and do not + # schedule this sequence, so this case should not happen. + context_len = seq_len - 1 + token_chunk_size = 1 + + tokens = seq_data.get_token_ids() + tokens = tokens[context_len:seq_len] + token_positions = range(context_len, seq_len) + token_types = seq_group_metadata.token_type_ids + + # For encoder-only models, the block_table is None, + # and there is no need to initialize the slot_mapping. + if block_table is not None: + slot_mapping = [_PAD_SLOT_ID] * len(token_positions) + for i, pos in enumerate(token_positions): + block_number = block_table[pos // block_size] + block_offset = pos % block_size + slot = block_number * block_size + block_offset + slot_mapping[i] = slot + data.slot_mapping.extend(slot_mapping) + + # The MROPE positions are prepared in _compute_multi_modal_input + data.input_positions.extend(token_positions) + + if data.token_type_ids is not None: + data.token_type_ids.extend(token_types if token_types else []) + + # Update fields + data.input_tokens.extend(tokens) + data.num_prefills += 1 + data.num_prefill_tokens += len(tokens) + data.query_lens.append(len(tokens)) + data.prefill_block_tables.append(block_table) + data.seq_lens.append(seq_len) + + def _compute_multi_modal_input(self, + seq_group_metadata: SequenceGroupMetadata, + seq_data: SequenceData): + computed_len = seq_data.get_num_computed_tokens() + seq_len = self.input_data.seq_lens[-1] + + # NOTE: mm_kwargs only includes the subset of multi-modal items that + # intersect with the current prefill positions. + mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( + seq_group_metadata, range(computed_len, seq_len)) + + if not mm_kwargs: + return + + # special processing for mrope position deltas. + if self.runner.model_config.uses_mrope: + assert not self.chunked_prefill, \ + "MROPE on CPU does not support chunked-prefill." + + image_grid_thw = mm_kwargs.get("image_grid_thw", None) + video_grid_thw = mm_kwargs.get("video_grid_thw", None) + audio_feature_lengths = mm_kwargs.get("audio_feature_lengths", + None) + assert ( + image_grid_thw is not None or video_grid_thw is not None + or audio_feature_lengths is not None), ( + "mrope embedding type requires multi-modal input mapper " + "returns 'image_grid_thw' or 'video_grid_thw' or " + "'audio_feature_lengths'.") + + second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None) + use_audio_in_video = mm_kwargs.get("use_audio_in_video", False) + hf_config = self.runner.model_config.hf_config + token_ids = seq_data.get_token_ids() + + mrope_positions, mrope_position_delta = \ + MRotaryEmbedding.get_input_positions( + token_ids, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=computed_len, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + seq_data.mrope_position_delta = mrope_position_delta + + for i in range(3): + self.input_data.input_mrope_positions[ # type: ignore + i].extend(mrope_positions[i]) + + self.input_data.multi_modal_inputs_list.append(mm_kwargs) + for modality, placeholder_map in placeholder_maps.items(): + self.input_data.multi_modal_placeholder_maps[modality].extend( + placeholder_map) + + def _prepare_lora_input( + self, seq_group_metadata_list: List[SequenceGroupMetadata], + is_prefill: bool) -> LoRAMapping: + index_mapping = [] + prompt_mapping = [] + for seq in seq_group_metadata_list: + lora_id = seq.lora_int_id + query_len = seq.token_chunk_size + + index_mapping += [lora_id] * query_len + prompt_mapping += [lora_id] * ( + query_len if seq.sampling_params + and seq.sampling_params.prompt_logprobs is not None else 1) + + return LoRAMapping(index_mapping=tuple(index_mapping), + prompt_mapping=tuple(prompt_mapping), + is_prefill=is_prefill) + + +class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): + """ + Helper class for shared methods between CPU model runners. + """ + _model_input_cls: Type[TModelInputForCPU] + _builder_cls: Type[ModelInputForCPUBuilder] + builder: ModelInputForCPUBuilder + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + return_hidden_states: bool = False, + *args, + **kwargs, + ): + ModelRunnerBase.__init__(self, vllm_config) + model_config = self.model_config + cache_config = self.cache_config + + self.is_driver_worker = is_driver_worker + self.return_hidden_states = return_hidden_states + + self.device = self.device_config.device + self.pin_memory = False + + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + num_attn_heads = self.model_config.get_num_attention_heads( + self.parallel_config) + needs_attn_backend = (num_attn_heads != 0 + or self.model_config.is_attention_free) + self.attn_backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + use_mla=self.model_config.use_mla, + ) if needs_attn_backend else None + + # Lazy initialization. + self.model: nn.Module # Set after init_Model + # Set after load_model. + self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None + self.sampler = get_sampler() + + if hasattr(self, "_builder_cls"): + # multi-step model runner does not have `_builder_cls` + self.builder = self._builder_cls(weakref.proxy(self)) + + def load_model(self) -> None: + self.model = get_model(vllm_config=self.vllm_config) + + if self.lora_config: + assert supports_lora( + self.model + ), f"{self.model.__class__.__name__} does not support LoRA yet." + + if supports_multimodal(self.model): + logger.warning("Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + + # Use get_text_config() in case of multimodal models + text_config = self.model_config.hf_config.get_text_config() + + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=text_config.max_position_embeddings, + ) + self.model = self.lora_manager.create_lora_manager(self.model) + + def get_model(self) -> nn.Module: + return self.model + + def _prepare_model_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_requests_ids: Optional[List[str]] = None + ) -> TModelInputForCPU: + """Helper method to prepare the model input based on a given sequence + group. Prepares metadata needed for the base model forward pass but not + metadata for possible additional steps, e.g., sampling. + + """ + self.builder.prepare(finished_requests_ids) + self.builder.set_seq_group_list(seq_group_metadata_list) + + return self.builder.build() # type: ignore + + @property + def vocab_size(self) -> int: + return self.model_config.get_vocab_size() + + def remove_all_loras(self): + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.remove_all_adapters() + + def set_active_loras(self, lora_requests: Set[LoRARequest], + lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_adapter(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_adapter(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.pin_adapter(lora_id) + + def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_adapters() + + +class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]): + _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = ( + ModelInputForCPUWithSamplingMetadata) + _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder + + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, Any], + ) -> ModelInputForCPUWithSamplingMetadata: + return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501 + tensor_dict, + attn_backend=self.attn_backend, + ) + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForCPUWithSamplingMetadata: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + """ + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + # Sampling metadata is only required for the final pp group + generators = self.get_generators(finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + pin_memory=False, + generators=generators) + + is_prompt = (seq_group_metadata_list[0].is_prompt + if seq_group_metadata_list else None) + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + virtual_engine=virtual_engine, + is_prompt=is_prompt) + + @torch.no_grad() + def execute_model( + self, + model_input: ModelInputForCPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + previous_hidden_states: Optional[torch.Tensor] = None, + ) -> Optional[List[SamplerOutput]]: + if num_steps > 1: + raise ValueError( + "CPU worker does not support multi-step execution.") + + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + model_executable = self.model + + multimodal_kwargs = {} + if model_input.multi_modal_kwargs is not None: + multimodal_kwargs = MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs, + device=self.device, + ) + execute_model_kwargs = {} + if previous_hidden_states is not None: + execute_model_kwargs.update( + {"previous_hidden_states": previous_hidden_states}) + + with set_forward_context(model_input.attn_metadata, self.vllm_config, + model_input.virtual_engine): + hidden_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + intermediate_tensors=intermediate_tensors, + **execute_model_kwargs, + **multimodal_kwargs, + ) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) + + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return [] + + # Sample the next token. + output = self.sampler( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + if self.return_hidden_states: + # we only need to pass hidden states of most recent token + if model_input.is_prompt: + output.prefill_hidden_states = hidden_states + output.hidden_states = hidden_states + return [output] + + def generate_proposals(self, *args, **kwargs): + return self.model.generate_proposals(*args, **kwargs) diff --git a/vllm/worker/cpu_pooling_model_runner.py b/vllm/worker/cpu_pooling_model_runner.py new file mode 100644 index 000000000000..203fdf225a41 --- /dev/null +++ b/vllm/worker/cpu_pooling_model_runner.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch + +from vllm.forward_context import set_forward_context +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.multimodal import MultiModalKwargs +from vllm.pooling_params import PoolingParams +from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, + SequenceGroupMetadata) +from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU, + ModelInputForCPUBuilder) + + +@dataclasses.dataclass(frozen=True) +class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU): + """ + Used by the CPUPoolingModelRunner. + """ + pooling_metadata: Optional["PoolingMetadata"] = None + + +class CPUPoolingModelRunner( + CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]): + _model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = ( + ModelInputForCPUWithPoolingMetadata) + _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForCPUWithPoolingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]: + if num_steps > 1: + raise ValueError( + "CPU worker does not support multi-step execution.") + + model_executable = self.model + cross_enc_kwargs = {} + if model_input.token_type_ids is not None: + cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids + execute_model_kwargs = { + "input_ids": + model_input.input_tokens, + "positions": + model_input.input_positions, + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + device=self.device, + ), + **cross_enc_kwargs, + "intermediate_tensors": + intermediate_tensors, + } + + with set_forward_context(model_input.attn_metadata, self.vllm_config, + model_input.virtual_engine): + hidden_states = model_executable(**execute_model_kwargs) + + # Only perform pooling in the driver worker. + if not self.is_driver_worker: + return [] + + return [ + self.model.pooler(hidden_states=hidden_states, + pooling_metadata=model_input.pooling_metadata) + ] + + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, + Any]) -> ModelInputForCPUWithPoolingMetadata: + return ModelInputForCPUWithPoolingMetadata.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + + def prepare_model_input( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForCPUWithPoolingMetadata: + assert seq_group_metadata_list is not None + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + # Prepare PoolingMetadata. + assert model_input.seq_lens is not None + pooling_metadata = self._prepare_pooling(seq_group_metadata_list, + model_input.seq_lens) + + return dataclasses.replace(model_input, + virtual_engine=virtual_engine, + pooling_metadata=pooling_metadata) + + def _prepare_pooling( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + ) -> PoolingMetadata: + """Prepare PoolingMetadata for the sequence group metadata list.""" + seq_groups: List[Tuple[List[int], PoolingParams]] = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + pooling_params = seq_group_metadata.pooling_params + seq_groups.append((seq_ids, pooling_params)) + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + pooling_metadata = PoolingMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + ) + + return pooling_metadata diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py new file mode 100644 index 000000000000..a8998127b60f --- /dev/null +++ b/vllm/worker/cpu_worker.py @@ -0,0 +1,452 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""A CPU worker class.""" +import os +from importlib import util +from typing import List, Optional, Set, Tuple, Type + +import torch +import torch.distributed + +import vllm.envs as envs +from vllm.attention import get_attn_backend +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, VllmConfig) +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor import set_random_seed +from vllm.sequence import ExecuteModelRequest +from vllm.utils import bind_kv_cache +from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner +from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase +from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, + WorkerInput) + +logger = init_logger(__name__) + + +class CPUCacheEngine: + """Manages the KV cache for CPU backend. + + This class is responsible for initializing and managing CPU KV + caches. It also provides methods for performing KV cache operations, such + as copying. + """ + + def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, + parallel_config: ParallelConfig, + device_config: DeviceConfig) -> None: + assert device_config.device_type == "cpu" + self.cache_config = cache_config + self.model_config = model_config + self.parallel_config = parallel_config + + self.head_size = model_config.get_head_size() + self.num_layers = model_config.get_num_layers(parallel_config) + self.num_heads = model_config.get_num_kv_heads(parallel_config) + + self.block_size = cache_config.block_size + # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks + # for CPU backend, because we want to reuse KV cache management + # in the scheduler. + self.num_cpu_blocks = cache_config.num_gpu_blocks + + self.dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config, + model_config) + + # Get attention backend. + self.attn_backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + cache_config.cache_dtype, + self.block_size, + self.model_config.is_attention_free, + use_mla=self.model_config.use_mla, + ) + + # Initialize the cache. + self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks) + + def _allocate_kv_cache( + self, + num_blocks: int, + ) -> List[torch.Tensor]: + """Allocates KV cache on CPU.""" + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_heads, self.head_size) + kv_cache: List[torch.Tensor] = [] + for _ in range(self.num_layers): + kv_cache.append( + torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu")) + return kv_cache + + def swap_in(self, src_to_dst: torch.Tensor) -> None: + raise NotImplementedError("Swap is not supported in CPUCacheEngine.") + + def swap_out(self, src_to_dst: torch.Tensor) -> None: + raise NotImplementedError("Swap is not supported in CPUCacheEngine.") + + def copy(self, src_to_dsts: torch.Tensor) -> None: + self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts) + + @staticmethod + def get_kv_cache_dtype(cache_config: CacheConfig, + model_config: ModelConfig): + if cache_config.cache_dtype == "auto": + return model_config.dtype + elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]: + return torch.float8_e5m2 + else: + raise NotImplementedError(f"Unsupported KV cache type " + f"{cache_config.cache_dtype}.") + + @staticmethod + def get_cache_block_size( + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + ) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_layers = model_config.get_num_layers(parallel_config) + + key_cache_block = cache_config.block_size * num_heads * head_size + value_cache_block = key_cache_block if not model_config.use_mla else 0 + total = num_layers * (key_cache_block + value_cache_block) + dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config, model_config) + dtype_size = torch.tensor([], dtype=dtype).element_size() + return dtype_size * total + + +class CPUWorker(LocalOrDistributedWorkerBase): + """A worker class that executes (a partition of) the model on a CPU socket. + + Each worker is associated with a single CPU socket. The worker is + responsible for maintaining the KV cache and executing the model on the + CPU. In case of distributed inference, each worker is assigned a partition + of the model. + """ + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + model_runner_cls: Optional[Type[CPUModelRunner]] = None, + ) -> None: + WorkerBase.__init__(self, vllm_config=vllm_config) + + self.local_rank = local_rank + self.rank = rank + vllm_config.parallel_config.rank = rank + + self.distributed_init_method = distributed_init_method + + self.is_driver_worker = is_driver_worker + if self.is_driver_worker: + assert self.rank == 0, "The driver worker must have rank 0." + + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + # Setup OpenMP threads affinity. + omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND + self.local_omp_cpuid = "all" + if omp_cpuids == "auto": + self.local_omp_cpuid = self.get_cpus_id_binding_based_on_numa_nodes( + ) + else: + self.local_omp_cpuid = omp_cpuids.split("|")[rank] + + # Return hidden states from target model if the draft model is an + # mlp_speculator + speculative_config = self.speculative_config + model_config = self.model_config + speculative_args = {} if speculative_config is None \ + or (speculative_config.draft_model_config.model == + model_config.model) \ + or (speculative_config.draft_model_config.hf_config.model_type + not in ["medusa", "mlp_speculator", "eagle"]) \ + else {"return_hidden_states": True} + ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner + if self.model_config.runner_type == "pooling": + ModelRunnerClass = CPUPoolingModelRunner + elif self.model_config.is_encoder_decoder: + ModelRunnerClass = CPUEncoderDecoderModelRunner + self.model_runner: CPUModelRunnerBase = ModelRunnerClass( + vllm_config=vllm_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + **speculative_args, + ) + if model_runner_cls is not None: + self.model_runner = model_runner_cls(self.model_runner) + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: List[CPUCacheEngine] + # Initialize cpu_cache as pooling models don't initialize kv_caches + self.cpu_cache: Optional[List[List[torch.Tensor]]] = None + + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + else: + self.profiler = None + + def start_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.start() + + def stop_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.stop() + + def init_device(self) -> None: + if self.local_omp_cpuid != "all": + ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) + if ret: + logger.info(ret) + + # Note: unique identifier for creating allreduce shared memory + os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split( + ":")[-1] + self.device = torch.device("cpu") + self.init_distributed_environment() + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of blocks available for the KV cache. + + This determines how many KV blocks can fit into the configured CPU + KV cache space. + + Note that since vLLM assumes a block resides on GPU if it can be + modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0. + This allows us to reuse the scheduler of vLLM without generalizing it + to different devices. + """ + # For CPU device, the block number will be calculated based on the + # cpu_kvcache_space. + cache_block_size = self.get_cache_block_size_bytes() + num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes // + cache_block_size) + num_cpu_blocks = max(num_cpu_blocks, 0) + + # Note: To reuse the cache management procedure, + # use cpu cache as 'gpu cache'. + num_gpu_blocks = num_cpu_blocks + num_cpu_blocks = 0 + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache. Currently, swappable CPU memory is not + supported. + + Since this worker does not support GPUs, we use the num_gpu_blocks to + determine how many non-swappable CPU blocks to allocate. + """ + assert (num_cpu_blocks == 0 + ), f"{type(self)} does not support swappable cache" + + # Note: To reuse the cache management procedure, + # use cpu cache as 'gpu cache'. + num_cpu_blocks = num_gpu_blocks + + self._validate_num_cpu_blocks(num_cpu_blocks) + self.cache_config.num_gpu_blocks = num_cpu_blocks + self.cache_config.num_cpu_blocks = 0 + + # Initialize the cache. + self._init_cache_engine() + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + return self.model_runner.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_runner.list_loras() + + def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: + """Raise errors if the num_cpu_blocks is invalid. + """ + if num_cpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `VLLM_CPU_KVCACHE_SPACE` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_cpu_blocks + if self.model_config.max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({self.model_config.max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when " + "initializing the engine.") + + def _init_cache_engine(self) -> None: + self.cache_engine = [ + CPUCacheEngine(self.cache_config, self.model_config, + self.parallel_config, self.device_config) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + self.cpu_cache = [ + self.cache_engine[ve].cpu_cache + for ve in range(self.parallel_config.pipeline_parallel_size) + ] + bind_kv_cache(self.compilation_config.static_forward_context, + self.cpu_cache) + self.model_runner.block_size = self.cache_engine[0].block_size + + assert all( + self.cpu_cache[ve] is not None + for ve in range(self.parallel_config.pipeline_parallel_size)) + + # Populate the cache to warmup the memory + for ve in range(self.parallel_config.pipeline_parallel_size): + for layer_cache in self.cpu_cache[ve]: + layer_cache.fill_(0) + + @property + def do_metadata_broadcast(self) -> bool: + return self.parallel_config.tensor_parallel_size > 1 + + @property + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: + return self.cpu_cache + + @property + def vocab_size(self) -> int: + return self.model_runner.vocab_size + + @property + def max_model_len(self) -> int: + return self.model_config.max_model_len + + def execute_worker( + self, + worker_input: WorkerInput, + ) -> None: + if (worker_input.blocks_to_copy is not None + and worker_input.blocks_to_copy.numel() > 0): + self.cache_engine[worker_input.virtual_engine].copy( + worker_input.blocks_to_copy) + + @torch.inference_mode() + def prepare_worker_input( + self, execute_model_req: ExecuteModelRequest) -> WorkerInput: + assert execute_model_req is not None + virtual_engine: int = execute_model_req.virtual_engine + num_seq_groups: int = len(execute_model_req.seq_group_metadata_list) + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device="cpu", + dtype=torch.int64).view(-1, 2) + assert len(execute_model_req.blocks_to_swap_in) == 0 + assert len(execute_model_req.blocks_to_swap_out) == 0 + return WorkerInput( + num_seq_groups=num_seq_groups, + blocks_to_copy=blocks_to_copy, + virtual_engine=virtual_engine, + ) + + def init_distributed_environment(self) -> None: + """Initialize the distributed environment.""" + + parallel_config = self.parallel_config + rank = self.rank + distributed_init_method = self.distributed_init_method + init_distributed_environment( + world_size=parallel_config.world_size, + rank=rank, + distributed_init_method=distributed_init_method, + backend="gloo", + ) + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cpu()) + + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + + def get_cache_block_size_bytes(self) -> int: + """Return the size in bytes of a single KV cache block. + """ + return CPUCacheEngine.get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) + + def get_cpus_id_binding_based_on_numa_nodes(self) -> str: + """Return CPUs id binding based on NUMA nodes. + """ + rank_to_cpus = self.local_omp_cpuid + # Setup OpenMP thread affinity based on NUMA nodes automatically + world_size = self.vllm_config.parallel_config.world_size + libnuma_found = util.find_spec("numa") is not None + psutil_found = util.find_spec("psutil") is not None + if libnuma_found and psutil_found: + import psutil + from numa import info + cpu_count = psutil.cpu_count(logical=False) + cpus_allow_list = psutil.Process().cpu_affinity() + numa_size = info.get_num_configured_nodes() + cpu_count_per_numa = cpu_count // numa_size + num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU, + cpu_count_per_numa // 2) + + # check allow node_to_cpus list + node_to_cpus = [] + for i in range(numa_size): + node_intersect = set( + info.node_to_cpus(i)).intersection(cpus_allow_list) + if bool(node_intersect): + node_to_cpus.append(list(node_intersect)) + + if world_size > len(node_to_cpus): + logger.error( + "Auto thread-binding failed due to " + "world size: %d is larger than " + "allowed NUMA nodes number: %d." + "Please try to bind threads manually.", world_size, + len(node_to_cpus)) + else: + end = cpu_count_per_numa - num_of_reserved_cpu + rank_to_cpus_list = node_to_cpus[self.rank][:end] + rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list) + logger.info("auto thread-binding list: %s", rank_to_cpus) + else: + logger.warning( + "Auto thread-binding is not supported due to " + "the lack of package numa and psutil," + "fallback to no thread-binding. To get better performance," + "please try to manually bind threads.") + return rank_to_cpus diff --git a/vllm/worker/multi_step_tpu_worker.py b/vllm/worker/multi_step_tpu_worker.py new file mode 100644 index 000000000000..ed9f00166615 --- /dev/null +++ b/vllm/worker/multi_step_tpu_worker.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +from typing import Dict, Optional, Tuple + +import torch + +from vllm.distributed import broadcast_tensor_dict +from vllm.sequence import ExecuteModelRequest +from vllm.worker.tpu_model_runner import ModelInputForTPU +from vllm.worker.tpu_worker import TPUWorker +from vllm.worker.worker_base import WorkerInput + + +class MultiStepTPUWorker(TPUWorker): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cached_model_input: Optional[ModelInputForTPU] = None + + def _get_driver_input_and_broadcast( + self, execute_model_req: ExecuteModelRequest + ) -> Tuple[ModelInputForTPU, WorkerInput, Dict[str, torch.Tensor]]: + assert self.is_driver_worker + assert execute_model_req.virtual_engine == 0 + + is_first_multi_step = execute_model_req.is_first_multi_step + is_last_step = execute_model_req.is_last_step + if is_first_multi_step: + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + worker_input = dataclasses.replace( + worker_input, + num_steps=execute_model_req.num_lookahead_slots + 1) + model_input: ModelInputForTPU = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) + + if execute_model_req.async_callback: + model_input = dataclasses.replace( + model_input, + async_callback=execute_model_req.async_callback) + else: + assert self.cached_model_input is not None + model_input = self.cached_model_input + worker_input = WorkerInput() + model_input = dataclasses.replace( + model_input, + is_first_multi_step=is_first_multi_step, + is_last_step=is_last_step) + + if self.do_metadata_broadcast: + if is_first_multi_step: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update( + model_input.as_broadcastable_tensor_dict()) + broadcast_tensor_dict(broadcast_data, src=0) + else: + broadcast_data = { + "is_first_multi_step": is_first_multi_step, + "is_last_step": is_last_step, + } + broadcast_tensor_dict(broadcast_data, src=0) + + # Retuning empty dict here to keep this compatible with + # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast` + return model_input, worker_input, {} + + def prepare_input( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[Tuple[ModelInputForTPU, WorkerInput, Dict[str, + torch.Tensor]]]: + if self.is_driver_worker: + if execute_model_req is None: + if self.do_metadata_broadcast: + broadcast_tensor_dict({}, src=0) + return None + + model_input, worker_input, _ = self._get_driver_input_and_broadcast( + execute_model_req) + if model_input.is_first_multi_step: + self.cached_model_input = model_input + return model_input, worker_input, {} + else: + broadcast_data = broadcast_tensor_dict(src=0) + if not broadcast_data: + return None + + if len(broadcast_data) == 2: + assert self.cached_model_input is not None + self.cached_model_input = dataclasses.replace( + self.cached_model_input, + is_first_multi_step=broadcast_data["is_first_multi_step"], + is_last_step=broadcast_data["is_last_step"]) + empty_worker_input = WorkerInput() + return self.cached_model_input, empty_worker_input, {} + + worker_input = WorkerInput.from_broadcasted_tensor_dict( + broadcast_data) + model_input = ( + self.model_runner. + make_model_input_from_broadcasted_tensor_dict(broadcast_data)) + self.cached_model_input = model_input + return model_input, worker_input, {} diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py new file mode 100644 index 000000000000..336bc0bcec36 --- /dev/null +++ b/vllm/worker/tpu_model_runner.py @@ -0,0 +1,909 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import enum +import time +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Type, Union) +from unittest.mock import patch + +import numpy as np +import torch +import torch.nn as nn +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context, set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + Logprob, SequenceGroupMetadata, SequenceOutput) +from vllm.worker.model_runner_base import ( + ModelRunnerBase, ModelRunnerInputBase, + _add_attn_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + +# Here we utilize the behavior that out-of-bound index is ignored. +# FIXME(woosuk): Find a more reliable way to prevent possible bugs. +_PAD_SLOT_ID = 1_000_000_000 +# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. +_ENABLE_TOP_P = False +# FIXME(woosuk): A temporary hack to support `n > 1`. +# This can significantly affect the performance if too large. +_MAX_NUM_SAMPLES = 128 + + +class ExecutionMode(enum.Enum): + PREFILL = enum.auto() + DECODE = enum.auto() + PREFIX_PREFILL = enum.auto() + + def is_prefill(self) -> bool: + return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) + + +@dataclass(frozen=True) +class ModelInputForTPU(ModelRunnerInputBase): + token_ids: torch.Tensor + position_ids: torch.Tensor + attn_metadata: AttentionMetadata + input_lens: torch.Tensor + t: torch.Tensor + p: torch.Tensor + num_samples: int + n: List[int] + seq_groups: List[List[int]] + is_first_multi_step: bool = True + is_last_step: bool = True + virtual_engine: int = 0 + async_callback: Optional[Callable] = None + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = { + "token_ids": self.token_ids, + "position_ids": self.position_ids, + "input_lens": self.input_lens, + "t": self.t, + "p": self.p, + "num_samples": self.num_samples, + "n": self.n, + "seq_groups": self.seq_groups, + "is_first_multi_step": self.is_first_multi_step, + "is_last_step": self.is_last_step, + "virtual_engine": self.virtual_engine, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type["ModelInputForTPU"], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForTPU": + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): + + def __init__( + self, + vllm_config: VllmConfig, + is_driver_worker: bool = False, + ): + ModelRunnerBase.__init__(self, vllm_config=vllm_config) + self.is_driver_worker = is_driver_worker + + self.block_size = self.cache_config.block_size + self.max_num_blocks_per_seq = (self.model_config.max_model_len // + self.block_size) + self.block_tables = np.zeros( + (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq), + dtype=np.int32) + self.attn_backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.cache_config.cache_dtype, + self.block_size, + self.model_config.is_attention_free, + False, + ) + self.cached_step_outputs: List[torch.Tensor] = [] + + smem_size = 512 * 1024 + block_table_size = 4 * self.block_tables.size + if block_table_size >= smem_size: + logger.warning( + "The max_model_len (%d) is too large. This may degrade the " + "performance due to the insufficient smem size. Consider " + "setting --max-model-len to a smaller value, like %d.", + self.model_config.max_model_len, + self.model_config.max_model_len / + (block_table_size / smem_size)) + + def load_model(self) -> None: + self.device = self.device_config.device + + # NOTE(woosuk): While the executor assigns the TP ranks to the worker + # process, the ranks can be different from the ranks internally assigned + # by the xm runtime. Therefore, there is a mismatch in the rank + # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. + # This is not a problem in linear layers because all-reduce is + # rank-agnostic. However, it matters for all-gather as the ranks + # determine the order of concatenating the output tensors. + # As a workaround, we use the xm's rank assignment only when loading + # the embedding weights. + xm_tp_rank = xr.global_ordinal() + with patch( + "vllm.model_executor.layers.vocab_parallel_embedding." + "get_tensor_model_parallel_rank", + return_value=xm_tp_rank): + model = get_model(vllm_config=self.vllm_config) + model = model.eval() + xm.wait_device_ops() + model = ModelWrapper(model) + self.model = torch.compile(model, + backend="openxla", + fullgraph=True, + dynamic=False) + + def get_model(self) -> nn.Module: + return self.model.model + + def _dummy_run( + self, + batch_size: int, + seq_len: int, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + exec_mode: ExecutionMode, + ) -> None: + exec_mode = ExecutionMode(exec_mode) + if exec_mode.is_prefill(): + seq_len = (seq_len + 15) // 16 * 16 + token_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((batch_size, seq_len), + dtype=torch.int64, + device=self.device) + input_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + if exec_mode == ExecutionMode.PREFILL: + attn_metadata = self.attn_backend.make_metadata( + num_prefills=batch_size, + num_prefill_tokens=batch_size * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + block_tables=None, + context_lens=None, + effective_query_lens=None, + ) + else: + context_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + block_tables = torch.tensor(self.block_tables[:batch_size], + dtype=torch.int32, + device=self.device) + effective_query_lens = torch.ones_like(context_lens) + attn_metadata = self.attn_backend.make_metadata( + num_prefills=batch_size, + num_prefill_tokens=batch_size * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + block_tables=block_tables, + context_lens=context_lens, + effective_query_lens=effective_query_lens, + ) + else: + assert seq_len == 1 + token_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((batch_size, seq_len), + dtype=torch.int64, + device=self.device) + block_tables = torch.zeros( + (batch_size, self.max_num_blocks_per_seq), + dtype=torch.int32, + device=self.device) + context_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + input_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size * seq_len, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + block_tables=block_tables, + context_lens=context_lens, + ) + t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) + p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) + num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1 + + # NOTE(woosuk): There are two stages of compilation: torch.compile and + # XLA compilation. Using `mark_dynamic` can reduce the torch.compile + # overhead by reusing the FX graph for different shapes. + # However, the XLA graph will still require static shapes and needs to + # be re-compiled for every different shapes. This overhead is inevitable + # in the first run, but can be skipped afterwards as we cache the XLA + # graphs in the disk (VLLM_XLA_CACHE_PATH). + if exec_mode.is_prefill(): + # Prefll + torch._dynamo.mark_dynamic(token_ids, 1) + torch._dynamo.mark_dynamic(position_ids, 1) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) + else: + # Decode + torch._dynamo.mark_dynamic(token_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(input_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + torch._dynamo.mark_dynamic(t, 0) + torch._dynamo.mark_dynamic(p, 0) + # Dummy run. + with set_forward_context(attn_metadata, self.vllm_config, 0): + self.model(token_ids, position_ids, input_lens, t, p, num_samples, + kv_caches) + + def warmup_model( + self, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> None: + # Prefill + logger.info("Compiling the model with different input shapes...") + start = time.time() + for batch_size in [1]: + seq_len = 16 + while seq_len <= self.model_config.max_model_len: + self._dummy_run(batch_size, + seq_len, + kv_caches, + exec_mode=ExecutionMode.PREFILL) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) + num_tokens = batch_size * seq_len + if num_tokens >= self.scheduler_config.max_num_batched_tokens: + break + seq_len = seq_len * 2 + + end = time.time() + logger.info("Compilation for prefill done in %.2f s.", end - start) + + # Prefix prefill + if self.cache_config.enable_prefix_caching: + logger.info("Compiling the model with different input shapes for " + "prefix prefill...") + start = time.time() + for batch_size in [1]: + seq_len = 16 + while seq_len <= self.model_config.max_model_len: + self._dummy_run(batch_size, + seq_len, + kv_caches, + exec_mode=ExecutionMode.PREFIX_PREFILL) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, + seq_len) + num_tokens = batch_size * seq_len + if (num_tokens + >= self.scheduler_config.max_num_batched_tokens): + break + seq_len = seq_len * 2 + end = time.time() + logger.info("Compilation for prefix prefill done in %.2f s.", + end - start) + + # Decode + start = time.time() + seq_len = 1 + batch_size = 8 # Must be in sync with _get_padded_batch_size() + while True: + self._dummy_run(batch_size, + seq_len, + kv_caches, + exec_mode=ExecutionMode.DECODE) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) + + if batch_size >= self.scheduler_config.max_num_seqs: + break + batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 + + end = time.time() + logger.info("Compilation for decode done in %.2f s.", end - start) + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + prompt_lens: List[int] = [] + context_lens: List[int] = [] + slot_mapping: List[int] = [] + + for batch_idx, seq_group_metadata in enumerate( + seq_group_metadata_list): + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + # Could include output tokens when a request is preempted. + prompt_tokens = seq_data.get_token_ids() + seq_len = len(prompt_tokens) + + num_computed_blocks = len(seq_group_metadata.computed_block_nums) + num_computed_tokens = num_computed_blocks * self.block_size + if num_computed_tokens > 0: + prompt_tokens = prompt_tokens[num_computed_tokens:] + context_lens.append(seq_len) + else: + context_lens.append(0) + + prompt_len = len(prompt_tokens) + prompt_lens.append(prompt_len) + + input_tokens.extend(prompt_tokens) + input_positions.extend(range(num_computed_tokens, seq_len)) + + assert seq_group_metadata.block_tables is not None + block_table = seq_group_metadata.block_tables[seq_id] + for i in range(num_computed_tokens, seq_len): + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + if num_computed_tokens > 0: + self.block_tables[batch_idx, :len(block_table)] = block_table + + # Add paddings to EACH prompt to the smallest power of 2 that is + # greater than or equal to the prompt length. + # We pad the seq_len to reduce the compilation overhead. + # We execute each prompt individually (i.e., with batch_size 1) + # because the FlashAttention kernel does not support ragged inputs. + # TODO(woosuk): Use SplashAttention to support ragged inputs. + padded_prompt_len = _get_padded_prefill_len(prompt_len) + num_paddings = padded_prompt_len - prompt_len + input_tokens += [0] * num_paddings + input_positions += [0] * num_paddings + slot_mapping += [_PAD_SLOT_ID] * num_paddings + + assert len(prompt_lens) > 0 + num_prefills = len(prompt_lens) + input_tokens = torch.tensor(input_tokens, + dtype=torch.int32, + device="cpu") + input_positions = torch.tensor(input_positions, + dtype=torch.int32, + device="cpu") + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.int64, + device="cpu") + prompt_lens = torch.tensor(prompt_lens, + dtype=torch.int32, + device="cpu") + context_lens = torch.tensor(context_lens, + dtype=torch.int32, + device="cpu") + block_tables = torch.tensor(self.block_tables[:num_prefills], + dtype=torch.int32, + device="cpu") + attn_metadata = self.attn_backend.make_metadata( + num_prefills=num_prefills, + num_prefill_tokens=0, # NOTE: This is not used. + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + block_tables=block_tables, + context_lens=context_lens, + effective_query_lens=prompt_lens, + ) + return input_tokens, input_positions, attn_metadata, prompt_lens + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + context_lens: List[int] = [] + + batch_idx = 0 + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append([generation_token]) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append([position]) + context_lens.append(seq_len) + + assert seq_group_metadata.block_tables is not None + block_table = seq_group_metadata.block_tables[seq_id] + self.block_tables[batch_idx, :len(block_table)] = block_table + batch_idx += 1 + + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append([slot]) + + batch_size = _get_padded_batch_size(batch_idx) + num_paddings = batch_size - batch_idx + input_tokens = input_tokens + [[0]] * num_paddings + input_positions = input_positions + [[0]] * num_paddings + slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings + context_lens = context_lens + [0] * num_paddings + + input_tokens = torch.tensor(input_tokens, + dtype=torch.int32, + device="cpu") + input_positions = torch.tensor(input_positions, + dtype=torch.int32, + device="cpu") + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.int64, + device="cpu") + context_lens = torch.tensor(context_lens, + dtype=torch.int32, + device="cpu") + block_tables = torch.tensor(self.block_tables[:batch_size], + dtype=torch.int32, + device="cpu") + input_lens = torch.tensor([1] * batch_size, + dtype=torch.int32, + device="cpu") + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + block_tables=block_tables, + context_lens=context_lens, + ) + return input_tokens, input_positions, attn_metadata, input_lens + + def _prepare_sample( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + padded_batch_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: + assert len(seq_group_metadata_list) > 0 + t = [] + p = [] + n = [] + for seq_group_metadata in seq_group_metadata_list: + sampling_params = seq_group_metadata.sampling_params + t.append(sampling_params.temperature) + if sampling_params.top_p != 1 and not _ENABLE_TOP_P: + raise NotImplementedError( + "Top-p sampling is currently disabled for the TPU backend " + "due to performance issues.") + p.append(sampling_params.top_p) + if sampling_params.top_k > 0: + raise NotImplementedError( + "Top-k sampling is currently disabled for the TPU backend " + "due to performance issues.") + if sampling_params.n > _MAX_NUM_SAMPLES: + raise NotImplementedError( + f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU " + "backend.") + n.append(sampling_params.n) + if sampling_params.logprobs is not None: + raise NotImplementedError( + "logprobs is not currently supported by the TPU backend.") + if sampling_params.prompt_logprobs is not None: + raise NotImplementedError( + "prompt_logprobs is not currently supported by the TPU " + "backend.") + + # Repeat the sampling params if the seq group has multiple seqs. + num_seqs = len(seq_group_metadata.seq_data) + t += [t[-1]] * (num_seqs - 1) + p += [p[-1]] * (num_seqs - 1) + n += [n[-1]] * (num_seqs - 1) + + num_paddings = padded_batch_size - len(t) + t += [1.0] * num_paddings + p += [1.0] * num_paddings + + t = torch.tensor(t, dtype=torch.float32, device="cpu") + p = torch.tensor(p, dtype=torch.float32, device="cpu") + return t, p, n + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None, + ) -> ModelInputForTPU: + del finished_requests_ids # Unused. + assert virtual_engine == 0 + assert len(seq_group_metadata_list) > 0 + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + if is_prompt: + inputs = self._prepare_prompt(seq_group_metadata_list) + else: + inputs = self._prepare_decode(seq_group_metadata_list) + input_tokens, input_positions, attn_metadata, input_lens = inputs + padded_batch_size = input_tokens.shape[0] + t, p, n = self._prepare_sample(seq_group_metadata_list, + padded_batch_size) + num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 + + seq_groups = [ + list(metadata.seq_data.keys()) + for metadata in seq_group_metadata_list + ] + return ModelInputForTPU(input_tokens, input_positions, attn_metadata, + input_lens, t, p, num_samples, n, seq_groups) + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU: + model_input = ModelInputForTPU.from_broadcasted_tensor_dict( + tensor_dict, attn_backend=self.attn_backend) + return model_input + + @torch.no_grad() + def execute_model( + self, + model_input: ModelInputForTPU, + kv_caches: Optional[List[Any]], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> List[SamplerOutput]: + assert intermediate_tensors is None + if not model_input.is_first_multi_step: + if not model_input.is_last_step: + return [] + + use_async_out_proc = model_input.async_callback is not None + sampler_outputs = [] + num_outputs = len(self.cached_step_outputs) + for i in range(num_outputs): + next_token_ids = self.cached_step_outputs.pop(0) + next_token_ids = next_token_ids.cpu().tolist() + sampler_output = _make_decode_output(next_token_ids, + model_input.seq_groups) + sampler_outputs.append(sampler_output) + + if i < num_outputs - 1 and use_async_out_proc: + assert model_input.async_callback is not None + ctx = model_input.async_callback.keywords[ # type: ignore + "ctx"] + ctx.append_output( + outputs=[sampler_output], + seq_group_metadata_list=ctx.seq_group_metadata_list, + scheduler_outputs=ctx.scheduler_outputs, + is_async=False, + is_last_step=False, + is_first_step_output=i == 0) + model_input.async_callback() + if use_async_out_proc: + return [sampler_outputs[-1]] + else: + return sampler_outputs + + is_prompt = model_input.attn_metadata.num_prefills > 0 + if is_prompt: + assert num_steps == 1 + # NOTE(woosuk): Since the FlashAttention kernel does not support + # ragged inputs, we split the prompts into different batches and + # process them separately. This is a temporary hack that should be + # optimized by using SplashAttention. + orig_slot_mapping = model_input.attn_metadata.slot_mapping + orig_block_tables = model_input.attn_metadata.block_tables + orig_context_lens = model_input.attn_metadata.context_lens + orig_effective_query_lens = \ + model_input.attn_metadata.effective_query_lens + batch_size = model_input.input_lens.shape[0] + start_idx = 0 + next_token_ids = [] + for i in range(batch_size): + # Get the actual prefill_len. + prefill_len = model_input.input_lens[i:i + 1].item() + prefill_len = _get_padded_prefill_len(prefill_len) + end_idx = start_idx + prefill_len + + token_ids = model_input.token_ids[None, start_idx:end_idx].to( + self.device) + position_ids = model_input.position_ids[None, + start_idx:end_idx].to( + self.device) + attn_metadata = model_input.attn_metadata + attn_metadata.num_prefills = 1 + attn_metadata.slot_mapping = orig_slot_mapping[ + None, start_idx:end_idx].to(self.device) + if orig_context_lens[i].item() > 0: + attn_metadata.context_lens = orig_context_lens[i:i + 1].to( + self.device) + attn_metadata.block_tables = orig_block_tables[ + i].unsqueeze(0).to(self.device) + attn_metadata.effective_query_lens = \ + orig_effective_query_lens[i:i + 1].to(self.device) + else: + attn_metadata.context_lens = None + attn_metadata.block_tables = None + attn_metadata.effective_query_lens = None + input_lens = model_input.input_lens[i:i + 1].to(self.device) + t = model_input.t[i:i + 1].to(self.device) + p = model_input.p[i:i + 1].to(self.device) + with set_forward_context(model_input.attn_metadata, + self.vllm_config, + model_input.virtual_engine): + output_token_ids = self.model(token_ids, position_ids, + input_lens, t, p, + model_input.num_samples, + kv_caches) + next_token_ids.append(output_token_ids[0]) + start_idx = end_idx + + if model_input.async_callback is not None: + model_input.async_callback() + # Retrieve the outputs to CPU. + next_token_ids = [ + output_token_ids.cpu().tolist() + for output_token_ids in next_token_ids + ] + + # NOTE(woosuk): Minimal code to construct the sampler outputs. + # The TPU backend does not reuse the sampler, since the TPU backend + # does not support advanced sampling parameters such as logprobs. + zero_logprob = Logprob(0.0) + sampler_outputs = [] + for i, seq_group in enumerate(model_input.seq_groups): + seq_ids = seq_group + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + seq_outputs = [] + for j in range(model_input.n[i]): + next_token_id = next_token_ids[i][j] + seq_outputs.append( + SequenceOutput(seq_id, next_token_id, + {next_token_id: zero_logprob})) + sampler_outputs.append( + CompletionSequenceGroupOutput(seq_outputs, None)) + return [SamplerOutput(sampler_outputs)] + else: + token_ids = model_input.token_ids.to(self.device) + position_ids = model_input.position_ids.to(self.device) + attn_metadata = model_input.attn_metadata + attn_metadata.slot_mapping = attn_metadata.slot_mapping.to( + self.device) + attn_metadata.block_tables = attn_metadata.block_tables.to( + self.device) + attn_metadata.context_lens = attn_metadata.context_lens.to( + self.device) + t = model_input.t.to(self.device) + p = model_input.p.to(self.device) + input_lens = model_input.input_lens.to(self.device) + for i in range(num_steps): + slot_mapping = attn_metadata.slot_mapping + with set_forward_context(model_input.attn_metadata, + self.vllm_config, + model_input.virtual_engine): + output_token_ids = self.model(token_ids, position_ids, + input_lens, t, p, + model_input.num_samples, + kv_caches) + self.cached_step_outputs.append(output_token_ids) + + if i < num_steps - 1: + # Prepare the inputs for the next step. + token_ids = output_token_ids.unsqueeze(dim=1).int() + position_ids = position_ids + 1 + attn_metadata.context_lens = attn_metadata.context_lens + 1 + + block_tables = attn_metadata.block_tables + block_number = block_tables.gather( + 1, + position_ids.long() // self.block_size) + block_offset = position_ids % self.block_size + + is_padding = slot_mapping == _PAD_SLOT_ID + slot_mapping = block_number * self.block_size + block_offset + slot_mapping = slot_mapping.long() + slot_mapping = torch.where(is_padding, _PAD_SLOT_ID, + slot_mapping) + attn_metadata.slot_mapping = slot_mapping + + if model_input.async_callback is not None: + model_input.async_callback() + + if num_steps > 1: + return [] + # Retrieve the outputs to CPU. + next_token_ids = self.cached_step_outputs.pop(0) + next_token_ids = next_token_ids.cpu().tolist() + sampler_output = _make_decode_output(next_token_ids, + model_input.seq_groups) + return [sampler_output] + + +class ModelWrapper(nn.Module): + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + def forward( + self, + token_ids: torch.Tensor, + position_ids: torch.Tensor, + input_lens: torch.Tensor, + t: torch.Tensor, + p: torch.Tensor, + num_samples: int, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> torch.Tensor: + """Executes the forward pass of the model and samples the next token. + + Args: + token_ids: The input token IDs of shape [batch_size, seq_len]. + position_ids: The input position IDs of shape [batch_size, seq_len]. + input_lens: The actual input lengths of shape [batch_size]. + t: The sampling temperature of shape [batch_size]. + p: The top-p probability of shape [batch_size]. + num_samples: Number of samples to draw from each logits vector. + kv_caches: The key and value caches. They can be None during the + memory profiling at initialization. + """ + batch_size, seq_len = token_ids.shape + # Calculate the positions to sample from. + start_indices = torch.arange( + batch_size, dtype=torch.int32, device=input_lens.device) * seq_len + logits_indices = start_indices + input_lens - 1 + attn_metadata = get_forward_context().attn_metadata + + # FIXME(woosuk): This is a temporary hack to avoid using the existing + # sampler and sampling metadata. + sampling_metadata = SamplingMetadata( + seq_groups=[], + selected_token_indices=logits_indices, + categorized_sample_indices={}, + num_prompts=attn_metadata.num_prefills, + ) + + # Skip this in memory profiling at initialization. + if kv_caches[0][0].numel() > 0: + # index_copy_(slot_mapping) only works when the inserted dimension + # is 0. However, the KV cache in the Pallas backend has the shape + # [num_kv_heads, num_blocks, block_size, head_size]. To make it + # work, we need to flatten the first three dimensions and modify + # the slot_mapping accordingly. + num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape + slot_mapping = attn_metadata.slot_mapping + slot_mapping = slot_mapping.flatten() + head_indices = torch.arange(0, + num_kv_heads, + device=slot_mapping.device, + dtype=slot_mapping.dtype) + head_indices *= block_size * num_blocks + slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( + -1, num_kv_heads) + slot_mapping = slot_mapping + head_indices.view(1, -1) + slot_mapping = slot_mapping.flatten() + attn_metadata.slot_mapping = slot_mapping + + hidden_states = self.model(token_ids, position_ids) + hidden_states = hidden_states.flatten(0, 1) + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Argmax sampling. + argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + argmax_token_ids = argmax_token_ids.repeat(1, num_samples) + + # Zero temperature means greedy decoding. Avoid division by zero. + nonzero_t = torch.where(t != 0, t, 1.0) + logits = logits / nonzero_t.unsqueeze(dim=1) + if _ENABLE_TOP_P: + logits = _apply_top_p(logits, p.unsqueeze(dim=1)) + + # Random sampling. + probs = torch.softmax(logits, dim=-1, dtype=torch.float32) + sampled_token_ids = torch.multinomial(probs, + num_samples, + replacement=True) + if num_samples == 1: + argmax_token_ids = argmax_token_ids.squeeze(dim=-1) + sampled_token_ids = sampled_token_ids.squeeze(dim=-1) + next_token_ids = torch.where(t != 0, sampled_token_ids, + argmax_token_ids) + return next_token_ids + + +def _get_padded_prefill_len(x: int) -> int: + # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence + # length to be a multiple of 16. We pad the prompt length to the nearest + # multiple of 16. This is also good for performance. + if x <= 16: + return 16 + return 1 << (x - 1).bit_length() + + +def _get_padded_batch_size(batch_size: int) -> int: + # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. + # To meet this requirement in the simplest way, we set the minimal batch + # size to 8. + if batch_size <= 8: + return 8 + else: + return ((batch_size + 15) // 16) * 16 + + +def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + logits_sorted = torch.sort(logits, dim=-1, descending=True).values + sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1) + cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True) + cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index) + logits = logits.masked_fill_(logits < cutoff_logit, -float("inf")) + return logits + + +def _make_decode_output( + next_token_ids: List[int], + seq_groups: List[List[int]], +) -> SamplerOutput: + zero_logprob = Logprob(0.0) + sampler_outputs = [] + batch_idx = 0 + for seq_group in seq_groups: + seq_ids = seq_group + seq_outputs = [] + for seq_id in seq_ids: + next_token_id = next_token_ids[batch_idx] + seq_outputs.append( + SequenceOutput(seq_id, next_token_id, + {next_token_id: zero_logprob})) + batch_idx += 1 + sampler_outputs.append(CompletionSequenceGroupOutput( + seq_outputs, None)) + return SamplerOutput(sampler_outputs) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py new file mode 100644 index 000000000000..ad5ed19e2f89 --- /dev/null +++ b/vllm/worker/tpu_worker.py @@ -0,0 +1,337 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from typing import List, Optional, Tuple, Union + +import torch +import torch_xla.core.xla_model as xm +import torch_xla.debug.profiler as xp +import torch_xla.runtime as xr + +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.sequence import ExecuteModelRequest +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size +from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, + LoRANotSupportedWorkerBase, WorkerBase, + WorkerInput) + +logger = init_logger(__name__) + + +class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool, + ) -> None: + WorkerBase.__init__(self, vllm_config=vllm_config) + self.parallel_config.rank = rank + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker + + assert self.device_config.device_type == "tpu" + if self.cache_config.cache_dtype == "auto": + self.cache_dtype = self.model_config.dtype + else: + self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + self.cache_config.cache_dtype] + + self.model_runner: TPUModelRunner = TPUModelRunner( + vllm_config=vllm_config, is_driver_worker=is_driver_worker) + + if self.model_config.seed is None: + self.model_config.seed = 0 + + if vllm_config.lora_config is not None: + raise NotImplementedError( + "The V0 TPU backend doesn't support LoRA serving") + + def init_device(self) -> None: + os.environ["PJRT_DEVICE"] = "TPU" + torch.set_grad_enabled(False) + torch.set_default_dtype(self.model_config.dtype) + + # NOTE(woosuk): This is just to initialize the TP group and broadcast + # the input objects on CPU. The all-reduce and all-gather ops on TPU + # are invoked by `xm.all_reduce` and `xm.all_gather` which use their + # own context. + init_distributed_environment( + world_size=self.parallel_config.world_size, + rank=self.rank, + local_rank=self.local_rank, + distributed_init_method=self.distributed_init_method, + backend="gloo", + ) + ensure_model_parallel_initialized( + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size) + + # Device initialization should happen after initializing the distributed + # runtime. + self.device = xm.xla_device() + self.device_config.device = self.device + + # Set random seed. + set_random_seed(self.model_config.seed) + xm.set_rng_state(self.model_config.seed, self.device) + + # Increase the cache size limit, which is the maximum number of + # dynamo graphs that can be compiled. + # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and + # 30-40 graphs for decode. 128 is an arbitrary safe number. + torch._dynamo.config.cache_size_limit = 128 + # Use persistent cache to avoid XLA recompilation. + # NOTE(woosuk): Set per-rank cache path since different ranks + # can have slightly different XLA graphs. + world_size = self.parallel_config.world_size + rank = xr.global_ordinal() + # The PyTorch/XLA compilation cache uses the Torch IR to generate keys. + # Consequently, changes in optimization flags, which affect compilation + # results, don't change the cache key. This can result in the wrong + # compilation being used. To prevent this, disabling the XLA compilation + # cache during development is recommended.We can disable it by + # `export VLLM_XLA_CACHE_PATH=` + if envs.VLLM_XLA_CACHE_PATH: + per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, + f"tp{world_size}_rank{rank}") + xr.initialize_cache(per_rank_path, readonly=False) + + self.profiler = None + if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1: + # For TPU, we can only have 1 active profiler session for 1 profiler + # server. So we only profile on rank0. + self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + self.profile_dir) + self.profiler = xp.start_server(9012) + + def start_profile(self): + if self.rank < 1: + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + xp.start_trace(self.profile_dir) + + def stop_profile(self): + if self.rank < 1: + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + xp.stop_trace() + + def load_model(self): + self.model_runner.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + num_layers = self.model_config.get_num_layers(self.parallel_config) + head_size = self.model_config.get_head_size() + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + kv_caches = [(torch.tensor([], dtype=torch.float32, + device=self.device), + torch.tensor([], dtype=torch.float32, + device=self.device)) + for _ in range(num_layers)] + bind_kv_cache(self.compilation_config.static_forward_context, + [kv_caches]) + self.model_runner._dummy_run( + batch_size=1, + seq_len=self.scheduler_config.max_num_batched_tokens, + kv_caches=kv_caches, + exec_mode=ExecutionMode.PREFILL, + ) + # Synchronize before measuring the memory usage. + xm.wait_device_ops() + + # Get the maximum amount of memory used by the model weights and + # intermediate activations. + m = xm.get_memory_info(self.device) + total_memory_size = m["bytes_limit"] + profiled = m["peak_bytes_used"] # Weights + intermediate activations. + + # Calculate the TPU KV cache size based on profiling. + usable_memory_size = int(total_memory_size * + self.cache_config.gpu_memory_utilization) + tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) + dtype_bytes = get_dtype_size(self.cache_dtype) + block_size_bytes = (dtype_bytes * self.cache_config.block_size * + num_layers * 2 * head_size * num_kv_heads) + num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes + num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8. + + # Calculate the CPU KV cache size based on the config. + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + block_size_bytes) + num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8. + return num_tpu_blocks, num_cpu_blocks + + def initialize_cache( + self, + num_gpu_blocks: int, + num_cpu_blocks: int, + ) -> None: + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + self.block_size = self.cache_config.block_size + + dtype = self.cache_dtype + num_layers = self.model_config.get_num_layers(self.parallel_config) + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + head_size = self.model_config.get_head_size() + + self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] + self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] + tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( + num_gpu_blocks, self.block_size, num_kv_heads, head_size) + cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( + num_cpu_blocks, self.block_size, num_kv_heads, head_size) + for _ in range(num_layers): + tpu_k_cache = torch.zeros(tpu_cache_shape, + dtype=dtype, + device=self.device) + tpu_v_cache = torch.zeros_like(tpu_k_cache) + self.tpu_cache.append((tpu_k_cache, tpu_v_cache)) + cpu_k_cache = torch.zeros(cpu_cache_shape, + dtype=dtype, + device="cpu") + cpu_v_cache = torch.zeros_like(cpu_k_cache) + self.cpu_cache.append((cpu_k_cache, cpu_v_cache)) + bind_kv_cache(self.compilation_config.static_forward_context, + [self.tpu_cache]) + self._warmup_model() + + def _warmup_model(self) -> None: + # FIXME(woosuk): Here we are abusing `enforce_eager` which is defined + # for CUDA graphs. We should refactor this part. + if not self.model_config.enforce_eager: + # Warm up the model with all possible input shapes so that + # compilation never happens during the actual execution. + # This may take ~30 mins for the first run and ~20 mins for the + # subsequent runs. + # If `enforce_eager` is True, the ahead-of-time compilation is + # skipped and the compilation happens during the actual execution, + # which is bad for performance but useful for development. + self.model_runner.warmup_model(self.tpu_cache) + + def get_cache_block_size_bytes(self) -> int: + head_size = self.model_config.get_head_size() + num_heads = self.model_config.get_num_kv_heads(self.parallel_config) + num_layers = self.model_config.get_num_layers(self.parallel_config) + + key_cache_block = self.cache_config.block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_layers * (key_cache_block + value_cache_block) + dtype_size = get_dtype_size(self.cache_dtype) + return dtype_size * total + + @property + def do_metadata_broadcast(self) -> bool: + return self.parallel_config.tensor_parallel_size > 1 + + @property + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: + # NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline + # parallelism. + return [self.tpu_cache] + + def prepare_worker_input( + self, + execute_model_req: ExecuteModelRequest, + ) -> WorkerInput: + virtual_engine = execute_model_req.virtual_engine + num_seq_groups = len(execute_model_req.seq_group_metadata_list) + blocks_to_swap_in = _make_src_to_dst( + execute_model_req.blocks_to_swap_in, "cpu", self.device) + blocks_to_swap_out = _make_src_to_dst( + execute_model_req.blocks_to_swap_out, self.device, "cpu") + blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy, + self.device, self.device) + return WorkerInput( + num_seq_groups=num_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + virtual_engine=virtual_engine, + ) + + def execute_worker(self, worker_input: WorkerInput) -> None: + virtual_engine = worker_input.virtual_engine + assert virtual_engine == 0 + attn_backend = self.model_runner.attn_backend + num_layers = self.model_config.get_num_layers(self.parallel_config) + + # Issue cache operations. + if worker_input.blocks_to_swap_in is not None: + src_indices, dst_indices = worker_input.blocks_to_swap_in + if src_indices.numel() > 0: + # Swap from CPU to TPU. + for i in range(num_layers): + tpu_k_cache, tpu_v_cache = self.tpu_cache[i] + cpu_k_cache, cpu_v_cache = self.cpu_cache[i] + k = cpu_k_cache[:, src_indices].to(self.device) + v = cpu_v_cache[:, src_indices].to(self.device) + _insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache) + + if worker_input.blocks_to_swap_out is not None: + src_indices, dst_indices = worker_input.blocks_to_swap_out + if src_indices.numel() > 0: + # Swap from TPU to CPU. + for i in range(num_layers): + tpu_k_cache, tpu_v_cache = self.tpu_cache[i] + cpu_k_cache, cpu_v_cache = self.cpu_cache[i] + cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices] + cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices] + + if worker_input.blocks_to_copy is not None: + src_indices, dst_indices = worker_input.blocks_to_copy + if src_indices.numel() > 0: + attn_backend.copy_blocks(self.tpu_cache, + (src_indices, dst_indices)) + + +def _make_src_to_dst( + mapping: List[Tuple[int, int]], + src_device: Union[torch.device, str], + dst_device: Union[torch.device, str], +) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + if not mapping: + return None + + src_indices = [i for i, _ in mapping] + dst_indices = [i for _, i in mapping] + src_indices = torch.tensor(src_indices, + device=src_device, + dtype=torch.int64) + dst_indices = torch.tensor(dst_indices, + device=dst_device, + dtype=torch.int64) + return src_indices, dst_indices + + +@torch.compile(backend="openxla") +def _insert_kv( + k: torch.Tensor, + v: torch.Tensor, + indices: torch.Tensor, + tpu_k_cache: torch.Tensor, + tpu_v_cache: torch.Tensor, +) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True) + torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True) + tpu_k_cache[:, indices] = k + tpu_v_cache[:, indices] = v diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py new file mode 100644 index 000000000000..b2d3ce8526d5 --- /dev/null +++ b/vllm/worker/xpu_model_runner.py @@ -0,0 +1,606 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +import time +import weakref +from collections import defaultdict +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Type, TypeVar) + +import torch +import torch.nn as nn + +from vllm.attention import get_attn_backend +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group +from vllm.forward_context import set_forward_context +from vllm.inputs import INPUT_REGISTRY, InputRegistry +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadataCache +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.model_loader import get_model +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, + MultiModalKwargs, MultiModalPlaceholderMap, + MultiModalRegistry) +from vllm.sampling_params import SamplingParams +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.utils import DeviceMemoryProfiler, GiB_bytes, make_tensor_with_pad +from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata +from vllm.worker.model_runner_base import ( + ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + +_PAD_SLOT_ID = -1 + +TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU") + + +@dataclass(frozen=True) +class ModelInputForXPU(ModelRunnerInputBase): + """ + Used by the NeuronModelRunner. + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + attn_metadata: Optional["AttentionMetadata"] = None + multi_modal_kwargs: Optional[BatchedTensorInputs] = None + virtual_engine: Optional[int] = None + seq_lens: Optional[List[int]] = None + query_lens: Optional[List[int]] = None + async_callback: Optional[Callable] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type[TModelInputForXPU], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> TModelInputForXPU: + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +@dataclass(frozen=True) +class ModelInputForXPUWithSamplingMetadata(ModelInputForXPU): + """ + Used by the ModelRunner. + """ + sampling_metadata: Optional["SamplingMetadata"] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForXPUWithSamplingMetadata": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): + + def __init__(self, + runner: "XPUModelRunner", + finished_requests_ids: Optional[List[str]] = None) -> None: + super().__init__() + self.runner = runner + self.model_input_cls = self.runner._model_input_cls + self.attn_backend = self.runner.attn_backend + self.sliding_window = self.runner.sliding_window + self.block_size = self.runner.block_size + self.device = self.runner.device + + def prepare(self, + finished_requests_ids: Optional[List[str]] = None) -> None: + self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] + + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): + self.seq_group_metadata_list.append(seq_group_metadata) + + def build(self) -> ModelInputForXPU: + is_prompt = self.seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, attn_metadata, seq_lens, + multi_modal_kwargs) = self._prepare_prompt( + self.seq_group_metadata_list) + else: + (input_tokens, input_positions, + attn_metadata) = self._prepare_decode( + self.seq_group_metadata_list) + seq_lens = None + multi_modal_kwargs = None + + return self.model_input_cls( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + multi_modal_kwargs=multi_modal_kwargs, + seq_lens=seq_lens, + query_lens=seq_lens, + ) + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], + BatchedTensorInputs]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + seq_lens: List[int] = [] + multi_modal_kwargs_list: List[MultiModalKwargs] = [] + multi_modal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + computed_len = seq_data.get_num_computed_tokens() + seq_len = len(prompt_tokens) + + seq_lens.append(seq_len) # Prompt token num + input_tokens.extend(prompt_tokens) # Token ids + + # Token position ids + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + positions_range = range(computed_len, seq_len) + input_positions.extend(list(positions_range)) + + if seq_group_metadata.multi_modal_data: + # NOTE: mm_kwargs only includes the subset of multi-modal items + # that intersect with the current prefill positions. + mm_kwargs, placeholder_maps = MultiModalPlaceholderMap \ + .from_seq_group(seq_group_metadata, positions_range) + + multi_modal_kwargs_list.append(mm_kwargs) + + for modality, placeholder_map in placeholder_maps.items(): + multi_modal_placeholder_maps[modality].extend( + placeholder_map) + + if seq_group_metadata.block_tables is None: + # During memory profiling, the block tables are not initialized + # yet. In this case, we just use a dummy slot mapping. + slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + continue + + # Compute the slot mapping. + block_table = seq_group_metadata.block_tables[seq_id] + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, seq_len - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + start_idx = max(0, seq_len - self.sliding_window) + + for i in range(computed_len, seq_len): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // + self.block_size] # type: ignore + block_offset = i % self.block_size # type: ignore + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + num_prompt_tokens = len(input_tokens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) # type: ignore + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) # type: ignore + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) # type: ignore + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + multi_modal_placeholder_maps.items() + } + + max_seqlen = max(seq_lens) + tmp = [0] + tmp.extend(seq_lens) + seqlen = torch.tensor(tmp) + seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, + seq_lens=seq_lens, + seqlen_q=seqlen_q, + max_seqlen=max_seqlen, + seq_lens_tensor=torch.tensor([]), + max_decode_seq_len=0, + num_prefills=len(seq_lens), + num_prefill_tokens=num_prompt_tokens, + num_decode_tokens=0, + block_tables=torch.tensor([], device=self.device, dtype=torch.int), + ) + + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) + + return (input_tokens, input_positions, attn_metadata, seq_lens, + multi_modal_kwargs) + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + seq_lens: List[int] = [] + block_tables: List[List[int]] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + assert seq_group_metadata.token_chunk_size == 1 + + seq_ids = list(seq_group_metadata.seq_data.keys()) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append(generation_token) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append(position) + + seq_len = seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window) + seq_lens.append(seq_len) + + block_table = seq_group_metadata.block_tables[seq_id] + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + block_tables.append(block_table) + + max_decode_seq_len = max(seq_lens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + + block_tables = make_tensor_with_pad( + block_tables, + pad=0, + dtype=torch.int, + device=self.device, + ) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + seq_lens=seq_lens, + seqlen_q=torch.tensor([]), + max_seqlen=0, + seq_lens_tensor=seq_lens_tensor, + max_decode_seq_len=max_decode_seq_len, + num_prefill_tokens=0, + num_decode_tokens=len(input_tokens), + num_prefills=0, + block_tables=block_tables, + ) + return ( + input_tokens, + input_positions, + attn_metadata, + ) + + +class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): + _model_input_cls: Type[ModelInputForXPUWithSamplingMetadata] = ( + ModelInputForXPUWithSamplingMetadata) + _builder_cls: Type[ModelInputForXPUBuilder] = ModelInputForXPUBuilder + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + return_hidden_states: bool = False, + input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + ): + + ModelRunnerBase.__init__(self, vllm_config=vllm_config) + model_config = self.model_config + cache_config = self.cache_config + self.is_driver_worker = is_driver_worker + self.return_hidden_states = return_hidden_states + + self.device = self.device_config.device + + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + + self.attn_backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + ) + + # Multi-modal data support + self.input_registry = input_registry + self.mm_registry = mm_registry + + # Lazy initialization. + self.model: nn.Module # Set after init_Model + self.sampler = get_sampler() + + self.sampling_metadata_cache: SamplingMetadataCache = \ + SamplingMetadataCache() \ + if self.parallel_config.pipeline_parallel_size == 1 else None + + self.builder = self._builder_cls(weakref.proxy(self)) + + def load_model(self) -> None: + with DeviceMemoryProfiler() as m: + self.model = get_model(vllm_config=self.vllm_config) + + self.model_memory_usage = m.consumed_memory + logger.info("Loading model weights took %.4f GiB", + self.model_memory_usage / GiB_bytes) + + def get_model(self) -> nn.Module: + return self.model + + @property + def vocab_size(self) -> int: + return self.model_config.get_vocab_size() + + @torch.inference_mode() + def profile_run(self) -> None: + # Enable top-k sampling to reflect the accurate memory usage. + sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + # Additional GPU memory may be needed for multi-modal encoding, which + # needs to be accounted for when calculating the GPU blocks for + # vLLM blocker manager. + # To exercise the worst scenario for GPU memory consumption, + # the number of seqs (batch_size) is chosen to maximize the number + # of images processed. + max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( + self.model_config) + if max_mm_tokens > 0: + max_num_seqs_orig = max_num_seqs + max_num_seqs = min(max_num_seqs, + max_num_batched_tokens // max_mm_tokens) + if max_num_seqs < 1: + expr = (f"min({max_num_seqs_orig}, " + f"{max_num_batched_tokens} // {max_mm_tokens})") + logger.warning( + "Computed max_num_seqs (%s) to be less than 1. " + "Setting it to the minimum value of 1.", expr) + max_num_seqs = 1 + + batch_size = 0 + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + batch_size += seq_len + + dummy_data = self.input_registry \ + .dummy_data_for_profiling(self.model_config, + seq_len, + self.mm_registry) + + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: dummy_data.seq_data}, + sampling_params=sampling_params, + block_tables=None, + lora_request=None, + multi_modal_data=dummy_data.multi_modal_data, + multi_modal_placeholders=dummy_data.multi_modal_placeholders) + seqs.append(seq) + + finished_requests_ids = [seq.request_id for seq in seqs] + model_input = self.prepare_model_input( + seqs, finished_requests_ids=finished_requests_ids) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=batch_size, + dtype=self.model_config.dtype, + device=self.device) + self.execute_model(model_input, None, intermediate_tensors) + torch.xpu.synchronize() + return + + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, + Any]) -> ModelInputForXPUWithSamplingMetadata: + return ( + ModelInputForXPUWithSamplingMetadata.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + )) + + def _prepare_model_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForXPUWithSamplingMetadata: + """Helper method to prepare the model input based on a given sequence + group. Prepares metadata needed for the base model forward pass but not + metadata for possible additional steps, e.g., sampling. + + """ + builder = self.builder + builder.prepare(finished_requests_ids) + for seq_group_metadata in seq_group_metadata_list: + builder.add_seq_group(seq_group_metadata) + + return builder.build() # type: ignore + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForXPUWithSamplingMetadata: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + """ + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + # Sampling metadata is only required for the final pp group + generators = self.get_generators(finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + pin_memory=False, + generators=generators, + cache=self.sampling_metadata_cache) + + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + virtual_engine=virtual_engine) + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForXPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + if num_steps > 1: + raise ValueError( + "XPUModelRunner does not support multi-step execution.") + + model_executable = self.model + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_start_time = time.time() + with set_forward_context(model_input.attn_metadata, self.vllm_config, + model_input.virtual_engine): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + intermediate_tensors=intermediate_tensors, + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + device=self.device, + ), + ) + # Compute the logits in the last pipeline stage. + if not get_pp_group().is_last_rank: + return hidden_or_intermediate_states + + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_end_time = time.time() + + # Compute the logits. + logits = self.model.compute_logits(hidden_or_intermediate_states, + model_input.sampling_metadata) + + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return [] + + if model_input.async_callback is not None: + model_input.async_callback() + + # Sample the next token. + output: SamplerOutput = self.sampler( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time + and output is not None): + model_forward_time = (model_forward_end_time - + model_forward_start_time) + # If there are multiple workers, we are still tracking the latency + # from the start time of the driver worker to the end time of the + # driver worker. The model forward time will then end up covering + # the communication time as well. + output.model_forward_time = model_forward_time + + return [output] diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py new file mode 100644 index 000000000000..fe321c059f52 --- /dev/null +++ b/vllm/worker/xpu_worker.py @@ -0,0 +1,186 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""A XPU worker class.""" +import gc +import os +from typing import List, Optional, Tuple + +import intel_extension_for_pytorch # noqa: F401 +import oneccl_bindings_for_pytorch # noqa: F401 +import torch +import torch.distributed + +from vllm.config import VllmConfig +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.platforms import current_platform +from vllm.worker.cache_engine import CacheEngine +from vllm.worker.worker import Worker +from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase +from vllm.worker.xpu_model_runner import XPUModelRunner + +logger = init_logger(__name__) + + +class XPUWorker(LoRANotSupportedWorkerBase, Worker): + """A worker class that executes (a partition of) the model on a GPU. + + Each worker is associated with a single XPU device. The worker is + responsible for maintaining the KV cache and executing the model on the + XPU. In case of distributed inference, each worker is assigned a partition + of the model. + """ + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + ) -> None: + WorkerBase.__init__(self, vllm_config=vllm_config) + device_config = self.device_config + parallel_config = self.parallel_config + assert device_config.device_type == "xpu" + assert current_platform.is_xpu() + + self.parallel_config.rank = rank + + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker + if parallel_config and is_driver_worker: + assert rank % parallel_config.tensor_parallel_size == 0, \ + "Driver worker should be rank 0 of tensor parallel group." + + self.model_runner = XPUModelRunner( # type: ignore + vllm_config=vllm_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=is_driver_worker, + ) + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: List[CacheEngine] + self.gpu_cache: Optional[List[List[torch.Tensor]]] + + def init_device(self) -> None: + if self.device_config.device.type == "xpu" and current_platform.is_xpu( + ): + self.device = torch.device(f"xpu:{self.local_rank}") + torch.xpu.set_device(self.device) + torch.xpu.empty_cache() + self.init_gpu_memory = torch.xpu.get_device_properties( + self.local_rank).total_memory + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + # Initialize the distributed environment. + self.init_worker_distributed_environment() + # Initialize the model. + set_random_seed(self.model_config.seed) + + # keep this method for `empty_cache` and `synchronize` api + @torch.inference_mode() + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + Tip: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.xpu.empty_cache() + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.xpu.synchronize() + used_memory = torch.xpu.memory_allocated() + total_gpu_memory = torch.xpu.get_device_properties( + self.local_rank).total_memory + free_gpu_memory = total_gpu_memory - used_memory + + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. + peak_memory = self.init_gpu_memory - free_gpu_memory + assert peak_memory > 0, ( + "Error in memory profiling. " + f"Initial free memory {self.init_gpu_memory}, current free memory" + f" {free_gpu_memory}. This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance.") + + cache_block_size = self.get_cache_block_size_bytes() + num_gpu_blocks = int( + (total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) + num_gpu_blocks = max(num_gpu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) + gc.collect() + torch.xpu.empty_cache() + return num_gpu_blocks, num_cpu_blocks + + def _warm_up_model(self) -> None: + # IPEX don't support capture graph yet + pass + + def init_worker_distributed_environment(self) -> None: + """Initialize the distributed environment.""" + + parallel_config = self.parallel_config + rank = self.rank + distributed_init_method = self.distributed_init_method + + if torch.distributed.is_initialized(): + torch_world_size = torch.distributed.get_world_size() + if torch_world_size != parallel_config.world_size: + raise RuntimeError( + "torch.distributed is already initialized but the torch " + "world size does not match parallel_config.world_size " + f"({torch_world_size} vs. {parallel_config.world_size}).") + elif not distributed_init_method: + raise ValueError( + "distributed_init_method must be set if torch.distributed " + "is not already initialized") + else: + # use sockets as default Level zero IPC exchange backend. By + # default oneccl will use `drmfd` as mechanism which need extra + # dependency (libdrm and drm headers) on your system. + ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi") + ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE", + str(parallel_config.world_size)) + os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT + os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE + os.environ["LOCAL_RANK"] = str(self.local_rank) + init_distributed_environment( + world_size=parallel_config.world_size, + rank=rank, + distributed_init_method=distributed_init_method, + local_rank=self.local_rank, + backend="ccl") + + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + # global all_reduce needed for overall oneccl warm up + torch.distributed.all_reduce(torch.zeros(1).xpu()) + + if parallel_config.pipeline_parallel_size > 1: + # Add pp group init to avoid + # p2p communication as the first call + get_pp_group().all_reduce(torch.zeros(1).xpu())