From 426734423bc815a717cbc7c90fb4ffbc0f4e7357 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sat, 1 Feb 2025 04:42:02 +0000 Subject: [PATCH 01/38] chunked mla Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/utils.py | 7 +- vllm/attention/backends/triton_mla.py | 94 ++++++++++++++++++++++++++- 2 files changed, 94 insertions(+), 7 deletions(-) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index a41140ec8378..c2f8dc0a2df6 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -442,10 +442,6 @@ def forward( is_decode = attn_metadata.decode_metadata is not None is_prefill = attn_metadata.prefill_metadata is not None - if (is_decode and is_prefill): - raise NotImplementedError( - "chunked prefill is not supported for MLAImplBase") - # Restore head dim (for rotary embedding) k_pe = k_pe.unsqueeze(1) assert hasattr(attn_metadata, "input_positions") @@ -479,7 +475,8 @@ def forward( ) if attn_metadata.prefill_metadata is not None: - return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata) + return self._forward_prefill(q, k_c_normed, k_pe, kv_cache, + attn_metadata) if attn_metadata.decode_metadata is not None: return self._forward_decode(q_nope, q_pe, kv_cache, attn_metadata) diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index 9a1984a931b5..d9e14c487b55 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -6,6 +6,9 @@ from itertools import accumulate from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +import triton +import triton.language as tl + from vllm.multimodal import MultiModalPlaceholderMap try: @@ -647,6 +650,63 @@ def build(self, seq_lens: List[int], query_lens: List[int], ) +@triton.jit +def _gather_kv_cache( + # Pointers to inputs and output + seq_start_locs, # (batch_size + 1,) + block_tables, # (batch_size, max_blocks_per_seq) + block_table_stride, + kv_cache, # (num_blocks, block_size, head_size) + kv_page_stride, + kv_out, + CACHE_PAGE_SIZE: tl.constexpr, + CACHE_ENTRY_SIZE: tl.constexpr, + CACHE_ENTRIES_PER_PAGE: tl.constexpr, + CACHE_PAGE_SIZE_POW_2: tl.constexpr, + CACHE_ENTRY_SIZE_POW_2: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + + seq_start_loc = tl.load(seq_start_locs + g_id) + seq_len = tl.load(seq_start_locs + g_id + 1) - seq_start_loc + + pages_to_copy = tl.cdiv(seq_len, CACHE_ENTRIES_PER_PAGE) + kv_out = kv_out + seq_start_loc * CACHE_ENTRY_SIZE + block_table = block_tables + g_id * block_table_stride + + cache_page_range = tl.arange(0, CACHE_PAGE_SIZE_POW_2) + cache_page_mask = cache_page_range < CACHE_PAGE_SIZE + for i in range(pages_to_copy - 1): + page = tl.load(block_table + i) + page_start = kv_cache + page * kv_page_stride + page_data = tl.load(page_start + cache_page_range, + mask=cache_page_mask) + tl.store(kv_out + i * CACHE_PAGE_SIZE + cache_page_range, + page_data, + mask=cache_page_mask) + + last_page_len = seq_len % CACHE_ENTRIES_PER_PAGE + last_page = tl.load(block_table + pages_to_copy - 1) + last_page_start = kv_cache + last_page * kv_page_stride + + cache_entry_range = tl.arange(0, CACHE_ENTRY_SIZE_POW_2) + cache_entry_mask = cache_entry_range < CACHE_ENTRY_SIZE + kv_out_page = kv_out + (pages_to_copy - 1) * CACHE_PAGE_SIZE + for i in range(last_page_len): + last_page_data = tl.load(last_page_start + \ + i * CACHE_ENTRY_SIZE + cache_entry_range, + mask=cache_entry_mask) + tl.store(kv_out_page + i * CACHE_ENTRY_SIZE + cache_entry_range, + last_page_data, + mask=cache_entry_mask) + + class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): def __init__( @@ -686,12 +746,42 @@ def __init__( def _forward_prefill( self, q: torch.Tensor, - kv_c_normed: torch.Tensor, + kv_c: torch.Tensor, k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: TritonMLAMetadata, ) -> torch.Tensor: assert isinstance(attn_metadata, TritonMLAMetadata) - return self._forward_prefill_flash(q, kv_c_normed, k_pe, + + if attn_metadata.prefill_metadata.context_lens_tensor is not None and \ + max(attn_metadata.prefill_metadata.context_lens_tensor) > 0: + entries_total = attn_metadata.prefill_metadata.seq_start_loc[-1] + kv_c_k_pe_cache = torch.empty( + (entries_total, kv_c_and_k_pe_cache.shape[-1]), + dtype=kv_c_and_k_pe_cache.dtype, + device=kv_c_and_k_pe_cache.device, + ) + + assert kv_c_and_k_pe_cache.shape[-1] == 576 + assert kv_c_and_k_pe_cache.shape[-2] == 16 + _gather_kv_cache[(attn_metadata.num_prefills, )]( + attn_metadata.prefill_metadata.seq_start_loc, + attn_metadata.prefill_metadata.block_tables, + attn_metadata.prefill_metadata.block_tables.stride(0), + kv_c_and_k_pe_cache, + kv_c_and_k_pe_cache.stride(0), + kv_c_k_pe_cache, + CACHE_PAGE_SIZE=576 * 16, + CACHE_ENTRY_SIZE=576, + CACHE_ENTRIES_PER_PAGE=16, + CACHE_ENTRY_SIZE_POW_2=triton.next_power_of_2(576), + CACHE_PAGE_SIZE_POW_2=triton.next_power_of_2(576 * 16), + ) + + kv_c = kv_c_k_pe_cache[..., :self.kv_lora_rank].unsqueeze(1) + k_pe = kv_c_k_pe_cache[..., self.kv_lora_rank:].unsqueeze(1) + + return self._forward_prefill_flash(q, kv_c, k_pe, attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len) From 2821aed511390cd908ae20a091482cfc8546247a Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 5 Feb 2025 19:06:27 +0000 Subject: [PATCH 02/38] add gather cache kernel Signed-off-by: Lucas Wilkinson --- csrc/cache.h | 4 ++ csrc/cache_kernels.cu | 125 ++++++++++++++++++++++++++++++++++++ csrc/cuda_utils.h | 10 +++ csrc/torch_bindings.cpp | 6 ++ tests/kernels/test_cache.py | 72 ++++++++++++++++++++- vllm/_custom_ops.py | 7 ++ 6 files changed, 222 insertions(+), 2 deletions(-) diff --git a/csrc/cache.h b/csrc/cache.h index cf4a65c29055..f21673532071 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -39,3 +39,7 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype); + +void gather_cache(torch::Tensor const& src_cache, torch::Tensor const& dst, + torch::Tensor const& block_table, + torch::Tensor const& cu_seq_lens, int64_t batch_size); \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 0960888d1f75..79d88470d598 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -2,6 +2,7 @@ #include #include +#include "cuda_utils.h" #include "cuda_compat.h" #include "dispatch_utils.h" @@ -570,3 +571,127 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); } } + +namespace vllm { + +// grid (batch, num_splits) +template +__global__ void gather_cache( + const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, + // ENTRIES...] + scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...] + const int32_t __restrict__* block_table, // [BATCH, BLOCK_INDICES] + const int32_t __restrict__* cu_seq_lens, // [BATCH+1] + const int32_t block_size, const int32_t entry_size, + const int64_t block_table_stride, const int64_t cache_block_stride, + const int64_t cache_entry_stride, const int64_t dst_entry_stride) { + const int64_t bid = blockIdx.x; // Batch ID + const int32_t num_splits = gridDim.y; + const int32_t split = blockIdx.y; + const int32_t seq_start = cu_seq_lens[bid]; + const int32_t seq_end = cu_seq_lens[bid + 1]; + const int32_t seq_len = seq_end - seq_start; + const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size); + const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits); + + const int32_t split_start = split * split_blocks; + const int32_t split_end = min((split + 1) * split_blocks, tot_blocks); + + const bool is_active_split = (split_start < tot_blocks); + const bool is_last_split = (split_end == tot_blocks); + + if (!is_active_split) return; + + int32_t full_blocks_end = split_end; + int32_t partial_block_size = 0; + + block_table += bid * block_table_stride; + dst += seq_start * dst_entry_stride; + + if (is_last_split) { + partial_block_size = seq_len % block_size; + if (partial_block_size) full_blocks_end -= 1; + } + + auto copy_entry = [&](const scalar_t* __restrict__ _src, + scalar_t* __restrict__ _dst) { + for (int i = threadIdx.x; i < entry_size; i += blockDim.x) + _dst[i] = _src[i]; + }; + + for (int pid = split_start; pid < full_blocks_end; ++pid) { + auto block_id = block_table[pid]; + auto block_start_ptr = src_cache + block_id * cache_block_stride; + auto block_dst_ptr = dst + pid * block_size * dst_entry_stride; + for (int eid = 0; eid < block_size; ++eid) { + copy_entry(block_start_ptr + eid * cache_entry_stride, + block_dst_ptr + eid * dst_entry_stride); + } + } + + if (partial_block_size) { + auto block_id = block_table[full_blocks_end]; + auto block_start_ptr = src_cache + block_id * cache_block_stride; + auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride; + for (int eid = 0; eid < partial_block_size; ++eid) { + copy_entry(block_start_ptr + eid * cache_entry_stride, + block_dst_ptr + eid * dst_entry_stride); + } + } +} + +} // namespace vllm + +#define CALL_GATHER_CACHE(CPY_DTYPE) \ + vllm::gather_cache<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + block_size, entry_size, block_table_stride, cache_block_stride, \ + cache_entry_stride, dst_entry_stride); + +// Gather sequences from cache into dst tensor based on block_table +// cu_seq_lens is the cumulative sequence lengths of the batch and +// simultaneously used as the start locations in dst for each sequence +// in the batch. +void gather_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + int64_t batch_size) { + at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int32_t block_size = src_cache.size(1); + int32_t entry_size = src_cache.flatten(2, -1).size(2); + + TORCH_CHECK(block_table.dtype() == torch::kInt32, + "block_table must be int32"); + TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32, + "cu_seq_lens must be int32"); + + int64_t block_table_stride = block_table.stride(0); + int64_t cache_block_stride = src_cache.stride(0); + int64_t cache_entry_stride = src_cache.stride(1); + int64_t dst_entry_stride = dst.stride(0); + + int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; + + dim3 grid(batch_size, num_splits); + dim3 block(1024); + + TORCH_CHECK(src_cache.dtype() == dst.dtype(), + "src_cache and dst must have the same dtype"); + + const int dtype_bits = src_cache.element_size() * 8; + if (dtype_bits == 32) { + CALL_GATHER_CACHE(uint32_t); + } else if (dtype_bits == 16) { + CALL_GATHER_CACHE(uint16_t); + } else if (dtype_bits == 8) { + CALL_GATHER_CACHE(uint8_t); + } else { + TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); + } +} \ No newline at end of file diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index 6f79d2b74452..19e8228510c2 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -25,3 +25,13 @@ int64_t get_device_attribute(int64_t attribute, int64_t device_id); int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); + +namespace cuda_utils { + +template +HOST_DEVICE_INLINE constexpr std::enable_if_t, T> +ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +}; // namespace cuda_utils \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 784ded26299e..6d869a2b1e88 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -494,6 +494,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " "str kv_cache_dtype) -> ()"); cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); + + // Gather cache blocks from src_cache to dst. + cache_ops.def( + "gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, " + "Tensor cu_seq_lens, int batch_size) -> ()"); + cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 21c02c5de35c..2c9a36915562 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -682,8 +682,6 @@ def test_swap_blocks_mla( torch.ops._C_cache_ops.swap_blocks, (src_cache, dst_cache, block_mapping_tensor), test_utils=DEFAULT_OPCHECK_TEST_UTILS, - cond=(kv_lora_rank == KV_LORA_RANKS[0] - and qk_rope_head_dim == QK_ROPE_HEAD_DIMS[0]), ) ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor) @@ -694,3 +692,73 @@ def test_swap_blocks_mla( dst_cache[dst].cpu(), msg=f"Block {src} from src should have been swapped to block " f"{dst} in dst_cache.") + + +@pytest.mark.parametrize("kv_lora_rank", [512]) +@pytest.mark.parametrize("qk_rope_head_dim", [64]) +@pytest.mark.parametrize("block_size", [2]) +@pytest.mark.parametrize("num_blocks", [1024]) +@pytest.mark.parametrize("max_seq_len", [4]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("kv_cache_dtype", + ["auto"]) # You can also test "fp8" if needed. +@pytest.mark.parametrize("align_cache", [False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, + num_blocks, max_seq_len, batch_size, dtype, + kv_cache_dtype, align_cache, device): + entry_size = kv_lora_rank + qk_rope_head_dim + src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, + kv_cache_dtype, device, align_cache) + _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype) + + seq_len_tensor = torch.randint(1, + max_seq_len + 1, (batch_size, ), + device=device) + + total_tokens = seq_len_tensor.sum() + cu_seq_lens = torch.empty((batch_size + 1), + dtype=torch.int32, + device=device) + cu_seq_lens[0] = 0 + cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32) + + tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size + block_table = torch.empty((batch_size, num_blocks), + dtype=torch.int32, + device=device) + + for b in range(batch_size): + perm = torch.randperm(num_blocks, device=device) + block_table[b, :] = perm + + dst = torch.zeros((total_tokens, entry_size), + dtype=src_cache.dtype, + device=device) + + expected_batches = [] + for b in range(batch_size): + s = seq_len_tensor[b] + tot = tot_blocks_tensor[b] + blocks = block_table[b, :tot].tolist() + + gathered_rows = [] + for i in range(tot - 1): + gathered_rows.append(src_cache[blocks[i]]) + remaining = s - (tot - 1) * block_size + gathered_rows.append(src_cache[blocks[-1], :remaining, :]) + + batch_expected = torch.cat(gathered_rows, dim=0) + expected_batches.append(batch_expected) + expected = torch.cat(expected_batches, dim=0) + + opcheck( + torch.ops._C_cache_ops.gather_cache, + (src_cache, dst, block_table, cu_seq_lens, batch_size), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) + + ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size) + torch.testing.assert_close(dst, expected) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 67843c177403..08be95f15634 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1111,6 +1111,13 @@ def convert_fp8(output: torch.Tensor, torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) +def gather_cache(src_cache: torch.Tensor, dst: torch.Tensor, + block_table: torch.Tensor, cu_seq_lens: torch.Tensor, + batch_size: int) -> None: + torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table, + cu_seq_lens, batch_size) + + def get_device_attribute(attribute: int, device: int) -> int: return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) From dc00371ba39f42dd290f7ea1e0a14a52e7ea7d8b Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 6 Feb 2025 02:31:02 +0000 Subject: [PATCH 03/38] wip Signed-off-by: Lucas Wilkinson --- tests/kernels/test_cache.py | 6 +- vllm/attention/backends/mla/utils.py | 6 +- vllm/attention/backends/triton_mla.py | 137 ++++++++++---------------- vllm/config.py | 16 +-- 4 files changed, 67 insertions(+), 98 deletions(-) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 2c9a36915562..ff5b2ccf7a49 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -703,7 +703,7 @@ def test_swap_blocks_mla( @pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize("kv_cache_dtype", ["auto"]) # You can also test "fp8" if needed. -@pytest.mark.parametrize("align_cache", [False]) +@pytest.mark.parametrize("align_cache", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, @@ -714,7 +714,7 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, kv_cache_dtype, device, align_cache) _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype) - seq_len_tensor = torch.randint(1, + seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size, ), device=device) @@ -741,6 +741,8 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, expected_batches = [] for b in range(batch_size): s = seq_len_tensor[b] + if s == 0: + continue tot = tot_blocks_tensor[b] blocks = block_table[b, :tot].tolist() diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index c2f8dc0a2df6..6c59c70cbbe7 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -487,7 +487,9 @@ def _forward_prefill_flash( q: torch.Tensor, k_c_normed: torch.Tensor, k_pe: torch.Tensor, + query_start_loc: torch.Tensor, seq_start_loc: torch.Tensor, + max_query_len: int, max_prefill_seq_len: int, ) -> torch.Tensor: @@ -507,9 +509,9 @@ def _forward_prefill_flash( q=q, k=k, v=v_padded, - cu_seqlens_q=seq_start_loc, + cu_seqlens_q=query_start_loc, cu_seqlens_k=seq_start_loc, - max_seqlen_q=max_prefill_seq_len, + max_seqlen_q=max_query_len, max_seqlen_k=max_prefill_seq_len, softmax_scale=self.scale, causal=True, diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index d9e14c487b55..cfcd90f78a1e 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -6,9 +6,6 @@ from itertools import accumulate from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type -import triton -import triton.language as tl - from vllm.multimodal import MultiModalPlaceholderMap try: @@ -30,7 +27,8 @@ compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd -from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from vllm.utils import (align_to_256bytes, async_tensor_h2d, + make_tensor_with_pad) if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -266,6 +264,9 @@ class TritonMLAMetadata(MLACommonMetadata): # The dimension of the attention heads head_dim: Optional[int] = None + chunked_prefill_workspace: Optional[torch.Tensor] = None + prefill_seq_len_total: Optional[int] = None + def __post_init__(self): supported_head_sizes = TritonMLABackend.get_supported_head_sizes() if self.head_dim is not None and self.head_dim \ @@ -323,7 +324,9 @@ def prefill_metadata(self) -> Optional["TritonMLAMetadata"]: context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, - head_dim=self.head_dim) + head_dim=self.head_dim, + chunked_prefill_workspace=self.chunked_prefill_workspace, + prefill_seq_len_total=seq_lens_tensor.sum().item()) return self._cached_prefill_metadata @property @@ -455,6 +458,21 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size + self.chunked_prefill_workspace = None + + scheduler_config = self.runner.vllm_config.scheduler_config + model_config = self.runner.model_config + + if scheduler_config.enable_chunked_prefill: + head_size = model_config.get_head_size() + + self.chunked_prefill_workspace = torch.empty( + (model_config.max_model_len * scheduler_config.max_num_seqs, + align_to_256bytes(head_size, model_config.dtype)), + dtype=model_config.dtype, + device=self.runner.device, + )[..., :head_size] + def prepare(self): self.slot_mapping: List[int] = [] self.prefill_seq_lens: List[int] = [] @@ -647,66 +665,10 @@ def build(self, seq_lens: List[int], query_lens: List[int], use_cuda_graph=use_captured_graph, num_kv_splits=4, # TODO(lucas) add heuristic head_dim=self.runner.model_config.get_head_size(), + chunked_prefill_workspace=self.chunked_prefill_workspace, ) -@triton.jit -def _gather_kv_cache( - # Pointers to inputs and output - seq_start_locs, # (batch_size + 1,) - block_tables, # (batch_size, max_blocks_per_seq) - block_table_stride, - kv_cache, # (num_blocks, block_size, head_size) - kv_page_stride, - kv_out, - CACHE_PAGE_SIZE: tl.constexpr, - CACHE_ENTRY_SIZE: tl.constexpr, - CACHE_ENTRIES_PER_PAGE: tl.constexpr, - CACHE_PAGE_SIZE_POW_2: tl.constexpr, - CACHE_ENTRY_SIZE_POW_2: tl.constexpr, -): - """A Triton-accelerated function to perform per-token-group - quantization on a tensor. - This function converts the tensor values into float8 values. - """ - - # Map the program id to the row of X and Y it should compute. - g_id = tl.program_id(0) - - seq_start_loc = tl.load(seq_start_locs + g_id) - seq_len = tl.load(seq_start_locs + g_id + 1) - seq_start_loc - - pages_to_copy = tl.cdiv(seq_len, CACHE_ENTRIES_PER_PAGE) - kv_out = kv_out + seq_start_loc * CACHE_ENTRY_SIZE - block_table = block_tables + g_id * block_table_stride - - cache_page_range = tl.arange(0, CACHE_PAGE_SIZE_POW_2) - cache_page_mask = cache_page_range < CACHE_PAGE_SIZE - for i in range(pages_to_copy - 1): - page = tl.load(block_table + i) - page_start = kv_cache + page * kv_page_stride - page_data = tl.load(page_start + cache_page_range, - mask=cache_page_mask) - tl.store(kv_out + i * CACHE_PAGE_SIZE + cache_page_range, - page_data, - mask=cache_page_mask) - - last_page_len = seq_len % CACHE_ENTRIES_PER_PAGE - last_page = tl.load(block_table + pages_to_copy - 1) - last_page_start = kv_cache + last_page * kv_page_stride - - cache_entry_range = tl.arange(0, CACHE_ENTRY_SIZE_POW_2) - cache_entry_mask = cache_entry_range < CACHE_ENTRY_SIZE - kv_out_page = kv_out + (pages_to_copy - 1) * CACHE_PAGE_SIZE - for i in range(last_page_len): - last_page_data = tl.load(last_page_start + \ - i * CACHE_ENTRY_SIZE + cache_entry_range, - mask=cache_entry_mask) - tl.store(kv_out_page + i * CACHE_ENTRY_SIZE + cache_entry_range, - last_page_data, - mask=cache_entry_mask) - - class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): def __init__( @@ -753,36 +715,39 @@ def _forward_prefill( ) -> torch.Tensor: assert isinstance(attn_metadata, TritonMLAMetadata) - if attn_metadata.prefill_metadata.context_lens_tensor is not None and \ - max(attn_metadata.prefill_metadata.context_lens_tensor) > 0: - entries_total = attn_metadata.prefill_metadata.seq_start_loc[-1] - kv_c_k_pe_cache = torch.empty( - (entries_total, kv_c_and_k_pe_cache.shape[-1]), - dtype=kv_c_and_k_pe_cache.dtype, - device=kv_c_and_k_pe_cache.device, - ) + prefill_metadata = attn_metadata.prefill_metadata + + if prefill_metadata.context_lens_tensor is not None \ + and kv_c_and_k_pe_cache.numel() > 0: + print("******** running chunked prefill") + print("prefill_metadata.seq_start_loc[-1]", + prefill_metadata.seq_start_loc[-1]) + + workspace = attn_metadata.chunked_prefill_workspace + print("workspace", workspace.shape) + + print(kv_c_and_k_pe_cache.shape) + print(workspace.shape) + print(prefill_metadata.block_tables.shape) + print(prefill_metadata.seq_start_loc.shape) - assert kv_c_and_k_pe_cache.shape[-1] == 576 - assert kv_c_and_k_pe_cache.shape[-2] == 16 - _gather_kv_cache[(attn_metadata.num_prefills, )]( - attn_metadata.prefill_metadata.seq_start_loc, - attn_metadata.prefill_metadata.block_tables, - attn_metadata.prefill_metadata.block_tables.stride(0), - kv_c_and_k_pe_cache, - kv_c_and_k_pe_cache.stride(0), - kv_c_k_pe_cache, - CACHE_PAGE_SIZE=576 * 16, - CACHE_ENTRY_SIZE=576, - CACHE_ENTRIES_PER_PAGE=16, - CACHE_ENTRY_SIZE_POW_2=triton.next_power_of_2(576), - CACHE_PAGE_SIZE_POW_2=triton.next_power_of_2(576 * 16), + ops.gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_tables, + cu_seq_lens=prefill_metadata.seq_start_loc, + batch_size=attn_metadata.num_prefills, ) - kv_c = kv_c_k_pe_cache[..., :self.kv_lora_rank].unsqueeze(1) - k_pe = kv_c_k_pe_cache[..., self.kv_lora_rank:].unsqueeze(1) + toks = prefill_metadata.prefill_seq_len_total + print("toks", toks) + kv_c = workspace[:toks][..., :self.kv_lora_rank].unsqueeze(1) + k_pe = workspace[:toks][..., self.kv_lora_rank:].unsqueeze(1) return self._forward_prefill_flash(q, kv_c, k_pe, + attn_metadata.query_start_loc, attn_metadata.seq_start_loc, + attn_metadata.max_query_len, attn_metadata.max_prefill_seq_len) def _forward_decode( diff --git a/vllm/config.py b/vllm/config.py index 10004b8f6291..23edc5387480 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3272,14 +3272,14 @@ def __post_init__(self): current_platform.check_and_update_config(self) # If MLA is enabled, force disable chunked prefill and prefix caching - if self.model_config and self.model_config.use_mla: - logger.info("MLA is enabled; forcing chunked prefill and prefix " - "caching to be disabled.") - self.scheduler_config.enable_chunked_prefill = False - self.scheduler_config.chunked_prefill_enabled = False - - if self.cache_config is not None: - self.cache_config.enable_prefix_caching = False + # if self.model_config and self.model_config.use_mla: + # logger.info("MLA is enabled; forcing chunked prefill and prefix " + # "caching to be disabled.") + # self.scheduler_config.enable_chunked_prefill = False + # self.scheduler_config.chunked_prefill_enabled = False + + # if self.cache_config is not None: + # self.cache_config.enable_prefix_caching = False if not self.instance_id: self.instance_id = random_uuid()[:5] From f50719b55b247a878f4b318c18f2266f427c44bf Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 6 Feb 2025 05:21:34 +0000 Subject: [PATCH 04/38] wip running Co-authored-by: Patrick Horn Signed-off-by: Lucas Wilkinson --- vllm/attention/__init__.py | 12 ++---- vllm/attention/backends/mla/utils.py | 58 +++++++++++++++++---------- vllm/attention/backends/triton_mla.py | 17 ++++---- vllm/attention/backends/utils.py | 4 ++ 4 files changed, 53 insertions(+), 38 deletions(-) diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 85c5715faba7..1ddba63078bb 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -4,16 +4,12 @@ AttentionMetadata, AttentionMetadataBuilder, AttentionState, AttentionType) +from vllm.attention.backends.utils import FLASH_ATTN_VERSION from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend __all__ = [ - "Attention", - "AttentionBackend", - "AttentionMetadata", - "AttentionType", - "AttentionMetadataBuilder", - "Attention", - "AttentionState", - "get_attn_backend", + "Attention", "AttentionBackend", "AttentionMetadata", "AttentionType", + "AttentionMetadataBuilder", "Attention", "AttentionState", + "get_attn_backend", "FLASH_ATTN_VERSION" ] diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index 6c59c70cbbe7..af4ef9e923a8 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -439,29 +439,36 @@ def forward( raise NotImplementedError( "output is not yet supported for MLAImplBase") - is_decode = attn_metadata.decode_metadata is not None - is_prefill = attn_metadata.prefill_metadata is not None + has_decode = attn_metadata.decode_metadata is not None + has_prefill = attn_metadata.prefill_metadata is not None # Restore head dim (for rotary embedding) k_pe = k_pe.unsqueeze(1) assert hasattr(attn_metadata, "input_positions") + rope_fn = self.rotary_emb - if is_decode: - q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c) - q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\ + num_prefill_tokens: int = attn_metadata.num_prefill_tokens + + if has_decode: + decode_q = hidden_states_or_q_c[num_prefill_tokens:] + decode_q_nope = self._q_proj_and_k_up_proj(decode_q) + decode_q_pe = torch.matmul(decode_q, self.W_QR)\ .view(-1, self.num_heads, self.qk_rope_head_dim) - q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe, - k_pe) - else: - assert is_prefill - q = self.q_proj(hidden_states_or_q_c)[0]\ + decode_q_pe, k_pe[num_prefill_tokens:] = \ + rope_fn(attn_metadata.input_positions[num_prefill_tokens:], + decode_q_pe, k_pe[num_prefill_tokens:]) + if has_prefill: + prefill_q = hidden_states_or_q_c[:num_prefill_tokens] + prefill_k_pe = k_pe[:num_prefill_tokens] + prefill_q = self.q_proj(prefill_q)[0]\ .view(-1, self.num_heads, self.qk_head_dim) # TODO(lucas): there must be a nicer way to write this line - q[..., self.qk_nope_head_dim:], k_pe = \ - self.rotary_emb( - attn_metadata.input_positions, - q[..., self.qk_nope_head_dim:], k_pe) + prefill_q[..., self.qk_nope_head_dim:], prefill_k_pe = \ + rope_fn( + attn_metadata.input_positions[:num_prefill_tokens], + prefill_q[..., self.qk_nope_head_dim:], + prefill_k_pe) # write the latent and rope to kv cache if kv_cache.numel() > 0: @@ -473,13 +480,22 @@ def forward( kv_cache_dtype=self.kv_cache_dtype, scale=layer._k_scale, ) - - if attn_metadata.prefill_metadata is not None: - return self._forward_prefill(q, k_c_normed, k_pe, kv_cache, - attn_metadata) - - if attn_metadata.decode_metadata is not None: - return self._forward_decode(q_nope, q_pe, kv_cache, attn_metadata) + output = torch.empty(attn_metadata.num_prefill_tokens + + attn_metadata.num_decode_tokens, + self.o_proj.output_size, + device=hidden_states_or_q_c.device, + dtype=hidden_states_or_q_c.dtype) + + if has_prefill: + output[:num_prefill_tokens] = self._forward_prefill( + prefill_q, k_c_normed[:num_prefill_tokens].contiguous(), + prefill_k_pe.contiguous(), kv_cache, attn_metadata) + + if has_decode: + output[num_prefill_tokens:] = self._forward_decode( + decode_q_nope, decode_q_pe, kv_cache, attn_metadata) + + return output # Optional common flash-attn based prefill def _forward_prefill_flash( diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index cfcd90f78a1e..f2053de729cd 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -719,17 +719,17 @@ def _forward_prefill( if prefill_metadata.context_lens_tensor is not None \ and kv_c_and_k_pe_cache.numel() > 0: - print("******** running chunked prefill") - print("prefill_metadata.seq_start_loc[-1]", - prefill_metadata.seq_start_loc[-1]) + # print("******** running chunked prefill") + # print("prefill_metadata.seq_start_loc[-1]", + # prefill_metadata.seq_start_loc[-1]) workspace = attn_metadata.chunked_prefill_workspace - print("workspace", workspace.shape) + # print("workspace", workspace.shape) - print(kv_c_and_k_pe_cache.shape) - print(workspace.shape) - print(prefill_metadata.block_tables.shape) - print(prefill_metadata.seq_start_loc.shape) + # print(kv_c_and_k_pe_cache.shape) + # print(workspace.shape) + # print(prefill_metadata.block_tables.shape) + # print(prefill_metadata.seq_start_loc.shape) ops.gather_cache( src_cache=kv_c_and_k_pe_cache, @@ -740,7 +740,6 @@ def _forward_prefill( ) toks = prefill_metadata.prefill_seq_len_total - print("toks", toks) kv_c = workspace[:toks][..., :self.kv_lora_rank].unsqueeze(1) k_pe = workspace[:toks][..., self.kv_lora_rank:].unsqueeze(1) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 5c1f9916e22c..e327b1200968 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -16,6 +16,10 @@ from vllm.multimodal import MultiModalPlaceholderMap from vllm.platforms import current_platform from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from vllm.vllm_flash_attn.flash_attn_interface import ( + fa_version_unsupported_reason, is_fa_version_supported) + +logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) From 3d2e770545d476af860cdbede53cddb47dfc00f0 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 6 Feb 2025 05:29:29 +0000 Subject: [PATCH 05/38] more cleanup Signed-off-by: Lucas Wilkinson --- vllm/attention/__init__.py | 4 ++-- vllm/attention/backends/utils.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 1ddba63078bb..c710ce6d5a45 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -4,12 +4,12 @@ AttentionMetadata, AttentionMetadataBuilder, AttentionState, AttentionType) -from vllm.attention.backends.utils import FLASH_ATTN_VERSION +from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend __all__ = [ "Attention", "AttentionBackend", "AttentionMetadata", "AttentionType", "AttentionMetadataBuilder", "Attention", "AttentionState", - "get_attn_backend", "FLASH_ATTN_VERSION" + "get_attn_backend", "VLLM_FLASH_ATTN_VERSION" ] diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index e327b1200968..e62d97ac8162 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -616,3 +616,4 @@ def get_flash_attn_version(): return fa_version except (ImportError, AssertionError): return None + From ea19198d1aac5c90d8bdd711959b47b432c387bf Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 6 Feb 2025 05:48:16 +0000 Subject: [PATCH 06/38] better defaults Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/triton_mla.py | 10 ---- vllm/config.py | 10 ---- vllm/engine/arg_utils.py | 70 +++++++++++++++------------ 3 files changed, 39 insertions(+), 51 deletions(-) diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index f2053de729cd..21986c19b2cc 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -719,17 +719,7 @@ def _forward_prefill( if prefill_metadata.context_lens_tensor is not None \ and kv_c_and_k_pe_cache.numel() > 0: - # print("******** running chunked prefill") - # print("prefill_metadata.seq_start_loc[-1]", - # prefill_metadata.seq_start_loc[-1]) - workspace = attn_metadata.chunked_prefill_workspace - # print("workspace", workspace.shape) - - # print(kv_c_and_k_pe_cache.shape) - # print(workspace.shape) - # print(prefill_metadata.block_tables.shape) - # print(prefill_metadata.seq_start_loc.shape) ops.gather_cache( src_cache=kv_c_and_k_pe_cache, diff --git a/vllm/config.py b/vllm/config.py index 23edc5387480..d145dd985ea7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3271,16 +3271,6 @@ def __post_init__(self): current_platform.check_and_update_config(self) - # If MLA is enabled, force disable chunked prefill and prefix caching - # if self.model_config and self.model_config.use_mla: - # logger.info("MLA is enabled; forcing chunked prefill and prefix " - # "caching to be disabled.") - # self.scheduler_config.enable_chunked_prefill = False - # self.scheduler_config.chunked_prefill_enabled = False - - # if self.cache_config is not None: - # self.cache_config.enable_prefix_caching = False - if not self.instance_id: self.instance_id = random_uuid()[:5] diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 83ee6b97f938..3b94d609559b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -118,7 +118,7 @@ class EngineArgs: use_v2_block_manager: bool = True swap_space: float = 4 # GiB cpu_offload_gb: float = 0 # GiB - gpu_memory_utilization: float = 0.90 + gpu_memory_utilization: Optional[float] = None max_num_batched_tokens: Optional[int] = None max_num_seqs: Optional[int] = None max_logprobs: int = 20 # Default value for OpenAI Chat Completions API @@ -1097,33 +1097,6 @@ def create_engine_config(self, "has been disabled.") self.enable_prefix_caching = False - cache_config = CacheConfig( - block_size=self.block_size, - gpu_memory_utilization=self.gpu_memory_utilization, - swap_space=self.swap_space, - cache_dtype=self.kv_cache_dtype, - is_attention_free=model_config.is_attention_free, - num_gpu_blocks_override=self.num_gpu_blocks_override, - sliding_window=model_config.get_sliding_window(), - enable_prefix_caching=self.enable_prefix_caching, - cpu_offload_gb=self.cpu_offload_gb, - calculate_kv_scales=self.calculate_kv_scales, - ) - parallel_config = ParallelConfig( - pipeline_parallel_size=self.pipeline_parallel_size, - tensor_parallel_size=self.tensor_parallel_size, - max_parallel_loading_workers=self.max_parallel_loading_workers, - disable_custom_all_reduce=self.disable_custom_all_reduce, - tokenizer_pool_config=TokenizerPoolConfig.create_config( - self.tokenizer_pool_size, - self.tokenizer_pool_type, - self.tokenizer_pool_extra_config, - ), - ray_workers_use_nsight=self.ray_workers_use_nsight, - distributed_executor_backend=self.distributed_executor_backend, - worker_cls=self.worker_cls, - ) - max_model_len = model_config.max_model_len use_long_context = max_model_len > 32768 if self.enable_chunked_prefill is None: @@ -1131,9 +1104,9 @@ def create_engine_config(self, # long context (> 32K) models. This is to avoid OOM errors in the # initial memory profiling phase. - # For multimodal models, chunked prefill is disabled by default in - # V0, but enabled by design in V1 - if model_config.is_multimodal_model: + # For multimodal models and models with MLA, chunked prefill is + # disabled by default in V0, but enabled by design in V1 + if model_config.is_multimodal_model and model_config.use_mla: self.enable_chunked_prefill = bool(envs.VLLM_USE_V1) elif use_long_context: @@ -1168,6 +1141,41 @@ def create_engine_config(self, msg = "Chunked prefill is not supported for pooling models" raise ValueError(msg) + # Default to `gpu_memory_utilization` of 0.9 if not specified + gpu_memory_utilization = self.gpu_memory_utilization if \ + self.gpu_memory_utilization is not None else 0.9 + # For models using MLA and chunked prefill, lower the default to 0.8 + # to account for the extra memory required to up-project the MLA cache + if self.gpu_memory_utilization is None and \ + (self.enable_chunked_prefill and model_config.use_mla): + gpu_memory_utilization = 0.8 + + cache_config = CacheConfig( + block_size=self.block_size, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=self.swap_space, + cache_dtype=self.kv_cache_dtype, + is_attention_free=model_config.is_attention_free, + num_gpu_blocks_override=self.num_gpu_blocks_override, + sliding_window=model_config.get_sliding_window(), + enable_prefix_caching=self.enable_prefix_caching, + cpu_offload_gb=self.cpu_offload_gb, + calculate_kv_scales=self.calculate_kv_scales, + ) + parallel_config = ParallelConfig( + pipeline_parallel_size=self.pipeline_parallel_size, + tensor_parallel_size=self.tensor_parallel_size, + max_parallel_loading_workers=self.max_parallel_loading_workers, + disable_custom_all_reduce=self.disable_custom_all_reduce, + tokenizer_pool_config=TokenizerPoolConfig.create_config( + self.tokenizer_pool_size, + self.tokenizer_pool_type, + self.tokenizer_pool_extra_config, + ), + ray_workers_use_nsight=self.ray_workers_use_nsight, + distributed_executor_backend=self.distributed_executor_backend, + worker_cls=self.worker_cls, + ) speculative_config = SpeculativeConfig.maybe_create_spec_config( target_model_config=model_config, From d116752b3672d97336fa32ceab0753e85f91770e Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 6 Feb 2025 05:53:09 +0000 Subject: [PATCH 07/38] increase MLA gpu_memory_utilization default Signed-off-by: Lucas Wilkinson --- vllm/engine/arg_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3b94d609559b..429f24b95f28 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1144,11 +1144,11 @@ def create_engine_config(self, # Default to `gpu_memory_utilization` of 0.9 if not specified gpu_memory_utilization = self.gpu_memory_utilization if \ self.gpu_memory_utilization is not None else 0.9 - # For models using MLA and chunked prefill, lower the default to 0.8 + # For models using MLA and chunked prefill, lower the default to 0.85 # to account for the extra memory required to up-project the MLA cache if self.gpu_memory_utilization is None and \ (self.enable_chunked_prefill and model_config.use_mla): - gpu_memory_utilization = 0.8 + gpu_memory_utilization = 0.85 cache_config = CacheConfig( block_size=self.block_size, From c3ad9886820664bf4da36475db73a0299a24c255 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 7 Feb 2025 21:08:45 +0000 Subject: [PATCH 08/38] wip Signed-off-by: Lucas Wilkinson --- csrc/cache.h | 9 +- csrc/cache_kernels.cu | 53 +++++--- csrc/torch_bindings.cpp | 2 +- examples/basic_deepseek.py | 107 +++++++++++++++ vllm/_custom_ops.py | 4 +- vllm/attention/backends/triton_mla.py | 182 ++++++++++++++++++++++---- vllm/distributed/parallel_state.py | 4 +- vllm/engine/arg_utils.py | 7 + 8 files changed, 322 insertions(+), 46 deletions(-) create mode 100644 examples/basic_deepseek.py diff --git a/csrc/cache.h b/csrc/cache.h index f21673532071..0970b704be3a 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -40,6 +40,9 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype); -void gather_cache(torch::Tensor const& src_cache, torch::Tensor const& dst, - torch::Tensor const& block_table, - torch::Tensor const& cu_seq_lens, int64_t batch_size); \ No newline at end of file +void gather_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + int64_t batch_size, std::optional seq_starts = std::nullopt); \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 79d88470d598..203d4fc433b9 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -574,17 +574,20 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, namespace vllm { -// grid (batch, num_splits) +// grid is launched with dimensions (batch, num_splits) template __global__ void gather_cache( const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, // ENTRIES...] scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...] - const int32_t __restrict__* block_table, // [BATCH, BLOCK_INDICES] - const int32_t __restrict__* cu_seq_lens, // [BATCH+1] + const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] + const int32_t* __restrict__ cu_seq_lens, // [BATCH+1] const int32_t block_size, const int32_t entry_size, const int64_t block_table_stride, const int64_t cache_block_stride, - const int64_t cache_entry_stride, const int64_t dst_entry_stride) { + const int64_t cache_entry_stride, const int64_t dst_entry_stride, + const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per + // batch + const int64_t bid = blockIdx.x; // Batch ID const int32_t num_splits = gridDim.y; const int32_t split = blockIdx.y; @@ -605,7 +608,17 @@ __global__ void gather_cache( int32_t full_blocks_end = split_end; int32_t partial_block_size = 0; - block_table += bid * block_table_stride; + // Adjust the pointer for the block_table for this batch. + // If seq_starts is provided, compute an offset based on (seq_starts[bid] / + // page_size) + const int32_t batch_offset = bid * block_table_stride; + int32_t offset = 0; + if (seq_starts != nullptr) { + offset = seq_starts[bid] / block_size; + } + const int32_t* batch_block_table = block_table + batch_offset + offset; + + // Adjust dst pointer based on the cumulative sequence lengths. dst += seq_start * dst_entry_stride; if (is_last_split) { @@ -620,7 +633,7 @@ __global__ void gather_cache( }; for (int pid = split_start; pid < full_blocks_end; ++pid) { - auto block_id = block_table[pid]; + auto block_id = batch_block_table[pid]; auto block_start_ptr = src_cache + block_id * cache_block_stride; auto block_dst_ptr = dst + pid * block_size * dst_entry_stride; for (int eid = 0; eid < block_size; ++eid) { @@ -630,7 +643,7 @@ __global__ void gather_cache( } if (partial_block_size) { - auto block_id = block_table[full_blocks_end]; + auto block_id = batch_block_table[full_blocks_end]; auto block_start_ptr = src_cache + block_id * cache_block_stride; auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride; for (int eid = 0; eid < partial_block_size; ++eid) { @@ -642,24 +655,27 @@ __global__ void gather_cache( } // namespace vllm +// Macro to dispatch the kernel based on the data type. #define CALL_GATHER_CACHE(CPY_DTYPE) \ vllm::gather_cache<<>>( \ reinterpret_cast(src_cache.data_ptr()), \ reinterpret_cast(dst.data_ptr()), \ block_table.data_ptr(), cu_seq_lens.data_ptr(), \ block_size, entry_size, block_table_stride, cache_block_stride, \ - cache_entry_stride, dst_entry_stride); + cache_entry_stride, dst_entry_stride, seq_starts_ptr); -// Gather sequences from cache into dst tensor based on block_table -// cu_seq_lens is the cumulative sequence lengths of the batch and -// simultaneously used as the start locations in dst for each sequence -// in the batch. +// Gather sequences from the cache into the destination tensor. +// - cu_seq_lens contains the cumulative sequence lengths for each batch +// - block_table contains the cache block indices for each sequence +// - Optionally, seq_starts (if provided) offsets the starting block index by +// (seq_starts[bid] / page_size) void gather_cache( torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] torch::Tensor const& cu_seq_lens, // [BATCH+1] - int64_t batch_size) { + int64_t batch_size, + std::optional seq_starts = std::nullopt) { at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -670,14 +686,18 @@ void gather_cache( "block_table must be int32"); TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32, "cu_seq_lens must be int32"); + if (seq_starts.has_value()) { + TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32, + "seq_starts must be int32"); + } int64_t block_table_stride = block_table.stride(0); int64_t cache_block_stride = src_cache.stride(0); int64_t cache_entry_stride = src_cache.stride(1); int64_t dst_entry_stride = dst.stride(0); + // Decide on the number of splits based on the batch size. int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; - dim3 grid(batch_size, num_splits); dim3 block(1024); @@ -685,6 +705,9 @@ void gather_cache( "src_cache and dst must have the same dtype"); const int dtype_bits = src_cache.element_size() * 8; + const int32_t* seq_starts_ptr = + seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr; + if (dtype_bits == 32) { CALL_GATHER_CACHE(uint32_t); } else if (dtype_bits == 16) { @@ -694,4 +717,4 @@ void gather_cache( } else { TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); } -} \ No newline at end of file +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 6d869a2b1e88..ada9abbf2fde 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -498,7 +498,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { // Gather cache blocks from src_cache to dst. cache_ops.def( "gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, " - "Tensor cu_seq_lens, int batch_size) -> ()"); + "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache); } diff --git a/examples/basic_deepseek.py b/examples/basic_deepseek.py new file mode 100644 index 000000000000..05c2690a9c13 --- /dev/null +++ b/examples/basic_deepseek.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + """FROM fairest creatures we desire increase, +That thereby beauty's rose might never die, +But as the riper should by time decease, +His tender heir might bear his memory: +But thou, contracted to thine own bright eyes, +Feed'st thy light'st flame with self-substantial fuel, +Making a famine where abundance lies, +Thyself thy foe, to thy sweet self too cruel. +Thou that art now the world's fresh ornament +And only herald to the gaudy spring, +Within thine own bud buriest thy content +And, tender churl, makest waste in niggarding. +Pity the world, or else this glutton be, +To eat the world's due, by the grave and thee. +When forty winters shall besiege thy brow, +And dig deep trenches in thy beauty's field, +Thy youth's proud livery, so gazed on now, +Will be a tatter'd weed, of small worth held: +Then being ask'd where all thy beauty lies, +Where all the treasure of thy lusty days, +To say, within thine own deep-sunken eyes, +Were an all-eating shame and thriftless praise. +How much more praise deserved thy beauty's use, +If thou couldst answer 'This fair child of mine +Shall sum my count and make my old excuse,' +Proving his beauty by succession thine! +This were to be new made when thou art old, +And see thy blood warm when thou feel'st it cold. +Look in thy glass, and tell the face thou viewest +Now is the time that face should form another; +Whose fresh repair if now thou not renewest, +Thou dost beguile the world, unbless some mother. +For where is she so fair whose unear'd womb +Disdains the tillage of thy husbandry? +Or who is he so fond will be the tomb +Of his self-love, to stop posterity? +Thou art thy mother's glass, and she in thee +Calls back the lovely April of her prime: +So thou through windows of thine age shall see +Despite of wrinkles this thy golden time. +But if thou live, remember'd not to be, +Die single, and thine image dies with thee. +Unthrifty loveliness, why dost thou spend +Upon thyself thy beauty's legacy? +Nature's bequest gives nothing but doth lend, +And being frank she lends to those are free. +Then, beauteous niggard, why dost thou abuse +The bounteous largess given thee to give? +Profitless usurer, why dost thou use +So great a sum of sums, yet canst not live? +For having traffic with thyself alone, +Thou of thyself thy sweet self dost deceive. +Then how, when nature calls thee to be gone, +What acceptable audit canst thou leave? +Thy unused beauty must be tomb'd with thee, +Which, used, lives th' executor to be. +Those hours, that with gentle work did frame +The lovely gaze where every eye doth dwell, +Will play the tyrants to the very same +And that unfair which fairly doth excel: +For never-resting time leads summer on +To hideous winter and confounds him there; +Sap cheque'd with frost and lusty leaves quite gone, +Beauty o'ersnow'd and bareness every where: +Then, were not summer's distillation left, +A liquid prisoner pent in walls of glass, +Beauty's effect with beauty were bereft, +Nor it nor no remembrance what it was: +But flowers distill'd though they with winter meet, +Leese but their show; their substance still lives sweet. +Then let not winter's ragged hand deface +In thee thy summer, ere thou be distill'd: +""", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Create an LLM. + +llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat", + trust_remote_code=True, + enforce_eager=True, + max_model_len=2048, + tensor_parallel_size=2, + enable_chunked_prefill=True, + max_num_batched_tokens=256, + gpu_memory_utilization=0.7 + #cpu_offload_gb=10, + #hf_overrides={"num_hidden_layers": 14}, + ) +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 08be95f15634..36a6102b953c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1113,9 +1113,9 @@ def convert_fp8(output: torch.Tensor, def gather_cache(src_cache: torch.Tensor, dst: torch.Tensor, block_table: torch.Tensor, cu_seq_lens: torch.Tensor, - batch_size: int) -> None: + batch_size: int, seq_starts: Optional[torch.Tensor]) -> None: torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table, - cu_seq_lens, batch_size) + cu_seq_lens, batch_size, seq_starts) def get_device_attribute(attribute: int, device: int) -> int: diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index 21986c19b2cc..c9b706f4c773 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -23,12 +23,20 @@ AttentionMetadataBuilder, AttentionState, AttentionType) from vllm.attention.backends.mla.utils import MLACommonImpl, MLACommonMetadata -from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, +from vllm.attention.backends.utils import (PAD_SLOT_ID, + VLLM_FLASH_ATTN_VERSION, + compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.utils import (align_to_256bytes, async_tensor_h2d, make_tensor_with_pad) +from vllm.v1.attention.backends.flash_attn import merge_attn_states + +try: + from vllm.vllm_flash_attn import flash_attn_varlen_func +except ImportError: + from flash_attn import flash_attn_varlen_func if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -264,9 +272,15 @@ class TritonMLAMetadata(MLACommonMetadata): # The dimension of the attention heads head_dim: Optional[int] = None - chunked_prefill_workspace: Optional[torch.Tensor] = None prefill_seq_len_total: Optional[int] = None + # For chunked prefill + chunk_prefill_workspace_size: Optional[int] = None + chunk_cu_seq_lens: Optional[List[torch.Tensor]] = None + chunk_seq_starts: Optional[List[torch.Tensor]] = None + chunk_iter_toks: Optional[List[int]] = None + chunk_max_seq_lens: Optional[List[int]] = None + def __post_init__(self): supported_head_sizes = TritonMLABackend.get_supported_head_sizes() if self.head_dim is not None and self.head_dim \ @@ -325,8 +339,15 @@ def prefill_metadata(self) -> Optional["TritonMLAMetadata"]: block_tables=block_tables, use_cuda_graph=False, head_dim=self.head_dim, - chunked_prefill_workspace=self.chunked_prefill_workspace, - prefill_seq_len_total=seq_lens_tensor.sum().item()) + prefill_seq_len_total=seq_lens_tensor.sum().item(), + + # Chunk prefill meta-data + chunk_prefill_workspace_size=self.chunk_prefill_workspace_size, + chunk_cu_seq_lens=self.chunk_cu_seq_lens, + chunk_seq_starts=self.chunk_seq_starts, + chunk_iter_toks=self.chunk_iter_toks, + chunk_max_seq_lens=self.chunk_max_seq_lens, + ) return self._cached_prefill_metadata @property @@ -467,7 +488,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): head_size = model_config.get_head_size() self.chunked_prefill_workspace = torch.empty( - (model_config.max_model_len * scheduler_config.max_num_seqs, + (model_config.max_model_len * 2, align_to_256bytes(head_size, model_config.dtype)), dtype=model_config.dtype, device=self.runner.device, @@ -644,6 +665,46 @@ def build(self, seq_lens: List[int], query_lens: List[int], self.multimodal_placeholder_maps.items() } + chunk_cu_seq_lens = None + chunk_seq_starts = None + chunk_iter_toks = None + chunk_max_seq_lens = None + + chunked_prefill_enabled = self.input_builder.chunked_prefill_enabled + if chunked_prefill_enabled: + chunk_prefill_workspace_size = \ + self.runner.model_config.max_model_len * 4 + page_size = self.runner.block_size + seq_chunk_size = chunk_prefill_workspace_size // self.num_prefills + # align seq_chunk_size to page_size by rounding down + seq_chunk_size = seq_chunk_size - (seq_chunk_size % page_size) + num_chunks = (max_prefill_seq_len + seq_chunk_size - + 1) // seq_chunk_size + + chunk_cu_seq_lens = [] + chunk_seq_starts = [] + chunk_iter_toks = [] + chunk_max_seq_lens = [] + + for chunk in range(num_chunks): + chunk_starts = chunk * seq_chunk_size + chunk_starts = torch.tensor([chunk_starts] * self.num_prefills, + dtype=torch.int32, + device=device) + chunk_ends = seq_lens_tensor.clamp(max=(chunk + 1) * + seq_chunk_size) + _chunk_cu_seq_lens = (chunk_ends - chunk_starts).clamp( + min=0).cumsum(dim=0).to(torch.int32) + chunk_iter_toks.append(_chunk_cu_seq_lens.sum()) + chunk_max_seq_lens.append(_chunk_cu_seq_lens.max().item()) + + zero = torch.zeros(1, dtype=torch.int32, device=device) + _chunk_cu_seq_lens = torch.cat([zero, _chunk_cu_seq_lens], + dim=0) + + chunk_cu_seq_lens.append(_chunk_cu_seq_lens.contiguous()) + chunk_seq_starts.append(chunk_starts.contiguous()) + return TritonMLAMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, @@ -665,7 +726,11 @@ def build(self, seq_lens: List[int], query_lens: List[int], use_cuda_graph=use_captured_graph, num_kv_splits=4, # TODO(lucas) add heuristic head_dim=self.runner.model_config.get_head_size(), - chunked_prefill_workspace=self.chunked_prefill_workspace, + chunk_prefill_workspace_size=chunk_prefill_workspace_size, + chunk_cu_seq_lens=chunk_cu_seq_lens, + chunk_seq_starts=chunk_seq_starts, + chunk_iter_toks=chunk_iter_toks, + chunk_max_seq_lens=chunk_max_seq_lens, ) @@ -719,25 +784,94 @@ def _forward_prefill( if prefill_metadata.context_lens_tensor is not None \ and kv_c_and_k_pe_cache.numel() > 0: - workspace = attn_metadata.chunked_prefill_workspace - - ops.gather_cache( - src_cache=kv_c_and_k_pe_cache, - dst=workspace, - block_table=prefill_metadata.block_tables, - cu_seq_lens=prefill_metadata.seq_start_loc, - batch_size=attn_metadata.num_prefills, - ) - - toks = prefill_metadata.prefill_seq_len_total - kv_c = workspace[:toks][..., :self.kv_lora_rank].unsqueeze(1) - k_pe = workspace[:toks][..., self.kv_lora_rank:].unsqueeze(1) - return self._forward_prefill_flash(q, kv_c, k_pe, - attn_metadata.query_start_loc, - attn_metadata.seq_start_loc, - attn_metadata.max_query_len, - attn_metadata.max_prefill_seq_len) + if not hasattr(self, "workspace"): + self.workspace = torch.empty( + (attn_metadata.chunk_prefill_workspace_size, + self.kv_lora_rank + self.qk_rope_head_dim), + dtype=q.dtype, + device=q.device, + ) + + output = None + for chunk_cu_seq_lens, chunk_seq_starts, toks, max_seq_len in \ + zip( + attn_metadata.chunk_cu_seq_lens, \ + attn_metadata.chunk_seq_starts, \ + attn_metadata.chunk_iter_toks, \ + attn_metadata.chunk_max_seq_lens): + + print("cu_seq_lens", chunk_cu_seq_lens) + print("seq_starts", chunk_seq_starts) + print("toks", toks) + print("max_seq_len", max_seq_len) + + ops.gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=self.workspace, + block_table=prefill_metadata.block_tables, + cu_seq_lens=chunk_cu_seq_lens, + batch_size=attn_metadata.num_prefills, + seq_starts=chunk_seq_starts, + ) + + k_c_normed = self.workspace[:toks][ + ..., :self.kv_lora_rank].unsqueeze(1) + k_pe = self.workspace[:toks][..., + self.kv_lora_rank:].unsqueeze(1) + + kv_nope = self.kv_b_proj(k_c_normed)[0]\ + .view(-1, self.num_heads, \ + self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), + dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad + # out v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad( + v, [0, q.shape[-1] - v.shape[-1]], value=0) + + print("q", q.shape) + print("k", k.shape) + print("v", v_padded.shape) + + attn_output, attn_softmax_lse = flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=chunk_cu_seq_lens, + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=max_seq_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=True, + fa_version=VLLM_FLASH_ATTN_VERSION, + ) + + if output is None: + output = attn_output + else: + merge_attn_states( + output=output, + prefix_output=output, + prefix_lse=attn_softmax_lse, + new_output=attn_output, + new_lse=attn_softmax_lse, + ) + + output = output\ + .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ + .reshape(-1, self.num_heads * v.shape[-1]) + return self.o_proj(output)[0] + else: + return self._forward_prefill_flash( + q, kv_c, k_pe, attn_metadata.query_start_loc, + attn_metadata.seq_start_loc, attn_metadata.max_query_len, + attn_metadata.max_prefill_seq_len) def _forward_decode( self, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index bfc41703b94d..65da45b02893 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -609,7 +609,9 @@ def broadcast_tensor_dict( # all happening on CPU. Therefore, we can use the CPU group. self.broadcast_object(metadata_list, src=src) async_handles = [] - for tensor in tensor_list: + for i, tensor in enumerate(tensor_list): + print("broadcasting", metadata_list[i][0]) + print(" ", tensor.shape, tensor.stride()) if tensor.numel() == 0: # Skip broadcasting empty tensors. continue diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 429f24b95f28..130137b1f1f0 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1150,6 +1150,13 @@ def create_engine_config(self, (self.enable_chunked_prefill and model_config.use_mla): gpu_memory_utilization = 0.85 + if self.enable_chunked_prefill and model_config.use_mla: + # Currently chunked prefill with MLA is not supported with CUDA + # graphs + logger.warning("Chunked prefill is not supported with MLA. " + "Disabling CUDA graphs.") + self.enforce_eager = True + cache_config = CacheConfig( block_size=self.block_size, gpu_memory_utilization=gpu_memory_utilization, From ca1b07d571fa1b68c7cb0620f202bdf00663952f Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 11 Feb 2025 21:54:35 +0000 Subject: [PATCH 09/38] wip fix tensor on wrong device Signed-off-by: Lucas Wilkinson --- csrc/cache_kernels.cu | 11 ++++ tests/kernels/test_cache.py | 2 +- vllm/_custom_ops.py | 9 ++- vllm/attention/backends/triton_mla.py | 80 +++++++++++++-------------- 4 files changed, 57 insertions(+), 45 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 203d4fc433b9..a6f8602a0588 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -691,6 +691,17 @@ void gather_cache( "seq_starts must be int32"); } + TORCH_CHECK(src_cache.device() == dst.device(), + "src_cache and dst must be on the same device"); + TORCH_CHECK(src_cache.device() == block_table.device(), + "src_cache and block_table must be on the same device"); + TORCH_CHECK(src_cache.device() == cu_seq_lens.device(), + "src_cache and cu_seq_lens must be on the same device"); + if (seq_starts.has_value()) { + TORCH_CHECK(src_cache.device() == seq_starts.value().device(), + "src_cache and seq_starts must be on the same device"); + } + int64_t block_table_stride = block_table.stride(0); int64_t cache_block_stride = src_cache.stride(0); int64_t cache_entry_stride = src_cache.stride(1); diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index ff5b2ccf7a49..b9c525c96f88 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -758,7 +758,7 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, opcheck( torch.ops._C_cache_ops.gather_cache, - (src_cache, dst, block_table, cu_seq_lens, batch_size), + (src_cache, dst, block_table, cu_seq_lens, batch_size, None), test_utils=DEFAULT_OPCHECK_TEST_UTILS, ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 36a6102b953c..66afedcd5154 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1111,9 +1111,12 @@ def convert_fp8(output: torch.Tensor, torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) -def gather_cache(src_cache: torch.Tensor, dst: torch.Tensor, - block_table: torch.Tensor, cu_seq_lens: torch.Tensor, - batch_size: int, seq_starts: Optional[torch.Tensor]) -> None: +def gather_cache(src_cache: torch.Tensor, + dst: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + batch_size: int, + seq_starts: Optional[torch.Tensor] = None) -> None: torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size, seq_starts) diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index c9b706f4c773..68661577d3e6 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -276,8 +276,8 @@ class TritonMLAMetadata(MLACommonMetadata): # For chunked prefill chunk_prefill_workspace_size: Optional[int] = None - chunk_cu_seq_lens: Optional[List[torch.Tensor]] = None - chunk_seq_starts: Optional[List[torch.Tensor]] = None + chunk_cu_seq_lens: Optional[torch.Tensor] = None + chunk_seq_starts: Optional[torch.Tensor] = None chunk_iter_toks: Optional[List[int]] = None chunk_max_seq_lens: Optional[List[int]] = None @@ -665,15 +665,17 @@ def build(self, seq_lens: List[int], query_lens: List[int], self.multimodal_placeholder_maps.items() } + chunk_prefill_workspace_size = \ + self.runner.model_config.max_model_len * 4 chunk_cu_seq_lens = None chunk_seq_starts = None chunk_iter_toks = None chunk_max_seq_lens = None + print("seq_start_loc_tensor", seq_start_loc_tensor.device) + chunked_prefill_enabled = self.input_builder.chunked_prefill_enabled - if chunked_prefill_enabled: - chunk_prefill_workspace_size = \ - self.runner.model_config.max_model_len * 4 + if chunked_prefill_enabled and self.num_prefills > 0: page_size = self.runner.block_size seq_chunk_size = chunk_prefill_workspace_size // self.num_prefills # align seq_chunk_size to page_size by rounding down @@ -681,29 +683,30 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_chunks = (max_prefill_seq_len + seq_chunk_size - 1) // seq_chunk_size - chunk_cu_seq_lens = [] - chunk_seq_starts = [] - chunk_iter_toks = [] - chunk_max_seq_lens = [] - - for chunk in range(num_chunks): - chunk_starts = chunk * seq_chunk_size - chunk_starts = torch.tensor([chunk_starts] * self.num_prefills, - dtype=torch.int32, - device=device) - chunk_ends = seq_lens_tensor.clamp(max=(chunk + 1) * - seq_chunk_size) - _chunk_cu_seq_lens = (chunk_ends - chunk_starts).clamp( - min=0).cumsum(dim=0).to(torch.int32) - chunk_iter_toks.append(_chunk_cu_seq_lens.sum()) - chunk_max_seq_lens.append(_chunk_cu_seq_lens.max().item()) - - zero = torch.zeros(1, dtype=torch.int32, device=device) - _chunk_cu_seq_lens = torch.cat([zero, _chunk_cu_seq_lens], - dim=0) - - chunk_cu_seq_lens.append(_chunk_cu_seq_lens.contiguous()) - chunk_seq_starts.append(chunk_starts.contiguous()) + # if `seq_chunk_size = 256`, `num_chunks = 3`, and + # `num_prefills = 4`, create a tensor that looks like + # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] + chunk_seq_starts = \ + torch.arange(num_chunks, device=device, dtype=torch.int32)\ + .unsqueeze(1).expand(-1, self.num_prefills)\ + * seq_chunk_size + chunk_ends = torch.min(seq_lens_tensor.unsqueeze(0),\ + chunk_seq_starts + seq_chunk_size) + chunk_seq_lens = (chunk_ends - chunk_seq_starts).clamp(min=0) + _chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(torch.int32) + zero = torch.zeros(num_chunks, dtype=torch.int32, device=device)\ + .unsqueeze(0) + print(self.num_prefills, zero, _chunk_cu_seq_lens) + chunk_cu_seq_lens = torch.cat([zero, _chunk_cu_seq_lens], dim=1) + chunk_max_seq_lens = chunk_seq_lens.max(dim=1).values.tolist() + chunk_iter_toks = chunk_seq_lens.sum(dim=1).tolist() + + print(chunk_seq_starts) + print(chunk_ends) + #print(_chunk_cu_seq_lens) + print(chunk_cu_seq_lens) + print(chunk_max_seq_lens) + print(chunk_iter_toks) return TritonMLAMetadata( num_prefills=self.num_prefills, @@ -726,6 +729,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], use_cuda_graph=use_captured_graph, num_kv_splits=4, # TODO(lucas) add heuristic head_dim=self.runner.model_config.get_head_size(), + # Chunked prefill data chunk_prefill_workspace_size=chunk_prefill_workspace_size, chunk_cu_seq_lens=chunk_cu_seq_lens, chunk_seq_starts=chunk_seq_starts, @@ -794,17 +798,11 @@ def _forward_prefill( ) output = None - for chunk_cu_seq_lens, chunk_seq_starts, toks, max_seq_len in \ - zip( - attn_metadata.chunk_cu_seq_lens, \ - attn_metadata.chunk_seq_starts, \ - attn_metadata.chunk_iter_toks, \ - attn_metadata.chunk_max_seq_lens): - - print("cu_seq_lens", chunk_cu_seq_lens) - print("seq_starts", chunk_seq_starts) - print("toks", toks) - print("max_seq_len", max_seq_len) + for i in range(len(prefill_metadata.chunk_iter_toks)): + chunk_cu_seq_lens = prefill_metadata.chunk_cu_seq_lens[i] + chunk_seq_starts = prefill_metadata.chunk_seq_starts[i] + toks = prefill_metadata.chunk_iter_toks[i] + max_seq_len = prefill_metadata.chunk_max_seq_lens[i] ops.gather_cache( src_cache=kv_c_and_k_pe_cache, @@ -859,8 +857,8 @@ def _forward_prefill( output=output, prefix_output=output, prefix_lse=attn_softmax_lse, - new_output=attn_output, - new_lse=attn_softmax_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, ) output = output\ From 396f4db20cadeaad7002b2d753059a908f903693 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 12 Feb 2025 01:21:14 +0000 Subject: [PATCH 10/38] wip Signed-off-by: Lucas Wilkinson --- examples/basic_deepseek.py | 107 ++++++++------------------ examples/chat_deepseek.py | 88 +++++++++++++++++++++ tests/kernels/test_cache.py | 7 +- vllm/attention/__init__.py | 4 +- vllm/attention/backends/mla/utils.py | 24 +++--- vllm/attention/backends/triton_mla.py | 75 ++++++++++++------ vllm/attention/backends/utils.py | 7 +- 7 files changed, 193 insertions(+), 119 deletions(-) create mode 100644 examples/chat_deepseek.py diff --git a/examples/basic_deepseek.py b/examples/basic_deepseek.py index 05c2690a9c13..239a2596ee8c 100644 --- a/examples/basic_deepseek.py +++ b/examples/basic_deepseek.py @@ -4,83 +4,37 @@ # Sample prompts. prompts = [ - """FROM fairest creatures we desire increase, -That thereby beauty's rose might never die, -But as the riper should by time decease, -His tender heir might bear his memory: -But thou, contracted to thine own bright eyes, -Feed'st thy light'st flame with self-substantial fuel, -Making a famine where abundance lies, -Thyself thy foe, to thy sweet self too cruel. -Thou that art now the world's fresh ornament -And only herald to the gaudy spring, -Within thine own bud buriest thy content -And, tender churl, makest waste in niggarding. -Pity the world, or else this glutton be, -To eat the world's due, by the grave and thee. -When forty winters shall besiege thy brow, -And dig deep trenches in thy beauty's field, -Thy youth's proud livery, so gazed on now, -Will be a tatter'd weed, of small worth held: -Then being ask'd where all thy beauty lies, -Where all the treasure of thy lusty days, -To say, within thine own deep-sunken eyes, -Were an all-eating shame and thriftless praise. -How much more praise deserved thy beauty's use, -If thou couldst answer 'This fair child of mine -Shall sum my count and make my old excuse,' -Proving his beauty by succession thine! -This were to be new made when thou art old, -And see thy blood warm when thou feel'st it cold. -Look in thy glass, and tell the face thou viewest -Now is the time that face should form another; -Whose fresh repair if now thou not renewest, -Thou dost beguile the world, unbless some mother. -For where is she so fair whose unear'd womb -Disdains the tillage of thy husbandry? -Or who is he so fond will be the tomb -Of his self-love, to stop posterity? -Thou art thy mother's glass, and she in thee -Calls back the lovely April of her prime: -So thou through windows of thine age shall see -Despite of wrinkles this thy golden time. -But if thou live, remember'd not to be, -Die single, and thine image dies with thee. -Unthrifty loveliness, why dost thou spend -Upon thyself thy beauty's legacy? -Nature's bequest gives nothing but doth lend, -And being frank she lends to those are free. -Then, beauteous niggard, why dost thou abuse -The bounteous largess given thee to give? -Profitless usurer, why dost thou use -So great a sum of sums, yet canst not live? -For having traffic with thyself alone, -Thou of thyself thy sweet self dost deceive. -Then how, when nature calls thee to be gone, -What acceptable audit canst thou leave? -Thy unused beauty must be tomb'd with thee, -Which, used, lives th' executor to be. -Those hours, that with gentle work did frame -The lovely gaze where every eye doth dwell, -Will play the tyrants to the very same -And that unfair which fairly doth excel: -For never-resting time leads summer on -To hideous winter and confounds him there; -Sap cheque'd with frost and lusty leaves quite gone, -Beauty o'ersnow'd and bareness every where: -Then, were not summer's distillation left, -A liquid prisoner pent in walls of glass, -Beauty's effect with beauty were bereft, -Nor it nor no remembrance what it was: -But flowers distill'd though they with winter meet, -Leese but their show; their substance still lives sweet. -Then let not winter's ragged hand deface -In thee thy summer, ere thou be distill'd: -""", "The president of the United States is", "The capital of France is", "The future of AI is", + """ +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Create an LLM. + +llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat", + trust_remote_code=True, + enforce_eager=True, + max_model_len=2048, + tensor_parallel_size=1, + enable_chunked_prefill=True, + max_num_batched_tokens=256, + gpu_memory_utilization=0.7 + #cpu_offload_gb=10, + #hf_overrides={"num_hidden_layers": 14}, + ) +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +print("RUNNING GENERATION") +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(""", ] + # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) @@ -90,7 +44,7 @@ trust_remote_code=True, enforce_eager=True, max_model_len=2048, - tensor_parallel_size=2, + tensor_parallel_size=1, enable_chunked_prefill=True, max_num_batched_tokens=256, gpu_memory_utilization=0.7 @@ -99,6 +53,11 @@ ) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. +print(""" + &&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&& + &&&&&&&&&&&&&&&&&&&&&&&&& RUNNING GENERATION &&&&&&&&&&&&&&&&&&&&&&&&&&&& + &&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&& + """) outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: diff --git a/examples/chat_deepseek.py b/examples/chat_deepseek.py new file mode 100644 index 000000000000..b7b0ee780a8e --- /dev/null +++ b/examples/chat_deepseek.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +from vllm import LLM, SamplingParams + +llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat", + trust_remote_code=True, + enforce_eager=True, + max_model_len=2048, + tensor_parallel_size=1 + #cpu_offload_gb=10, + #hf_overrides={"num_hidden_layers": 14}, + ) +sampling_params = SamplingParams(temperature=0.5) + + +def print_outputs(outputs): + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print("-" * 80) + + +print("=" * 80) + +# In this script, we demonstrate how to pass input to the chat method: + +conversation = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "Hello" + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + { + "role": "user", + "content": "Write an essay about the importance of higher education.", + }, +] +outputs = llm.chat(conversation, + sampling_params=sampling_params, + use_tqdm=False) +print_outputs(outputs) + +# You can run batch inference with llm.chat API +conversation = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "Hello" + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + { + "role": "user", + "content": "Write an essay about the importance of higher education.", + }, +] +conversations = [conversation for _ in range(10)] + +# We turn on tqdm progress bar to verify it's indeed running batch inference +outputs = llm.chat(messages=conversations, + sampling_params=sampling_params, + use_tqdm=True) +print_outputs(outputs) + +# A chat template can be optionally supplied. +# If not, the model will use its default chat template. + +# with open('template_falcon_180b.jinja', "r") as f: +# chat_template = f.read() + +# outputs = llm.chat( +# conversations, +# sampling_params=sampling_params, +# use_tqdm=False, +# chat_template=chat_template, +# ) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index b9c525c96f88..b8b5e2045457 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -696,10 +696,10 @@ def test_swap_blocks_mla( @pytest.mark.parametrize("kv_lora_rank", [512]) @pytest.mark.parametrize("qk_rope_head_dim", [64]) -@pytest.mark.parametrize("block_size", [2]) +@pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("num_blocks", [1024]) -@pytest.mark.parametrize("max_seq_len", [4]) -@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("max_seq_len", [512]) +@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize("kv_cache_dtype", ["auto"]) # You can also test "fp8" if needed. @@ -724,6 +724,7 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, device=device) cu_seq_lens[0] = 0 cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32) + print("seq_len_tensor", seq_len_tensor) tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size block_table = torch.empty((batch_size, num_blocks), diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index c710ce6d5a45..89229e7b87a0 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -4,12 +4,12 @@ AttentionMetadata, AttentionMetadataBuilder, AttentionState, AttentionType) -from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION +from vllm.attention.backends.utils import get_flash_attn_version from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend __all__ = [ "Attention", "AttentionBackend", "AttentionMetadata", "AttentionType", "AttentionMetadataBuilder", "Attention", "AttentionState", - "get_attn_backend", "VLLM_FLASH_ATTN_VERSION" + "get_attn_backend", "get_flash_attn_version" ] diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index af4ef9e923a8..92927f662b9f 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -445,30 +445,33 @@ def forward( # Restore head dim (for rotary embedding) k_pe = k_pe.unsqueeze(1) assert hasattr(attn_metadata, "input_positions") - rope_fn = self.rotary_emb num_prefill_tokens: int = attn_metadata.num_prefill_tokens + # print("has_decode:", has_decode, " has_prefill: ", has_prefill) + if has_decode: decode_q = hidden_states_or_q_c[num_prefill_tokens:] decode_q_nope = self._q_proj_and_k_up_proj(decode_q) decode_q_pe = torch.matmul(decode_q, self.W_QR)\ .view(-1, self.num_heads, self.qk_rope_head_dim) decode_q_pe, k_pe[num_prefill_tokens:] = \ - rope_fn(attn_metadata.input_positions[num_prefill_tokens:], - decode_q_pe, k_pe[num_prefill_tokens:]) + self.rotary_emb( + attn_metadata.input_positions[num_prefill_tokens:], + decode_q_pe, k_pe[num_prefill_tokens:]) + if has_prefill: - prefill_q = hidden_states_or_q_c[:num_prefill_tokens] - prefill_k_pe = k_pe[:num_prefill_tokens] - prefill_q = self.q_proj(prefill_q)[0]\ + prefill_q = \ + self.q_proj(hidden_states_or_q_c[:num_prefill_tokens])[0]\ .view(-1, self.num_heads, self.qk_head_dim) # TODO(lucas): there must be a nicer way to write this line - prefill_q[..., self.qk_nope_head_dim:], prefill_k_pe = \ - rope_fn( + prefill_q[..., self.qk_nope_head_dim:], \ + k_pe[:num_prefill_tokens] = \ + self.rotary_emb( attn_metadata.input_positions[:num_prefill_tokens], prefill_q[..., self.qk_nope_head_dim:], - prefill_k_pe) + k_pe[:num_prefill_tokens]) # write the latent and rope to kv cache if kv_cache.numel() > 0: @@ -489,7 +492,8 @@ def forward( if has_prefill: output[:num_prefill_tokens] = self._forward_prefill( prefill_q, k_c_normed[:num_prefill_tokens].contiguous(), - prefill_k_pe.contiguous(), kv_cache, attn_metadata) + k_pe[:num_prefill_tokens].contiguous(), kv_cache, + attn_metadata) if has_decode: output[num_prefill_tokens:] = self._forward_decode( diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index 68661577d3e6..ea7f5a1ca63b 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -23,9 +23,7 @@ AttentionMetadataBuilder, AttentionState, AttentionType) from vllm.attention.backends.mla.utils import MLACommonImpl, MLACommonMetadata -from vllm.attention.backends.utils import (PAD_SLOT_ID, - VLLM_FLASH_ATTN_VERSION, - compute_slot_mapping, +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd @@ -281,6 +279,8 @@ class TritonMLAMetadata(MLACommonMetadata): chunk_iter_toks: Optional[List[int]] = None chunk_max_seq_lens: Optional[List[int]] = None + first_layer: bool = True + def __post_init__(self): supported_head_sizes = TritonMLABackend.get_supported_head_sizes() if self.head_dim is not None and self.head_dim \ @@ -520,6 +520,9 @@ def _add_seq_group( is_prompt = inter_data.is_prompt block_tables = inter_data.block_tables + print("scheduling", inter_data.is_prompt, inter_data, + inter_data.seq_lens, inter_data.context_lens) + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, curr_sliding_window_block, input_positions) in zip( inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], @@ -609,10 +612,13 @@ def build(self, seq_lens: List[int], query_lens: List[int], inter_data.prefix_cache_hit for inter_data in self.input_builder.inter_data_list ]) + + print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") for inter_data in self.input_builder.inter_data_list: self._add_seq_group(inter_data, self.input_builder.chunked_prefill_enabled, prefix_cache_hit) + print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 @@ -690,23 +696,24 @@ def build(self, seq_lens: List[int], query_lens: List[int], torch.arange(num_chunks, device=device, dtype=torch.int32)\ .unsqueeze(1).expand(-1, self.num_prefills)\ * seq_chunk_size - chunk_ends = torch.min(seq_lens_tensor.unsqueeze(0),\ - chunk_seq_starts + seq_chunk_size) + chunk_ends = torch.min(seq_lens_tensor[:self.num_prefills]\ + .unsqueeze(0), chunk_seq_starts + seq_chunk_size) chunk_seq_lens = (chunk_ends - chunk_seq_starts).clamp(min=0) _chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(torch.int32) zero = torch.zeros(num_chunks, dtype=torch.int32, device=device)\ .unsqueeze(0) - print(self.num_prefills, zero, _chunk_cu_seq_lens) + print("self.num_prefills", self.num_prefills, "_chunk_cu_seq_lens", + _chunk_cu_seq_lens) chunk_cu_seq_lens = torch.cat([zero, _chunk_cu_seq_lens], dim=1) chunk_max_seq_lens = chunk_seq_lens.max(dim=1).values.tolist() chunk_iter_toks = chunk_seq_lens.sum(dim=1).tolist() - print(chunk_seq_starts) - print(chunk_ends) + print("chunk_seq_starts", chunk_seq_starts) + print("chunk_ends", chunk_ends) #print(_chunk_cu_seq_lens) - print(chunk_cu_seq_lens) - print(chunk_max_seq_lens) - print(chunk_iter_toks) + print("chunk_cu_seq_lens", chunk_cu_seq_lens) + print("chunk_max_seq_lens", chunk_max_seq_lens) + print("chunk_iter_toks", chunk_iter_toks) return TritonMLAMetadata( num_prefills=self.num_prefills, @@ -787,6 +794,7 @@ def _forward_prefill( prefill_metadata = attn_metadata.prefill_metadata if prefill_metadata.context_lens_tensor is not None \ + and prefill_metadata.context_lens_tensor.max() > 0 \ and kv_c_and_k_pe_cache.numel() > 0: if not hasattr(self, "workspace"): @@ -798,29 +806,37 @@ def _forward_prefill( ) output = None - for i in range(len(prefill_metadata.chunk_iter_toks)): + iters = len(prefill_metadata.chunk_iter_toks) + for i in range(iters): chunk_cu_seq_lens = prefill_metadata.chunk_cu_seq_lens[i] chunk_seq_starts = prefill_metadata.chunk_seq_starts[i] toks = prefill_metadata.chunk_iter_toks[i] max_seq_len = prefill_metadata.chunk_max_seq_lens[i] + if attn_metadata.first_layer: + print("running gather") + print(prefill_metadata.block_tables) + print("chunk_cu_seq_lens", chunk_cu_seq_lens) + print("chunk_seq_starts", chunk_seq_starts) + ops.gather_cache( src_cache=kv_c_and_k_pe_cache, dst=self.workspace, block_table=prefill_metadata.block_tables, cu_seq_lens=chunk_cu_seq_lens, - batch_size=attn_metadata.num_prefills, + batch_size=prefill_metadata.num_prefills, seq_starts=chunk_seq_starts, ) - k_c_normed = self.workspace[:toks][ - ..., :self.kv_lora_rank].unsqueeze(1) - k_pe = self.workspace[:toks][..., - self.kv_lora_rank:].unsqueeze(1) - - kv_nope = self.kv_b_proj(k_c_normed)[0]\ - .view(-1, self.num_heads, \ - self.qk_nope_head_dim + self.v_head_dim) + k_c_normed = self.workspace[:toks]\ + [..., :self.kv_lora_rank].unsqueeze(1) + k_pe = self.workspace[:toks]\ + [..., self.kv_lora_rank:].unsqueeze(1) + # print("====================================") + # print(prefill_metadata.block_tables) + # print("====================================") + kv_nope = self.kv_b_proj(k_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv_nope\ .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) @@ -832,9 +848,9 @@ def _forward_prefill( v_padded = torch.nn.functional.pad( v, [0, q.shape[-1] - v.shape[-1]], value=0) - print("q", q.shape) - print("k", k.shape) - print("v", v_padded.shape) + if attn_metadata.first_layer: + print("prefill_metadata.query_start_loc", + prefill_metadata.query_start_loc) attn_output, attn_softmax_lse = flash_attn_varlen_func( q=q, @@ -847,9 +863,14 @@ def _forward_prefill( softmax_scale=self.scale, causal=True, return_softmax_lse=True, - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) + if attn_metadata.first_layer: + print(q.shape) + print(k.shape) + print(v_padded.shape) + if output is None: output = attn_output else: @@ -861,6 +882,8 @@ def _forward_prefill( suffix_lse=attn_softmax_lse, ) + attn_metadata.first_layer = False + output = output\ .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ .reshape(-1, self.num_heads * v.shape[-1]) @@ -882,6 +905,8 @@ def _forward_decode( if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError("FP8 Triton MLA not yet supported") + # print("running decode") + decode_meta = attn_metadata.decode_metadata assert decode_meta is not None B = q_nope.shape[0] diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index e62d97ac8162..d85c2ca06633 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -16,10 +16,6 @@ from vllm.multimodal import MultiModalPlaceholderMap from vllm.platforms import current_platform from vllm.utils import async_tensor_h2d, make_tensor_with_pad -from vllm.vllm_flash_attn.flash_attn_interface import ( - fa_version_unsupported_reason, is_fa_version_supported) - -logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) @@ -217,9 +213,11 @@ def build(self, seq_lens: List[int], query_lens: List[int], -1 if cuda graph is not used. batch_size: The maybe padded batch size. """ + print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") for inter_data in self.input_builder.inter_data_list: self._add_seq_group(inter_data, self.input_builder.chunked_prefill_enabled) + print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 @@ -616,4 +614,3 @@ def get_flash_attn_version(): return fa_version except (ImportError, AssertionError): return None - From b4a900e61c41a3417ddccabac2a77e74d3f18783 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 12 Feb 2025 23:16:41 +0000 Subject: [PATCH 11/38] finally :/ Signed-off-by: Lucas Wilkinson --- examples/basic_deepseek.py | 3 +- vllm/attention/backends/triton_mla.py | 279 +++++++++++++++-------- vllm/v1/attention/backends/flash_attn.py | 14 +- 3 files changed, 196 insertions(+), 100 deletions(-) diff --git a/examples/basic_deepseek.py b/examples/basic_deepseek.py index 239a2596ee8c..4043887deb04 100644 --- a/examples/basic_deepseek.py +++ b/examples/basic_deepseek.py @@ -63,4 +63,5 @@ for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print(f" Prompt: {prompt!r}") + print(f"** Generated text: {generated_text!r}") diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index ea7f5a1ca63b..fcc432e8726d 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -273,7 +273,7 @@ class TritonMLAMetadata(MLACommonMetadata): prefill_seq_len_total: Optional[int] = None # For chunked prefill - chunk_prefill_workspace_size: Optional[int] = None + chunked_prefill_workspace_size: Optional[int] = None chunk_cu_seq_lens: Optional[torch.Tensor] = None chunk_seq_starts: Optional[torch.Tensor] = None chunk_iter_toks: Optional[List[int]] = None @@ -342,7 +342,7 @@ def prefill_metadata(self) -> Optional["TritonMLAMetadata"]: prefill_seq_len_total=seq_lens_tensor.sum().item(), # Chunk prefill meta-data - chunk_prefill_workspace_size=self.chunk_prefill_workspace_size, + chunked_prefill_workspace_size=self.chunked_prefill_workspace_size, chunk_cu_seq_lens=self.chunk_cu_seq_lens, chunk_seq_starts=self.chunk_seq_starts, chunk_iter_toks=self.chunk_iter_toks, @@ -671,8 +671,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], self.multimodal_placeholder_maps.items() } - chunk_prefill_workspace_size = \ - self.runner.model_config.max_model_len * 4 + chunked_prefill_workspace_size = max(16 * self.num_prefills, 128) chunk_cu_seq_lens = None chunk_seq_starts = None chunk_iter_toks = None @@ -681,29 +680,35 @@ def build(self, seq_lens: List[int], query_lens: List[int], print("seq_start_loc_tensor", seq_start_loc_tensor.device) chunked_prefill_enabled = self.input_builder.chunked_prefill_enabled - if chunked_prefill_enabled and self.num_prefills > 0: + if chunked_prefill_enabled and self.num_prefills > 0 \ + and context_lens_tensor is not None: page_size = self.runner.block_size - seq_chunk_size = chunk_prefill_workspace_size // self.num_prefills + seq_chunk_size = chunked_prefill_workspace_size // self.num_prefills + print("seq_chunk_size", seq_chunk_size) # align seq_chunk_size to page_size by rounding down seq_chunk_size = seq_chunk_size - (seq_chunk_size % page_size) - num_chunks = (max_prefill_seq_len + seq_chunk_size - + assert seq_chunk_size > 0 + num_chunks = (context_lens_tensor.max() + seq_chunk_size - 1) // seq_chunk_size # if `seq_chunk_size = 256`, `num_chunks = 3`, and # `num_prefills = 4`, create a tensor that looks like # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] + # Note: we only chunk the current context so we can separate the + # causal masked and un-masked portions of the computation chunk_seq_starts = \ torch.arange(num_chunks, device=device, dtype=torch.int32)\ .unsqueeze(1).expand(-1, self.num_prefills)\ * seq_chunk_size - chunk_ends = torch.min(seq_lens_tensor[:self.num_prefills]\ + chunk_ends = torch.min(context_lens_tensor[:self.num_prefills]\ .unsqueeze(0), chunk_seq_starts + seq_chunk_size) chunk_seq_lens = (chunk_ends - chunk_seq_starts).clamp(min=0) _chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(torch.int32) zero = torch.zeros(num_chunks, dtype=torch.int32, device=device)\ - .unsqueeze(0) + .unsqueeze(-1) + print("zero", zero.shape) print("self.num_prefills", self.num_prefills, "_chunk_cu_seq_lens", - _chunk_cu_seq_lens) + _chunk_cu_seq_lens.shape) chunk_cu_seq_lens = torch.cat([zero, _chunk_cu_seq_lens], dim=1) chunk_max_seq_lens = chunk_seq_lens.max(dim=1).values.tolist() chunk_iter_toks = chunk_seq_lens.sum(dim=1).tolist() @@ -737,7 +742,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_kv_splits=4, # TODO(lucas) add heuristic head_dim=self.runner.model_config.get_head_size(), # Chunked prefill data - chunk_prefill_workspace_size=chunk_prefill_workspace_size, + chunked_prefill_workspace_size=chunked_prefill_workspace_size, chunk_cu_seq_lens=chunk_cu_seq_lens, chunk_seq_starts=chunk_seq_starts, chunk_iter_toks=chunk_iter_toks, @@ -781,112 +786,192 @@ def __init__( "are not implemented for " "TritonMLAImpl") - def _forward_prefill( + def _compute_prefill_context_attention( self, q: torch.Tensor, - kv_c: torch.Tensor, - k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: TritonMLAMetadata, - ) -> torch.Tensor: - assert isinstance(attn_metadata, TritonMLAMetadata) - + ): prefill_metadata = attn_metadata.prefill_metadata + if not hasattr(self, "workspace"): + self.workspace = torch.empty( + (attn_metadata.chunked_prefill_workspace_size, + self.kv_lora_rank + self.qk_rope_head_dim), + dtype=q.dtype, + device=q.device, + ) - if prefill_metadata.context_lens_tensor is not None \ - and prefill_metadata.context_lens_tensor.max() > 0 \ - and kv_c_and_k_pe_cache.numel() > 0: - - if not hasattr(self, "workspace"): - self.workspace = torch.empty( - (attn_metadata.chunk_prefill_workspace_size, - self.kv_lora_rank + self.qk_rope_head_dim), - dtype=q.dtype, - device=q.device, - ) + output = None + iters = len(prefill_metadata.chunk_iter_toks) + for i in range(iters): + chunk_cu_seq_lens = prefill_metadata.chunk_cu_seq_lens[i] + chunk_seq_starts = prefill_metadata.chunk_seq_starts[i] + toks = prefill_metadata.chunk_iter_toks[i] + max_seq_len = prefill_metadata.chunk_max_seq_lens[i] + + if attn_metadata.first_layer: + print("running gather") + print(prefill_metadata.block_tables) + print("chunk_cu_seq_lens", chunk_cu_seq_lens) + print("chunk_seq_starts", chunk_seq_starts) + + ops.gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=self.workspace, + block_table=prefill_metadata.block_tables, + cu_seq_lens=chunk_cu_seq_lens, + batch_size=prefill_metadata.num_prefills, + seq_starts=chunk_seq_starts, + ) - output = None - iters = len(prefill_metadata.chunk_iter_toks) - for i in range(iters): - chunk_cu_seq_lens = prefill_metadata.chunk_cu_seq_lens[i] - chunk_seq_starts = prefill_metadata.chunk_seq_starts[i] - toks = prefill_metadata.chunk_iter_toks[i] - max_seq_len = prefill_metadata.chunk_max_seq_lens[i] + k_c_normed = self.workspace[:toks]\ + [..., :self.kv_lora_rank].unsqueeze(1) + k_pe = self.workspace[:toks]\ + [..., self.kv_lora_rank:].unsqueeze(1) + # print("====================================") + # print(prefill_metadata.block_tables) + # print("====================================") + kv_nope = self.kv_b_proj(k_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), + dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad + # out v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, + [0, q.shape[-1] - v.shape[-1]], + value=0) + + if attn_metadata.first_layer: + print("prefill_metadata.query_start_loc", + prefill_metadata.query_start_loc) + + attn_output, attn_softmax_lse = flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=chunk_cu_seq_lens, + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=max_seq_len, + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + fa_version=self.vllm_flash_attn_version, + ) + if attn_metadata.first_layer: + print("q", q.shape) + print("k", k.shape) + print("v", v_padded.shape) + print("prefill_metadata.query_start_loc", + prefill_metadata.query_start_loc) + print("chunk_cu_seq_lens", chunk_cu_seq_lens) + print(attn_output.shape) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse if attn_metadata.first_layer: - print("running gather") - print(prefill_metadata.block_tables) - print("chunk_cu_seq_lens", chunk_cu_seq_lens) - print("chunk_seq_starts", chunk_seq_starts) - - ops.gather_cache( - src_cache=kv_c_and_k_pe_cache, - dst=self.workspace, - block_table=prefill_metadata.block_tables, - cu_seq_lens=chunk_cu_seq_lens, - batch_size=prefill_metadata.num_prefills, - seq_starts=chunk_seq_starts, - ) - - k_c_normed = self.workspace[:toks]\ - [..., :self.kv_lora_rank].unsqueeze(1) - k_pe = self.workspace[:toks]\ - [..., self.kv_lora_rank:].unsqueeze(1) - # print("====================================") - # print(prefill_metadata.block_tables) - # print("====================================") - kv_nope = self.kv_b_proj(k_c_normed)[0].view( \ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), - dim=-1) - - # For MLA the v head dim is smaller than qk head dim so we pad - # out v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad( - v, [0, q.shape[-1] - v.shape[-1]], value=0) + print("============ first output =============") + print("output", output.shape, output[0]) + print("output_lse", output_lse.shape, output_lse[0]) + print("========================================") + else: if attn_metadata.first_layer: - print("prefill_metadata.query_start_loc", - prefill_metadata.query_start_loc) - - attn_output, attn_softmax_lse = flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=chunk_cu_seq_lens, - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=max_seq_len, - softmax_scale=self.scale, - causal=True, - return_softmax_lse=True, - fa_version=self.vllm_flash_attn_version, + print("merging", output.shape, attn_output.shape) + + output_tmp = torch.ones_like(output) + output_lse_tmp = torch.ones_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, ) if attn_metadata.first_layer: - print(q.shape) - print(k.shape) - print(v_padded.shape) + print("attn_softmax_lse", attn_softmax_lse.shape, + attn_softmax_lse[0]) + print("output_lse_tmp", output_lse_tmp.shape, + output_lse_tmp[0]) + print("output_tmp", output_tmp.shape, output_tmp[0]) + print("output", output.shape, output[0]) + print("att_output", attn_output.shape, attn_output[0]) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse - if output is None: - output = attn_output - else: - merge_attn_states( - output=output, - prefix_output=output, - prefix_lse=attn_softmax_lse, - suffix_output=attn_output, - suffix_lse=attn_softmax_lse, - ) + def _forward_prefill( + self, + q: torch.Tensor, + kv_c: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: TritonMLAMetadata, + ) -> torch.Tensor: + assert isinstance(attn_metadata, TritonMLAMetadata) - attn_metadata.first_layer = False + prefill_metadata = attn_metadata.prefill_metadata + + has_context = prefill_metadata.context_lens_tensor is not None \ + and prefill_metadata.context_lens_tensor.max() > 0 + + if has_context: + context_output, context_lse = \ + self._compute_prefill_context_attention( + q, kv_c_and_k_pe_cache, attn_metadata) + + kv_nope = self.kv_b_proj(kv_c)[0].view(\ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), + dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, + [0, q.shape[-1] - v.shape[-1]], + value=0) + + extend_output, extend_lse = flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=attn_metadata.query_start_loc, + cu_seqlens_k=attn_metadata.query_start_loc, + max_seqlen_q=attn_metadata.max_prefill_seq_len, + max_seqlen_k=attn_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=True, + fa_version=self.vllm_flash_attn_version, + ) + + output = torch.empty_like(context_output) + merge_attn_states( + output=output, + prefix_output=context_output, + prefix_lse=context_lse, + suffix_output=extend_output, + suffix_lse=extend_lse, + ) output = output\ .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ .reshape(-1, self.num_heads * v.shape[-1]) + + attn_metadata.first_layer = False + return self.o_proj(output)[0] else: return self._forward_prefill_flash( diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 5cb1e2fd26a5..d21085a6f3d3 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -378,6 +378,7 @@ def merge_attn_states( prefix_lse: torch.Tensor, suffix_output: torch.Tensor, suffix_lse: torch.Tensor, + output_lse: Optional[torch.Tensor] = None, ) -> None: num_tokens = output.shape[0] num_query_heads = output.shape[1] @@ -387,24 +388,28 @@ def merge_attn_states( # TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead. merge_attn_states_kernel[(num_tokens, num_query_heads)]( output, + output_lse, prefix_output, prefix_lse, suffix_output, suffix_lse, head_size, padded_head_size, + output_lse is not None, ) @triton.jit def merge_attn_states_kernel( output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + output_lse, # [NUM_HEADS, NUM_TOKENS] prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] prefix_lse, # [NUM_HEADS, NUM_TOKENS] suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] suffix_lse, # [NUM_HEADS, NUM_TOKENS] HEAD_SIZE: tl.constexpr, PADDED_HEAD_SIZE: tl.constexpr, + OUTPUT_LSE: tl.constexpr, ): token_idx = tl.program_id(0) num_tokens = tl.num_programs(0) @@ -416,6 +421,11 @@ def merge_attn_states_kernel( max_lse = tl.maximum(p_lse, s_lse) p_lse = p_lse - max_lse s_lse = s_lse - max_lse + out_se = (tl.exp(p_lse) + tl.exp(s_lse)) + + if OUTPUT_LSE: + out_lse = tl.log(out_se) + max_lse + tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse) head_arange = tl.arange(0, PADDED_HEAD_SIZE) head_mask = head_arange < HEAD_SIZE @@ -429,8 +439,8 @@ def merge_attn_states_kernel( # NOTE(woosuk): Be careful with the numerical stability. # We should compute the scale first, and then multiply it with the output. # Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly. - p_scale = tl.exp(p_lse) / (tl.exp(p_lse) + tl.exp(s_lse)) - s_scale = tl.exp(s_lse) / (tl.exp(p_lse) + tl.exp(s_lse)) + p_scale = tl.exp(p_lse) / out_se + s_scale = tl.exp(s_lse) / out_se out = p_out * p_scale + s_out * s_scale tl.store(output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange, From 04c6042ea7f0bf5044e9702a17e4daa19ac610e4 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 13 Feb 2025 03:16:35 +0000 Subject: [PATCH 12/38] working! Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/utils.py | 50 +------ vllm/attention/backends/triton_mla.py | 182 +++++++++++--------------- 2 files changed, 80 insertions(+), 152 deletions(-) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index 92927f662b9f..3cfc61c01ccf 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -30,11 +30,6 @@ from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) -try: - from vllm.vllm_flash_attn import flash_attn_varlen_func -except ImportError: - from flash_attn import flash_attn_varlen_func - @dataclass class MLACommonMetadata(AttentionMetadata): @@ -411,6 +406,7 @@ def _forward_prefill( q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: T, ) -> torch.Tensor: raise NotImplementedError @@ -448,8 +444,6 @@ def forward( num_prefill_tokens: int = attn_metadata.num_prefill_tokens - # print("has_decode:", has_decode, " has_prefill: ", has_prefill) - if has_decode: decode_q = hidden_states_or_q_c[num_prefill_tokens:] decode_q_nope = self._q_proj_and_k_up_proj(decode_q) @@ -500,45 +494,3 @@ def forward( decode_q_nope, decode_q_pe, kv_cache, attn_metadata) return output - - # Optional common flash-attn based prefill - def _forward_prefill_flash( - self, - q: torch.Tensor, - k_c_normed: torch.Tensor, - k_pe: torch.Tensor, - query_start_loc: torch.Tensor, - seq_start_loc: torch.Tensor, - max_query_len: int, - max_prefill_seq_len: int, - ) -> torch.Tensor: - - kv_nope = self.kv_b_proj(k_c_normed)[0]\ - .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) - - attn_output = flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=query_start_loc, - cu_seqlens_k=seq_start_loc, - max_seqlen_q=max_query_len, - max_seqlen_k=max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - fa_version=self.vllm_flash_attn_version, - ) - attn_output = attn_output\ - .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ - .reshape(-1, self.num_heads * v.shape[-1]) - - return self.o_proj(attn_output)[0] diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index fcc432e8726d..d7b7ed359224 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -279,8 +279,6 @@ class TritonMLAMetadata(MLACommonMetadata): chunk_iter_toks: Optional[List[int]] = None chunk_max_seq_lens: Optional[List[int]] = None - first_layer: bool = True - def __post_init__(self): supported_head_sizes = TritonMLABackend.get_supported_head_sizes() if self.head_dim is not None and self.head_dim \ @@ -520,9 +518,6 @@ def _add_seq_group( is_prompt = inter_data.is_prompt block_tables = inter_data.block_tables - print("scheduling", inter_data.is_prompt, inter_data, - inter_data.seq_lens, inter_data.context_lens) - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, curr_sliding_window_block, input_positions) in zip( inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], @@ -613,12 +608,10 @@ def build(self, seq_lens: List[int], query_lens: List[int], for inter_data in self.input_builder.inter_data_list ]) - print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") for inter_data in self.input_builder.inter_data_list: self._add_seq_group(inter_data, self.input_builder.chunked_prefill_enabled, prefix_cache_hit) - print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 @@ -677,14 +670,12 @@ def build(self, seq_lens: List[int], query_lens: List[int], chunk_iter_toks = None chunk_max_seq_lens = None - print("seq_start_loc_tensor", seq_start_loc_tensor.device) - chunked_prefill_enabled = self.input_builder.chunked_prefill_enabled if chunked_prefill_enabled and self.num_prefills > 0 \ and context_lens_tensor is not None: page_size = self.runner.block_size seq_chunk_size = chunked_prefill_workspace_size // self.num_prefills - print("seq_chunk_size", seq_chunk_size) + # align seq_chunk_size to page_size by rounding down seq_chunk_size = seq_chunk_size - (seq_chunk_size % page_size) assert seq_chunk_size > 0 @@ -706,20 +697,10 @@ def build(self, seq_lens: List[int], query_lens: List[int], _chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(torch.int32) zero = torch.zeros(num_chunks, dtype=torch.int32, device=device)\ .unsqueeze(-1) - print("zero", zero.shape) - print("self.num_prefills", self.num_prefills, "_chunk_cu_seq_lens", - _chunk_cu_seq_lens.shape) chunk_cu_seq_lens = torch.cat([zero, _chunk_cu_seq_lens], dim=1) chunk_max_seq_lens = chunk_seq_lens.max(dim=1).values.tolist() chunk_iter_toks = chunk_seq_lens.sum(dim=1).tolist() - print("chunk_seq_starts", chunk_seq_starts) - print("chunk_ends", chunk_ends) - #print(_chunk_cu_seq_lens) - print("chunk_cu_seq_lens", chunk_cu_seq_lens) - print("chunk_max_seq_lens", chunk_max_seq_lens) - print("chunk_iter_toks", chunk_iter_toks) - return TritonMLAMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, @@ -809,12 +790,6 @@ def _compute_prefill_context_attention( toks = prefill_metadata.chunk_iter_toks[i] max_seq_len = prefill_metadata.chunk_max_seq_lens[i] - if attn_metadata.first_layer: - print("running gather") - print(prefill_metadata.block_tables) - print("chunk_cu_seq_lens", chunk_cu_seq_lens) - print("chunk_seq_starts", chunk_seq_starts) - ops.gather_cache( src_cache=kv_c_and_k_pe_cache, dst=self.workspace, @@ -828,9 +803,7 @@ def _compute_prefill_context_attention( [..., :self.kv_lora_rank].unsqueeze(1) k_pe = self.workspace[:toks]\ [..., self.kv_lora_rank:].unsqueeze(1) - # print("====================================") - # print(prefill_metadata.block_tables) - # print("====================================") + kv_nope = self.kv_b_proj(k_c_normed)[0].view( \ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv_nope\ @@ -845,10 +818,6 @@ def _compute_prefill_context_attention( [0, q.shape[-1] - v.shape[-1]], value=0) - if attn_metadata.first_layer: - print("prefill_metadata.query_start_loc", - prefill_metadata.query_start_loc) - attn_output, attn_softmax_lse = flash_attn_varlen_func( q=q, k=k, @@ -863,28 +832,10 @@ def _compute_prefill_context_attention( fa_version=self.vllm_flash_attn_version, ) - if attn_metadata.first_layer: - print("q", q.shape) - print("k", k.shape) - print("v", v_padded.shape) - print("prefill_metadata.query_start_loc", - prefill_metadata.query_start_loc) - print("chunk_cu_seq_lens", chunk_cu_seq_lens) - print(attn_output.shape) - if output is None: output = attn_output output_lse = attn_softmax_lse - if attn_metadata.first_layer: - print("============ first output =============") - print("output", output.shape, output[0]) - print("output_lse", output_lse.shape, output_lse[0]) - print("========================================") - else: - if attn_metadata.first_layer: - print("merging", output.shape, attn_output.shape) - output_tmp = torch.ones_like(output) output_lse_tmp = torch.ones_like(output_lse) merge_attn_states( @@ -895,20 +846,53 @@ def _compute_prefill_context_attention( suffix_output=attn_output, suffix_lse=attn_softmax_lse, ) - - if attn_metadata.first_layer: - print("attn_softmax_lse", attn_softmax_lse.shape, - attn_softmax_lse[0]) - print("output_lse_tmp", output_lse_tmp.shape, - output_lse_tmp[0]) - print("output_tmp", output_tmp.shape, output_tmp[0]) - print("output", output.shape, output[0]) - print("att_output", attn_output.shape, attn_output[0]) output = output_tmp output_lse = output_lse_tmp return output, output_lse + # Optional common flash-attn based prefill + def _forward_prefill_flash( + self, + q: torch.Tensor, + k_c_normed: torch.Tensor, + k_pe: torch.Tensor, + query_start_loc: torch.Tensor, + seq_start_loc: torch.Tensor, + max_query_len: int, + max_prefill_seq_len: int, + ) -> torch.Tensor: + + kv_nope = self.kv_b_proj(k_c_normed)[0]\ + .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) + + attn_output = flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=query_start_loc, + cu_seqlens_k=seq_start_loc, + max_seqlen_q=max_query_len, + max_seqlen_k=max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + fa_version=self.vllm_flash_attn_version, + ) + attn_output = attn_output\ + .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ + .reshape(-1, self.num_heads * v.shape[-1]) + + return self.o_proj(attn_output)[0] + def _forward_prefill( self, q: torch.Tensor, @@ -924,60 +908,54 @@ def _forward_prefill( has_context = prefill_metadata.context_lens_tensor is not None \ and prefill_metadata.context_lens_tensor.max() > 0 + kv_nope = self.kv_b_proj(kv_c)[0].view(\ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) + + output = flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=attn_metadata.query_start_loc, + cu_seqlens_k=attn_metadata.query_start_loc, + max_seqlen_q=attn_metadata.max_prefill_seq_len, + max_seqlen_k=attn_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=has_context, + fa_version=self.vllm_flash_attn_version, + ) + if has_context: + suffix_output, suffix_lse = output context_output, context_lse = \ self._compute_prefill_context_attention( q, kv_c_and_k_pe_cache, attn_metadata) - kv_nope = self.kv_b_proj(kv_c)[0].view(\ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), - dim=-1) - - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, - [0, q.shape[-1] - v.shape[-1]], - value=0) - - extend_output, extend_lse = flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=attn_metadata.query_start_loc, - cu_seqlens_k=attn_metadata.query_start_loc, - max_seqlen_q=attn_metadata.max_prefill_seq_len, - max_seqlen_k=attn_metadata.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - return_softmax_lse=True, - fa_version=self.vllm_flash_attn_version, - ) - - output = torch.empty_like(context_output) + output = torch.empty_like(suffix_output) merge_attn_states( output=output, prefix_output=context_output, prefix_lse=context_lse, - suffix_output=extend_output, - suffix_lse=extend_lse, + suffix_output=suffix_output, + suffix_lse=suffix_lse, ) - output = output\ - .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ - .reshape(-1, self.num_heads * v.shape[-1]) + output = output\ + .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ + .reshape(-1, self.num_heads * v.shape[-1]) - attn_metadata.first_layer = False + attn_metadata.first_layer = False - return self.o_proj(output)[0] - else: - return self._forward_prefill_flash( - q, kv_c, k_pe, attn_metadata.query_start_loc, - attn_metadata.seq_start_loc, attn_metadata.max_query_len, - attn_metadata.max_prefill_seq_len) + return self.o_proj(output)[0] def _forward_decode( self, @@ -990,8 +968,6 @@ def _forward_decode( if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError("FP8 Triton MLA not yet supported") - # print("running decode") - decode_meta = attn_metadata.decode_metadata assert decode_meta is not None B = q_nope.shape[0] From d0925bbff38015eec6fb23305a06af224d80bb58 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 13 Feb 2025 04:47:43 +0000 Subject: [PATCH 13/38] clean-up Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/utils.py | 41 ++++++++++++++-------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index 3cfc61c01ccf..9fb60c6448e5 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -444,28 +444,30 @@ def forward( num_prefill_tokens: int = attn_metadata.num_prefill_tokens + decode_hs_or_q_c = hidden_states_or_q_c[num_prefill_tokens:] + decode_k_pe = k_pe[num_prefill_tokens:] + decode_input_positions = \ + attn_metadata.input_positions[num_prefill_tokens:] + + prefill_hs_or_q_c = hidden_states_or_q_c[:num_prefill_tokens] + prefill_k_pe = k_pe[:num_prefill_tokens] + prefill_input_positions = \ + attn_metadata.input_positions[:num_prefill_tokens] + prefill_k_c_normed = k_c_normed[:num_prefill_tokens] + if has_decode: - decode_q = hidden_states_or_q_c[num_prefill_tokens:] - decode_q_nope = self._q_proj_and_k_up_proj(decode_q) - decode_q_pe = torch.matmul(decode_q, self.W_QR)\ + decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c) + decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\ .view(-1, self.num_heads, self.qk_rope_head_dim) - decode_q_pe, k_pe[num_prefill_tokens:] = \ - self.rotary_emb( - attn_metadata.input_positions[num_prefill_tokens:], - decode_q_pe, k_pe[num_prefill_tokens:]) + decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( + decode_input_positions, decode_q_pe, decode_k_pe) if has_prefill: - prefill_q = \ - self.q_proj(hidden_states_or_q_c[:num_prefill_tokens])[0]\ + prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ .view(-1, self.num_heads, self.qk_head_dim) - - # TODO(lucas): there must be a nicer way to write this line - prefill_q[..., self.qk_nope_head_dim:], \ - k_pe[:num_prefill_tokens] = \ - self.rotary_emb( - attn_metadata.input_positions[:num_prefill_tokens], - prefill_q[..., self.qk_nope_head_dim:], - k_pe[:num_prefill_tokens]) + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] + prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( + prefill_input_positions, prefill_q_pe, prefill_k_pe) # write the latent and rope to kv cache if kv_cache.numel() > 0: @@ -477,16 +479,15 @@ def forward( kv_cache_dtype=self.kv_cache_dtype, scale=layer._k_scale, ) + output = torch.empty(attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens, self.o_proj.output_size, device=hidden_states_or_q_c.device, dtype=hidden_states_or_q_c.dtype) - if has_prefill: output[:num_prefill_tokens] = self._forward_prefill( - prefill_q, k_c_normed[:num_prefill_tokens].contiguous(), - k_pe[:num_prefill_tokens].contiguous(), kv_cache, + prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, attn_metadata) if has_decode: From 0e173be0ea560b77a4ed45132f6c5fa0a86fec0d Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 13 Feb 2025 04:48:34 +0000 Subject: [PATCH 14/38] delete files Signed-off-by: Lucas Wilkinson --- examples/basic_deepseek.py | 67 ----------------------------- examples/chat_deepseek.py | 88 -------------------------------------- 2 files changed, 155 deletions(-) delete mode 100644 examples/basic_deepseek.py delete mode 100644 examples/chat_deepseek.py diff --git a/examples/basic_deepseek.py b/examples/basic_deepseek.py deleted file mode 100644 index 4043887deb04..000000000000 --- a/examples/basic_deepseek.py +++ /dev/null @@ -1,67 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from vllm import LLM, SamplingParams - -# Sample prompts. -prompts = [ - "The president of the United States is", - "The capital of France is", - "The future of AI is", - """ -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - -# Create an LLM. - -llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat", - trust_remote_code=True, - enforce_eager=True, - max_model_len=2048, - tensor_parallel_size=1, - enable_chunked_prefill=True, - max_num_batched_tokens=256, - gpu_memory_utilization=0.7 - #cpu_offload_gb=10, - #hf_overrides={"num_hidden_layers": 14}, - ) -# Generate texts from the prompts. The output is a list of RequestOutput objects -# that contain the prompt, generated text, and other information. -print("RUNNING GENERATION") -outputs = llm.generate(prompts, sampling_params) -# Print the outputs. -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(""", -] - -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - -# Create an LLM. - -llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat", - trust_remote_code=True, - enforce_eager=True, - max_model_len=2048, - tensor_parallel_size=1, - enable_chunked_prefill=True, - max_num_batched_tokens=256, - gpu_memory_utilization=0.7 - #cpu_offload_gb=10, - #hf_overrides={"num_hidden_layers": 14}, - ) -# Generate texts from the prompts. The output is a list of RequestOutput objects -# that contain the prompt, generated text, and other information. -print(""" - &&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&& - &&&&&&&&&&&&&&&&&&&&&&&&& RUNNING GENERATION &&&&&&&&&&&&&&&&&&&&&&&&&&&& - &&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&& - """) -outputs = llm.generate(prompts, sampling_params) -# Print the outputs. -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f" Prompt: {prompt!r}") - print(f"** Generated text: {generated_text!r}") diff --git a/examples/chat_deepseek.py b/examples/chat_deepseek.py deleted file mode 100644 index b7b0ee780a8e..000000000000 --- a/examples/chat_deepseek.py +++ /dev/null @@ -1,88 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -from vllm import LLM, SamplingParams - -llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat", - trust_remote_code=True, - enforce_eager=True, - max_model_len=2048, - tensor_parallel_size=1 - #cpu_offload_gb=10, - #hf_overrides={"num_hidden_layers": 14}, - ) -sampling_params = SamplingParams(temperature=0.5) - - -def print_outputs(outputs): - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - print("-" * 80) - - -print("=" * 80) - -# In this script, we demonstrate how to pass input to the chat method: - -conversation = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": "Hello" - }, - { - "role": "assistant", - "content": "Hello! How can I assist you today?" - }, - { - "role": "user", - "content": "Write an essay about the importance of higher education.", - }, -] -outputs = llm.chat(conversation, - sampling_params=sampling_params, - use_tqdm=False) -print_outputs(outputs) - -# You can run batch inference with llm.chat API -conversation = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": "Hello" - }, - { - "role": "assistant", - "content": "Hello! How can I assist you today?" - }, - { - "role": "user", - "content": "Write an essay about the importance of higher education.", - }, -] -conversations = [conversation for _ in range(10)] - -# We turn on tqdm progress bar to verify it's indeed running batch inference -outputs = llm.chat(messages=conversations, - sampling_params=sampling_params, - use_tqdm=True) -print_outputs(outputs) - -# A chat template can be optionally supplied. -# If not, the model will use its default chat template. - -# with open('template_falcon_180b.jinja', "r") as f: -# chat_template = f.read() - -# outputs = llm.chat( -# conversations, -# sampling_params=sampling_params, -# use_tqdm=False, -# chat_template=chat_template, -# ) From 25bd9cb581b46a0946a46578a3a465f2b10fa015 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 13 Feb 2025 06:27:23 +0000 Subject: [PATCH 15/38] cleanup Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/utils.py | 894 +++++++++++++++++++++++- vllm/attention/backends/triton_mla.py | 936 +------------------------- 2 files changed, 907 insertions(+), 923 deletions(-) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index 9fb60c6448e5..74dfa8be87b0 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -1,18 +1,27 @@ # SPDX-License-Identifier: Apache-2.0 from abc import abstractmethod +from collections import defaultdict +from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Dict, Generic, List, Optional, Tuple +from itertools import accumulate +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, + Type) import torch from compressed_tensors.quantization import QuantizationStrategy from vllm import _custom_ops as ops from vllm import envs -from vllm.attention.backends.abstract import (AttentionLayer, +from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, - MLAAttentionImpl, T) -from vllm.attention.backends.utils import get_flash_attn_version + AttentionMetadataBuilder, + AttentionState, MLAAttentionImpl, + T) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + get_flash_attn_version, + is_block_tables_empty) from vllm.distributed import (get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -29,14 +38,710 @@ scaled_dequantize, scaled_quantize) from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) +from vllm.multimodal import MultiModalPlaceholderMap +from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from vllm.v1.attention.backends.flash_attn import merge_attn_states +from vllm.vllm_flash_attn import flash_attn_varlen_func + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + + +class MLACommonBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "TRITON_MLA" + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return MLACommonMetadata + + @staticmethod + def get_builder_cls() -> Type["MLACommonMetadataBuilder"]: + return MLACommonMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["MLACommonState"]: + return MLACommonState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + ops.copy_blocks_mla(kv_caches, src_to_dists) + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [576] + + +class MLACommonState(AttentionState): + + def __init__(self, runner): + self.runner = runner + self._is_graph_capturing = False + + scheduler_config = runner.scheduler_config + model_config = runner.model_config + cache_config = runner.cache_config + + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + + if self.chunked_prefill_enabled: + workspace_size = min( + # Max sure there is enough for 1 full length request or at least + # 2 pages of cache per request + max( + model_config.max_model_len, 2 * + scheduler_config.max_num_seqs * cache_config.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens + 64 * 1024) + + self.chunked_prefill_workspace = torch.empty( + (workspace_size, model_config.get_head_size()), + dtype=model_config.dtype, + device=runner.device, + ) + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + + self._positions = torch.zeros((max_batch_size, ), + dtype=torch.long, + device=self.runner.device) + + yield + + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + del self._positions + + def graph_clone(self, batch_size: int): + assert self._is_graph_capturing + return self.__class__(self.runner) + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + assert self._is_graph_capturing + + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=self._graph_slot_mapping[:batch_size], + seq_lens=None, + seq_lens_tensor=self._graph_seq_lens[:batch_size], + max_query_len=1, + max_decode_query_len=1, + max_prefill_seq_len=0, + max_decode_seq_len=self.runner.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self._graph_block_tables[:batch_size], + input_positions=self._positions[:batch_size], + head_dim=self.runner.model_config.get_head_size()) + + if is_encoder_decoder_model: + raise NotImplementedError( + "MLACommonState does not support encoder/decoder yet") + + return attn_metadata + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_buffers = { + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + "input_positions": attn_metadata.decode_metadata.input_positions, + } + if is_encoder_decoder_model: + raise NotImplementedError( + "MLACommonState does not support encoder/decoder yet") + + return input_buffers + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_positions = attn_metadata.input_positions + num_positions = input_positions.shape[0] + input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) + # CUDA graph buffer is padded so only perform a partial copy based on + # num_positions + input_buffers["input_positions"][:num_positions].copy_( + input_positions, non_blocking=True) + if is_encoder_decoder_model: + raise NotImplementedError( + "TritonMLAState does not support encoder/decoder yet") + + def begin_forward(self, model_input): + return @dataclass class MLACommonMetadata(AttentionMetadata): + """Metadata for MLACommon. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # Smuggle the state to the impl via meta-data, we need this for the + # `chunked_prefill_workspace` but passing that directly will result in + # it unnecessarily being broadcasted to all workers. + attn_state: MLACommonState + # Input positions for rotrary embeddings since for MLA the rotary # position embeddings are applied inside the attention backend input_positions: torch.Tensor + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Maximum query length in the batch. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + _cached_prefill_metadata: Optional["MLACommonMetadata"] = None + _cached_decode_metadata: Optional["MLACommonMetadata"] = None + + num_prefill_tokens: int + + # The dimension of the attention heads + head_dim: Optional[int] = None + + # For chunked prefill + chunk_cu_seq_lens: Optional[torch.Tensor] = None + chunk_seq_starts: Optional[torch.Tensor] = None + chunk_iter_toks: Optional[List[int]] = None + chunk_max_seq_lens: Optional[List[int]] = None + + def __post_init__(self): + supported_head_sizes = MLACommonBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f"received {self.head_dim}.") + + @property + def prefill_metadata(self) -> Optional["MLACommonMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) + input_positions = (None if self.input_positions is None else + self.input_positions[:self.num_prefill_tokens]) + + self._cached_prefill_metadata = MLACommonMetadata( + # Required by ModelRunner + use_cuda_graph=False, # Not Attention Related + # Required by Attention Metadata + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + # Required by Attention Metadata (not used) + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + # MLACommonMetadata + attn_state=self.attn_state, + input_positions=input_positions, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_query_len=0, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + head_dim=self.head_dim, + # MLACommonMetadata Chunk prefill specific + chunk_cu_seq_lens=self.chunk_cu_seq_lens, + chunk_seq_starts=self.chunk_seq_starts, + chunk_iter_toks=self.chunk_iter_toks, + chunk_max_seq_lens=self.chunk_max_seq_lens, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["MLACommonMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.seq_lens_tensor is not None + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + input_positions = (None if self.input_positions is None else + self.input_positions[self.num_prefill_tokens:]) + + self._cached_decode_metadata = MLACommonMetadata( + # Required by ModelRunner + use_cuda_graph=self.use_cuda_graph, # Not Attention Related + # Required by Attention Metadata + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + # Required by Attention Metadata (not used) + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + # MLACommonMetadata + attn_state=self.attn_state, + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=self.max_decode_query_len, + max_query_len=self.max_query_len, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + # Batch may be composed of prefill|decodes, adjust query start + # indices to refer to the start of decodes. E.g. + # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, + context_lens_tensor=None, + block_tables=block_tables, + input_positions=input_positions, + head_dim=self.head_dim) + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + + if turn_prefills_into_decodes: + # When Mutli-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + + +class MLACommonMetadataBuilder(AttentionMetadataBuilder[MLACommonMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.chunked_prefill_enabled = \ + self.runner.scheduler_config.chunked_prefill_enabled + self.attn_state = self.input_builder.runner.attn_state + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.input_positions: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block, input_positions) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks, + inter_data.input_positions): + self.input_positions.extend(input_positions) + self.context_lens.append(context_len) + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def _get_graph_runner_block_tables( + self, num_seqs: int, + block_tables: List[List[int]]) -> torch.Tensor: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + max_batch_size, max_blocks = self.runner.graph_block_tables.shape + assert max_batch_size >= num_seqs + + graph_block_tables = self.runner.graph_block_tables[:num_seqs] + for i, block_table in enumerate(block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + graph_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + graph_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + return torch.from_numpy(graph_block_tables).to( + device=self.runner.device, non_blocking=True) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) + + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + num_seqs = len(seq_lens) + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size - self.num_prefill_tokens + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + input_positions = async_tensor_h2d(self.input_positions, torch.long, + device, self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + + chunk_cu_seq_lens = None + chunk_seq_starts = None + chunk_iter_toks = None + chunk_max_seq_lens = None + + if self.chunked_prefill_enabled and self.num_prefills > 0 \ + and context_lens_tensor is not None: + chunked_prefill_workspace_size = \ + self.attn_state.chunked_prefill_workspace.shape[0] + page_size = self.runner.block_size + seq_chunk_size = chunked_prefill_workspace_size // self.num_prefills + + print(self.attn_state.chunked_prefill_workspace.shape, + self.num_prefills, page_size, seq_chunk_size) + + # align seq_chunk_size to page_size by rounding down + seq_chunk_size = seq_chunk_size - (seq_chunk_size % page_size) + assert seq_chunk_size > 0 + num_chunks = (context_lens_tensor.max() + seq_chunk_size - + 1) // seq_chunk_size + + # if `seq_chunk_size = 256`, `num_chunks = 3`, and + # `num_prefills = 4`, create a tensor that looks like + # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] + # Note: we only chunk the current context so we can separate the + # causal masked and un-masked portions of the computation + chunk_seq_starts = \ + torch.arange(num_chunks, device=device, dtype=torch.int32)\ + .unsqueeze(1).expand(-1, self.num_prefills)\ + * seq_chunk_size + chunk_ends = torch.min(context_lens_tensor[:self.num_prefills]\ + .unsqueeze(0), chunk_seq_starts + seq_chunk_size) + chunk_seq_lens = (chunk_ends - chunk_seq_starts).clamp(min=0) + _chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(torch.int32) + zero = torch.zeros(num_chunks, dtype=torch.int32, device=device)\ + .unsqueeze(-1) + chunk_cu_seq_lens = torch.cat([zero, _chunk_cu_seq_lens], dim=1) + chunk_max_seq_lens = chunk_seq_lens.max(dim=1).values.tolist() + chunk_iter_toks = chunk_seq_lens.sum(dim=1).tolist() + + return MLACommonMetadata( + # Required by ModelRunner + use_cuda_graph=use_captured_graph, # Not Attention Related + # Required by Attention Metadata + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + # Required by Attention Metadata (not used) + multi_modal_placeholder_index_maps=None, # Not Attention Related + enable_kv_scales_calculation=False, + # MLACommonMetadata + attn_state=self.attn_state, + input_positions=input_positions, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + head_dim=self.runner.model_config.get_head_size(), + # MLACommonMetadata Chunk prefill specific + chunk_cu_seq_lens=chunk_cu_seq_lens, + chunk_seq_starts=chunk_seq_starts, + chunk_iter_toks=chunk_iter_toks, + chunk_max_seq_lens=chunk_max_seq_lens, + ) + class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): """ @@ -400,23 +1105,196 @@ def get_and_maybe_dequant_weights(layer: LinearBase): self.W_UK = W_UK self.W_Q = W_Q.flatten(start_dim=1) - @abstractmethod + def _compute_prefill_context( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ): + prefill_metadata = attn_metadata.prefill_metadata + + output = None + iters = len(prefill_metadata.chunk_iter_toks) + workspace = prefill_metadata.attn_state.chunked_prefill_workspace + + for i in range(iters): + chunk_cu_seq_lens = prefill_metadata.chunk_cu_seq_lens[i] + chunk_seq_starts = prefill_metadata.chunk_seq_starts[i] + toks = prefill_metadata.chunk_iter_toks[i] + max_seq_len = prefill_metadata.chunk_max_seq_lens[i] + + ops.gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_tables, + cu_seq_lens=chunk_cu_seq_lens, + batch_size=prefill_metadata.num_prefills, + seq_starts=chunk_seq_starts, + ) + + k_c_normed = workspace[:toks]\ + [..., :self.kv_lora_rank].unsqueeze(1) + k_pe = workspace[:toks]\ + [..., self.kv_lora_rank:].unsqueeze(1) + + kv_nope = self.kv_b_proj(k_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), + dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad + # out v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, + [0, q.shape[-1] - v.shape[-1]], + value=0) + + attn_output, attn_softmax_lse = flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=chunk_cu_seq_lens, + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=max_seq_len, + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + fa_version=self.vllm_flash_attn_version, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.ones_like(output) + output_lse_tmp = torch.ones_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + + # Optional common flash-attn based prefill + def _forward_prefill_flash( + self, + q: torch.Tensor, + k_c_normed: torch.Tensor, + k_pe: torch.Tensor, + query_start_loc: torch.Tensor, + seq_start_loc: torch.Tensor, + max_query_len: int, + max_prefill_seq_len: int, + ) -> torch.Tensor: + + kv_nope = self.kv_b_proj(k_c_normed)[0]\ + .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) + + attn_output = flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=query_start_loc, + cu_seqlens_k=seq_start_loc, + max_seqlen_q=max_query_len, + max_seqlen_k=max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + fa_version=self.vllm_flash_attn_version, + ) + attn_output = attn_output\ + .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ + .reshape(-1, self.num_heads * v.shape[-1]) + + return self.o_proj(attn_output)[0] + def _forward_prefill( self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: T, + attn_metadata: MLACommonMetadata, ) -> torch.Tensor: - raise NotImplementedError + assert isinstance(attn_metadata, MLACommonMetadata) + + prefill_metadata = attn_metadata.prefill_metadata + + has_context = prefill_metadata.context_lens_tensor is not None \ + and prefill_metadata.context_lens_tensor.max() > 0 + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) + + output = flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=attn_metadata.query_start_loc, + cu_seqlens_k=attn_metadata.query_start_loc, + max_seqlen_q=attn_metadata.max_prefill_seq_len, + max_seqlen_k=attn_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=has_context, + fa_version=self.vllm_flash_attn_version, + ) + + if has_context: + suffix_output, suffix_lse = output + context_output, context_lse = self._compute_prefill_context( \ + q, kv_c_and_k_pe_cache, attn_metadata) + + output = torch.empty_like(suffix_output) + merge_attn_states( + output=output, + prefix_output=context_output, + prefix_lse=context_lse, + suffix_output=suffix_output, + suffix_lse=suffix_lse, + ) + + output = output\ + .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ + .reshape(-1, self.num_heads * v.shape[-1]) + + attn_metadata.first_layer = False + + return self.o_proj(output)[0] @abstractmethod def _forward_decode( self, q_nope: torch.Tensor, q_pe: torch.Tensor, - kv_cache: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: T, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index d7b7ed359224..edecda2ff897 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -1,47 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 -from collections import defaultdict -from contextlib import contextmanager -from dataclasses import dataclass -from itertools import accumulate -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type - -from vllm.multimodal import MultiModalPlaceholderMap - -try: - from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper - FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 -except ImportError: - BatchDecodeMlaWithPagedKVCacheWrapper = None - FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 +from typing import Any, Dict, List, Optional, Type import torch -from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionState, AttentionType) -from vllm.attention.backends.mla.utils import MLACommonImpl, MLACommonMetadata -from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, - compute_slot_mapping_start_idx, - is_block_tables_empty) +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.mla.utils import (MLACommonBackend, MLACommonImpl, + MLACommonMetadata) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd -from vllm.utils import (align_to_256bytes, async_tensor_h2d, - make_tensor_with_pad) -from vllm.v1.attention.backends.flash_attn import merge_attn_states - -try: - from vllm.vllm_flash_attn import flash_attn_varlen_func -except ImportError: - from flash_attn import flash_attn_varlen_func -if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) - -class TritonMLABackend(AttentionBackend): +class TritonMLABackend(MLACommonBackend): @staticmethod def get_name() -> str: @@ -51,687 +20,8 @@ def get_name() -> str: def get_impl_cls() -> Type["TritonMLAImpl"]: return TritonMLAImpl - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return TritonMLAMetadata - - @staticmethod - def get_builder_cls() -> Type["TritonMLAMetadataBuilder"]: - return TritonMLAMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["TritonMLAState"]: - return TritonMLAState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, # assumed to be 1 for MLA - head_size: int, - ) -> Tuple[int, ...]: - return (num_blocks, block_size, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - ops.copy_blocks_mla(kv_caches, src_to_dists) - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [576] - - -class TritonMLAState(AttentionState): - - def __init__(self, runner): - self.runner = runner - self._is_graph_capturing = False - - @contextmanager - def graph_capture(self, max_batch_size: int): - self._is_graph_capturing = True - - self._graph_slot_mapping = torch.full((max_batch_size, ), - PAD_SLOT_ID, - dtype=torch.long, - device=self.runner.device) - self._graph_seq_lens = torch.ones(max_batch_size, - dtype=torch.int32, - device=self.runner.device) - self._graph_block_tables = torch.from_numpy( - self.runner.graph_block_tables).to(device=self.runner.device) - - self._positions = torch.zeros((max_batch_size, ), - dtype=torch.long, - device=self.runner.device) - - yield - - self._is_graph_capturing = False - del self._graph_slot_mapping - del self._graph_seq_lens - del self._graph_block_tables - del self._positions - - def graph_clone(self, batch_size: int): - assert self._is_graph_capturing - return self.__class__(self.runner) - - def graph_capture_get_metadata_for_batch( - self, batch_size: int, is_encoder_decoder_model: bool = False): - assert self._is_graph_capturing - - attn_metadata = self.runner.attn_backend.make_metadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=self._graph_slot_mapping[:batch_size], - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - seq_lens=None, - seq_lens_tensor=self._graph_seq_lens[:batch_size], - max_query_len=1, - max_decode_query_len=1, - max_prefill_seq_len=0, - max_decode_seq_len=self.runner.max_seq_len_to_capture, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self._graph_block_tables[:batch_size], - use_cuda_graph=True, - input_positions=self._positions[:batch_size], - head_dim=self.runner.model_config.get_head_size()) - - if is_encoder_decoder_model: - raise NotImplementedError( - "TritonMLAState does not support encoder/decoder yet") - - return attn_metadata - - def get_graph_input_buffers(self, - attn_metadata, - is_encoder_decoder_model: bool = False): - input_buffers = { - "slot_mapping": attn_metadata.slot_mapping, - "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, - "block_tables": attn_metadata.decode_metadata.block_tables, - "input_positions": attn_metadata.decode_metadata.input_positions, - } - if is_encoder_decoder_model: - raise NotImplementedError( - "TritonMLAState does not support encoder/decoder yet") - - return input_buffers - - def prepare_graph_input_buffers(self, - input_buffers, - attn_metadata, - is_encoder_decoder_model: bool = False): - input_positions = attn_metadata.input_positions - num_positions = input_positions.shape[0] - input_buffers["seq_lens_tensor"].copy_( - attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) - input_buffers["block_tables"].copy_( - attn_metadata.decode_metadata.block_tables, non_blocking=True) - # CUDA graph buffer is padded so only perform a partial copy based on - # num_positions - input_buffers["input_positions"][:num_positions].copy_( - input_positions, non_blocking=True) - if is_encoder_decoder_model: - raise NotImplementedError( - "TritonMLAState does not support encoder/decoder yet") - - def begin_forward(self, model_input): - return - - -@dataclass -class TritonMLAMetadata(MLACommonMetadata): - """Metadata for TritonMLAMetadata. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - - use_cuda_graph: bool - - # Maximum query length in the batch. - max_query_len: Optional[int] = None - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] = None - - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - - _cached_prefill_metadata: Optional["TritonMLAMetadata"] = None - _cached_decode_metadata: Optional["TritonMLAMetadata"] = None - - num_prefill_tokens: int - - num_kv_splits: int = 4 # TODO(lucas) add heuristic - attn_logits: Optional[torch.Tensor] = None - req_idx: Optional[torch.Tensor] = None - - # The dimension of the attention heads - head_dim: Optional[int] = None - - prefill_seq_len_total: Optional[int] = None - - # For chunked prefill - chunked_prefill_workspace_size: Optional[int] = None - chunk_cu_seq_lens: Optional[torch.Tensor] = None - chunk_seq_starts: Optional[torch.Tensor] = None - chunk_iter_toks: Optional[List[int]] = None - chunk_max_seq_lens: Optional[List[int]] = None - - def __post_init__(self): - supported_head_sizes = TritonMLABackend.get_supported_head_sizes() - if self.head_dim is not None and self.head_dim \ - not in supported_head_sizes: - raise ValueError( - f"Only {supported_head_sizes} are supported for head_dim,", - f"received {self.head_dim}.") - - @property - def prefill_metadata(self) -> Optional["TritonMLAMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - - # Compute some attn_metadata fields which default to None - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1]) - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[:self.num_prefill_tokens]) - seq_lens = (None if self.seq_lens is None else - self.seq_lens[:self.num_prefills]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills]) - seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills]) - block_tables = (None if self.block_tables is None else - self.block_tables[:self.num_prefills]) - input_positions = (None if self.input_positions is None else - self.input_positions[:self.num_prefill_tokens]) - - self._cached_prefill_metadata = TritonMLAMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, - enable_kv_scales_calculation=self.enable_kv_scales_calculation, - input_positions=input_positions, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_query_len=0, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - head_dim=self.head_dim, - prefill_seq_len_total=seq_lens_tensor.sum().item(), - - # Chunk prefill meta-data - chunked_prefill_workspace_size=self.chunked_prefill_workspace_size, - chunk_cu_seq_lens=self.chunk_cu_seq_lens, - chunk_seq_starts=self.chunk_seq_starts, - chunk_iter_toks=self.chunk_iter_toks, - chunk_max_seq_lens=self.chunk_max_seq_lens, - ) - return self._cached_prefill_metadata - - @property - def decode_metadata(self) -> Optional["TritonMLAMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert self.seq_lens_tensor is not None - - # Compute some attn_metadata fields which default to None - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[self.num_prefill_tokens:]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:]) - block_tables = (None if self.block_tables is None else - self.block_tables[self.num_prefills:]) - input_positions = (None if self.input_positions is None else - self.input_positions[self.num_prefill_tokens:]) - - self._cached_decode_metadata = TritonMLAMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - seq_lens=None, - seq_lens_tensor=seq_lens_tensor, - max_decode_query_len=self.max_decode_query_len, - max_query_len=self.max_query_len, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - # Batch may be composed of prefill|decodes, adjust query start - # indices to refer to the start of decodes. E.g. - # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - query_start_loc=(self.query_start_loc[self.num_prefills:] - - self.query_start_loc[self.num_prefills]) - if self.query_start_loc is not None else None, - seq_start_loc=self.seq_start_loc[self.num_prefills:] - if self.seq_start_loc is not None else None, - context_lens_tensor=None, - block_tables=block_tables, - use_cuda_graph=self.use_cuda_graph, - input_positions=input_positions, - head_dim=self.head_dim) - return self._cached_decode_metadata - - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - assert self.use_cuda_graph - - if turn_prefills_into_decodes: - # When Mutli-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes. This update reflects that - # conversion. - assert self.num_decode_tokens + self.num_prefills == num_seqs - self.num_decode_tokens += self.num_prefills - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.max_prefill_seq_len = 0 - self.max_query_len = 1 - - self.slot_mapping = self.slot_mapping[:num_seqs] - else: - assert self.seq_lens is not None - assert self.max_decode_seq_len == max(self.seq_lens) - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.num_decode_tokens == num_seqs - assert self.slot_mapping.shape == (num_seqs, ) - - assert self.seq_lens is not None - assert len(self.seq_lens) == num_seqs - assert self.seq_lens_tensor is not None - assert self.seq_lens_tensor.shape == (num_seqs, ) - assert self.max_query_len == 1 - assert self.max_prefill_seq_len == 0 - - assert self.query_start_loc is not None - assert self.query_start_loc.shape == (num_queries + 1, ) - assert self.seq_start_loc is not None - assert self.seq_start_loc.shape == (num_seqs + 1, ) - assert self.context_lens_tensor is not None - assert self.context_lens_tensor.shape == (num_queries, ) - - assert self.block_tables is not None - assert self.block_tables.shape[0] == num_seqs - - # Update query lengths. Note that we update only queries and not seqs, - # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): - self.seq_lens[i] += 1 - self.max_decode_seq_len = max(self.seq_lens) - - ops.advance_step_flashattn(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables) - - -class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]): - - def __init__(self, input_builder: "ModelInputForGPUBuilder"): - self.input_builder = input_builder - self.runner = input_builder.runner - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - - self.chunked_prefill_workspace = None - - scheduler_config = self.runner.vllm_config.scheduler_config - model_config = self.runner.model_config - - if scheduler_config.enable_chunked_prefill: - head_size = model_config.get_head_size() - - self.chunked_prefill_workspace = torch.empty( - (model_config.max_model_len * 2, - align_to_256bytes(head_size, model_config.dtype)), - dtype=model_config.dtype, - device=self.runner.device, - )[..., :head_size] - - def prepare(self): - self.slot_mapping: List[int] = [] - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.block_tables: List[List[int]] = [] - self.curr_seq_lens: List[int] = [] - self.input_positions: List[int] = [] - self.multimodal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - self.has_prefix_cache_hit = False - - def _add_seq_group( - self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool, prefix_cache_hit: bool): - """Add a sequence group to the metadata. Specifically update/append - 1. context length. - 2. block table. - 3. slot mapping. - """ - is_prompt = inter_data.is_prompt - block_tables = inter_data.block_tables - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block, input_positions) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks, - inter_data.input_positions): - self.input_positions.extend(input_positions) - self.context_lens.append(context_len) - if is_prompt: - mm_maps = inter_data.multi_modal_placeholder_maps - if mm_maps: - for modality, placeholders in mm_maps.items(): - self.multimodal_placeholder_maps[modality].extend( - placeholders) - - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - # Compute block table. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - block_table = [] - if prefix_cache_hit: - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - block_table = block_tables[seq_id] - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): - if curr_sliding_window_block == 0: - block_table = block_tables[seq_id] - else: - block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.block_tables.append(block_table) - - # Compute slot mapping. - is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, - self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) - - def _get_graph_runner_block_tables( - self, num_seqs: int, - block_tables: List[List[int]]) -> torch.Tensor: - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - max_batch_size, max_blocks = self.runner.graph_block_tables.shape - assert max_batch_size >= num_seqs - - graph_block_tables = self.runner.graph_block_tables[:num_seqs] - for i, block_table in enumerate(block_tables): - if block_table: - num_blocks = len(block_table) - if num_blocks <= max_blocks: - graph_block_tables[i, :num_blocks] = block_table - else: - # It may be possible to have more blocks allocated due - # to lookahead slots of multi-step, however, they are - # not used anyway, so can be safely ignored. - graph_block_tables[ - i, :max_blocks] = block_table[:max_blocks] - - return torch.from_numpy(graph_block_tables).to( - device=self.runner.device, non_blocking=True) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors. - - Args: - seq_lens: The maybe padded sequence lengths of the input sequences. - query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. - """ - prefix_cache_hit = any([ - inter_data.prefix_cache_hit - for inter_data in self.input_builder.inter_data_list - ]) - - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled, - prefix_cache_hit) - - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - max_query_len = max(query_lens) - decode_query_lens = query_lens[self.num_prefills:] - if len(decode_query_lens) > 0: - max_decode_query_len = max(decode_query_lens) - else: - max_decode_query_len = 1 - max_prefill_seq_len = max(self.prefill_seq_lens, default=0) - max_decode_seq_len = max(self.curr_seq_lens, default=0) - num_decode_tokens = self.num_decode_tokens - query_start_loc = list(accumulate(query_lens, initial=0)) - seq_start_loc = list(accumulate(seq_lens, initial=0)) - - num_seqs = len(seq_lens) - if use_captured_graph: - self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size - self.num_prefill_tokens - block_tables = self._get_graph_runner_block_tables( - num_seqs, self.block_tables) - else: - block_tables = make_tensor_with_pad( - self.block_tables, - pad=0, - dtype=torch.int, - device=device, - ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - - assert device is not None - context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, - device, self.runner.pin_memory) - seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, - self.runner.pin_memory) - input_positions = async_tensor_h2d(self.input_positions, torch.long, - device, self.runner.pin_memory) - slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, - device, self.runner.pin_memory) - query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, - device, - self.runner.pin_memory) - seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, - device, self.runner.pin_memory) - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - self.multimodal_placeholder_maps.items() - } - - chunked_prefill_workspace_size = max(16 * self.num_prefills, 128) - chunk_cu_seq_lens = None - chunk_seq_starts = None - chunk_iter_toks = None - chunk_max_seq_lens = None - - chunked_prefill_enabled = self.input_builder.chunked_prefill_enabled - if chunked_prefill_enabled and self.num_prefills > 0 \ - and context_lens_tensor is not None: - page_size = self.runner.block_size - seq_chunk_size = chunked_prefill_workspace_size // self.num_prefills - - # align seq_chunk_size to page_size by rounding down - seq_chunk_size = seq_chunk_size - (seq_chunk_size % page_size) - assert seq_chunk_size > 0 - num_chunks = (context_lens_tensor.max() + seq_chunk_size - - 1) // seq_chunk_size - - # if `seq_chunk_size = 256`, `num_chunks = 3`, and - # `num_prefills = 4`, create a tensor that looks like - # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] - # Note: we only chunk the current context so we can separate the - # causal masked and un-masked portions of the computation - chunk_seq_starts = \ - torch.arange(num_chunks, device=device, dtype=torch.int32)\ - .unsqueeze(1).expand(-1, self.num_prefills)\ - * seq_chunk_size - chunk_ends = torch.min(context_lens_tensor[:self.num_prefills]\ - .unsqueeze(0), chunk_seq_starts + seq_chunk_size) - chunk_seq_lens = (chunk_ends - chunk_seq_starts).clamp(min=0) - _chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(torch.int32) - zero = torch.zeros(num_chunks, dtype=torch.int32, device=device)\ - .unsqueeze(-1) - chunk_cu_seq_lens = torch.cat([zero, _chunk_cu_seq_lens], dim=1) - chunk_max_seq_lens = chunk_seq_lens.max(dim=1).values.tolist() - chunk_iter_toks = chunk_seq_lens.sum(dim=1).tolist() - - return TritonMLAMetadata( - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - multi_modal_placeholder_index_maps=placeholder_index_maps, - enable_kv_scales_calculation=True, - input_positions=input_positions, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_decode_query_len=max_decode_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc_tensor, - seq_start_loc=seq_start_loc_tensor, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - num_kv_splits=4, # TODO(lucas) add heuristic - head_dim=self.runner.model_config.get_head_size(), - # Chunked prefill data - chunked_prefill_workspace_size=chunked_prefill_workspace_size, - chunk_cu_seq_lens=chunk_cu_seq_lens, - chunk_seq_starts=chunk_seq_starts, - chunk_iter_toks=chunk_iter_toks, - chunk_max_seq_lens=chunk_max_seq_lens, - ) - - -class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): +class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): def __init__( self, @@ -746,11 +36,11 @@ def __init__( logits_soft_cap: Optional[float], attn_type: str, # MLA Specific Arguments - **kwargs) -> None: + **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap, attn_type, - **kwargs) + **mla_args) unsupported_features = [ alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap @@ -767,202 +57,12 @@ def __init__( "are not implemented for " "TritonMLAImpl") - def _compute_prefill_context_attention( - self, - q: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: TritonMLAMetadata, - ): - prefill_metadata = attn_metadata.prefill_metadata - if not hasattr(self, "workspace"): - self.workspace = torch.empty( - (attn_metadata.chunked_prefill_workspace_size, - self.kv_lora_rank + self.qk_rope_head_dim), - dtype=q.dtype, - device=q.device, - ) - - output = None - iters = len(prefill_metadata.chunk_iter_toks) - for i in range(iters): - chunk_cu_seq_lens = prefill_metadata.chunk_cu_seq_lens[i] - chunk_seq_starts = prefill_metadata.chunk_seq_starts[i] - toks = prefill_metadata.chunk_iter_toks[i] - max_seq_len = prefill_metadata.chunk_max_seq_lens[i] - - ops.gather_cache( - src_cache=kv_c_and_k_pe_cache, - dst=self.workspace, - block_table=prefill_metadata.block_tables, - cu_seq_lens=chunk_cu_seq_lens, - batch_size=prefill_metadata.num_prefills, - seq_starts=chunk_seq_starts, - ) - - k_c_normed = self.workspace[:toks]\ - [..., :self.kv_lora_rank].unsqueeze(1) - k_pe = self.workspace[:toks]\ - [..., self.kv_lora_rank:].unsqueeze(1) - - kv_nope = self.kv_b_proj(k_c_normed)[0].view( \ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), - dim=-1) - - # For MLA the v head dim is smaller than qk head dim so we pad - # out v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, - [0, q.shape[-1] - v.shape[-1]], - value=0) - - attn_output, attn_softmax_lse = flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=chunk_cu_seq_lens, - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=max_seq_len, - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_softmax_lse=True, - fa_version=self.vllm_flash_attn_version, - ) - - if output is None: - output = attn_output - output_lse = attn_softmax_lse - else: - output_tmp = torch.ones_like(output) - output_lse_tmp = torch.ones_like(output_lse) - merge_attn_states( - output=output_tmp, - output_lse=output_lse_tmp, - prefix_output=output, - prefix_lse=output_lse, - suffix_output=attn_output, - suffix_lse=attn_softmax_lse, - ) - output = output_tmp - output_lse = output_lse_tmp - - return output, output_lse - - # Optional common flash-attn based prefill - def _forward_prefill_flash( - self, - q: torch.Tensor, - k_c_normed: torch.Tensor, - k_pe: torch.Tensor, - query_start_loc: torch.Tensor, - seq_start_loc: torch.Tensor, - max_query_len: int, - max_prefill_seq_len: int, - ) -> torch.Tensor: - - kv_nope = self.kv_b_proj(k_c_normed)[0]\ - .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) - - attn_output = flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=query_start_loc, - cu_seqlens_k=seq_start_loc, - max_seqlen_q=max_query_len, - max_seqlen_k=max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - fa_version=self.vllm_flash_attn_version, - ) - attn_output = attn_output\ - .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ - .reshape(-1, self.num_heads * v.shape[-1]) - - return self.o_proj(attn_output)[0] - - def _forward_prefill( - self, - q: torch.Tensor, - kv_c: torch.Tensor, - k_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: TritonMLAMetadata, - ) -> torch.Tensor: - assert isinstance(attn_metadata, TritonMLAMetadata) - - prefill_metadata = attn_metadata.prefill_metadata - - has_context = prefill_metadata.context_lens_tensor is not None \ - and prefill_metadata.context_lens_tensor.max() > 0 - - kv_nope = self.kv_b_proj(kv_c)[0].view(\ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) - - output = flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=attn_metadata.query_start_loc, - cu_seqlens_k=attn_metadata.query_start_loc, - max_seqlen_q=attn_metadata.max_prefill_seq_len, - max_seqlen_k=attn_metadata.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - return_softmax_lse=has_context, - fa_version=self.vllm_flash_attn_version, - ) - - if has_context: - suffix_output, suffix_lse = output - context_output, context_lse = \ - self._compute_prefill_context_attention( - q, kv_c_and_k_pe_cache, attn_metadata) - - output = torch.empty_like(suffix_output) - merge_attn_states( - output=output, - prefix_output=context_output, - prefix_lse=context_lse, - suffix_output=suffix_output, - suffix_lse=suffix_lse, - ) - - output = output\ - .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ - .reshape(-1, self.num_heads * v.shape[-1]) - - attn_metadata.first_layer = False - - return self.o_proj(output)[0] - def _forward_decode( self, q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: TritonMLAMetadata, + attn_metadata: MLACommonMetadata, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 if self.kv_cache_dtype.startswith("fp8"): @@ -984,7 +84,7 @@ def _forward_decode( ( B, self.num_heads, - attn_metadata.num_kv_splits, + 4, #attn_metadata.num_kv_splits, # NOTE(lucas) idk why the +1 is here but sglang has it so we # just mirror that self.kv_lora_rank + 1, @@ -999,10 +99,16 @@ def _forward_decode( PAGE_SIZE = kv_c_and_k_pe_cache.size(1) # Run MQA - decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, attn_logits, - attn_metadata.num_kv_splits, self.scale, - PAGE_SIZE) + decode_attention_fwd( + q, + kv_c_and_k_pe_cache, + kv_c_cache, + o, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + attn_logits, + 4, + self.scale, #attn_metadata.num_kv_splits + PAGE_SIZE) return self._v_up_proj_and_o_proj(o) From b73fb74c30e62f1906823b05e7c8f8d0a1e0be70 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 13 Feb 2025 15:55:38 +0000 Subject: [PATCH 16/38] cleanup Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/utils.py | 54 +++++----------------------- vllm/attention/backends/utils.py | 2 -- vllm/distributed/parallel_state.py | 4 +-- vllm/engine/arg_utils.py | 25 +++++-------- 4 files changed, 18 insertions(+), 67 deletions(-) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index 74dfa8be87b0..07718dcd16d1 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -1112,6 +1112,13 @@ def _compute_prefill_context( attn_metadata: MLACommonMetadata, ): prefill_metadata = attn_metadata.prefill_metadata + assert prefill_metadata is not None + assert prefill_metadata.chunk_iter_toks is not None + assert prefill_metadata.chunk_cu_seq_lens is not None + assert prefill_metadata.chunk_seq_starts is not None + assert prefill_metadata.chunk_max_seq_lens is not None + # assert prefill_metadata.block_tables is not None + assert prefill_metadata.context_lens_tensor is not None output = None iters = len(prefill_metadata.chunk_iter_toks) @@ -1184,59 +1191,18 @@ def _compute_prefill_context( return output, output_lse - # Optional common flash-attn based prefill - def _forward_prefill_flash( - self, - q: torch.Tensor, - k_c_normed: torch.Tensor, - k_pe: torch.Tensor, - query_start_loc: torch.Tensor, - seq_start_loc: torch.Tensor, - max_query_len: int, - max_prefill_seq_len: int, - ) -> torch.Tensor: - - kv_nope = self.kv_b_proj(k_c_normed)[0]\ - .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) - - attn_output = flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=query_start_loc, - cu_seqlens_k=seq_start_loc, - max_seqlen_q=max_query_len, - max_seqlen_k=max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - fa_version=self.vllm_flash_attn_version, - ) - attn_output = attn_output\ - .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ - .reshape(-1, self.num_heads * v.shape[-1]) - - return self.o_proj(attn_output)[0] - def _forward_prefill( self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, + attn_metadata: T, ) -> torch.Tensor: assert isinstance(attn_metadata, MLACommonMetadata) prefill_metadata = attn_metadata.prefill_metadata + assert prefill_metadata is not None has_context = prefill_metadata.context_lens_tensor is not None \ and prefill_metadata.context_lens_tensor.max() > 0 @@ -1285,8 +1251,6 @@ def _forward_prefill( .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ .reshape(-1, self.num_heads * v.shape[-1]) - attn_metadata.first_layer = False - return self.o_proj(output)[0] @abstractmethod diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index d85c2ca06633..5c1f9916e22c 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -213,11 +213,9 @@ def build(self, seq_lens: List[int], query_lens: List[int], -1 if cuda graph is not used. batch_size: The maybe padded batch size. """ - print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") for inter_data in self.input_builder.inter_data_list: self._add_seq_group(inter_data, self.input_builder.chunked_prefill_enabled) - print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 65da45b02893..bfc41703b94d 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -609,9 +609,7 @@ def broadcast_tensor_dict( # all happening on CPU. Therefore, we can use the CPU group. self.broadcast_object(metadata_list, src=src) async_handles = [] - for i, tensor in enumerate(tensor_list): - print("broadcasting", metadata_list[i][0]) - print(" ", tensor.shape, tensor.stride()) + for tensor in tensor_list: if tensor.numel() == 0: # Skip broadcasting empty tensors. continue diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 130137b1f1f0..3cfdc361cf6e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1141,25 +1141,9 @@ def create_engine_config(self, msg = "Chunked prefill is not supported for pooling models" raise ValueError(msg) - # Default to `gpu_memory_utilization` of 0.9 if not specified - gpu_memory_utilization = self.gpu_memory_utilization if \ - self.gpu_memory_utilization is not None else 0.9 - # For models using MLA and chunked prefill, lower the default to 0.85 - # to account for the extra memory required to up-project the MLA cache - if self.gpu_memory_utilization is None and \ - (self.enable_chunked_prefill and model_config.use_mla): - gpu_memory_utilization = 0.85 - - if self.enable_chunked_prefill and model_config.use_mla: - # Currently chunked prefill with MLA is not supported with CUDA - # graphs - logger.warning("Chunked prefill is not supported with MLA. " - "Disabling CUDA graphs.") - self.enforce_eager = True - cache_config = CacheConfig( block_size=self.block_size, - gpu_memory_utilization=gpu_memory_utilization, + gpu_memory_utilization=self.gpu_memory_utilization, swap_space=self.swap_space, cache_dtype=self.kv_cache_dtype, is_attention_free=model_config.is_attention_free, @@ -1184,6 +1168,13 @@ def create_engine_config(self, worker_cls=self.worker_cls, ) + if self.enable_chunked_prefill and model_config.use_mla: + # Currently chunked prefill with MLA is not supported with CUDA + # graphs + logger.warning("Chunked prefill is not supported with MLA. " + "Disabling CUDA graphs.") + self.enforce_eager = True + speculative_config = SpeculativeConfig.maybe_create_spec_config( target_model_config=model_config, target_parallel_config=parallel_config, From 829ce2bdd904938351d528363b80dcd9bb828bfb Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 13 Feb 2025 16:57:14 +0000 Subject: [PATCH 17/38] relocate merge_attn_states Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/utils.py | 2 +- .../attention/ops/triton_merge_attn_states.py | 82 +++++++++++++++++++ vllm/v1/attention/backends/flash_attn.py | 81 +----------------- 3 files changed, 85 insertions(+), 80 deletions(-) create mode 100644 vllm/attention/ops/triton_merge_attn_states.py diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index 07718dcd16d1..dd635e83b949 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -22,6 +22,7 @@ compute_slot_mapping_start_idx, get_flash_attn_version, is_block_tables_empty) +from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.distributed import (get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -40,7 +41,6 @@ DeepseekScalingRotaryEmbedding, RotaryEmbedding) from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d, make_tensor_with_pad -from vllm.v1.attention.backends.flash_attn import merge_attn_states from vllm.vllm_flash_attn import flash_attn_varlen_func if TYPE_CHECKING: diff --git a/vllm/attention/ops/triton_merge_attn_states.py b/vllm/attention/ops/triton_merge_attn_states.py new file mode 100644 index 000000000000..0cf9a4bb818c --- /dev/null +++ b/vllm/attention/ops/triton_merge_attn_states.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch +import triton +import triton.language as tl + + +def merge_attn_states( + output: torch.Tensor, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output_lse: Optional[torch.Tensor] = None, +) -> None: + num_tokens = output.shape[0] + num_query_heads = output.shape[1] + head_size = output.shape[2] + padded_head_size = triton.next_power_of_2(head_size) + + # TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead. + merge_attn_states_kernel[(num_tokens, num_query_heads)]( + output, + output_lse, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + head_size, + padded_head_size, + output_lse is not None, + ) + + +@triton.jit +def merge_attn_states_kernel( + output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + output_lse, # [NUM_HEADS, NUM_TOKENS] + prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + prefix_lse, # [NUM_HEADS, NUM_TOKENS] + suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + suffix_lse, # [NUM_HEADS, NUM_TOKENS] + HEAD_SIZE: tl.constexpr, + PADDED_HEAD_SIZE: tl.constexpr, + OUTPUT_LSE: tl.constexpr, +): + token_idx = tl.program_id(0) + num_tokens = tl.num_programs(0) + head_idx = tl.program_id(1) + num_heads = tl.num_programs(1) + + p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx) + s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx) + max_lse = tl.maximum(p_lse, s_lse) + p_lse = p_lse - max_lse + s_lse = s_lse - max_lse + out_se = (tl.exp(p_lse) + tl.exp(s_lse)) + + if OUTPUT_LSE: + out_lse = tl.log(out_se) + max_lse + tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse) + + head_arange = tl.arange(0, PADDED_HEAD_SIZE) + head_mask = head_arange < HEAD_SIZE + p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + head_arange, + mask=head_mask) + s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + head_arange, + mask=head_mask) + + # NOTE(woosuk): Be careful with the numerical stability. + # We should compute the scale first, and then multiply it with the output. + # Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly. + p_scale = tl.exp(p_lse) / out_se + s_scale = tl.exp(s_lse) / out_se + out = p_out * p_scale + s_out * s_scale + tl.store(output + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + head_arange, + out, + mask=head_mask) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index d21085a6f3d3..fd66af9db97c 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -5,12 +5,11 @@ import numpy as np import torch -import triton -import triton.language as tl from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import get_flash_attn_version +from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.logger import init_logger from vllm.utils import cdiv from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -369,80 +368,4 @@ def cascade_attention( # Merge prefix and suffix outputs, and store the result in output. merge_attn_states(output, prefix_output, prefix_lse, suffix_output, - suffix_lse) - - -def merge_attn_states( - output: torch.Tensor, - prefix_output: torch.Tensor, - prefix_lse: torch.Tensor, - suffix_output: torch.Tensor, - suffix_lse: torch.Tensor, - output_lse: Optional[torch.Tensor] = None, -) -> None: - num_tokens = output.shape[0] - num_query_heads = output.shape[1] - head_size = output.shape[2] - padded_head_size = triton.next_power_of_2(head_size) - - # TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead. - merge_attn_states_kernel[(num_tokens, num_query_heads)]( - output, - output_lse, - prefix_output, - prefix_lse, - suffix_output, - suffix_lse, - head_size, - padded_head_size, - output_lse is not None, - ) - - -@triton.jit -def merge_attn_states_kernel( - output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - output_lse, # [NUM_HEADS, NUM_TOKENS] - prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - prefix_lse, # [NUM_HEADS, NUM_TOKENS] - suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - suffix_lse, # [NUM_HEADS, NUM_TOKENS] - HEAD_SIZE: tl.constexpr, - PADDED_HEAD_SIZE: tl.constexpr, - OUTPUT_LSE: tl.constexpr, -): - token_idx = tl.program_id(0) - num_tokens = tl.num_programs(0) - head_idx = tl.program_id(1) - num_heads = tl.num_programs(1) - - p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx) - s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx) - max_lse = tl.maximum(p_lse, s_lse) - p_lse = p_lse - max_lse - s_lse = s_lse - max_lse - out_se = (tl.exp(p_lse) + tl.exp(s_lse)) - - if OUTPUT_LSE: - out_lse = tl.log(out_se) + max_lse - tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse) - - head_arange = tl.arange(0, PADDED_HEAD_SIZE) - head_mask = head_arange < HEAD_SIZE - p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE + - head_idx * HEAD_SIZE + head_arange, - mask=head_mask) - s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE + - head_idx * HEAD_SIZE + head_arange, - mask=head_mask) - - # NOTE(woosuk): Be careful with the numerical stability. - # We should compute the scale first, and then multiply it with the output. - # Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly. - p_scale = tl.exp(p_lse) / out_se - s_scale = tl.exp(s_lse) / out_se - out = p_out * p_scale + s_out * s_scale - tl.store(output + token_idx * num_heads * HEAD_SIZE + - head_idx * HEAD_SIZE + head_arange, - out, - mask=head_mask) + suffix_lse) \ No newline at end of file From dee34f73ce9c2e844d0a4c19cab07cce45499519 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 13 Feb 2025 21:11:25 +0000 Subject: [PATCH 18/38] add comments Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/utils.py | 405 +++++++++++++++++---------- vllm/utils.py | 4 + 2 files changed, 268 insertions(+), 141 deletions(-) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index dd635e83b949..3a22886f3ca6 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -1,4 +1,194 @@ # SPDX-License-Identifier: Apache-2.0 +""" +This file implements common components for MLA implementations. + +First we define: + +Sq as Q sequence length +Skv as KV sequence length + +MLA has two possible ways of computing, a data-movement friendly approach and a +compute friendly approach, we generally want to use the compute friendly +approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1) +and the data-movement friendly approach for "decode" (i.e. the ratio +Sq / Skv is "large"). + +NOTE what we deem small and large is currently determined by if its labelled +prefill or decode by the scheduler, but this is something we should probably +tune. + +Main reference: DeepseekV2 paper, and FlashInfer Implementation +(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). + +Deepseek's MLA attention works the following way: +* Use a single latent vector to represent the entire KV cache. +* The attention "simulates" a multi-head attention, while the compute is + similar to multi-query attention. + +Below is example of both paths assuming batchsize = 1 + +## More Extent Definitions: + +C Context length, `Skv - Sq` +H hidden size +N number of attention heads +Lq latent dimension for Q 1536 in DSV3 +Lkv latent dimension for K/V 512 in DSV3 +P nope dimension, no rope. 128 in DSV3 +R rope dimension, goes through rope. 64 in DSV3 +V V head dim. 128 in DSV3 + +## Vector/Matrix Definitions + +h_t hidden states (input to attention) shape [Sq, H] +q_c latent/compressed Q shape [Sq, Lq] +q_nope uncompressed Q (no-rope) shape [Sq, N, P] +q_pe uncompressed Q (rope) shape [Sq, N, R] +kv_c latent/compressed KV shape [Skv, Lkv] +k_pe decoupled k position embeddings shape [Skv, R] +new_kv_c new kv_c from current iter shape [Sq, Lkv] +new_k_pe new k_pe from current iter shape [Sq, R] +cache_kv_c cached k_c from previous iters shape [C, N, Lkv] +cache_k_pe cached k_pe from previous iters shape [C, N, R] +W_DQ project h_t to q_c shape [H, Lq] +W_UQ project q_c to q_nope shape [Lq, N * P] +W_QR project q_c to q_pe shape [Lq, N * R] +W_DKV project h_t to kv_c shape [H, Lkv] +W_UK project kv_c to k_nope shape [Lkv, N * P] +W_KR project h_t to k_pe shape [H, N * R] +W_UV project kv_c to v shape [Lkv, N * V] +W_O project v to h_t shape [N * V, H] + + +## Compute Friendly Approach (i.e. "_forward_prefill"): + +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(Sq, N, P) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) +k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) +k_nope = (kv_c @ W_UK).view(Skv, N, P) +v = (kv_c @ W_UV).view(Skv, N, V) + +// MHA with QK headdim = P + R +// V headdim = V +// spda_o shape [Sq, N, V] +spda_o = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([k_nope, k_pe.repeat(Skv, N, R)], dim=-1), + v +) +return spda_o @ W_O + +NOTE: in the actual code, + `kv_b_proj` is [W_UK; W_UV] + `q_b_proj` is [W_UQ; W_QR] + `out_proj` is W_O + + +## Data-Movement Friendly Approach (i.e. "_forward_decode"): + +Ahead of time, compute: + +% this projects from q_c to [Sq, N * Lkv] +W_UQ_UK = einsum("qnp,knp -> qnk" + W_UQ.view(Lq, N, P), W_UK.view(Lkv, N, P) + ).view(Lkv, N * Lkv) +% this projects from attn output [Sq, N * Lkv] to [Sq, H] +W_UV_O = einsum("knv,nvh -> nkh" + W_UV.view(Lkv, N, V), W_O.view(N, V, H) + ).view(N * Lkv, H) + +Runtime +q_c = h_t @ W_DQ +q_latent = q_c @ W_UQ_UK.view(Sq, N, Lkv) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) +k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) + +// MQA with QK headdim = Lkv + R +// V headdim = Lkv +// spda_o shape [Sq, N, Lkv] +// NOTE: this is less compute-friendly since Lkv > P +// but is more data-movement friendly since its MQA vs MHA +spda_o = scaled_dot_product_attention( + torch.cat([q_latent, q_pe], dim=-1), + torch.cat([kv_c, k_pe)], dim=-1), + kv_c +) +return spda_o.reshape(-1, N * Lkv) @ W_UV_O + + +## Chunked Prefill + +For chunked prefill we want to use the compute friendly. We are assuming +sufficiently large Sq / Skv ratio, in the future may want to switch to the +data-movement friendly approach if the chunk (i.e. `Sq`) is small. + +However, the compute-friendly approach can potentially run out of memory if Skv +is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` + +To mitigate this, we chunk the computation of attention with current context +i.e. `cache_kv_c` and `cache_k_pe` so that we can used a fixed workspace size. + +The chunked prefill approach is as follows: + +MCC Max chunk of context to process per iter, computed dynamically, + used to bound the memory usage + +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(Sq, N, P) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +new_k_nope = (new_kv_c @ W_UK).view(Sq, N, P) +new_v = (new_kv_c @ W_UV).view(Sq, N, V) + +// MHA between queries and new KV +// with QK headdim = P + R +// V headdim = V +// curr_o shape [Sq, N, V] +// curr_lse shape [N, Sq], this is just order FA returns +curr_o, curr_lse = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([new_k_nope, new_k_pe.repeat(Sq, N, R)], dim=-1), + new_v, + casual=True, + return_softmax_lse=True +) + +// Compute attention with the already existing context +for chunk_idx in range(cdiv(C, MCC)): + chunk_start = chunk_idx * MCC + chunk_end = min(chunk_start + MCC, C) + Sc = chunk_end - chunk_start + cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end] + cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end] + cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P) + cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V) + + chunk_o, chunk_lse = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([cache_k_nope_chunk, cache_k_pe_chunk.repeat(Sc, N, R)], + dim=-1), + cache_v_chunk, + casual=False, + return_softmax_lse=True + ) + + curr_o, curr_lse = merge_attn_states( + suffix_output=curr_o, + suffix_lse=curr_lse, + prefix_output=chunk_o, + prefix_lse=chunk_lse, + ) + +return curr_o @ W_O +""" from abc import abstractmethod from collections import defaultdict @@ -40,7 +230,7 @@ from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) from vllm.multimodal import MultiModalPlaceholderMap -from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down from vllm.vllm_flash_attn import flash_attn_varlen_func if TYPE_CHECKING: @@ -221,7 +411,10 @@ def begin_forward(self, model_input): @dataclass class MLACommonMetadata(AttentionMetadata): - """Metadata for MLACommon. + """Metadata for MLACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class NOTE: Any python object stored here is not updated when it is cuda-graph replayed. If you have values that need to be changed @@ -298,10 +491,10 @@ class MLACommonMetadata(AttentionMetadata): head_dim: Optional[int] = None # For chunked prefill - chunk_cu_seq_lens: Optional[torch.Tensor] = None - chunk_seq_starts: Optional[torch.Tensor] = None - chunk_iter_toks: Optional[List[int]] = None - chunk_max_seq_lens: Optional[List[int]] = None + context_chunk_cu_seq_lens: Optional[torch.Tensor] = None + context_chunk_starts: Optional[torch.Tensor] = None + context_chunk_seq_tot: Optional[List[int]] = None + context_chunk_max_seq_lens: Optional[List[int]] = None def __post_init__(self): supported_head_sizes = MLACommonBackend.get_supported_head_sizes() @@ -366,10 +559,10 @@ def prefill_metadata(self) -> Optional["MLACommonMetadata"]: block_tables=block_tables, head_dim=self.head_dim, # MLACommonMetadata Chunk prefill specific - chunk_cu_seq_lens=self.chunk_cu_seq_lens, - chunk_seq_starts=self.chunk_seq_starts, - chunk_iter_toks=self.chunk_iter_toks, - chunk_max_seq_lens=self.chunk_max_seq_lens, + context_chunk_cu_seq_lens=self.context_chunk_cu_seq_lens, + context_chunk_starts=self.context_chunk_starts, + context_chunk_seq_tot=self.context_chunk_seq_tot, + context_chunk_max_seq_lens=self.context_chunk_max_seq_lens, ) return self._cached_prefill_metadata @@ -499,6 +692,10 @@ def advance_step(self, class MLACommonMetadataBuilder(AttentionMetadataBuilder[MLACommonMetadata]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.input_builder = input_builder @@ -670,45 +867,59 @@ def build(self, seq_lens: List[int], query_lens: List[int], seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, device, self.runner.pin_memory) - chunk_cu_seq_lens = None - chunk_seq_starts = None - chunk_iter_toks = None - chunk_max_seq_lens = None + context_chunk_cu_seq_lens = None + context_chunk_starts = None + context_chunk_seq_tot = None + context_chunk_max_seq_lens = None if self.chunked_prefill_enabled and self.num_prefills > 0 \ - and context_lens_tensor is not None: + and context_lens_tensor is not None \ + and context_lens_tensor[:self.num_prefills].max() > 0: + + # NOTE: it is recommend you read the `Chunked Prefill` section in + # the comment at the top of the file before trying to understand + # the following code + chunked_prefill_workspace_size = \ self.attn_state.chunked_prefill_workspace.shape[0] page_size = self.runner.block_size - seq_chunk_size = chunked_prefill_workspace_size // self.num_prefills - print(self.attn_state.chunked_prefill_workspace.shape, - self.num_prefills, page_size, seq_chunk_size) - - # align seq_chunk_size to page_size by rounding down - seq_chunk_size = seq_chunk_size - (seq_chunk_size % page_size) - assert seq_chunk_size > 0 - num_chunks = (context_lens_tensor.max() + seq_chunk_size - - 1) // seq_chunk_size - - # if `seq_chunk_size = 256`, `num_chunks = 3`, and - # `num_prefills = 4`, create a tensor that looks like + num_prefills_with_context = \ + (context_lens_tensor[:self.num_prefills] > 0).sum().item() + + # currently we allocate an equal amount of workspace for each + # prefill in the batch, we could probably use a more advanced + # algorithm here and allocate more workspace to prefills with + # longer context lengths + max_context_chunk = \ + chunked_prefill_workspace_size // num_prefills_with_context + + # align max_context_chunk to page_size by rounding down, + # currently the `gather_cache` kernel cannot handle + # `context_chunk_starts` that are not aligned to page_size + max_context_chunk = round_down(max_context_chunk, page_size) + assert max_context_chunk > 0 + num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk) + + # if `max_context_chunk = 256`, `num_chunks = 3`, and + # `num_prefills_with_context = 4`, create a tensor that looks like # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] - # Note: we only chunk the current context so we can separate the - # causal masked and un-masked portions of the computation - chunk_seq_starts = \ + context_chunk_starts = \ torch.arange(num_chunks, device=device, dtype=torch.int32)\ .unsqueeze(1).expand(-1, self.num_prefills)\ - * seq_chunk_size + * max_context_chunk chunk_ends = torch.min(context_lens_tensor[:self.num_prefills]\ - .unsqueeze(0), chunk_seq_starts + seq_chunk_size) - chunk_seq_lens = (chunk_ends - chunk_seq_starts).clamp(min=0) - _chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(torch.int32) + .unsqueeze(0), context_chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - context_chunk_starts).clamp(min=0) + _context_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to( + torch.int32) zero = torch.zeros(num_chunks, dtype=torch.int32, device=device)\ .unsqueeze(-1) - chunk_cu_seq_lens = torch.cat([zero, _chunk_cu_seq_lens], dim=1) - chunk_max_seq_lens = chunk_seq_lens.max(dim=1).values.tolist() - chunk_iter_toks = chunk_seq_lens.sum(dim=1).tolist() + context_chunk_cu_seq_lens = \ + torch.cat([zero, _context_chunk_cu_seq_lens], dim=1) + context_chunk_max_seq_lens = \ + chunk_seq_lens.max(dim=1).values.tolist() + context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist() return MLACommonMetadata( # Required by ModelRunner @@ -736,103 +947,17 @@ def build(self, seq_lens: List[int], query_lens: List[int], block_tables=block_tables, head_dim=self.runner.model_config.get_head_size(), # MLACommonMetadata Chunk prefill specific - chunk_cu_seq_lens=chunk_cu_seq_lens, - chunk_seq_starts=chunk_seq_starts, - chunk_iter_toks=chunk_iter_toks, - chunk_max_seq_lens=chunk_max_seq_lens, + context_chunk_cu_seq_lens=context_chunk_cu_seq_lens, + context_chunk_starts=context_chunk_starts, + context_chunk_seq_tot=context_chunk_seq_tot, + context_chunk_max_seq_lens=context_chunk_max_seq_lens, ) class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): """ - Common class for implementing repeated parts - - Main reference: DeepseekV2 paper, and FlashInfer Implementation - (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). - - Deepseek's MLA attention works the following way: - * Use a single latent vector to represent the entire KV cache. - * The attention "simulates" a multi-head attention, while the compute is - similar to multi-query attention. - * The dataflow is as follows, - - * B: batch/sequence length - * H: hidden size - * N: number of attention heads - * Lq: latent dimension for Q - * Lkv: latent dimension for K/V - * P: nope dimension, P+R is the actual head_dim in common attention. - * R: rope dimension, this slide of the head_dim goes through rope. - * V: V head dim. - * kv_c: latent/compressed KV - * q_c: latent/compressed Q - - # - # Outside the MLA attention backend - # - - 1. The hidden states (B, H) are projected down into cq (B, Lq) and - kv_c_k_pe (B, Lkv+R). - 2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq - and kv_c are normalized. - - # - # Inside the MLA attention backend - # - - * if prefill: - - 3. The q_c is then projected up into the multi-head version. - * q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope - (B, N, P) and q_pe (B, N, R). - 4. q_pe, k_pe are then passed through rotary embeddings. - 5. kv_c and k_pe are concatenated and inserted into the cache - 6. The kv_c is then projected up into the multi-head version. - * kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope - dimensions for K and V, which is split into k_nope (B, N, P) - and v (B, N, V). - 7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from - q_nope, q_pe, k_nope, k_pe. - 8. Attention is computued with q, k, v. - 9. The attention computation returns (B, N, V), which is projected back - to (B, H) using out projection. - - * if decode: - - 3. Here's the change, we do not perform up the full up projection for - q_c, and there is no up projection at all for kv_c. This is - achieved by the technique of "weight absorption". The paper says - "Fortunately, due to the associative law of matrix multiplication, - we can absorb WUK into WUQ, and WUV into WO" - * The q up projection turns (B, Lq) into (B, N, (P+R)), we split it - into W_UQ (Lq, N, P) and W_QR (Lq, N, R). - * The kv_c up projection turns (B, Lkv) into (B, N, (P+V)), we split - it into W_UK (Lkv, N, P) and W_UV (Lkv, N, V). - * The out projection shape W_O (N*V, H) turns (B, N, V) into (B, H). - * We can precompute the product of W_UQ and W_UK into - W_UQ_UK (Lq, N, Lkv), which is possible due to QK^T operation in - attention. - * We can precompute the product of W_UV and W_O into - W_UV_O (N, Lkv, H), which is possible due to V@O as the - "epilogue" of attention - 4. We still need to compute q_pe (B, N, R) by applying W_QR to q_latent. - 5. q_pe, k_pe are then passed through rotary embeddings. - 6. kv_c and k_pe are concatenated and inserted into the cache - 7. By applying W_UQ_UK to q_latent, we have the new q_nope of shape - (B, N, Lkv). - 8. q (B, N, (Lkv+R)), k (B, (Lkv+R)) are assembled from q_nope, q_pe, - kv_a, k_pe. v (B, Lkv) is exactly the same vector as kv_a. - 9. The attention is computed with q, k, v. Note that we just performed - a MQA attention with (LKv+R) as our head dim. - 10. The KV cache is updated using the new entries k (B, N, (Lkv+R)), - which included the v and rope values. - 11. The attention computation returns (B, N, Lkv), which is projected - back to (B, H) using W_UV_O. - - From @tsu-bin's calculation, we only want to use the absorption technique - for decode. The prefill algorithm should still use the up-projected MHA - for less flops and memory usage. - + NOTE: Please read the comment at the top of the file before trying to + understand this class """ def __init__( @@ -1113,30 +1238,28 @@ def _compute_prefill_context( ): prefill_metadata = attn_metadata.prefill_metadata assert prefill_metadata is not None - assert prefill_metadata.chunk_iter_toks is not None - assert prefill_metadata.chunk_cu_seq_lens is not None - assert prefill_metadata.chunk_seq_starts is not None - assert prefill_metadata.chunk_max_seq_lens is not None + assert prefill_metadata.context_chunk_seq_tot is not None + assert prefill_metadata.context_chunk_cu_seq_lens is not None + assert prefill_metadata.context_chunk_starts is not None + assert prefill_metadata.context_chunk_max_seq_lens is not None # assert prefill_metadata.block_tables is not None assert prefill_metadata.context_lens_tensor is not None output = None - iters = len(prefill_metadata.chunk_iter_toks) + iters = len(prefill_metadata.context_chunk_seq_tot) workspace = prefill_metadata.attn_state.chunked_prefill_workspace for i in range(iters): - chunk_cu_seq_lens = prefill_metadata.chunk_cu_seq_lens[i] - chunk_seq_starts = prefill_metadata.chunk_seq_starts[i] - toks = prefill_metadata.chunk_iter_toks[i] - max_seq_len = prefill_metadata.chunk_max_seq_lens[i] + toks = prefill_metadata.context_chunk_seq_tot[i] + max_seq_len = prefill_metadata.context_chunk_max_seq_lens[i] ops.gather_cache( src_cache=kv_c_and_k_pe_cache, dst=workspace, block_table=prefill_metadata.block_tables, - cu_seq_lens=chunk_cu_seq_lens, + cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i], batch_size=prefill_metadata.num_prefills, - seq_starts=chunk_seq_starts, + seq_starts=prefill_metadata.context_chunk_starts[i], ) k_c_normed = workspace[:toks]\ @@ -1163,7 +1286,7 @@ def _compute_prefill_context( k=k, v=v_padded, cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=chunk_cu_seq_lens, + cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], max_seqlen_q=prefill_metadata.max_query_len, max_seqlen_k=max_seq_len, softmax_scale=self.scale, diff --git a/vllm/utils.py b/vllm/utils.py index 1d7fbd4a7879..b64f1f3420f9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -565,6 +565,10 @@ def round_up(x: int, y: int) -> int: return ((x + y - 1) // y) * y +def round_down(x: int, y: int) -> int: + return (x // y) * y + + def _generate_random_fp8( tensor: torch.Tensor, low: float, From b28f99a6f4c611596c3de9bca4053eaf5c40c784 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 13 Feb 2025 21:19:24 +0000 Subject: [PATCH 19/38] comment fixes Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index 3a22886f3ca6..376a9b2a8c99 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -48,8 +48,8 @@ k_pe decoupled k position embeddings shape [Skv, R] new_kv_c new kv_c from current iter shape [Sq, Lkv] new_k_pe new k_pe from current iter shape [Sq, R] -cache_kv_c cached k_c from previous iters shape [C, N, Lkv] -cache_k_pe cached k_pe from previous iters shape [C, N, R] +cache_kv_c cached k_c from previous iters shape [C, Lkv] +cache_k_pe cached k_pe from previous iters shape [C, R] W_DQ project h_t to q_c shape [H, Lq] W_UQ project q_c to q_nope shape [Lq, N * P] W_QR project q_c to q_pe shape [Lq, N * R] @@ -77,7 +77,7 @@ // spda_o shape [Sq, N, V] spda_o = scaled_dot_product_attention( torch.cat([q_nope, q_pe], dim=-1), - torch.cat([k_nope, k_pe.repeat(Skv, N, R)], dim=-1), + torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), v ) return spda_o @ W_O @@ -155,7 +155,7 @@ // curr_lse shape [N, Sq], this is just order FA returns curr_o, curr_lse = scaled_dot_product_attention( torch.cat([q_nope, q_pe], dim=-1), - torch.cat([new_k_nope, new_k_pe.repeat(Sq, N, R)], dim=-1), + torch.cat([new_k_nope, new_k_pe..unsqueeze(1).expand(-1, N, -1)], dim=-1), new_v, casual=True, return_softmax_lse=True @@ -173,7 +173,8 @@ chunk_o, chunk_lse = scaled_dot_product_attention( torch.cat([q_nope, q_pe], dim=-1), - torch.cat([cache_k_nope_chunk, cache_k_pe_chunk.repeat(Sc, N, R)], + torch.cat([cache_k_nope_chunk, + cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)], dim=-1), cache_v_chunk, casual=False, From 54ae713ac314c80ebe96da452002aedf1b3eaea8 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 13 Feb 2025 21:20:22 +0000 Subject: [PATCH 20/38] minor fixes Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index 376a9b2a8c99..030fb2e1304c 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -117,7 +117,7 @@ // but is more data-movement friendly since its MQA vs MHA spda_o = scaled_dot_product_attention( torch.cat([q_latent, q_pe], dim=-1), - torch.cat([kv_c, k_pe)], dim=-1), + torch.cat([kv_c, k_pe], dim=-1), kv_c ) return spda_o.reshape(-1, N * Lkv) @ W_UV_O @@ -155,7 +155,7 @@ // curr_lse shape [N, Sq], this is just order FA returns curr_o, curr_lse = scaled_dot_product_attention( torch.cat([q_nope, q_pe], dim=-1), - torch.cat([new_k_nope, new_k_pe..unsqueeze(1).expand(-1, N, -1)], dim=-1), + torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), new_v, casual=True, return_softmax_lse=True From 3db8ab6603e01d352ec5b7744d93567fddc16c7a Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 13 Feb 2025 21:25:53 +0000 Subject: [PATCH 21/38] remove no-longer necessary changes Signed-off-by: Lucas Wilkinson --- vllm/engine/arg_utils.py | 56 ++++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3cfdc361cf6e..485e4caa5520 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -118,7 +118,7 @@ class EngineArgs: use_v2_block_manager: bool = True swap_space: float = 4 # GiB cpu_offload_gb: float = 0 # GiB - gpu_memory_utilization: Optional[float] = None + gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None max_num_seqs: Optional[int] = None max_logprobs: int = 20 # Default value for OpenAI Chat Completions API @@ -1097,6 +1097,33 @@ def create_engine_config(self, "has been disabled.") self.enable_prefix_caching = False + cache_config = CacheConfig( + block_size=self.block_size, + gpu_memory_utilization=self.gpu_memory_utilization, + swap_space=self.swap_space, + cache_dtype=self.kv_cache_dtype, + is_attention_free=model_config.is_attention_free, + num_gpu_blocks_override=self.num_gpu_blocks_override, + sliding_window=model_config.get_sliding_window(), + enable_prefix_caching=self.enable_prefix_caching, + cpu_offload_gb=self.cpu_offload_gb, + calculate_kv_scales=self.calculate_kv_scales, + ) + parallel_config = ParallelConfig( + pipeline_parallel_size=self.pipeline_parallel_size, + tensor_parallel_size=self.tensor_parallel_size, + max_parallel_loading_workers=self.max_parallel_loading_workers, + disable_custom_all_reduce=self.disable_custom_all_reduce, + tokenizer_pool_config=TokenizerPoolConfig.create_config( + self.tokenizer_pool_size, + self.tokenizer_pool_type, + self.tokenizer_pool_extra_config, + ), + ray_workers_use_nsight=self.ray_workers_use_nsight, + distributed_executor_backend=self.distributed_executor_backend, + worker_cls=self.worker_cls, + ) + max_model_len = model_config.max_model_len use_long_context = max_model_len > 32768 if self.enable_chunked_prefill is None: @@ -1141,33 +1168,6 @@ def create_engine_config(self, msg = "Chunked prefill is not supported for pooling models" raise ValueError(msg) - cache_config = CacheConfig( - block_size=self.block_size, - gpu_memory_utilization=self.gpu_memory_utilization, - swap_space=self.swap_space, - cache_dtype=self.kv_cache_dtype, - is_attention_free=model_config.is_attention_free, - num_gpu_blocks_override=self.num_gpu_blocks_override, - sliding_window=model_config.get_sliding_window(), - enable_prefix_caching=self.enable_prefix_caching, - cpu_offload_gb=self.cpu_offload_gb, - calculate_kv_scales=self.calculate_kv_scales, - ) - parallel_config = ParallelConfig( - pipeline_parallel_size=self.pipeline_parallel_size, - tensor_parallel_size=self.tensor_parallel_size, - max_parallel_loading_workers=self.max_parallel_loading_workers, - disable_custom_all_reduce=self.disable_custom_all_reduce, - tokenizer_pool_config=TokenizerPoolConfig.create_config( - self.tokenizer_pool_size, - self.tokenizer_pool_type, - self.tokenizer_pool_extra_config, - ), - ray_workers_use_nsight=self.ray_workers_use_nsight, - distributed_executor_backend=self.distributed_executor_backend, - worker_cls=self.worker_cls, - ) - if self.enable_chunked_prefill and model_config.use_mla: # Currently chunked prefill with MLA is not supported with CUDA # graphs From 04644e3fa19b95a17225e36062920e67eabb4ab3 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 13 Feb 2025 21:46:23 +0000 Subject: [PATCH 22/38] clean-up Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/utils.py | 4 ++++ vllm/engine/arg_utils.py | 6 ++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index 030fb2e1304c..18f1fd5c70f8 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -350,6 +350,10 @@ def graph_capture_get_metadata_for_batch( assert self._is_graph_capturing attn_metadata = self.runner.attn_backend.make_metadata( + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + use_cuda_graph=True, + attn_state=self.runner.attn_state, num_prefills=0, num_prefill_tokens=0, num_decode_tokens=batch_size, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 485e4caa5520..5e136325f498 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1171,9 +1171,11 @@ def create_engine_config(self, if self.enable_chunked_prefill and model_config.use_mla: # Currently chunked prefill with MLA is not supported with CUDA # graphs - logger.warning("Chunked prefill is not supported with MLA. " - "Disabling CUDA graphs.") + logger.warning( + "Chunked prefill + CUDA Graphs is not supported with" + " MLA in V0. Disabling CUDA graphs.") self.enforce_eager = True + model_config.enforce_eager = True speculative_config = SpeculativeConfig.maybe_create_spec_config( target_model_config=model_config, From 4398787009b45e47aeb195783e89c00f33066e24 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 13 Feb 2025 22:42:27 +0000 Subject: [PATCH 23/38] fix tp Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/utils.py | 31 ++++++++++++---------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index 18f1fd5c70f8..b467ffcbf525 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -353,7 +353,6 @@ def graph_capture_get_metadata_for_batch( multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=False, use_cuda_graph=True, - attn_state=self.runner.attn_state, num_prefills=0, num_prefill_tokens=0, num_decode_tokens=batch_size, @@ -411,7 +410,9 @@ def prepare_graph_input_buffers(self, "TritonMLAState does not support encoder/decoder yet") def begin_forward(self, model_input): - return + if self.chunked_prefill_enabled: + model_input.attn_metadata.chunked_prefill_workspace = \ + self.chunked_prefill_workspace @dataclass @@ -431,11 +432,6 @@ class MLACommonMetadata(AttentionMetadata): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool - # Smuggle the state to the impl via meta-data, we need this for the - # `chunked_prefill_workspace` but passing that directly will result in - # it unnecessarily being broadcasted to all workers. - attn_state: MLACommonState - # Input positions for rotrary embeddings since for MLA the rotary # position embeddings are applied inside the attention backend input_positions: torch.Tensor @@ -550,7 +546,6 @@ def prefill_metadata(self) -> Optional["MLACommonMetadata"]: multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=False, # MLACommonMetadata - attn_state=self.attn_state, input_positions=input_positions, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, @@ -602,7 +597,6 @@ def decode_metadata(self) -> Optional["MLACommonMetadata"]: multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=False, # MLACommonMetadata - attn_state=self.attn_state, seq_lens=None, seq_lens_tensor=seq_lens_tensor, max_decode_query_len=self.max_decode_query_len, @@ -709,7 +703,12 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.block_size = input_builder.block_size self.chunked_prefill_enabled = \ self.runner.scheduler_config.chunked_prefill_enabled - self.attn_state = self.input_builder.runner.attn_state + + if self.chunked_prefill_enabled: + attn_state = self.input_builder.runner.attn_state + self.chunked_prefill_workspace_size = \ + attn_state.chunked_prefill_workspace.shape[0] + self.page_size = self.runner.block_size def prepare(self): self.slot_mapping: List[int] = [] @@ -885,10 +884,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], # the comment at the top of the file before trying to understand # the following code - chunked_prefill_workspace_size = \ - self.attn_state.chunked_prefill_workspace.shape[0] - page_size = self.runner.block_size - num_prefills_with_context = \ (context_lens_tensor[:self.num_prefills] > 0).sum().item() @@ -897,12 +892,12 @@ def build(self, seq_lens: List[int], query_lens: List[int], # algorithm here and allocate more workspace to prefills with # longer context lengths max_context_chunk = \ - chunked_prefill_workspace_size // num_prefills_with_context + self.chunked_prefill_workspace_size // num_prefills_with_context # align max_context_chunk to page_size by rounding down, # currently the `gather_cache` kernel cannot handle # `context_chunk_starts` that are not aligned to page_size - max_context_chunk = round_down(max_context_chunk, page_size) + max_context_chunk = round_down(max_context_chunk, self.page_size) assert max_context_chunk > 0 num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk) @@ -938,7 +933,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], multi_modal_placeholder_index_maps=None, # Not Attention Related enable_kv_scales_calculation=False, # MLACommonMetadata - attn_state=self.attn_state, input_positions=input_positions, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, @@ -1252,7 +1246,8 @@ def _compute_prefill_context( output = None iters = len(prefill_metadata.context_chunk_seq_tot) - workspace = prefill_metadata.attn_state.chunked_prefill_workspace + assert hasattr(attn_metadata, "chunked_prefill_workspace") + workspace = attn_metadata.chunked_prefill_workspace for i in range(iters): toks = prefill_metadata.context_chunk_seq_tot[i] From d73f9ff67b42af7c57d4f1656279a7b8682addd4 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 14 Feb 2025 03:52:02 +0000 Subject: [PATCH 24/38] review comments Signed-off-by: Lucas Wilkinson --- csrc/core/math.hpp | 5 ----- .../cutlass_w8a8/scaled_mm_c3x.cu | 5 +++-- vllm/attention/backends/mla/utils.py | 19 +++++++++++-------- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp index ddfaca27147b..b8171133f6aa 100644 --- a/csrc/core/math.hpp +++ b/csrc/core/math.hpp @@ -7,8 +7,3 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) { if (num <= 1) return num; return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); } - -template -inline constexpr std::enable_if_t, T> ceil_div(T a, T b) { - return (a + b - 1) / b; -} \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index e40f28229968..53921abc951c 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -1,7 +1,7 @@ #include #include "c3x/scaled_mm_kernels.hpp" -#include "core/math.hpp" +#include "cuda_utils.h" /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for @@ -33,7 +33,8 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, auto make_group_shape = [](torch::Tensor const& x, torch::Tensor const& s) -> GroupShape { TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); - return {ceil_div(x.size(0), s.size(0)), ceil_div(x.size(1), s.size(1))}; + return {cuda_utils::ceil_div(x.size(0), s.size(0)), + cuda_utils::ceil_div(x.size(1), s.size(1))}; }; GroupShape a_scale_group_shape = make_group_shape(a, a_scales); diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index b467ffcbf525..353eff1bf435 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -21,9 +21,9 @@ (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). Deepseek's MLA attention works the following way: -* Use a single latent vector to represent the entire KV cache. -* The attention "simulates" a multi-head attention, while the compute is - similar to multi-query attention. +* Use a single latent vector to represent the per-token entry of the KV cache. +* For decode (i.e. the memory friendly approach) the attention "simulates" a +multi-head attention, while the compute is similar to multi-query attention. Below is example of both paths assuming batchsize = 1 @@ -125,15 +125,16 @@ ## Chunked Prefill -For chunked prefill we want to use the compute friendly. We are assuming -sufficiently large Sq / Skv ratio, in the future may want to switch to the -data-movement friendly approach if the chunk (i.e. `Sq`) is small. +For chunked prefill we want to use the compute friendly algorithm. We are +assuming sufficiently large Sq / Skv ratio, in the future may want to switch to +the data-movement friendly approach if the chunk (i.e. `Sq`) is small. However, the compute-friendly approach can potentially run out of memory if Skv is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` -To mitigate this, we chunk the computation of attention with current context -i.e. `cache_kv_c` and `cache_k_pe` so that we can used a fixed workspace size. +To mitigate this, we chunk the computation of attention with respect to the +current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a +fixed workspace size. The chunked prefill approach is as follows: @@ -308,6 +309,8 @@ def __init__(self, runner): # For long-context models try not to over-allocate limiting # kv-cache space, limiting it to 64k tokens 64 * 1024) + assert scheduler_config.max_num_seqs * cache_config.block_size \ + > 64 * 1024 self.chunked_prefill_workspace = torch.empty( (workspace_size, model_config.get_head_size()), From f4da0b6152c4a342c933f418fefaa2f3a275c929 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 14 Feb 2025 04:11:38 +0000 Subject: [PATCH 25/38] minor fix Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index 353eff1bf435..fe18e6a8949f 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -309,8 +309,8 @@ def __init__(self, runner): # For long-context models try not to over-allocate limiting # kv-cache space, limiting it to 64k tokens 64 * 1024) - assert scheduler_config.max_num_seqs * cache_config.block_size \ - > 64 * 1024 + assert workspace_size > \ + scheduler_config.max_num_seqs * cache_config.block_size self.chunked_prefill_workspace = torch.empty( (workspace_size, model_config.get_head_size()), From 50a53aa875f0eab66028b770d465cda4e2b856c4 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 14 Feb 2025 06:17:31 +0000 Subject: [PATCH 26/38] fix wrong device, increase workspace, enable cuda-graphs Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/utils.py | 46 +++++++++++++++++++--------- vllm/engine/arg_utils.py | 9 ------ 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index fe18e6a8949f..0d08742cb92d 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -294,30 +294,30 @@ def __init__(self, runner): self._is_graph_capturing = False scheduler_config = runner.scheduler_config - model_config = runner.model_config + self.model_config = runner.model_config cache_config = runner.cache_config self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled if self.chunked_prefill_enabled: - workspace_size = min( - # Max sure there is enough for 1 full length request or at least - # 2 pages of cache per request + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request max( - model_config.max_model_len, 2 * + 8 * self.model_config.max_model_len, 4 * scheduler_config.max_num_seqs * cache_config.block_size), # For long-context models try not to over-allocate limiting - # kv-cache space, limiting it to 64k tokens - 64 * 1024) - assert workspace_size > \ + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ scheduler_config.max_num_seqs * cache_config.block_size - self.chunked_prefill_workspace = torch.empty( - (workspace_size, model_config.get_head_size()), - dtype=model_config.dtype, - device=runner.device, - ) - @contextmanager def graph_capture(self, max_batch_size: int): self._is_graph_capturing = True @@ -414,6 +414,20 @@ def prepare_graph_input_buffers(self, def begin_forward(self, model_input): if self.chunked_prefill_enabled: + if not hasattr(self, "chunked_prefill_workspace"): + # not self.runner.device does not return the correct device + # for this process, (init_device sets the correct device but + # only on the Worker). The only way Ive figured out to get the + # correct device is to allocate the workspace on the first call + # to begin_forward and use the device of the input tokens + assert model_input.input_tokens is not None + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=model_input.input_tokens.device, + ) + model_input.attn_metadata.chunked_prefill_workspace = \ self.chunked_prefill_workspace @@ -710,7 +724,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): if self.chunked_prefill_enabled: attn_state = self.input_builder.runner.attn_state self.chunked_prefill_workspace_size = \ - attn_state.chunked_prefill_workspace.shape[0] + attn_state.chunked_prefill_workspace_size self.page_size = self.runner.block_size def prepare(self): @@ -923,6 +937,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], context_chunk_max_seq_lens = \ chunk_seq_lens.max(dim=1).values.tolist() context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist() + assert max(context_chunk_seq_tot) <= \ + self.chunked_prefill_workspace_size return MLACommonMetadata( # Required by ModelRunner diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5e136325f498..94a65486c580 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1168,15 +1168,6 @@ def create_engine_config(self, msg = "Chunked prefill is not supported for pooling models" raise ValueError(msg) - if self.enable_chunked_prefill and model_config.use_mla: - # Currently chunked prefill with MLA is not supported with CUDA - # graphs - logger.warning( - "Chunked prefill + CUDA Graphs is not supported with" - " MLA in V0. Disabling CUDA graphs.") - self.enforce_eager = True - model_config.enforce_eager = True - speculative_config = SpeculativeConfig.maybe_create_spec_config( target_model_config=model_config, target_parallel_config=parallel_config, From e0a758e76d535c2a5504ad5c4fa80c7d3072bc59 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 14 Feb 2025 21:04:58 +0000 Subject: [PATCH 27/38] minor changes Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/utils.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index 0d08742cb92d..65a7cdb449e9 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -83,8 +83,8 @@ return spda_o @ W_O NOTE: in the actual code, - `kv_b_proj` is [W_UK; W_UV] - `q_b_proj` is [W_UQ; W_QR] + `kv_b_proj` is [W_UK; W_UV] concatnated per head + `q_b_proj` is [W_UQ; W_QR] concatnated per head `out_proj` is W_O @@ -449,6 +449,7 @@ class MLACommonMetadata(AttentionMetadata): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # New for MLA (compared to FlashAttention) # Input positions for rotrary embeddings since for MLA the rotary # position embeddings are applied inside the attention backend input_positions: torch.Tensor @@ -508,6 +509,7 @@ class MLACommonMetadata(AttentionMetadata): # The dimension of the attention heads head_dim: Optional[int] = None + # New for MLA (compared to FlashAttention) # For chunked prefill context_chunk_cu_seq_lens: Optional[torch.Tensor] = None context_chunk_starts: Optional[torch.Tensor] = None @@ -1270,7 +1272,6 @@ def _compute_prefill_context( for i in range(iters): toks = prefill_metadata.context_chunk_seq_tot[i] - max_seq_len = prefill_metadata.context_chunk_max_seq_lens[i] ops.gather_cache( src_cache=kv_c_and_k_pe_cache, @@ -1281,12 +1282,12 @@ def _compute_prefill_context( seq_starts=prefill_metadata.context_chunk_starts[i], ) - k_c_normed = workspace[:toks]\ + kv_c_normed = workspace[:toks]\ [..., :self.kv_lora_rank].unsqueeze(1) k_pe = workspace[:toks]\ [..., self.kv_lora_rank:].unsqueeze(1) - kv_nope = self.kv_b_proj(k_c_normed)[0].view( \ + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv_nope\ .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) @@ -1307,7 +1308,7 @@ def _compute_prefill_context( cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=max_seq_len, + max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i], softmax_scale=self.scale, causal=False, # Context is unmasked return_softmax_lse=True, @@ -1318,8 +1319,8 @@ def _compute_prefill_context( output = attn_output output_lse = attn_softmax_lse else: - output_tmp = torch.ones_like(output) - output_lse_tmp = torch.ones_like(output_lse) + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) merge_attn_states( output=output_tmp, output_lse=output_lse_tmp, @@ -1389,6 +1390,7 @@ def _forward_prefill( suffix_lse=suffix_lse, ) + # slice by `:v.shape[-1]` in order to remove v headdim padding output = output\ .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ .reshape(-1, self.num_heads * v.shape[-1]) From 3c800bb02ab5a639e5bd17beb5cf63dec7b8bf4a Mon Sep 17 00:00:00 2001 From: simon-mo Date: Fri, 14 Feb 2025 13:57:09 -0800 Subject: [PATCH 28/38] add comment Signed-off-by: Lucas Wilkinson --- vllm/attention/ops/triton_merge_attn_states.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/attention/ops/triton_merge_attn_states.py b/vllm/attention/ops/triton_merge_attn_states.py index 0cf9a4bb818c..31545b607fec 100644 --- a/vllm/attention/ops/triton_merge_attn_states.py +++ b/vllm/attention/ops/triton_merge_attn_states.py @@ -6,6 +6,8 @@ import triton.language as tl +# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 +# can be used to combine partial attention results (in the split-KV case) def merge_attn_states( output: torch.Tensor, prefix_output: torch.Tensor, From a79ee4ca8a8642fde64e6c4e0db01beb0816111c Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 14 Feb 2025 16:44:54 -0800 Subject: [PATCH 29/38] fix assert Signed-off-by: Lucas Wilkinson --- .../layers/quantization/utils/fp8_utils.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 9895537c219a..9e4875e3168a 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -162,6 +162,9 @@ def _per_token_group_quant_fp8( y_q_ptr, y_s_ptr, group_size, + # Num columns of y + y_num_columns, + y_row_stride, # Avoid to divide zero eps, # Information for float8 @@ -174,9 +177,14 @@ def _per_token_group_quant_fp8( quantization on a tensor. This function converts the tensor values into float8 values. """ + groups_per_row = y_num_columns // group_size + # Map the program id to the row of X and Y it should compute. g_id = tl.program_id(0) - y_ptr += g_id * group_size + row = g_id // groups_per_row + row_g_id = g_id % groups_per_row + + y_ptr += (row * y_row_stride) + (row_g_id * group_size) y_q_ptr += g_id * group_size y_s_ptr += g_id @@ -202,6 +210,7 @@ def _per_token_group_quant_fp8_colmajor( group_size, # Num columns of y y_num_columns, + y_row_stride, # Stride from one column to the next of y_s y_s_col_stride, # Avoid to divide zero @@ -216,9 +225,14 @@ def _per_token_group_quant_fp8_colmajor( quantization on a tensor. This function converts the tensor values into float8 values. """ + groups_per_row = y_num_columns // group_size + # Map the program id to the row of X and Y it should compute. g_id = tl.program_id(0) - y_ptr += g_id * group_size + row = g_id // groups_per_row + row_g_id = g_id % groups_per_row + + y_ptr += (row * y_row_stride) + (row_g_id * group_size) y_q_ptr += g_id * group_size # Convert g_id the flattened block coordinate to 2D so we can index @@ -267,7 +281,7 @@ def per_token_group_quant_fp8( assert (x.shape[-1] % group_size == 0), ( f"the last dimension of `x` {x.shape[-1]} must be divisible " f"by `group_size` {group_size}") - assert x.is_contiguous(), "`x` must be contiguous" + assert x.stride(-1) == 1, "`x` groups must be contiguous" finfo = torch.finfo(dtype) fp8_min = finfo.min @@ -295,6 +309,7 @@ def per_token_group_quant_fp8( x_s, group_size, x.shape[1], + x.stride(0), x_s.stride(1), eps, fp8_min=fp8_min, @@ -309,6 +324,8 @@ def per_token_group_quant_fp8( x_q, x_s, group_size, + x.shape[1], + x.stride(0), eps, fp8_min=fp8_min, fp8_max=fp8_max, From 1c595972ef844b97538cd8f93ad38c3904287cf0 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 14 Feb 2025 17:13:34 -0800 Subject: [PATCH 30/38] extra workspace allocation during profile run Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/utils.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index 65a7cdb449e9..3628c4287ade 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -509,12 +509,18 @@ class MLACommonMetadata(AttentionMetadata): # The dimension of the attention heads head_dim: Optional[int] = None + # Used when chunked prefill is enabled to simulate worst case workspace + # allocations, hopefully to avoid going OOM + is_profile_run: bool = False + # New for MLA (compared to FlashAttention) # For chunked prefill context_chunk_cu_seq_lens: Optional[torch.Tensor] = None context_chunk_starts: Optional[torch.Tensor] = None context_chunk_seq_tot: Optional[List[int]] = None context_chunk_max_seq_lens: Optional[List[int]] = None + # Set by MLAAttentionState in `begin_forward` so it doesnt get broadcasted + chunked_prefill_workspace: Optional[torch.Tensor] = None def __post_init__(self): supported_head_sizes = MLACommonBackend.get_supported_head_sizes() @@ -577,6 +583,7 @@ def prefill_metadata(self) -> Optional["MLACommonMetadata"]: context_lens_tensor=context_lens_tensor, block_tables=block_tables, head_dim=self.head_dim, + is_profile_run=self.is_profile_run, # MLACommonMetadata Chunk prefill specific context_chunk_cu_seq_lens=self.context_chunk_cu_seq_lens, context_chunk_starts=self.context_chunk_starts, @@ -633,7 +640,8 @@ def decode_metadata(self) -> Optional["MLACommonMetadata"]: context_lens_tensor=None, block_tables=block_tables, input_positions=input_positions, - head_dim=self.head_dim) + head_dim=self.head_dim, + is_profile_run=self.is_profile_run) return self._cached_decode_metadata def advance_step(self, @@ -966,6 +974,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], context_lens_tensor=context_lens_tensor, block_tables=block_tables, head_dim=self.runner.model_config.get_head_size(), + is_profile_run=self.runner.in_profile_run, # MLACommonMetadata Chunk prefill specific context_chunk_cu_seq_lens=context_chunk_cu_seq_lens, context_chunk_starts=context_chunk_starts, @@ -1421,6 +1430,18 @@ def forward( raise NotImplementedError( "output is not yet supported for MLAImplBase") + if attn_metadata.is_profile_run and \ + attn_metadata.chunked_prefill_workspace is not None: + # During the profile run try to simulate to worse case output size + # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` + # since this can be large + _ = torch.empty( + (attn_metadata.chunked_prefill_workspace.shape[0], + self.num_heads, self.qk_nope_head_dim + self.v_head_dim), + device=k_c_normed.device, + dtype=k_c_normed.dtype, + ) + has_decode = attn_metadata.decode_metadata is not None has_prefill = attn_metadata.prefill_metadata is not None From 1137f76f7ac30a429db00fd56435a65a6ef8920f Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 14 Feb 2025 17:16:29 -0800 Subject: [PATCH 31/38] rename Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/{utils.py => common.py} | 0 vllm/attention/backends/triton_mla.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename vllm/attention/backends/mla/{utils.py => common.py} (100%) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/common.py similarity index 100% rename from vllm/attention/backends/mla/utils.py rename to vllm/attention/backends/mla/common.py diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index edecda2ff897..d2358287c39e 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -5,7 +5,7 @@ import torch from vllm.attention.backends.abstract import AttentionType -from vllm.attention.backends.mla.utils import (MLACommonBackend, MLACommonImpl, +from vllm.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, MLACommonMetadata) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd From 920ecc69ab68849eeed14204a8a6fa88179b684c Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 17 Feb 2025 22:31:15 +0000 Subject: [PATCH 32/38] fix illegal memory access Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 3628c4287ade..e13514cbe187 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1375,10 +1375,10 @@ def _forward_prefill( q=q, k=k, v=v_padded, - cu_seqlens_q=attn_metadata.query_start_loc, - cu_seqlens_k=attn_metadata.query_start_loc, - max_seqlen_q=attn_metadata.max_prefill_seq_len, - max_seqlen_k=attn_metadata.max_prefill_seq_len, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, softmax_scale=self.scale, causal=True, return_softmax_lse=has_context, From b6655759337aba7460f32466f5ac4c7425009f2c Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 18 Feb 2025 02:38:42 +0000 Subject: [PATCH 33/38] format Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/common.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index e2fd5538c0b0..65283a09ca81 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -241,7 +241,6 @@ # For rocm use upstream flash attention from flash_attn import flash_attn_varlen_func - if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) @@ -516,8 +515,8 @@ class MLACommonMetadata(AttentionMetadata): # The dimension of the attention heads head_dim: Optional[int] = None - # Used when chunked prefill is enabled to simulate worst case workspace - # allocations, hopefully to avoid going OOM + # Used when chunked prefill is enabled to simulate worst case workspace + # allocations, hopefully to avoid going OOM is_profile_run: bool = False # New for MLA (compared to FlashAttention) @@ -526,7 +525,7 @@ class MLACommonMetadata(AttentionMetadata): context_chunk_starts: Optional[torch.Tensor] = None context_chunk_seq_tot: Optional[List[int]] = None context_chunk_max_seq_lens: Optional[List[int]] = None - # Set by MLAAttentionState in `begin_forward` so it doesnt get broadcasted + # Set by MLAAttentionState in `begin_forward` so it doesn't get broadcasted chunked_prefill_workspace: Optional[torch.Tensor] = None def __post_init__(self): @@ -1432,7 +1431,7 @@ def forward( # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` # since this can be large _ = torch.empty( - (attn_metadata.chunked_prefill_workspace.shape[0], + (attn_metadata.chunked_prefill_workspace.shape[0], self.num_heads, self.qk_nope_head_dim + self.v_head_dim), device=k_c_normed.device, dtype=k_c_normed.dtype, From 3a0ae51c829f225b11ed52a5a2471cc5894d6124 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 18 Feb 2025 04:19:19 +0000 Subject: [PATCH 34/38] format Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/triton_mla.py | 5 +++-- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index d2358287c39e..0c349c4239c2 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -5,8 +5,9 @@ import torch from vllm.attention.backends.abstract import AttentionType -from vllm.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, - MLACommonMetadata) +from vllm.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 9e4875e3168a..891edf23010c 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -184,7 +184,7 @@ def _per_token_group_quant_fp8( row = g_id // groups_per_row row_g_id = g_id % groups_per_row - y_ptr += (row * y_row_stride) + (row_g_id * group_size) + y_ptr += (row * y_row_stride) + (row_g_id * group_size) y_q_ptr += g_id * group_size y_s_ptr += g_id @@ -232,7 +232,7 @@ def _per_token_group_quant_fp8_colmajor( row = g_id // groups_per_row row_g_id = g_id % groups_per_row - y_ptr += (row * y_row_stride) + (row_g_id * group_size) + y_ptr += (row * y_row_stride) + (row_g_id * group_size) y_q_ptr += g_id * group_size # Convert g_id the flattened block coordinate to 2D so we can index From 28464b50603e001934432e1d3670351c81f971de Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 18 Feb 2025 05:58:15 +0000 Subject: [PATCH 35/38] mypy pass Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/common.py | 18 +++++++++++------- vllm/attention/backends/triton_mla.py | 19 +++++++------------ 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 65283a09ca81..c3dbbdb86823 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -199,7 +199,7 @@ from dataclasses import dataclass from itertools import accumulate from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, - Type) + Type, TypeVar) import torch from compressed_tensors.quantization import QuantizationStrategy @@ -209,8 +209,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, AttentionMetadataBuilder, - AttentionState, MLAAttentionImpl, - T) + AttentionState, MLAAttentionImpl) from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, compute_slot_mapping_start_idx, get_flash_attn_version, @@ -723,6 +722,9 @@ def advance_step(self, block_tables=self.block_tables) +T = TypeVar("T", bound=MLACommonMetadata) + + class MLACommonMetadataBuilder(AttentionMetadataBuilder[MLACommonMetadata]): """ NOTE: Please read the comment at the top of the file before trying to @@ -1268,12 +1270,15 @@ def _compute_prefill_context( assert prefill_metadata.context_chunk_cu_seq_lens is not None assert prefill_metadata.context_chunk_starts is not None assert prefill_metadata.context_chunk_max_seq_lens is not None - # assert prefill_metadata.block_tables is not None assert prefill_metadata.context_lens_tensor is not None output = None iters = len(prefill_metadata.context_chunk_seq_tot) - assert hasattr(attn_metadata, "chunked_prefill_workspace") + + # Fetch from attn_metadata directly, since it late bound by + # MLAAttentionState, grabbing it directly `attn_metadata` can avoid + # any weirdness around prefill_metadata caching + assert attn_metadata.chunked_prefill_workspace is not None workspace = attn_metadata.chunked_prefill_workspace for i in range(iters): @@ -1345,9 +1350,8 @@ def _forward_prefill( kv_c_normed: torch.Tensor, k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: T, + attn_metadata: MLACommonMetadata, ) -> torch.Tensor: - assert isinstance(attn_metadata, MLACommonMetadata) prefill_metadata = attn_metadata.prefill_metadata assert prefill_metadata is not None diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index 0c349c4239c2..08e8226ab04c 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -80,12 +80,14 @@ def _forward_decode( dtype=q.dtype, device=q.device) + num_kv_splits = 4 # TODO: heuristic + # TODO(lucas) Allocate ahead of time attn_logits = torch.empty( ( B, self.num_heads, - 4, #attn_metadata.num_kv_splits, + num_kv_splits, # NOTE(lucas) idk why the +1 is here but sglang has it so we # just mirror that self.kv_lora_rank + 1, @@ -100,16 +102,9 @@ def _forward_decode( PAGE_SIZE = kv_c_and_k_pe_cache.size(1) # Run MQA - decode_attention_fwd( - q, - kv_c_and_k_pe_cache, - kv_c_cache, - o, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, - attn_logits, - 4, - self.scale, #attn_metadata.num_kv_splits - PAGE_SIZE) + decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, attn_logits, + num_kv_splits, self.scale, PAGE_SIZE) return self._v_up_proj_and_o_proj(o) From dfb3ada417bb9a213dc4efecc707deb0e9533df4 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 19 Feb 2025 20:36:54 +0000 Subject: [PATCH 36/38] fix basic model test Signed-off-by: Lucas Wilkinson --- vllm/engine/arg_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 46ab997699e4..1afcbdb74c1c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1169,7 +1169,7 @@ def create_engine_config(self, # For multimodal models and models with MLA, chunked prefill is # disabled by default in V0, but enabled by design in V1 - if model_config.is_multimodal_model and model_config.use_mla: + if model_config.is_multimodal_model or model_config.use_mla: self.enable_chunked_prefill = bool(envs.VLLM_USE_V1) elif use_long_context: From 9ca182be4880fba64ec331672e81c0fdf590c079 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 19 Feb 2025 21:51:14 +0000 Subject: [PATCH 37/38] attempt to fix AMD build Signed-off-by: Lucas Wilkinson --- csrc/cuda_utils.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index 19e8228510c2..bea3165ae83c 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -3,9 +3,9 @@ #include #if defined(__CUDACC__) || defined(_NVHPC_CUDA) - #define HOST_DEVICE_INLINE __forceinline__ __host__ __device__ - #define DEVICE_INLINE __forceinline__ __device__ - #define HOST_INLINE __forceinline__ __host__ + #define HOST_DEVICE_INLINE __host__ __device__ __forceinline__ + #define DEVICE_INLINE __device__ __forceinline__ + #define HOST_INLINE __host__ __forceinline__ #else #define HOST_DEVICE_INLINE inline #define DEVICE_INLINE inline From d325935124626d0e6839dc767baf96c5cccbfe92 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 19 Feb 2025 22:00:16 +0000 Subject: [PATCH 38/38] attempt 2 fix amd build Signed-off-by: Lucas Wilkinson --- csrc/cuda_utils.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index bea3165ae83c..6e62ea208db8 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -2,7 +2,11 @@ #include -#if defined(__CUDACC__) || defined(_NVHPC_CUDA) +#if defined(__HIPCC__) + #define HOST_DEVICE_INLINE __host__ __device__ + #define DEVICE_INLINE __device__ + #define HOST_INLINE __host__ +#elif defined(__CUDACC__) || defined(_NVHPC_CUDA) #define HOST_DEVICE_INLINE __host__ __device__ __forceinline__ #define DEVICE_INLINE __device__ __forceinline__ #define HOST_INLINE __host__ __forceinline__