From f782c66432771f4eda255f544034382faf46fba6 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 28 Mar 2025 08:19:50 +0000 Subject: [PATCH 01/39] add AITER MLA implementation in attention backend Co-authored-by: qli88 Signed-off-by: vllmellm --- vllm/attention/backends/mla/common.py | 46 +- vllm/attention/backends/rocm_aiter_mla.py | 580 ++++++++++++++++++++++ vllm/attention/ops/rocm_aiter_mla.py | 44 ++ vllm/engine/arg_utils.py | 2 +- vllm/envs.py | 7 + vllm/platforms/rocm.py | 8 +- vllm/worker/model_runner.py | 6 + 7 files changed, 676 insertions(+), 17 deletions(-) create mode 100644 vllm/attention/backends/rocm_aiter_mla.py create mode 100644 vllm/attention/ops/rocm_aiter_mla.py diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 8d70afe282d6..525e56e1747f 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -650,16 +650,11 @@ def decode_metadata(self): is_profile_run=self.is_profile_run) return self._cached_decode_metadata - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ + def advance_step_assertions( + self, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False) -> None: # When using cudagraph, the num_seqs is padded to the next captured # batch sized, but num_queries tracks the actual number of requests in # the batch. For --enforce-eager mode, num_seqs == num_queries @@ -712,6 +707,22 @@ def advance_step(self, self.seq_lens[i] += 1 self.max_decode_seq_len = max(self.seq_lens) + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + self.advance_step_assertions( + num_seqs=num_seqs, + num_queries=num_queries, + turn_prefills_into_decodes=turn_prefills_into_decodes) + + # here we use advance_step_flashinfo to update the paged_kv_* tensors ops.advance_step_flashattn(num_seqs=num_seqs, num_queries=num_queries, block_size=block_size, @@ -839,6 +850,15 @@ def _get_graph_runner_block_tables( return torch.from_numpy(graph_block_tables).to( device=self.runner.device, non_blocking=True) + def get_block_tables_with_captured_graph(self, cuda_graph_pad_size: int, + num_seqs: int) -> torch.Tensor: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables) + + return block_tables + def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): """Build attention metadata with on-device tensors. @@ -877,11 +897,9 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_seqs = len(seq_lens) if use_captured_graph: - self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend([] * cuda_graph_pad_size) num_decode_tokens = batch_size - self.num_prefill_tokens - block_tables = self._get_graph_runner_block_tables( - num_seqs, self.block_tables) + block_tables = self.get_block_tables_with_captured_graph( + cuda_graph_pad_size, num_seqs) else: block_tables = make_tensor_with_pad( self.block_tables, diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py new file mode 100644 index 000000000000..5620338c6e13 --- /dev/null +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -0,0 +1,580 @@ +# SPDX-License-Identifier: Apache-2.0 + +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Type + +import torch + +import vllm._custom_ops as ops +import vllm.envs as envs +from vllm.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, + MLACommonState) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd +from vllm.attention.ops.triton_merge_attn_states import merge_attn_states + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + + +def is_aiter_mla_enabled() -> bool: + return envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_ROCM_USE_AITER_MLA + + +class AiterMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "AITER_MLA" + + @staticmethod + def get_impl_cls() -> Type["AiterMLAImpl"]: + return AiterMLAImpl + + @staticmethod + def get_metadata_cls() -> Type["AiterMLAMetadata"]: + return AiterMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["AiterMLAMetadataBuilder"]: + return AiterMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["AiterMLAState"]: + return AiterMLAState + + +@dataclass +class AiterMLAMetadata(MLACommonMetadata): + # The following 4 tensors are for current version of AITER MLA + block_table_bound: Optional[torch.Tensor] = None + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_lens: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self): + prefill_metadata = super().prefill_metadata + self._cached_prefill_metadata = prefill_metadata + + if prefill_metadata is not None: + prefill_metadata.paged_kv_indptr = self.paged_kv_indptr + prefill_metadata.paged_kv_indices = self.paged_kv_indices + prefill_metadata\ + .paged_kv_last_page_lens = self.paged_kv_last_page_lens + prefill_metadata.block_table_bound = self.block_table_bound + + self._cached_prefill_metadata = self.__class__( + **prefill_metadata.__dict__) + + return self._cached_prefill_metadata + + @property + def decode_metadata(self): + decode_metadata = super().decode_metadata + + self._cached_decode_metadata = decode_metadata + + if decode_metadata is not None: + decode_metadata.paged_kv_indptr = self.paged_kv_indptr + decode_metadata.paged_kv_indices = self.paged_kv_indices + decode_metadata\ + .paged_kv_last_page_lens = self.paged_kv_last_page_lens + decode_metadata.block_table_bound = self.block_table_bound + + self._cached_decode_metadata = self.__class__( + **decode_metadata.__dict__) + + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + self.advance_step_assertions( + num_seqs=num_seqs, + num_queries=num_queries, + turn_prefills_into_decodes=turn_prefills_into_decodes) + + ops.advance_step_flashinfer( + num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables, + paged_kv_indices=self.paged_kv_indices, + paged_kv_indptr=self.paged_kv_indptr, + paged_kv_last_page_lens=self.paged_kv_last_page_lens, + block_table_bound=self.block_table_bound) + + +class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + super().__init__(input_builder) + assert self.block_size == 1, "AITER MLA requires only block size 1." + + def prepare(self): + super().prepare() + self.paged_kv_indices: list[int] = [] + self.paged_kv_indptr: list[int] = [0] + self.paged_kv_last_page_lens: list[int] = [] + self.total_blocks = 0 + + def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, + prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block, input_positions) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks, + inter_data.input_positions): + self.input_positions.extend(input_positions) + self.context_lens.append(context_len) + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + if is_profile_run: + return + + # Update paged_kv_* tensors only for non-profile run + block_table = block_tables[seq_id] + self._update_paged_kv_tensors(block_table, seq_len) + + def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int): + # Get the number of valid blocks based on sequence length. + # If seq_len = 16, block_size = 16, + # block_table_bound is 1 with 1 valid block. + # If seq_len = 15, block_size = 16, + # block_table_bound is 0 + 1 with 1 valid block. + self.total_blocks += len(block_table) + block_table_bound = seq_len // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else seq_len // self.block_size + self.paged_kv_indices.extend(block_table[:block_table_bound]) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + block_table_bound) + + last_page_len = seq_len % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + self.paged_kv_last_page_lens.append(last_page_len) + + def get_block_tables_with_captured_graph(self, cuda_graph_pad_size: int, + num_seqs: int) -> torch.Tensor: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([[]] * cuda_graph_pad_size) + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables) + + return block_tables + + def build(self, seq_lens: list[int], query_lens: list[int], + cuda_graph_pad_size: int, batch_size: int) -> AiterMLAMetadata: + metadata = super().build(seq_lens, query_lens, cuda_graph_pad_size, + batch_size) + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + if use_captured_graph: + last_paged_kv_indptr = self.paged_kv_indptr[-1] + self.paged_kv_indptr.extend([last_paged_kv_indptr] * + cuda_graph_pad_size) + self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size) + + # For current version of AITER MLA + if len(self.paged_kv_indptr) > 0: + # extend to the maximum number of blocks as returned by the + # scheduler + self.paged_kv_indices.extend( + [0] * (self.total_blocks - len(self.paged_kv_indices))) + paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, + device=device, + dtype=torch.int) + paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, + device=device, + dtype=torch.int) + paged_kv_last_page_lens_tensor = torch.tensor( + self.paged_kv_last_page_lens, device=device, dtype=torch.int) + block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - + 1, + device=device, + dtype=torch.int) + else: + paged_kv_indices_tensor = None + paged_kv_indptr_tensor = None + paged_kv_last_page_lens_tensor = None + block_table_bound_tensor = None + + metadata.paged_kv_indptr = paged_kv_indptr_tensor + metadata.paged_kv_indices = paged_kv_indices_tensor + metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor + metadata.block_table_bound = block_table_bound_tensor + + return metadata + + +class AiterMLAState(MLACommonState[AiterMLAMetadata]): + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._paged_kv_indices_tensor = torch.from_numpy( + self.runner.paged_kv_indices).to(device=self.runner.device) + self._paged_kv_indptr_tensor = torch.zeros(max_batch_size + 1, + dtype=torch.int32, + device=self.runner.device) + self._paged_kv_last_page_lens_tensor = torch.full( + (max_batch_size, ), self.runner.block_size, dtype=torch.int32) + with super().graph_capture(max_batch_size): + yield + + del self._paged_kv_indices_tensor + del self._paged_kv_indptr_tensor + del self._paged_kv_last_page_lens_tensor + + def graph_capture_get_metadata_for_batch( + self, + batch_size: int, + is_encoder_decoder_model: bool = False) -> AiterMLAMetadata: + + metadata = super().graph_capture_get_metadata_for_batch( + batch_size, is_encoder_decoder_model) + + paged_kv_indptr = self._paged_kv_indptr_tensor[:batch_size + 1] + paged_kv_indices = self._paged_kv_indices_tensor + paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[: + batch_size] + + metadata.paged_kv_indptr = paged_kv_indptr + metadata.paged_kv_indices = paged_kv_indices + metadata.paged_kv_last_page_lens = paged_kv_last_page_lens + + return metadata + + def get_graph_input_buffers(self, + attn_metadata: AiterMLAMetadata, + is_encoder_decoder_model: bool = False): + input_buffers = super().get_graph_input_buffers( + attn_metadata, is_encoder_decoder_model) + input_buffers[ + 'paged_kv_indptr'] = attn_metadata.decode_metadata.paged_kv_indptr + input_buffers[ + "paged_kv_indices"] = attn_metadata.\ + decode_metadata.paged_kv_indices + input_buffers[ + "paged_kv_last_page_lens"] = attn_metadata.\ + decode_metadata.paged_kv_last_page_lens + + return input_buffers + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata: AiterMLAMetadata, + is_encoder_decoder_model: bool = False): + super().prepare_graph_input_buffers(input_buffers, attn_metadata, + is_encoder_decoder_model) + + num_total_blocks = attn_metadata.decode_metadata.paged_kv_indices.shape[ + 0] + input_buffers["paged_kv_indptr"].copy_( + attn_metadata.decode_metadata.paged_kv_indptr, non_blocking=True) + input_buffers["paged_kv_indices"][:num_total_blocks].copy_( + attn_metadata.decode_metadata.paged_kv_indices, non_blocking=True) + input_buffers["paged_kv_last_page_lens"].copy_( + attn_metadata.decode_metadata.paged_kv_last_page_lens, + non_blocking=True) + + +class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + assert is_aiter_mla_enabled( + ), "Aiter MLA is initialized without being enabled properly." + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "Aiter MLA does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + from aiter import flash_attn_varlen_func + + self.flash_attn_varlen_func = flash_attn_varlen_func + + def _compute_prefill_context( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: AiterMLAMetadata, + ): + prefill_metadata = attn_metadata.prefill_metadata + assert prefill_metadata is not None + assert prefill_metadata.context_chunk_seq_tot is not None + assert prefill_metadata.context_chunk_cu_seq_lens is not None + assert prefill_metadata.context_chunk_starts is not None + assert prefill_metadata.context_chunk_max_seq_lens is not None + assert prefill_metadata.context_lens_tensor is not None + + output = None + iters = len(prefill_metadata.context_chunk_seq_tot) + + # Fetch from attn_metadata directly, since it late bound by + # MLAAttentionState, grabbing it directly `attn_metadata` can avoid + # any weirdness around prefill_metadata caching + assert attn_metadata.context_chunk_workspace is not None + workspace = attn_metadata.context_chunk_workspace + + for i in range(iters): + toks = prefill_metadata.context_chunk_seq_tot[i] + + ops.gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_tables, + cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i], + batch_size=prefill_metadata.num_prefills, + seq_starts=prefill_metadata.context_chunk_starts[i], + ) + + kv_c_normed = workspace[:toks]\ + [..., :self.kv_lora_rank] + k_pe = workspace[:toks]\ + [..., self.kv_lora_rank:].unsqueeze(1) + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), + dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad + # out v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, + [0, q.shape[-1] - v.shape[-1]], + value=0) + attn_output, attn_softmax_lse = self.flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_lse=True, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ) -> torch.Tensor: + + prefill_metadata = attn_metadata.prefill_metadata + assert prefill_metadata is not None + + has_context = prefill_metadata.context_lens_tensor is not None \ + and prefill_metadata.context_lens_tensor.max() > 0 + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) + + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + return_lse=has_context, + ) + + if not has_context: + output = output[0] + + if has_context: + suffix_output, suffix_lse = output + context_output, context_lse = self._compute_prefill_context( \ + q, kv_c_and_k_pe_cache, attn_metadata) + + output = torch.empty_like(suffix_output) + merge_attn_states( + output=output, + prefix_output=context_output, + prefix_lse=context_lse, + suffix_output=suffix_output, + suffix_lse=suffix_lse, + ) + + # slice by `:v.shape[-1]` in order to remove v headdim padding + output = output\ + .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ + .reshape(-1, self.num_heads * v.shape[-1]) + + return self.o_proj(output)[0] + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: AiterMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + B = q_nope.shape[0] + + q = torch.cat([q_nope, q_pe], dim=-1) + o = torch.zeros(B, + self.num_heads, + self.kv_lora_rank, + dtype=q.dtype, + device=q.device) + + num_kv_splits = 4 # TODO: heuristic + + # TODO(lucas) Allocate ahead of time + attn_logits = torch.empty( + ( + B, + self.num_heads, + num_kv_splits, + # NOTE(lucas) idk why the +1 is here but sglang has it so we + # just mirror that + self.kv_lora_rank + 1, + ), + dtype=torch.float32, + device=q.device, + ) + + # Add a head dim of 1 + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) + + aiter_mla_decode_fwd(q, kv_c_and_k_pe_cache, o, attn_logits, + num_kv_splits, self.scale, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_lens) + + return self._v_up_proj_and_o_proj(o) diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py new file mode 100644 index 000000000000..57ead453839e --- /dev/null +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch + + +def get_aiter_mla_metadata(max_batch_size: int, block_size: int, + device: torch.device) -> tuple[torch.Tensor, ...]: + + paged_kv_indptr_tensor = torch.zeros(max_batch_size + 1, + dtype=torch.int32, + device=device) + paged_kv_last_page_lens_tensor = torch.full((max_batch_size, ), + block_size, + dtype=torch.int32) + return paged_kv_indptr_tensor, paged_kv_last_page_lens_tensor + + +def aiter_mla_decode_fwd( + q, + k_buffer, + o, + attn_logits, + num_kv_splits, + sm_scale, + kv_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_last_page_lens: Optional[torch.Tensor] = None, + logit_cap: float = 0.0, +): + assert num_kv_splits == attn_logits.shape[2] + + from aiter.mla import mla_decode_fwd + + mla_decode_fwd(q, + k_buffer.view(-1, 1, 1, q.shape[-1]), + o, + kv_indptr, + kv_indices, + kv_last_page_lens, + sm_scale=sm_scale, + logit_cap=logit_cap, + num_kv_splits=num_kv_splits) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 75ac326aaa3d..7be40b4c1dc9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -461,7 +461,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument('--block-size', type=int, default=EngineArgs.block_size, - choices=[8, 16, 32, 64, 128], + choices=[1, 8, 16, 32, 64, 128], help='Token block size for contiguous chunks of ' 'tokens. This is ignored on neuron devices and ' 'set to ``--max-model-len``. On CUDA devices, ' diff --git a/vllm/envs.py b/vllm/envs.py index 4c413006a641..6505eab0a35a 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -76,6 +76,7 @@ VLLM_ROCM_USE_AITER_MOE: bool = True VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False VLLM_ROCM_USE_AITER_RMSNORM: bool = True + VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True @@ -533,6 +534,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in ("true", "1")), + # Whether to use aiter mla ops. + # By default is enabled. + "VLLM_ROCM_USE_AITER_MLA": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in + ("true", "1")), + # Pad the fp8 weights to 256 bytes for ROCm "VLLM_ROCM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index d196e24ac7ac..894017851a82 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -116,8 +116,12 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla) -> str: if use_mla: - logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" + if envs.VLLM_ROCM_USE_AITER_MLA and envs.VLLM_ROCM_USE_AITER: + logger.info("Using AITER MLA backend.") + return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" + else: + logger.info("Using Triton MLA backend.") + return "vllm.attention.backends.triton_mla.TritonMLABackend" selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) if envs.VLLM_USE_V1: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index edbafb48c938..2c910c6c9049 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1049,6 +1049,12 @@ def __init__( (self.max_batchsize_to_capture, self.get_max_block_per_batch()), dtype=np.int32) + # paged_kv_indices: flattened graph_block_tables + # used by AITER MLA + self.paged_kv_indices = np.zeros(self.max_batchsize_to_capture * + self.get_max_block_per_batch(), + dtype=np.int32) + # Attention-free but stateful models like Mamba need a placeholder attn # backend, as the attention metadata is needed to manage internal state. # However we must bypass attention selection altogether for some models From 42d5c620213808c16389dd53341c3a31a1242375 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 28 Mar 2025 10:40:27 +0000 Subject: [PATCH 02/39] remove unused arguments in aiter mla decode fwd kernel Co-authored-by: qli88 Signed-off-by: vllmellm --- vllm/attention/backends/rocm_aiter_mla.py | 22 ++-------------------- vllm/attention/ops/rocm_aiter_mla.py | 11 +++-------- 2 files changed, 5 insertions(+), 28 deletions(-) diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index 5620338c6e13..4e9c1f15663c 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -552,27 +552,9 @@ def _forward_decode( dtype=q.dtype, device=q.device) - num_kv_splits = 4 # TODO: heuristic - - # TODO(lucas) Allocate ahead of time - attn_logits = torch.empty( - ( - B, - self.num_heads, - num_kv_splits, - # NOTE(lucas) idk why the +1 is here but sglang has it so we - # just mirror that - self.kv_lora_rank + 1, - ), - dtype=torch.float32, - device=q.device, - ) - - # Add a head dim of 1 - kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) + kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) - aiter_mla_decode_fwd(q, kv_c_and_k_pe_cache, o, attn_logits, - num_kv_splits, self.scale, + aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, attn_metadata.paged_kv_indptr, attn_metadata.paged_kv_indices, attn_metadata.paged_kv_last_page_lens) diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index 57ead453839e..f27bfe96a811 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -19,26 +19,21 @@ def get_aiter_mla_metadata(max_batch_size: int, block_size: int, def aiter_mla_decode_fwd( q, - k_buffer, + kv_buffer, o, - attn_logits, - num_kv_splits, sm_scale, kv_indptr: Optional[torch.Tensor] = None, kv_indices: Optional[torch.Tensor] = None, kv_last_page_lens: Optional[torch.Tensor] = None, logit_cap: float = 0.0, ): - assert num_kv_splits == attn_logits.shape[2] - from aiter.mla import mla_decode_fwd mla_decode_fwd(q, - k_buffer.view(-1, 1, 1, q.shape[-1]), + kv_buffer.view(-1, 1, 1, q.shape[-1]), o, kv_indptr, kv_indices, kv_last_page_lens, sm_scale=sm_scale, - logit_cap=logit_cap, - num_kv_splits=num_kv_splits) + logit_cap=logit_cap) From 565a3fdad5afe68337e4acdd3237f766b65db2ac Mon Sep 17 00:00:00 2001 From: vllmellm Date: Sat, 29 Mar 2025 09:20:02 +0000 Subject: [PATCH 03/39] add unittest for AITER MLA backend in attention selector Signed-off-by: vllmellm --- tests/kernels/test_attention_selector.py | 63 ++++++++++++++++++----- vllm/attention/backends/rocm_aiter_mla.py | 4 +- vllm/platforms/interface.py | 1 + vllm/platforms/rocm.py | 27 ++++++++-- 4 files changed, 73 insertions(+), 22 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index a51e70d45ee0..cf5169bb706b 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -19,13 +19,22 @@ def clear_cache(): _cached_get_attn_backend.cache_clear() -@pytest.mark.parametrize( - "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"]) +CUDA_SUPPORTED_BACKENDS = ["XFORMERS", "FLASHINFER"] +HIP_SUPPORTED_BACKENDS = ["ROCM_FLASH", "TRITON_MLA", "ROCM_AITER_MLA"] +CPU_SUPPORTED_BACKENDS = ["TORCH_SDPA"] + + +@pytest.mark.parametrize("name", [ + "TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "TRITON_MLA", + "ROCM_AITER_MLA" +]) @pytest.mark.parametrize("use_v1", [True, False]) +@pytest.mark.parametrize("use_mla", [True, False]) @pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) def test_env( name: str, use_v1: bool, + use_mla: bool, device: str, monkeypatch: pytest.MonkeyPatch, ): @@ -36,22 +45,48 @@ def test_env( with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") m.setenv(STR_BACKEND_ENV_VAR, name) + m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") if device == "cpu": - with patch("vllm.attention.selector.current_platform", - CpuPlatform()): - backend = get_attn_backend(16, torch.float16, torch.float16, - 16, False) - assert backend.get_name() == "TORCH_SDPA" + if name in CPU_SUPPORTED_BACKENDS: + with patch("vllm.attention.selector.current_platform", + CpuPlatform()): + backend = get_attn_backend(16, torch.float16, + torch.float16, 16, False) + assert backend.get_name() == "TORCH_SDPA" elif device == "hip": - with patch("vllm.attention.selector.current_platform", - RocmPlatform()): - backend = get_attn_backend(16, torch.float16, torch.float16, - 16, False) - EXPECTED = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH" - assert backend.get_name() == EXPECTED + if name in HIP_SUPPORTED_BACKENDS: + with patch("vllm.attention.selector.current_platform", + RocmPlatform()): + backend_1 = get_attn_backend(16, + torch.float16, + torch.float16, + 16, + False, + use_mla=use_mla) + backend_2 = get_attn_backend(16, + torch.float16, + torch.float16, + 1, + False, + use_mla=use_mla) + if use_mla: + # Only two MLA backends TRITON_MLA and ROCM_AITER_MLA. + # 1st case always selects TRITON_MLA. + EXPECTED_1 = "TRITON_MLA" + # 2nd case always selects ROCM_AITER_MLA unless + # TRITON_MLA is specified as backend. + EXPECTED_2 = ("ROCM_AITER_MLA" + if name != "TRITON_MLA" else "TRITON_MLA") + else: + EXPECTED_1 = ("TRITON_ATTN_VLLM_V1" + if use_v1 else "ROCM_FLASH") + EXPECTED_2 = EXPECTED_1 + + assert backend_1.get_name() == EXPECTED_1 + assert backend_2.get_name() == EXPECTED_2 else: - if name in ["XFORMERS", "FLASHINFER"]: + if name in CUDA_SUPPORTED_BACKENDS: with patch("vllm.attention.selector.current_platform", CudaPlatform()): backend = get_attn_backend(16, torch.float16, diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index 4e9c1f15663c..ae9a40d70123 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -33,7 +33,7 @@ class AiterMLABackend(MLACommonBackend): @staticmethod def get_name() -> str: - return "AITER_MLA" + return "ROCM_AITER_MLA" @staticmethod def get_impl_cls() -> Type["AiterMLAImpl"]: @@ -367,8 +367,6 @@ def __init__( alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap, attn_type, **mla_args) - assert is_aiter_mla_enabled( - ), "Aiter MLA is initialized without being enabled properly." unsupported_features = [ alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 36db70681a19..122582432ff4 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -33,6 +33,7 @@ class _Backend(enum.Enum): TRITON_ATTN_VLLM_V1 = enum.auto() XFORMERS = enum.auto() ROCM_FLASH = enum.auto() + ROCM_AITER_MLA = enum.auto() TORCH_SDPA = enum.auto() FLASHINFER = enum.auto() TRITON_MLA = enum.auto() # Supported by V1 diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 894017851a82..e2e02c6bbdfe 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -116,12 +116,29 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla) -> str: if use_mla: - if envs.VLLM_ROCM_USE_AITER_MLA and envs.VLLM_ROCM_USE_AITER: - logger.info("Using AITER MLA backend.") - return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" - else: + from vllm.attention.backends.rocm_aiter_mla import ( + is_aiter_mla_enabled) + + if selected_backend is None: + selected_backend = (_Backend.ROCM_AITER_MLA + if is_aiter_mla_enabled() else + _Backend.TRITON_MLA) + + if selected_backend == _Backend.TRITON_MLA: logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" + return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501 + + else: + if block_size == 1: + logger.info("Using AITER MLA backend.") + return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501 + else: + logger.warning( + "AITER MLA backend is not supported for block size %d." + "(currently only supports block size 1)", block_size) + logger.info("Falling back to use Triton MLA backend") + return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501 + selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) if envs.VLLM_USE_V1: From 645f4009a4dfd1112113105253dc0c59ba9661bb Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 1 Apr 2025 10:46:43 +0000 Subject: [PATCH 04/39] add unittest for MLA attention backend selector Signed-off-by: vllmellm --- tests/kernels/test_attention_selector.py | 190 ++++++++++++++++------- vllm/platforms/rocm.py | 25 +-- 2 files changed, 145 insertions(+), 70 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index cf5169bb706b..2e7d10156377 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -19,80 +19,152 @@ def clear_cache(): _cached_get_attn_backend.cache_clear() -CUDA_SUPPORTED_BACKENDS = ["XFORMERS", "FLASHINFER"] -HIP_SUPPORTED_BACKENDS = ["ROCM_FLASH", "TRITON_MLA", "ROCM_AITER_MLA"] -CPU_SUPPORTED_BACKENDS = ["TORCH_SDPA"] - - -@pytest.mark.parametrize("name", [ - "TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "TRITON_MLA", - "ROCM_AITER_MLA" -]) +# Define MLA and non-MLA backends separately +DEVICE_MLA_BACKENDS = { + "cuda": ["TRITON_MLA", "FLASHMLA"], + "hip": ["TRITON_MLA", "ROCM_AITER_MLA"], + "cpu": [], +} + +DEVICE_NON_MLA_BACKENDS = { + "cuda": ["XFORMERS", "FLASHINFER"], + "hip": ["ROCM_FLASH"], + "cpu": ["TORCH_SDPA"], +} + +DEVICE_MLA_BLOCK_SIZES = { + "cuda": [16, 64], # CUDA supports both standard and extended block sizes + "hip": [16, 1], # HIP requires special handling for block_size=1 + "cpu": [16] # CPU uses fixed block size from test cases +} + + +def generate_params(): + params = [] + for use_mla in [True, False]: + for device in ["cuda", "hip", "cpu"]: + backends = DEVICE_MLA_BACKENDS[ + device] if use_mla else DEVICE_NON_MLA_BACKENDS[device] + for name in backends: + block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [ + 16 + ] + for block_size in block_sizes: + params.append( + pytest.param( + device, + name, + use_mla, + block_size, + id= + f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}" + )) + return params + + +@pytest.mark.parametrize("device, name, use_mla, block_size", + generate_params()) @pytest.mark.parametrize("use_v1", [True, False]) -@pytest.mark.parametrize("use_mla", [True, False]) -@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) def test_env( + device: str, name: str, - use_v1: bool, use_mla: bool, - device: str, + block_size: int, + use_v1: bool, monkeypatch: pytest.MonkeyPatch, ): - """Test that the attention selector can be set via environment variable. - Note that we do not test FlashAttn because it is the default backend. - """ - + """Test attention backend selection with valid device-backend pairs.""" with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") m.setenv(STR_BACKEND_ENV_VAR, name) m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") if device == "cpu": - if name in CPU_SUPPORTED_BACKENDS: - with patch("vllm.attention.selector.current_platform", - CpuPlatform()): - backend = get_attn_backend(16, torch.float16, - torch.float16, 16, False) - assert backend.get_name() == "TORCH_SDPA" + with patch("vllm.attention.selector.current_platform", + CpuPlatform()): + backend = get_attn_backend(16, torch.float16, torch.float16, + block_size, False) + assert backend.get_name() == "TORCH_SDPA" + elif device == "hip": - if name in HIP_SUPPORTED_BACKENDS: - with patch("vllm.attention.selector.current_platform", - RocmPlatform()): - backend_1 = get_attn_backend(16, - torch.float16, - torch.float16, - 16, - False, - use_mla=use_mla) - backend_2 = get_attn_backend(16, - torch.float16, - torch.float16, - 1, - False, - use_mla=use_mla) + with patch("vllm.attention.selector.current_platform", + RocmPlatform()): if use_mla: - # Only two MLA backends TRITON_MLA and ROCM_AITER_MLA. - # 1st case always selects TRITON_MLA. - EXPECTED_1 = "TRITON_MLA" - # 2nd case always selects ROCM_AITER_MLA unless - # TRITON_MLA is specified as backend. - EXPECTED_2 = ("ROCM_AITER_MLA" - if name != "TRITON_MLA" else "TRITON_MLA") + # Validate HIP MLA backend-block_size combinations + valid_combination = ( + (name == "TRITON_MLA" and block_size != 1) + or (name == "ROCM_AITER_MLA" and block_size == 1)) + + if valid_combination: + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + assert backend.get_name() == name + else: + with pytest.raises(ValueError) as exc_info: + get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + assert f"The selected backend, {name}" in str( + exc_info.value) else: - EXPECTED_1 = ("TRITON_ATTN_VLLM_V1" - if use_v1 else "ROCM_FLASH") - EXPECTED_2 = EXPECTED_1 - - assert backend_1.get_name() == EXPECTED_1 - assert backend_2.get_name() == EXPECTED_2 - else: - if name in CUDA_SUPPORTED_BACKENDS: - with patch("vllm.attention.selector.current_platform", - CudaPlatform()): - backend = get_attn_backend(16, torch.float16, - torch.float16, 16, False) - EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else name - assert backend.get_name() == EXPECTED + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH" + assert backend.get_name() == expected + + elif device == "cuda": + with patch("vllm.attention.selector.current_platform", + CudaPlatform()): + if use_mla: + if name == "FLASHMLA" and block_size == 64: + from vllm.attention.backends.flashmla import ( + is_flashmla_supported) + + # only on cuda platforms with specific capability. + is_supported, _ = is_flashmla_supported() + + if not is_supported: + # if platform is not supported then skip this case. + pytest.skip() + else: + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = f"{name}_VLLM_V1" if use_v1 else name + assert backend.get_name() == expected + else: + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = ("TRITON_MLA_VLLM_V1" + if use_v1 else "TRITON_MLA") + assert backend.get_name() == expected + else: + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name + assert backend.get_name() == expected def test_flash_attn(monkeypatch: pytest.MonkeyPatch): diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index e2e02c6bbdfe..80100049bcf1 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -120,24 +120,27 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, is_aiter_mla_enabled) if selected_backend is None: - selected_backend = (_Backend.ROCM_AITER_MLA - if is_aiter_mla_enabled() else - _Backend.TRITON_MLA) + selected_backend = (_Backend.ROCM_AITER_MLA if + is_aiter_mla_enabled() or block_size == 1 + else _Backend.TRITON_MLA) if selected_backend == _Backend.TRITON_MLA: - logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501 - + if block_size != 1: + logger.info("Using Triton MLA backend.") + return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501 + else: + raise ValueError( + f" The selected backend, {selected_backend.name}," + "does not support block size {block_size}.") else: if block_size == 1: logger.info("Using AITER MLA backend.") return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501 else: - logger.warning( - "AITER MLA backend is not supported for block size %d." - "(currently only supports block size 1)", block_size) - logger.info("Falling back to use Triton MLA backend") - return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501 + raise ValueError( + f" The selected backend, {selected_backend.name}," + "does not support block size {block_size}." + "(currently only supports block size 1)") selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) From 22c8726064954193702357c5f81a78f999a2fef4 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 1 Apr 2025 11:56:29 +0000 Subject: [PATCH 05/39] code cleaning Signed-off-by: vllmellm --- vllm/attention/backends/rocm_aiter_mla.py | 23 +++++++++++++++-------- vllm/attention/ops/rocm_aiter_mla.py | 19 +++++++++++-------- vllm/worker/model_runner.py | 6 ------ 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index ae9a40d70123..c58430e91587 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -16,7 +16,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) -from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd +from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd, + get_aiter_mla_metadata) from vllm.attention.ops.triton_merge_attn_states import merge_attn_states if TYPE_CHECKING: @@ -76,6 +77,7 @@ def prefill_metadata(self): .paged_kv_last_page_lens = self.paged_kv_last_page_lens prefill_metadata.block_table_bound = self.block_table_bound + # update the cache self._cached_prefill_metadata = self.__class__( **prefill_metadata.__dict__) @@ -94,6 +96,7 @@ def decode_metadata(self): .paged_kv_last_page_lens = self.paged_kv_last_page_lens decode_metadata.block_table_bound = self.block_table_bound + # update the cache self._cached_decode_metadata = self.__class__( **decode_metadata.__dict__) @@ -134,6 +137,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): def __init__(self, input_builder: "ModelInputForGPUBuilder"): super().__init__(input_builder) + assert self.runner.model_config.max_model_len == 32768,\ + "AITER MLA requires max model len to be set to 32768" assert self.block_size == 1, "AITER MLA requires only block size 1." def prepare(self): @@ -280,13 +285,15 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]): @contextmanager def graph_capture(self, max_batch_size: int): - self._paged_kv_indices_tensor = torch.from_numpy( - self.runner.paged_kv_indices).to(device=self.runner.device) - self._paged_kv_indptr_tensor = torch.zeros(max_batch_size + 1, - dtype=torch.int32, - device=self.runner.device) - self._paged_kv_last_page_lens_tensor = torch.full( - (max_batch_size, ), self.runner.block_size, dtype=torch.int32) + kv_indices, kv_indptr, last_page_lens = get_aiter_mla_metadata( + max_batch_size=max_batch_size, + block_size=self.runner.block_size, + max_block_per_batch=self.runner.get_max_block_per_batch(), + device=self.runner.device) + self._paged_kv_indices_tensor = kv_indices + self._paged_kv_indptr_tensor = kv_indptr + self._paged_kv_last_page_lens_tensor = last_page_lens + with super().graph_capture(max_batch_size): yield diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index f27bfe96a811..6d51bf1d8644 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -6,15 +6,18 @@ def get_aiter_mla_metadata(max_batch_size: int, block_size: int, + max_block_per_batch: int, device: torch.device) -> tuple[torch.Tensor, ...]: - - paged_kv_indptr_tensor = torch.zeros(max_batch_size + 1, - dtype=torch.int32, - device=device) - paged_kv_last_page_lens_tensor = torch.full((max_batch_size, ), - block_size, - dtype=torch.int32) - return paged_kv_indptr_tensor, paged_kv_last_page_lens_tensor + paged_kv_indices = torch.zeros(max_batch_size * max_block_per_batch, + dtype=torch.int32, + device=device) + paged_kv_indptr = torch.zeros(max_batch_size + 1, + dtype=torch.int32, + device=device) + paged_kv_last_page_lens = torch.full((max_batch_size, ), + block_size, + dtype=torch.int32) + return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens def aiter_mla_decode_fwd( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2c910c6c9049..edbafb48c938 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1049,12 +1049,6 @@ def __init__( (self.max_batchsize_to_capture, self.get_max_block_per_batch()), dtype=np.int32) - # paged_kv_indices: flattened graph_block_tables - # used by AITER MLA - self.paged_kv_indices = np.zeros(self.max_batchsize_to_capture * - self.get_max_block_per_batch(), - dtype=np.int32) - # Attention-free but stateful models like Mamba need a placeholder attn # backend, as the attention metadata is needed to manage internal state. # However we must bypass attention selection altogether for some models From 5dc13489ead2efd3cb0e2ff3e22778bc4357cd03 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 1 Apr 2025 14:59:49 +0000 Subject: [PATCH 06/39] update AITER version Signed-off-by: vllmellm --- Dockerfile.rocm_base | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.rocm_base b/Dockerfile.rocm_base index 38d6a33636eb..fabb61ec3ca8 100644 --- a/Dockerfile.rocm_base +++ b/Dockerfile.rocm_base @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="b7d29fb" ARG FA_REPO="https://github.com/ROCm/flash-attention.git" -ARG AITER_BRANCH="21d47a9" +ARG AITER_BRANCH="5208527" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base From da8c69f90e2cf9815fe0d659f3acd17c54209083 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 2 Apr 2025 04:51:02 +0000 Subject: [PATCH 07/39] add ck flash attn in prefill mla computation Signed-off-by: vllmellm --- vllm/attention/backends/mla/common.py | 156 ++++++++-------- vllm/attention/backends/rocm_aiter_mla.py | 210 +++++++--------------- vllm/envs.py | 7 + 3 files changed, 161 insertions(+), 212 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 525e56e1747f..9a08605ee8c2 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1147,6 +1147,38 @@ def get_and_maybe_dequant_weights(layer: LinearBase): # Convert from (L, N, P) to (N, P, L) self.W_UK_T = W_UK.permute(1, 2, 0) + def _get_prefill_ctx_attn_output( + self, index: int, q: torch.Tensor, k: torch.Tensor, + v: torch.Tensor, + metadata: MLACommonMetadata) -> tuple[torch.Tensor, ...]: + if is_vllm_fa: + attn_output, attn_softmax_lse = self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=metadata.query_start_loc, + cu_seqlens_k=metadata.context_chunk_cu_seq_lens[index], + max_seqlen_q=metadata.max_query_len, + max_seqlen_k=metadata.context_chunk_max_seq_lens[index], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + ) + else: + attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=metadata.query_start_loc, + cu_seqlens_k=metadata.context_chunk_cu_seq_lens[index], + max_seqlen_q=metadata.max_query_len, + max_seqlen_k=metadata.context_chunk_max_seq_lens[index], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_attn_probs=True, + ) + return attn_output, attn_softmax_lse + def _compute_prefill_context( self, q: torch.Tensor, @@ -1201,34 +1233,8 @@ def _compute_prefill_context( [0, q.shape[-1] - v.shape[-1]], value=0) - if is_vllm_fa: - attn_output, attn_softmax_lse = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata. - context_chunk_max_seq_lens[i], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_softmax_lse=True, - ) - else: - attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata. - context_chunk_max_seq_lens[i], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_attn_probs=True, - ) + attn_output, attn_softmax_lse = self._get_prefill_ctx_attn_output( + i, q, k, v_padded, prefill_metadata) if output is None: output = attn_output @@ -1249,43 +1255,20 @@ def _compute_prefill_context( return output, output_lse - def _forward_prefill( - self, - q: torch.Tensor, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, - ) -> torch.Tensor: - - prefill_metadata = attn_metadata.prefill_metadata - assert prefill_metadata is not None - - has_context = prefill_metadata.context_lens_tensor is not None \ - and prefill_metadata.context_lens_tensor.max() > 0 - - kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) - + def _get_fwd_prefill_attn_output( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + metadata: MLACommonMetadata, + has_context: bool) -> tuple[torch.Tensor, ...]: if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context: output = self.triton_fa_func( q, k, - v_padded, + v, None, - prefill_metadata.query_start_loc, - prefill_metadata.query_start_loc, - prefill_metadata.max_prefill_seq_len, - prefill_metadata.max_prefill_seq_len, + metadata.query_start_loc, + metadata.query_start_loc, + metadata.max_prefill_seq_len, + metadata.max_prefill_seq_len, True, # causal self.scale, None, # attn_mask is None unless applying ALiBi mask @@ -1297,11 +1280,11 @@ def _forward_prefill( output = self.flash_attn_varlen_func( q=q, k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.query_start_loc, - max_seqlen_q=prefill_metadata.max_prefill_seq_len, - max_seqlen_k=prefill_metadata.max_prefill_seq_len, + v=v, + cu_seqlens_q=metadata.query_start_loc, + cu_seqlens_k=metadata.query_start_loc, + max_seqlen_q=metadata.max_prefill_seq_len, + max_seqlen_k=metadata.max_prefill_seq_len, softmax_scale=self.scale, causal=True, return_softmax_lse=has_context, @@ -1310,16 +1293,49 @@ def _forward_prefill( output = self.flash_attn_varlen_func( q=q, k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.query_start_loc, - max_seqlen_q=prefill_metadata.max_prefill_seq_len, - max_seqlen_k=prefill_metadata.max_prefill_seq_len, + v=v, + cu_seqlens_q=metadata.query_start_loc, + cu_seqlens_k=metadata.query_start_loc, + max_seqlen_q=metadata.max_prefill_seq_len, + max_seqlen_k=metadata.max_prefill_seq_len, softmax_scale=self.scale, causal=True, return_attn_probs=has_context, ) + return output + + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ) -> torch.Tensor: + + prefill_metadata = attn_metadata.prefill_metadata + assert prefill_metadata is not None + + has_context = prefill_metadata.context_lens_tensor is not None \ + and prefill_metadata.context_lens_tensor.max() > 0 + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) + + output = self._get_fwd_prefill_attn_output(q, k, v_padded, + prefill_metadata, + has_context) + if has_context: # ROCm flash_attn_varlen_func will return 3 objects instead of 2 suffix_output, suffix_lse, *rest = output diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index c58430e91587..60a99794f693 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from contextlib import contextmanager +from contextlib import contextmanager, suppress from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional, Type @@ -18,7 +18,6 @@ is_block_tables_empty) from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd, get_aiter_mla_metadata) -from vllm.attention.ops.triton_merge_attn_states import merge_attn_states if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -30,6 +29,11 @@ def is_aiter_mla_enabled() -> bool: and envs.VLLM_ROCM_USE_AITER_MLA +def is_aiter_fa_enabled() -> bool: + return envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_ROCM_USE_AITER_FA + + class AiterMLABackend(MLACommonBackend): @staticmethod @@ -384,158 +388,80 @@ def __init__( "alibi_slopes, sliding_window, blocksparse_params, " "logits_soft_cap") - from aiter import flash_attn_varlen_func + flash_attn_varlen_func = None + + if is_aiter_fa_enabled(): + from aiter import flash_attn_varlen_func + else: + with suppress(ImportError): + from flash_attn import flash_attn_varlen_func self.flash_attn_varlen_func = flash_attn_varlen_func - def _compute_prefill_context( - self, - q: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: AiterMLAMetadata, - ): - prefill_metadata = attn_metadata.prefill_metadata - assert prefill_metadata is not None - assert prefill_metadata.context_chunk_seq_tot is not None - assert prefill_metadata.context_chunk_cu_seq_lens is not None - assert prefill_metadata.context_chunk_starts is not None - assert prefill_metadata.context_chunk_max_seq_lens is not None - assert prefill_metadata.context_lens_tensor is not None - - output = None - iters = len(prefill_metadata.context_chunk_seq_tot) - - # Fetch from attn_metadata directly, since it late bound by - # MLAAttentionState, grabbing it directly `attn_metadata` can avoid - # any weirdness around prefill_metadata caching - assert attn_metadata.context_chunk_workspace is not None - workspace = attn_metadata.context_chunk_workspace - - for i in range(iters): - toks = prefill_metadata.context_chunk_seq_tot[i] - - ops.gather_cache( - src_cache=kv_c_and_k_pe_cache, - dst=workspace, - block_table=prefill_metadata.block_tables, - cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i], - batch_size=prefill_metadata.num_prefills, - seq_starts=prefill_metadata.context_chunk_starts[i], + def _get_fwd_prefill_attn_output( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + metadata: AiterMLAMetadata, + has_context: bool) -> tuple[torch.Tensor, ...]: + if is_aiter_fa_enabled(): + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=metadata.query_start_loc, + cu_seqlens_k=metadata.query_start_loc, + max_seqlen_q=metadata.max_prefill_seq_len, + max_seqlen_k=metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + return_lse=has_context, + ) + if not has_context: + return output[0] + else: + return self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=metadata.query_start_loc, + cu_seqlens_k=metadata.query_start_loc, + max_seqlen_q=metadata.max_prefill_seq_len, + max_seqlen_k=metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + return_attn_probs=has_context, ) - kv_c_normed = workspace[:toks]\ - [..., :self.kv_lora_rank] - k_pe = workspace[:toks]\ - [..., self.kv_lora_rank:].unsqueeze(1) - - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), - dim=-1) - - # For MLA the v head dim is smaller than qk head dim so we pad - # out v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, - [0, q.shape[-1] - v.shape[-1]], - value=0) - attn_output, attn_softmax_lse = self.flash_attn_varlen_func( + def _get_prefill_ctx_attn_output( + self, index: int, q: torch.Tensor, k: torch.Tensor, + v: torch.Tensor, + metadata: AiterMLAMetadata) -> tuple[torch.Tensor, ...]: + if is_aiter_fa_enabled(): + return self.flash_attn_varlen_func( q=q, k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i], + v=v, + cu_seqlens_q=metadata.query_start_loc, + cu_seqlens_k=metadata.context_chunk_cu_seq_lens[index], + max_seqlen_q=metadata.max_query_len, + max_seqlen_k=metadata.context_chunk_max_seq_lens[index], softmax_scale=self.scale, causal=False, # Context is unmasked return_lse=True, ) - - if output is None: - output = attn_output - output_lse = attn_softmax_lse - else: - output_tmp = torch.empty_like(output) - output_lse_tmp = torch.empty_like(output_lse) - merge_attn_states( - output=output_tmp, - output_lse=output_lse_tmp, - prefix_output=output, - prefix_lse=output_lse, - suffix_output=attn_output, - suffix_lse=attn_softmax_lse, - ) - output = output_tmp - output_lse = output_lse_tmp - - return output, output_lse - - def _forward_prefill( - self, - q: torch.Tensor, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, - ) -> torch.Tensor: - - prefill_metadata = attn_metadata.prefill_metadata - assert prefill_metadata is not None - - has_context = prefill_metadata.context_lens_tensor is not None \ - and prefill_metadata.context_lens_tensor.max() > 0 - - kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) - - output = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.query_start_loc, - max_seqlen_q=prefill_metadata.max_prefill_seq_len, - max_seqlen_k=prefill_metadata.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - return_lse=has_context, - ) - - if not has_context: - output = output[0] - - if has_context: - suffix_output, suffix_lse = output - context_output, context_lse = self._compute_prefill_context( \ - q, kv_c_and_k_pe_cache, attn_metadata) - - output = torch.empty_like(suffix_output) - merge_attn_states( - output=output, - prefix_output=context_output, - prefix_lse=context_lse, - suffix_output=suffix_output, - suffix_lse=suffix_lse, + else: + attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=metadata.query_start_loc, + cu_seqlens_k=metadata.context_chunk_cu_seq_lens[index], + max_seqlen_q=metadata.max_query_len, + max_seqlen_k=metadata.context_chunk_max_seq_lens[index], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_attn_probs=True, ) - - # slice by `:v.shape[-1]` in order to remove v headdim padding - output = output\ - .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ - .reshape(-1, self.num_heads * v.shape[-1]) - - return self.o_proj(output)[0] + return attn_output, attn_softmax_lse def _forward_decode( self, diff --git a/vllm/envs.py b/vllm/envs.py index ab6e9833531f..b5277acc2fe6 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -79,6 +79,7 @@ VLLM_ROCM_USE_AITER_MOE: bool = True VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False VLLM_ROCM_USE_AITER_RMSNORM: bool = True + VLLM_ROCM_USE_AITER_FA: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True @@ -557,6 +558,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in ("true", "1")), + # Whether to use aiter fa ops in aiter mla. + # By default is enabled. + "VLLM_ROCM_USE_AITER_FA": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_FA", "True").lower() in + ("true", "1")), + # Whether to use aiter mla ops. # By default is enabled. "VLLM_ROCM_USE_AITER_MLA": From 1ea5718e2ecd55d3b87bacea87cc7df83a3eaafd Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 2 Apr 2025 08:00:34 +0000 Subject: [PATCH 08/39] further code cleaning Signed-off-by: vllmellm --- vllm/attention/backends/mla/common.py | 23 +++++++++-------------- vllm/attention/backends/rocm_aiter_mla.py | 18 +++++------------- 2 files changed, 14 insertions(+), 27 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 9a08605ee8c2..82197acfb341 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -191,7 +191,7 @@ from dataclasses import dataclass from itertools import accumulate from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, - Type, TypeVar) + Type, TypeVar, Union) import torch @@ -739,6 +739,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): NOTE: Please read the comment at the top of the file before trying to understand this class """ + BLOCK_TABLE_EXTENDER: Union[List[List[int]], List[int]] = [] def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.input_builder = input_builder @@ -850,15 +851,6 @@ def _get_graph_runner_block_tables( return torch.from_numpy(graph_block_tables).to( device=self.runner.device, non_blocking=True) - def get_block_tables_with_captured_graph(self, cuda_graph_pad_size: int, - num_seqs: int) -> torch.Tensor: - self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend([] * cuda_graph_pad_size) - block_tables = self._get_graph_runner_block_tables( - num_seqs, self.block_tables) - - return block_tables - def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): """Build attention metadata with on-device tensors. @@ -898,8 +890,11 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_seqs = len(seq_lens) if use_captured_graph: num_decode_tokens = batch_size - self.num_prefill_tokens - block_tables = self.get_block_tables_with_captured_graph( - cuda_graph_pad_size, num_seqs) + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend(self.BLOCK_TABLE_EXTENDER * + cuda_graph_pad_size) + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables) else: block_tables = make_tensor_with_pad( self.block_tables, @@ -1150,7 +1145,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): def _get_prefill_ctx_attn_output( self, index: int, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - metadata: MLACommonMetadata) -> tuple[torch.Tensor, ...]: + metadata: MLACommonMetadata) -> Tuple[torch.Tensor, ...]: if is_vllm_fa: attn_output, attn_softmax_lse = self.flash_attn_varlen_func( q=q, @@ -1258,7 +1253,7 @@ def _compute_prefill_context( def _get_fwd_prefill_attn_output( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, metadata: MLACommonMetadata, - has_context: bool) -> tuple[torch.Tensor, ...]: + has_context: bool) -> Tuple[torch.Tensor, ...]: if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context: output = self.triton_fa_func( q, diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index 60a99794f693..a3b5ec4c5341 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -2,7 +2,7 @@ from contextlib import contextmanager, suppress from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, Type +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union import torch @@ -13,7 +13,7 @@ MLACommonMetadata, MLACommonMetadataBuilder, MLACommonState) -from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, +from vllm.attention.backends.utils import (compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd, @@ -138,6 +138,7 @@ def advance_step(self, class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): + BLOCK_TABLE_EXTENDER: Union[List[List[int]], List[int]] = [[]] def __init__(self, input_builder: "ModelInputForGPUBuilder"): super().__init__(input_builder) @@ -231,15 +232,6 @@ def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int): last_page_len = self.block_size self.paged_kv_last_page_lens.append(last_page_len) - def get_block_tables_with_captured_graph(self, cuda_graph_pad_size: int, - num_seqs: int) -> torch.Tensor: - self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend([[]] * cuda_graph_pad_size) - block_tables = self._get_graph_runner_block_tables( - num_seqs, self.block_tables) - - return block_tables - def build(self, seq_lens: list[int], query_lens: list[int], cuda_graph_pad_size: int, batch_size: int) -> AiterMLAMetadata: metadata = super().build(seq_lens, query_lens, cuda_graph_pad_size, @@ -401,7 +393,7 @@ def __init__( def _get_fwd_prefill_attn_output( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, metadata: AiterMLAMetadata, - has_context: bool) -> tuple[torch.Tensor, ...]: + has_context: bool) -> Tuple[torch.Tensor, ...]: if is_aiter_fa_enabled(): output = self.flash_attn_varlen_func( q=q, @@ -434,7 +426,7 @@ def _get_fwd_prefill_attn_output( def _get_prefill_ctx_attn_output( self, index: int, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - metadata: AiterMLAMetadata) -> tuple[torch.Tensor, ...]: + metadata: AiterMLAMetadata) -> Tuple[torch.Tensor, ...]: if is_aiter_fa_enabled(): return self.flash_attn_varlen_func( q=q, From 9ada055ac9628e627981c8e49e95fd797110868b Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 3 Apr 2025 04:41:43 +0000 Subject: [PATCH 09/39] fix mypy typing errors Signed-off-by: vllmellm --- vllm/attention/backends/mla/common.py | 76 ++++++++++++----------- vllm/attention/backends/rocm_aiter_mla.py | 70 ++++++++++++++------- 2 files changed, 88 insertions(+), 58 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 82197acfb341..cbadb8d5db18 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -190,8 +190,8 @@ from contextlib import contextmanager from dataclasses import dataclass from itertools import accumulate -from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, - Type, TypeVar, Union) +from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, + Optional, Tuple, Type, TypeVar) import torch @@ -739,7 +739,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): NOTE: Please read the comment at the top of the file before trying to understand this class """ - BLOCK_TABLE_EXTENDER: Union[List[List[int]], List[int]] = [] + BLOCK_TABLE_EXTENDER: Iterable[list[int]] = [] def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.input_builder = input_builder @@ -1146,6 +1146,9 @@ def _get_prefill_ctx_attn_output( self, index: int, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, metadata: MLACommonMetadata) -> Tuple[torch.Tensor, ...]: + assert metadata.context_chunk_cu_seq_lens is not None + assert metadata.context_chunk_max_seq_lens is not None + if is_vllm_fa: attn_output, attn_softmax_lse = self.flash_attn_varlen_func( q=q, @@ -1250,20 +1253,22 @@ def _compute_prefill_context( return output, output_lse - def _get_fwd_prefill_attn_output( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - metadata: MLACommonMetadata, - has_context: bool) -> Tuple[torch.Tensor, ...]: + def _get_fwd_prefill_attn_output(self, q: torch.Tensor, k: torch.Tensor, + v: torch.Tensor, + prefill_metadata: MLACommonMetadata, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + has_context: bool) -> torch.Tensor: if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context: output = self.triton_fa_func( q, k, v, None, - metadata.query_start_loc, - metadata.query_start_loc, - metadata.max_prefill_seq_len, - metadata.max_prefill_seq_len, + prefill_metadata.query_start_loc, + prefill_metadata.query_start_loc, + prefill_metadata.max_prefill_seq_len, + prefill_metadata.max_prefill_seq_len, True, # causal self.scale, None, # attn_mask is None unless applying ALiBi mask @@ -1276,10 +1281,10 @@ def _get_fwd_prefill_attn_output( q=q, k=k, v=v, - cu_seqlens_q=metadata.query_start_loc, - cu_seqlens_k=metadata.query_start_loc, - max_seqlen_q=metadata.max_prefill_seq_len, - max_seqlen_k=metadata.max_prefill_seq_len, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, softmax_scale=self.scale, causal=True, return_softmax_lse=has_context, @@ -1289,15 +1294,30 @@ def _get_fwd_prefill_attn_output( q=q, k=k, v=v, - cu_seqlens_q=metadata.query_start_loc, - cu_seqlens_k=metadata.query_start_loc, - max_seqlen_q=metadata.max_prefill_seq_len, - max_seqlen_k=metadata.max_prefill_seq_len, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, softmax_scale=self.scale, causal=True, return_attn_probs=has_context, ) + if has_context: + # ROCm flash_attn_varlen_func will return 3 objects instead of 2 + suffix_output, suffix_lse, *rest = output + context_output, context_lse = self._compute_prefill_context( \ + q, kv_c_and_k_pe_cache, attn_metadata) + + output = torch.empty_like(suffix_output) + merge_attn_states( + output=output, + prefix_output=context_output, + prefix_lse=context_lse, + suffix_output=suffix_output, + suffix_lse=suffix_lse, + ) + return output def _forward_prefill( @@ -1329,22 +1349,8 @@ def _forward_prefill( output = self._get_fwd_prefill_attn_output(q, k, v_padded, prefill_metadata, - has_context) - - if has_context: - # ROCm flash_attn_varlen_func will return 3 objects instead of 2 - suffix_output, suffix_lse, *rest = output - context_output, context_lse = self._compute_prefill_context( \ - q, kv_c_and_k_pe_cache, attn_metadata) - - output = torch.empty_like(suffix_output) - merge_attn_states( - output=output, - prefix_output=context_output, - prefix_lse=context_lse, - suffix_output=suffix_output, - suffix_lse=suffix_lse, - ) + kv_c_and_k_pe_cache, + attn_metadata, has_context) # slice by `:v.shape[-1]` in order to remove v headdim padding output = output\ diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index a3b5ec4c5341..dfb1e5034948 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -2,7 +2,7 @@ from contextlib import contextmanager, suppress from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Iterable, Optional, Tuple, Type import torch @@ -18,6 +18,7 @@ is_block_tables_empty) from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd, get_aiter_mla_metadata) +from vllm.attention.ops.triton_merge_attn_states import merge_attn_states if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -138,7 +139,7 @@ def advance_step(self, class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): - BLOCK_TABLE_EXTENDER: Union[List[List[int]], List[int]] = [[]] + BLOCK_TABLE_EXTENDER: Iterable[list[int]] = [[]] def __init__(self, input_builder: "ModelInputForGPUBuilder"): super().__init__(input_builder) @@ -380,29 +381,32 @@ def __init__( "alibi_slopes, sliding_window, blocksparse_params, " "logits_soft_cap") - flash_attn_varlen_func = None - - if is_aiter_fa_enabled(): - from aiter import flash_attn_varlen_func - else: - with suppress(ImportError): + flash_attn_func = None + with suppress(ImportError): + if is_aiter_fa_enabled(): + from aiter import flash_attn_varlen_func + flash_attn_func = flash_attn_varlen_func + else: from flash_attn import flash_attn_varlen_func + flash_attn_func = flash_attn_varlen_func - self.flash_attn_varlen_func = flash_attn_varlen_func + self.flash_attn_varlen_func = flash_attn_func - def _get_fwd_prefill_attn_output( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - metadata: AiterMLAMetadata, - has_context: bool) -> Tuple[torch.Tensor, ...]: + def _get_fwd_prefill_attn_output(self, q: torch.Tensor, k: torch.Tensor, + v: torch.Tensor, + prefill_metadata: MLACommonMetadata, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + has_context: bool) -> torch.Tensor: if is_aiter_fa_enabled(): output = self.flash_attn_varlen_func( q=q, k=k, v=v, - cu_seqlens_q=metadata.query_start_loc, - cu_seqlens_k=metadata.query_start_loc, - max_seqlen_q=metadata.max_prefill_seq_len, - max_seqlen_k=metadata.max_prefill_seq_len, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, softmax_scale=self.scale, causal=True, return_lse=has_context, @@ -410,23 +414,43 @@ def _get_fwd_prefill_attn_output( if not has_context: return output[0] else: - return self.flash_attn_varlen_func( + output = self.flash_attn_varlen_func( q=q, k=k, v=v, - cu_seqlens_q=metadata.query_start_loc, - cu_seqlens_k=metadata.query_start_loc, - max_seqlen_q=metadata.max_prefill_seq_len, - max_seqlen_k=metadata.max_prefill_seq_len, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, softmax_scale=self.scale, causal=True, return_attn_probs=has_context, ) + if has_context: + # ROCm flash_attn_varlen_func will return 3 objects instead of 2 + suffix_output, suffix_lse, *rest = output + context_output, context_lse = self._compute_prefill_context( \ + q, kv_c_and_k_pe_cache, attn_metadata) + + output = torch.empty_like(suffix_output) + merge_attn_states( + output=output, + prefix_output=context_output, + prefix_lse=context_lse, + suffix_output=suffix_output, + suffix_lse=suffix_lse, + ) + + return output + def _get_prefill_ctx_attn_output( self, index: int, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - metadata: AiterMLAMetadata) -> Tuple[torch.Tensor, ...]: + metadata: MLACommonMetadata) -> Tuple[torch.Tensor, ...]: + assert metadata.context_chunk_cu_seq_lens is not None + assert metadata.context_chunk_max_seq_lens is not None + if is_aiter_fa_enabled(): return self.flash_attn_varlen_func( q=q, From 20a3f074933ffe494d834dd6729e997bbbb968d5 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 3 Apr 2025 05:24:47 +0000 Subject: [PATCH 10/39] fix mypy error on Iterable typing error Signed-off-by: vllmellm --- vllm/attention/backends/mla/common.py | 8 ++++---- vllm/attention/backends/rocm_aiter_mla.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index cbadb8d5db18..ea8ff78e1946 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -190,8 +190,8 @@ from contextlib import contextmanager from dataclasses import dataclass from itertools import accumulate -from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, - Optional, Tuple, Type, TypeVar) +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, + Type, TypeVar) import torch @@ -739,7 +739,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): NOTE: Please read the comment at the top of the file before trying to understand this class """ - BLOCK_TABLE_EXTENDER: Iterable[list[int]] = [] + BLOCK_TABLE_EXTENDER: list[list[int]] = [] def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.input_builder = input_builder @@ -891,7 +891,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], if use_captured_graph: num_decode_tokens = batch_size - self.num_prefill_tokens self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend(self.BLOCK_TABLE_EXTENDER * + self.block_tables.extend(self.__class__.BLOCK_TABLE_EXTENDER * cuda_graph_pad_size) block_tables = self._get_graph_runner_block_tables( num_seqs, self.block_tables) diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index dfb1e5034948..4a3da4fc8e4b 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -2,7 +2,7 @@ from contextlib import contextmanager, suppress from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Iterable, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, Optional, Tuple, Type import torch @@ -139,7 +139,7 @@ def advance_step(self, class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): - BLOCK_TABLE_EXTENDER: Iterable[list[int]] = [[]] + BLOCK_TABLE_EXTENDER: list[list[int]] = [[]] def __init__(self, input_builder: "ModelInputForGPUBuilder"): super().__init__(input_builder) From 194a42a137e25b83580ac5ef6fda8fb22f3b7103 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 15 Apr 2025 11:17:07 +0000 Subject: [PATCH 11/39] remove padding for v tensor in AITER MLA which improves performance Co-authored-by: ArthurAMD yajhuang@amd.com Signed-off-by: vllmellm --- vllm/attention/backends/mla/common.py | 44 +++++++++++------------ vllm/attention/backends/rocm_aiter_mla.py | 22 +++++++++--- 2 files changed, 38 insertions(+), 28 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index ea8ff78e1946..6441134f4019 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1149,11 +1149,15 @@ def _get_prefill_ctx_attn_output( assert metadata.context_chunk_cu_seq_lens is not None assert metadata.context_chunk_max_seq_lens is not None + # For MLA the v head dim is smaller than qk head dim so we pad + # out v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) if is_vllm_fa: attn_output, attn_softmax_lse = self.flash_attn_varlen_func( q=q, k=k, - v=v, + v=v_padded, cu_seqlens_q=metadata.query_start_loc, cu_seqlens_k=metadata.context_chunk_cu_seq_lens[index], max_seqlen_q=metadata.max_query_len, @@ -1166,7 +1170,7 @@ def _get_prefill_ctx_attn_output( attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func( q=q, k=k, - v=v, + v=v_padded, cu_seqlens_q=metadata.query_start_loc, cu_seqlens_k=metadata.context_chunk_cu_seq_lens[index], max_seqlen_q=metadata.max_query_len, @@ -1225,14 +1229,8 @@ def _compute_prefill_context( k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # For MLA the v head dim is smaller than qk head dim so we pad - # out v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, - [0, q.shape[-1] - v.shape[-1]], - value=0) - attn_output, attn_softmax_lse = self._get_prefill_ctx_attn_output( - i, q, k, v_padded, prefill_metadata) + i, q, k, v, prefill_metadata) if output is None: output = attn_output @@ -1259,11 +1257,15 @@ def _get_fwd_prefill_attn_output(self, q: torch.Tensor, k: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, has_context: bool) -> torch.Tensor: + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context: output = self.triton_fa_func( q, k, - v, + v_padded, None, prefill_metadata.query_start_loc, prefill_metadata.query_start_loc, @@ -1280,7 +1282,7 @@ def _get_fwd_prefill_attn_output(self, q: torch.Tensor, k: torch.Tensor, output = self.flash_attn_varlen_func( q=q, k=k, - v=v, + v=v_padded, cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_k=prefill_metadata.query_start_loc, max_seqlen_q=prefill_metadata.max_prefill_seq_len, @@ -1293,7 +1295,7 @@ def _get_fwd_prefill_attn_output(self, q: torch.Tensor, k: torch.Tensor, output = self.flash_attn_varlen_func( q=q, k=k, - v=v, + v=v_padded, cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_k=prefill_metadata.query_start_loc, max_seqlen_q=prefill_metadata.max_prefill_seq_len, @@ -1318,6 +1320,11 @@ def _get_fwd_prefill_attn_output(self, q: torch.Tensor, k: torch.Tensor, suffix_lse=suffix_lse, ) + # slice by `:v.shape[-1]` in order to remove v headdim padding + output = output\ + .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ + .reshape(-1, self.num_heads * v.shape[-1]) + return output def _forward_prefill( @@ -1342,21 +1349,10 @@ def _forward_prefill( k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) - - output = self._get_fwd_prefill_attn_output(q, k, v_padded, - prefill_metadata, + output = self._get_fwd_prefill_attn_output(q, k, v, prefill_metadata, kv_c_and_k_pe_cache, attn_metadata, has_context) - # slice by `:v.shape[-1]` in order to remove v headdim padding - output = output\ - .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ - .reshape(-1, self.num_heads * v.shape[-1]) - return self.o_proj(output)[0] @abstractmethod diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index 4a3da4fc8e4b..01956a8af822 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -399,6 +399,7 @@ def _get_fwd_prefill_attn_output(self, q: torch.Tensor, k: torch.Tensor, attn_metadata: MLACommonMetadata, has_context: bool) -> torch.Tensor: if is_aiter_fa_enabled(): + is_v_padded = False output = self.flash_attn_varlen_func( q=q, k=k, @@ -411,13 +412,18 @@ def _get_fwd_prefill_attn_output(self, q: torch.Tensor, k: torch.Tensor, causal=True, return_lse=has_context, ) - if not has_context: - return output[0] + else: + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, + [0, q.shape[-1] - v.shape[-1]], + value=0) + is_v_padded = True output = self.flash_attn_varlen_func( q=q, k=k, - v=v, + v=v_padded, cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_k=prefill_metadata.query_start_loc, max_seqlen_q=prefill_metadata.max_prefill_seq_len, @@ -441,8 +447,16 @@ def _get_fwd_prefill_attn_output(self, q: torch.Tensor, k: torch.Tensor, suffix_output=suffix_output, suffix_lse=suffix_lse, ) + else: + output = output[0] - return output + if is_v_padded: + # slice by `:v.shape[-1]` in order to remove v headdim padding + return output\ + .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ + .reshape(-1, self.num_heads * v.shape[-1]) + else: + return output.reshape(-1, self.num_heads * v.shape[-1]) def _get_prefill_ctx_attn_output( self, index: int, q: torch.Tensor, k: torch.Tensor, From a9a02d5936ed7d4e89d2b5cab03e878abe18176a Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 15 Apr 2025 11:18:50 +0000 Subject: [PATCH 12/39] upgrade aiter package version Signed-off-by: vllmellm --- docker/Dockerfile.rocm_base | 2 +- vllm/attention/backends/rocm_aiter_mla.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index 61894918de2f..05192eb69b54 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="1a7f4dfa" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="5208527" +ARG AITER_BRANCH="5a77249" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index 01956a8af822..87a3e4fbf8d2 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -447,8 +447,6 @@ def _get_fwd_prefill_attn_output(self, q: torch.Tensor, k: torch.Tensor, suffix_output=suffix_output, suffix_lse=suffix_lse, ) - else: - output = output[0] if is_v_padded: # slice by `:v.shape[-1]` in order to remove v headdim padding From 02a4fb321a4d67f67310cd91cb6678018670774c Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 15 Apr 2025 11:25:15 +0000 Subject: [PATCH 13/39] only support AITER FA in AITER MLA backend to avoid latency caused by if/else statements Signed-off-by: vllmellm --- vllm/attention/backends/rocm_aiter_mla.py | 115 ++++++---------------- vllm/envs.py | 7 -- 2 files changed, 29 insertions(+), 93 deletions(-) diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index 87a3e4fbf8d2..bd70cd101461 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from contextlib import contextmanager, suppress +from contextlib import contextmanager from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional, Tuple, Type @@ -30,11 +30,6 @@ def is_aiter_mla_enabled() -> bool: and envs.VLLM_ROCM_USE_AITER_MLA -def is_aiter_fa_enabled() -> bool: - return envs.VLLM_ROCM_USE_AITER \ - and envs.VLLM_ROCM_USE_AITER_FA - - class AiterMLABackend(MLACommonBackend): @staticmethod @@ -381,16 +376,8 @@ def __init__( "alibi_slopes, sliding_window, blocksparse_params, " "logits_soft_cap") - flash_attn_func = None - with suppress(ImportError): - if is_aiter_fa_enabled(): - from aiter import flash_attn_varlen_func - flash_attn_func = flash_attn_varlen_func - else: - from flash_attn import flash_attn_varlen_func - flash_attn_func = flash_attn_varlen_func - - self.flash_attn_varlen_func = flash_attn_func + from aiter import flash_attn_varlen_func + self.flash_attn_varlen_func = flash_attn_varlen_func def _get_fwd_prefill_attn_output(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -398,44 +385,21 @@ def _get_fwd_prefill_attn_output(self, q: torch.Tensor, k: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, has_context: bool) -> torch.Tensor: - if is_aiter_fa_enabled(): - is_v_padded = False - output = self.flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.query_start_loc, - max_seqlen_q=prefill_metadata.max_prefill_seq_len, - max_seqlen_k=prefill_metadata.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - return_lse=has_context, - ) - - else: - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, - [0, q.shape[-1] - v.shape[-1]], - value=0) - is_v_padded = True - output = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.query_start_loc, - max_seqlen_q=prefill_metadata.max_prefill_seq_len, - max_seqlen_k=prefill_metadata.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - return_attn_probs=has_context, - ) + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + return_lse=has_context, + ) if has_context: - # ROCm flash_attn_varlen_func will return 3 objects instead of 2 - suffix_output, suffix_lse, *rest = output + suffix_output, suffix_lse = output context_output, context_lse = self._compute_prefill_context( \ q, kv_c_and_k_pe_cache, attn_metadata) @@ -448,13 +412,7 @@ def _get_fwd_prefill_attn_output(self, q: torch.Tensor, k: torch.Tensor, suffix_lse=suffix_lse, ) - if is_v_padded: - # slice by `:v.shape[-1]` in order to remove v headdim padding - return output\ - .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ - .reshape(-1, self.num_heads * v.shape[-1]) - else: - return output.reshape(-1, self.num_heads * v.shape[-1]) + return output.reshape(-1, self.num_heads * v.shape[-1]) def _get_prefill_ctx_attn_output( self, index: int, q: torch.Tensor, k: torch.Tensor, @@ -463,33 +421,18 @@ def _get_prefill_ctx_attn_output( assert metadata.context_chunk_cu_seq_lens is not None assert metadata.context_chunk_max_seq_lens is not None - if is_aiter_fa_enabled(): - return self.flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=metadata.query_start_loc, - cu_seqlens_k=metadata.context_chunk_cu_seq_lens[index], - max_seqlen_q=metadata.max_query_len, - max_seqlen_k=metadata.context_chunk_max_seq_lens[index], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_lse=True, - ) - else: - attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=metadata.query_start_loc, - cu_seqlens_k=metadata.context_chunk_cu_seq_lens[index], - max_seqlen_q=metadata.max_query_len, - max_seqlen_k=metadata.context_chunk_max_seq_lens[index], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_attn_probs=True, - ) - return attn_output, attn_softmax_lse + return self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=metadata.query_start_loc, + cu_seqlens_k=metadata.context_chunk_cu_seq_lens[index], + max_seqlen_q=metadata.max_query_len, + max_seqlen_k=metadata.context_chunk_max_seq_lens[index], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_lse=True, + ) def _forward_decode( self, diff --git a/vllm/envs.py b/vllm/envs.py index cd3d1afd7e72..cafb0bbb75a7 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -79,7 +79,6 @@ VLLM_ROCM_USE_AITER_MOE: bool = True VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False VLLM_ROCM_USE_AITER_RMSNORM: bool = True - VLLM_ROCM_USE_AITER_FA: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True @@ -559,12 +558,6 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in ("true", "1")), - # Whether to use aiter fa ops in aiter mla. - # By default is enabled. - "VLLM_ROCM_USE_AITER_FA": - lambda: (os.getenv("VLLM_ROCM_USE_AITER_FA", "True").lower() in - ("true", "1")), - # Whether to use aiter mla ops. # By default is enabled. "VLLM_ROCM_USE_AITER_MLA": From 6e4843376f88b41dea4ead5293aa7cb0fe26d1d9 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 16 Apr 2025 03:58:32 +0000 Subject: [PATCH 14/39] add missing data types of arguments in aiter_mla_decode_fwd Signed-off-by: vllmellm --- vllm/attention/ops/rocm_aiter_mla.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index 6d51bf1d8644..1c90f8c19b09 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -21,10 +21,10 @@ def get_aiter_mla_metadata(max_batch_size: int, block_size: int, def aiter_mla_decode_fwd( - q, - kv_buffer, - o, - sm_scale, + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + sm_scale: float, kv_indptr: Optional[torch.Tensor] = None, kv_indices: Optional[torch.Tensor] = None, kv_last_page_lens: Optional[torch.Tensor] = None, From 0265f201a358f929f4c6ea1c3540473d2dc82d4c Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 16 Apr 2025 13:40:23 +0000 Subject: [PATCH 15/39] support AITER MLA backend on V1 Signed-off-by: vllmellm --- requirements/common.txt | 8 +- vllm/attention/backends/rocm_aiter_mla.py | 10 +- vllm/attention/ops/rocm_aiter_mla.py | 48 +++- vllm/engine/arg_utils.py | 3 +- vllm/platforms/interface.py | 3 +- vllm/platforms/rocm.py | 16 +- vllm/v1/attention/backends/mla/common.py | 127 +++++---- .../attention/backends/mla/rocm_aiter_mla.py | 253 ++++++++++++++++++ 8 files changed, 400 insertions(+), 68 deletions(-) create mode 100644 vllm/v1/attention/backends/mla/rocm_aiter_mla.py diff --git a/requirements/common.txt b/requirements/common.txt index 4df32460c2db..14a4d2669c46 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -43,7 +43,7 @@ watchfiles # required for http server to monitor the updates of TLS files python-json-logger # Used by logging as per examples/other/logging_configuration.md scipy # Required for phi-4-multimodal-instruct ninja # Required for xgrammar, rocm, tpu, xpu -opentelemetry-sdk>=1.26.0,<1.27.0 # vllm.tracing -opentelemetry-api>=1.26.0,<1.27.0 # vllm.tracing -opentelemetry-exporter-otlp>=1.26.0,<1.27.0 # vllm.tracing -opentelemetry-semantic-conventions-ai>=0.4.1,<0.5.0 # vllm.tracing +# opentelemetry-sdk>=1.26.0,<1.27.0 # vllm.tracing +# opentelemetry-api>=1.26.0,<1.27.0 # vllm.tracing +# opentelemetry-exporter-otlp>=1.26.0,<1.27.0 # vllm.tracing +# opentelemetry-semantic-conventions-ai>=0.4.1,<0.5.0 # vllm.tracing diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index bd70cd101461..60dd7c11ce60 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -16,7 +16,7 @@ from vllm.attention.backends.utils import (compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) -from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd, +from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_forward, get_aiter_mla_metadata) from vllm.attention.ops.triton_merge_attn_states import merge_attn_states @@ -456,9 +456,9 @@ def _forward_decode( kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) - aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, - attn_metadata.paged_kv_indptr, - attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_lens) + aiter_mla_decode_forward(q, kv_buffer, o, self.scale, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_lens) return self._v_up_proj_and_o_proj(o) diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index 1c90f8c19b09..e9306794e02d 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -4,6 +4,9 @@ import torch +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op + def get_aiter_mla_metadata(max_batch_size: int, block_size: int, max_block_per_batch: int, @@ -20,7 +23,7 @@ def get_aiter_mla_metadata(max_batch_size: int, block_size: int, return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens -def aiter_mla_decode_fwd( +def aiter_mla_decode_forward( q: torch.Tensor, kv_buffer: torch.Tensor, o: torch.Tensor, @@ -30,6 +33,28 @@ def aiter_mla_decode_fwd( kv_last_page_lens: Optional[torch.Tensor] = None, logit_cap: float = 0.0, ): + + torch.ops.vllm.rocm_aiter_mla_decode_fwd(q, + kv_buffer.view( + -1, 1, 1, q.shape[-1]), + o, + kv_indptr, + kv_indices, + kv_last_page_lens, + sm_scale=sm_scale, + logit_cap=logit_cap) + + +def mla_decode_fwd_impl( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + kv_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_last_page_lens: Optional[torch.Tensor] = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: from aiter.mla import mla_decode_fwd mla_decode_fwd(q, @@ -40,3 +65,24 @@ def aiter_mla_decode_fwd( kv_last_page_lens, sm_scale=sm_scale, logit_cap=logit_cap) + + +def mla_decode_fwd_fake( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + kv_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_last_page_lens: Optional[torch.Tensor] = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: + pass + + +if current_platform.is_rocm(): + direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd", + op_func=mla_decode_fwd_impl, + mutates_args=["o"], + fake_impl=mla_decode_fwd_fake, + tags=[torch.Tag.needs_fixed_stride_order]) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 921d7bcd9f82..9e4561113dbf 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1532,7 +1532,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: # No FlashInfer or XFormers so far. V1_BACKENDS = [ "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1", - "TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA" + "TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA", "ROCM_AITER_MLA", + "ROCM_AITER_MLA_VLLM_V1" ] if (envs.is_set("VLLM_ATTENTION_BACKEND") and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index d28a8633e5d7..ce39a86c63b5 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -39,7 +39,8 @@ class _Backend(enum.Enum): TRITON_ATTN_VLLM_V1 = enum.auto() XFORMERS = enum.auto() ROCM_FLASH = enum.auto() - ROCM_AITER_MLA = enum.auto() + ROCM_AITER_MLA = enum.auto() # Supported by V1 + ROCM_AITER_MLA_VLLM_V1 = enum.auto() TORCH_SDPA = enum.auto() FLASHINFER = enum.auto() TRITON_MLA = enum.auto() # Supported by V1 diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 70da611b4d86..5052bad6230b 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -147,16 +147,24 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, if selected_backend == _Backend.TRITON_MLA: if block_size != 1: - logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501 + if use_v1: + logger.info("Using Triton MLA backend.") + return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" # noqa: E501 + else: + logger.info("Using Triton MLA backend.") + return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501 else: raise ValueError( f" The selected backend, {selected_backend.name}," "does not support block size {block_size}.") else: if block_size == 1: - logger.info("Using AITER MLA backend.") - return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501 + if use_v1: + logger.info("Using AITER MLA backend.") + return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 + else: + logger.info("Using AITER MLA backend.") + return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501 else: raise ValueError( f" The selected backend, {selected_backend.name}," diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 8c7179ba0a8a..b8e351fce688 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -708,6 +708,50 @@ def get_and_maybe_dequant_weights(layer: LinearBase): # Convert from (L, N, P) to (N, P, L) self.W_UK_T = W_UK.permute(1, 2, 0) + def _get_prefill_ctx_attn_output( + self, index: int, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + prefill_metadata: MLACommonPrefillMetadata + ) -> tuple[torch.Tensor, ...]: + output = None + output_lse = None + assert prefill_metadata.chunked_context is not None + # For MLA the v head dim is smaller than qk head dim so we pad + # out v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) + + attn_output, attn_softmax_lse = self.flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[index], + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[index], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + def _compute_prefill_context( self, q: torch.Tensor, @@ -747,62 +791,17 @@ def _compute_prefill_context( k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # For MLA the v head dim is smaller than qk head dim so we pad - # out v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, - [0, q.shape[-1] - v.shape[-1]], - value=0) - - attn_output, attn_softmax_lse = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i], - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_softmax_lse=True, - ) - - if output is None: - output = attn_output - output_lse = attn_softmax_lse - else: - output_tmp = torch.empty_like(output) - output_lse_tmp = torch.empty_like(output_lse) - merge_attn_states( - output=output_tmp, - output_lse=output_lse_tmp, - prefix_output=output, - prefix_lse=output_lse, - suffix_output=attn_output, - suffix_lse=attn_softmax_lse, - ) - output = output_tmp - output_lse = output_lse_tmp + output, output_lse = self._get_prefill_ctx_attn_output( + i, q, k, v, prefill_metadata) return output, output_lse - def _forward_prefill( - self, - q: torch.Tensor, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, - ) -> torch.Tensor: + def _get_fwd_prefill_attn_output(self, q: torch.Tensor, k: torch.Tensor, + v: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + has_context: bool) -> torch.Tensor: assert attn_metadata.prefill is not None - - has_context = attn_metadata.prefill.chunked_context is not None - kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # For MLA the v head dim is smaller than qk head dim so we pad out # v with 0s to match the qk head dim v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], @@ -840,6 +839,30 @@ def _forward_prefill( .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ .reshape(-1, self.num_heads * v.shape[-1]) + return output + + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ) -> torch.Tensor: + assert attn_metadata.prefill is not None + + has_context = attn_metadata.prefill.chunked_context is not None + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + output = self._get_fwd_prefill_attn_output(q, k, v, + kv_c_and_k_pe_cache, + attn_metadata, has_context) + return self.o_proj(output)[0] @abstractmethod diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py new file mode 100644 index 000000000000..26c41d3091e0 --- /dev/null +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -0,0 +1,253 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Any, Optional + +import torch + +import vllm.envs as envs +from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_forward, + get_aiter_mla_metadata) +from vllm.attention.ops.triton_merge_attn_states import merge_attn_states +# yapf conflicts with isort for this docstring +# yapf: disable +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, + MLACommonPrefillMetadata) + +# yapf: enable + + +def is_aiter_mla_enabled() -> bool: + return envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_ROCM_USE_AITER_MLA + + +class AiterMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "ROCM_AITER_MLA_VLLM_V1" + + @staticmethod + def get_impl_cls() -> type["AiterMLAImpl"]: + return AiterMLAImpl + + @staticmethod + def get_metadata_cls() -> type["AiterMLAMetadata"]: + return AiterMLAMetadata + + @staticmethod + def get_builder_cls() -> type["AiterMLAMetadataBuilder"]: + return AiterMLAMetadataBuilder + + +@dataclass +class AiterMLADecodeMetadata(MLACommonDecodeMetadata): + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_lens: Optional[torch.Tensor] = None + + +class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): + pass + + +class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): + + def __init__(self, runner): + super().__init__(runner) + max_model_len = self.runner.model_config.max_model_len + assert max_model_len == 32768,\ + "AITER MLA requires max_model_len=32768" + assert self.runner.block_size == 1, "AITER MLA" \ + "requires only block size 1." + + self.paged_kv_indices: list[int] = [] + self.paged_kv_indptr: list[int] = [0] + self.paged_kv_last_page_lens: list[int] = [] + self.total_blocks = 0 + + def _update_paged_kv_tensors(self, block_table: torch.Tensor, + seq_lens: torch.Tensor): + # Get the number of valid blocks based on sequence length. + # If seq_len = 16, block_size = 16, + # block_table_bound is 1 with 1 valid block. + # If seq_len = 15, block_size = 16, + # block_table_bound is 0 + 1 with 1 valid block. + for idx in range(seq_lens.shape[0]): + seq_len = seq_lens[idx] + self.total_blocks += len(block_table[idx]) + block_table_bound = seq_len // self.runner.block_size + 1 \ + if seq_len % self.runner.block_size != 0 \ + else seq_len // self.runner.block_size + self.paged_kv_indices.extend(block_table[idx][:block_table_bound]) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + block_table_bound) + + last_page_len = seq_len % self.runner.block_size + if last_page_len == 0: + last_page_len = self.runner.block_size + self.paged_kv_last_page_lens.append(last_page_len) + + def _build_decode(self, input_positions: torch.Tensor, + block_table: torch.Tensor, + seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: + self._update_paged_kv_tensors(block_table, seq_lens) + + if len(self.paged_kv_indptr) > 0: + device = self.runner.device + self.paged_kv_indices.extend( + [0] * (self.total_blocks - len(self.paged_kv_indices))) + paged_kv_indices = torch.tensor(self.paged_kv_indices, + device=device, + dtype=torch.int) + paged_kv_indptr = torch.tensor(self.paged_kv_indptr, + device=device, + dtype=torch.int) + paged_last_page_lens = torch.tensor(self.paged_kv_last_page_lens, + device=device, + dtype=torch.int) + else: + ( + paged_kv_indices, + paged_kv_indptr, + paged_last_page_lens + ) = get_aiter_mla_metadata( + max_batch_size=self.runner.max_num_reqs, + block_size=self.runner.block_size, + max_block_per_batch=self.runner\ + .max_num_blocks_per_req, + device=self.runner.device) + + return AiterMLADecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_lens=paged_last_page_lens) + + +class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "Aiter MLA does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + from aiter import flash_attn_varlen_func + self.flash_attn_varlen_func = flash_attn_varlen_func + + def _get_fwd_prefill_attn_output(self, q: torch.Tensor, k: torch.Tensor, + v: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + has_context: bool) -> torch.Tensor: + assert attn_metadata.prefill is not None + + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=attn_metadata.prefill.query_start_loc, + cu_seqlens_k=attn_metadata.prefill.query_start_loc, + max_seqlen_q=attn_metadata.prefill.max_query_len, + max_seqlen_k=attn_metadata.prefill.max_query_len, + softmax_scale=self.scale, + causal=True, + return_lse=has_context, + ) + + if has_context: + suffix_output, suffix_lse = output + context_output, context_lse = self._compute_prefill_context( \ + q, kv_c_and_k_pe_cache, attn_metadata) + + output = torch.empty_like(suffix_output) + merge_attn_states( + output=output, + prefix_output=context_output, + prefix_lse=context_lse, + suffix_output=suffix_output, + suffix_lse=suffix_lse, + ) + + return output.reshape(-1, self.num_heads * v.shape[-1]) + + def _get_prefill_ctx_attn_output( + self, index: int, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + prefill_metadata: MLACommonPrefillMetadata + ) -> tuple[torch.Tensor, ...]: + assert prefill_metadata.chunked_context is not None + + return self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[index], + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[index], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_lse=True, + ) + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: AiterMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + + B = q_nope.shape[0] + + q = torch.cat([q_nope, q_pe], dim=-1) + o = torch.zeros(B, + self.num_heads, + self.kv_lora_rank, + dtype=q.dtype, + device=q.device) + + kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) + + aiter_mla_decode_forward(q, kv_buffer, o, self.scale, + attn_metadata.decode.paged_kv_indptr, + attn_metadata.decode.paged_kv_indices, + attn_metadata.decode.paged_kv_last_page_lens) + + return self._v_up_proj_and_o_proj(o) From 693c8709deb7cfad64aae5d41ddb6538960edf17 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 16 Apr 2025 13:43:09 +0000 Subject: [PATCH 16/39] uncomment the required packages in common.txt Signed-off-by: vllmellm --- requirements/common.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/requirements/common.txt b/requirements/common.txt index 14a4d2669c46..4df32460c2db 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -43,7 +43,7 @@ watchfiles # required for http server to monitor the updates of TLS files python-json-logger # Used by logging as per examples/other/logging_configuration.md scipy # Required for phi-4-multimodal-instruct ninja # Required for xgrammar, rocm, tpu, xpu -# opentelemetry-sdk>=1.26.0,<1.27.0 # vllm.tracing -# opentelemetry-api>=1.26.0,<1.27.0 # vllm.tracing -# opentelemetry-exporter-otlp>=1.26.0,<1.27.0 # vllm.tracing -# opentelemetry-semantic-conventions-ai>=0.4.1,<0.5.0 # vllm.tracing +opentelemetry-sdk>=1.26.0,<1.27.0 # vllm.tracing +opentelemetry-api>=1.26.0,<1.27.0 # vllm.tracing +opentelemetry-exporter-otlp>=1.26.0,<1.27.0 # vllm.tracing +opentelemetry-semantic-conventions-ai>=0.4.1,<0.5.0 # vllm.tracing From a5a1a54e2c01a2c29a1b472c75c0fe7c481533b9 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 23 Apr 2025 15:46:41 +0000 Subject: [PATCH 17/39] bugfix in building decode metadata for AITER MLA decode forward pass Signed-off-by: vllmellm --- .../attention/backends/mla/rocm_aiter_mla.py | 104 +++++++++--------- 1 file changed, 53 insertions(+), 51 deletions(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 26c41d3091e0..9314cc069986 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -6,8 +6,7 @@ import torch import vllm.envs as envs -from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_forward, - get_aiter_mla_metadata) +from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_forward from vllm.attention.ops.triton_merge_attn_states import merge_attn_states # yapf conflicts with isort for this docstring # yapf: disable @@ -70,62 +69,65 @@ def __init__(self, runner): assert self.runner.block_size == 1, "AITER MLA" \ "requires only block size 1." - self.paged_kv_indices: list[int] = [] - self.paged_kv_indptr: list[int] = [0] - self.paged_kv_last_page_lens: list[int] = [] - self.total_blocks = 0 - - def _update_paged_kv_tensors(self, block_table: torch.Tensor, - seq_lens: torch.Tensor): - # Get the number of valid blocks based on sequence length. - # If seq_len = 16, block_size = 16, - # block_table_bound is 1 with 1 valid block. - # If seq_len = 15, block_size = 16, - # block_table_bound is 0 + 1 with 1 valid block. + self.block_size = self.runner.block_size + + def _get_paged_kv_tensors( + self, block_table: torch.Tensor, + seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: + paged_kv_indices: list[int] = [] + paged_kv_indptr: list[int] = [0] + paged_kv_last_page_lens: list[int] = [] + + device = self.runner.device + block_size = self.runner.block_size + total_blocks = 0 for idx in range(seq_lens.shape[0]): seq_len = seq_lens[idx] - self.total_blocks += len(block_table[idx]) - block_table_bound = seq_len // self.runner.block_size + 1 \ - if seq_len % self.runner.block_size != 0 \ - else seq_len // self.runner.block_size - self.paged_kv_indices.extend(block_table[idx][:block_table_bound]) - self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + - block_table_bound) - - last_page_len = seq_len % self.runner.block_size + total_blocks += seq_len + block_table_list = block_table.tolist()[idx] + block_table_list.insert(0, 0) + block_table_bound = seq_len // block_size + 1 \ + if seq_len % block_size != 0 \ + else seq_len // block_size + paged_kv_indices.extend(block_table_list[:block_table_bound]) + paged_kv_indptr.append(paged_kv_indptr[-1] + block_table_bound) + last_page_len = seq_len % block_size if last_page_len == 0: - last_page_len = self.runner.block_size - self.paged_kv_last_page_lens.append(last_page_len) + last_page_len = block_size + paged_kv_last_page_lens.append(last_page_len) + + if self.runner.use_cuda_graph: + cudagraph_batch_size = self.runner.cudagraph_batch_sizes[-1] + last_paged_kv_indptr = paged_kv_indptr[-1] + paged_kv_indptr.extend([last_paged_kv_indptr] * + cudagraph_batch_size) + paged_kv_last_page_lens.extend([0] * cudagraph_batch_size) + + paged_kv_indices.extend([0] * (total_blocks - len(paged_kv_indices))) + paged_kv_indices_tensor = torch.tensor(paged_kv_indices, + device=device, + dtype=torch.int) + paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr, + device=device, + dtype=torch.int) + paged_last_page_lens_tensor = torch.tensor(paged_kv_last_page_lens, + device=device, + dtype=torch.int) + return ( + paged_kv_indices_tensor, + paged_kv_indptr_tensor, + paged_last_page_lens_tensor, + ) def _build_decode(self, input_positions: torch.Tensor, block_table: torch.Tensor, seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: - self._update_paged_kv_tensors(block_table, seq_lens) - - if len(self.paged_kv_indptr) > 0: - device = self.runner.device - self.paged_kv_indices.extend( - [0] * (self.total_blocks - len(self.paged_kv_indices))) - paged_kv_indices = torch.tensor(self.paged_kv_indices, - device=device, - dtype=torch.int) - paged_kv_indptr = torch.tensor(self.paged_kv_indptr, - device=device, - dtype=torch.int) - paged_last_page_lens = torch.tensor(self.paged_kv_last_page_lens, - device=device, - dtype=torch.int) - else: - ( - paged_kv_indices, - paged_kv_indptr, - paged_last_page_lens - ) = get_aiter_mla_metadata( - max_batch_size=self.runner.max_num_reqs, - block_size=self.runner.block_size, - max_block_per_batch=self.runner\ - .max_num_blocks_per_req, - device=self.runner.device) + + ( + paged_kv_indices, + paged_kv_indptr, + paged_last_page_lens, + ) = self._get_paged_kv_tensors(block_table, seq_lens) return AiterMLADecodeMetadata( input_positions=input_positions, From 38c67c769ab335e702534fffb077ea655ec2abdd Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 24 Apr 2025 06:02:47 +0000 Subject: [PATCH 18/39] optimize the AITER decode metadata build Signed-off-by: vllmellm --- .../attention/backends/mla/rocm_aiter_mla.py | 94 ++++++++++++------- 1 file changed, 60 insertions(+), 34 deletions(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 9314cc069986..3210ea5a27c4 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -74,49 +74,75 @@ def __init__(self, runner): def _get_paged_kv_tensors( self, block_table: torch.Tensor, seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: - paged_kv_indices: list[int] = [] - paged_kv_indptr: list[int] = [0] - paged_kv_last_page_lens: list[int] = [] - device = self.runner.device block_size = self.runner.block_size - total_blocks = 0 - for idx in range(seq_lens.shape[0]): - seq_len = seq_lens[idx] - total_blocks += seq_len - block_table_list = block_table.tolist()[idx] - block_table_list.insert(0, 0) - block_table_bound = seq_len // block_size + 1 \ - if seq_len % block_size != 0 \ - else seq_len // block_size - paged_kv_indices.extend(block_table_list[:block_table_bound]) - paged_kv_indptr.append(paged_kv_indptr[-1] + block_table_bound) + batch_size = seq_lens.shape[0] + + # Calculate total_blocks and number of blocks per sequence + total_blocks = seq_lens.sum().item() + max_blocks = ((seq_lens + block_size - 1) // block_size) + max_blocks_sum = max_blocks.sum().item() + + # Initialize tensors with precomputed sizes + paged_kv_indices = torch.zeros(max_blocks_sum, + dtype=torch.int, + device=device) + paged_kv_indptr = torch.zeros(batch_size + 1, + dtype=torch.int, + device=device) + paged_kv_last_page_lens = torch.zeros(batch_size, + dtype=torch.int, + device=device) + + current_index = 0 + for idx in range(batch_size): + seq_len = seq_lens[idx].item() + block_table_slice = block_table[idx].tolist() + block_table_slice.insert(0, 0) + + block_table_bound = (seq_len + block_size - 1) // block_size + + # Fill paged_kv_indices + paged_kv_indices[current_index:current_index + + block_table_bound] = torch.tensor( + block_table_slice[:block_table_bound], + dtype=torch.int, + device=device) + + # Fill paged_kv_indptr + paged_kv_indptr[idx + 1] = paged_kv_indptr[idx] + block_table_bound + + # Fill paged_kv_last_page_lens last_page_len = seq_len % block_size if last_page_len == 0: last_page_len = block_size - paged_kv_last_page_lens.append(last_page_len) + paged_kv_last_page_lens[idx] = last_page_len + + current_index += block_table_bound if self.runner.use_cuda_graph: cudagraph_batch_size = self.runner.cudagraph_batch_sizes[-1] - last_paged_kv_indptr = paged_kv_indptr[-1] - paged_kv_indptr.extend([last_paged_kv_indptr] * - cudagraph_batch_size) - paged_kv_last_page_lens.extend([0] * cudagraph_batch_size) - - paged_kv_indices.extend([0] * (total_blocks - len(paged_kv_indices))) - paged_kv_indices_tensor = torch.tensor(paged_kv_indices, - device=device, - dtype=torch.int) - paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr, - device=device, - dtype=torch.int) - paged_last_page_lens_tensor = torch.tensor(paged_kv_last_page_lens, - device=device, - dtype=torch.int) + last_paged_kv_indptr = paged_kv_indptr[-1].item() + paged_kv_indptr = torch.cat([ + paged_kv_indptr, + paged_kv_indptr.new_full((cudagraph_batch_size, ), + last_paged_kv_indptr) + ]) + paged_kv_last_page_lens = torch.cat([ + paged_kv_last_page_lens, + paged_kv_last_page_lens.new_zeros((cudagraph_batch_size, )) + ]) + + pad_size = total_blocks - paged_kv_indices.size(0) + paged_kv_indices = torch.cat([ + paged_kv_indices, + torch.zeros(pad_size, dtype=torch.int, device=device) + ]) + return ( - paged_kv_indices_tensor, - paged_kv_indptr_tensor, - paged_last_page_lens_tensor, + paged_kv_indices, + paged_kv_indptr, + paged_kv_last_page_lens, ) def _build_decode(self, input_positions: torch.Tensor, From 643d07f456bcf288d56fe87f1302ee3d296b323c Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 24 Apr 2025 07:02:46 +0000 Subject: [PATCH 19/39] bugfix caused by merging with main Signed-off-by: vllmellm --- vllm/attention/backends/rocm_aiter_mla.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index 853b6dc9b64a..cb16fdd88f74 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -16,7 +16,7 @@ from vllm.attention.backends.utils import (compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) -from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd, +from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_forward, get_aiter_mla_metadata) if TYPE_CHECKING: @@ -404,9 +404,9 @@ def _forward_decode( kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) - aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, - attn_metadata.paged_kv_indptr, - attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_lens) + aiter_mla_decode_forward(q, kv_buffer, o, self.scale, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_lens) - return self._v_up_proj_and_o_proj(o) \ No newline at end of file + return self._v_up_proj_and_o_proj(o) From 6171e5089b5f0dcaeae70b196048df3f5b9c49f1 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 24 Apr 2025 09:26:08 +0000 Subject: [PATCH 20/39] Handle v1 AITER MLA backend in rocm platform Signed-off-by: vllmellm --- vllm/platforms/rocm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 885fb9f663d1..b7ab255caa9f 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -158,13 +158,14 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, raise ValueError( f" The selected backend, {selected_backend.name}," f"does not support block size {block_size}.") - elif selected_backend == _Backend.ROCM_AITER_MLA: + elif selected_backend == _Backend.ROCM_AITER_MLA \ + or selected_backend == _Backend.ROCM_AITER_MLA_VLLM_V1: if block_size == 1: if use_v1: logger.info("Using AITER MLA backend.") return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 else: - logger.info("Using AITER MLA backend.") + logger.info("Using AITER MLA V1 backend.") return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501 else: raise ValueError( From 905cec9333836d9cfb2004b9cd74b626542a2af5 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 1 May 2025 07:16:05 +0000 Subject: [PATCH 21/39] update AITER MLA decode metadata build Signed-off-by: vllmellm --- .../attention/backends/mla/rocm_aiter_mla.py | 94 +++++-------------- 1 file changed, 25 insertions(+), 69 deletions(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 84d72a37a121..320db75eaa9f 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -50,7 +50,7 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): paged_kv_indices: Optional[torch.Tensor] = None # The number of entries in the last page of each request in # the paged kv cache, shape: [batch_size] - paged_kv_last_page_lens: Optional[torch.Tensor] = None + paged_kv_last_page_len: Optional[torch.Tensor] = None class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): @@ -72,75 +72,29 @@ def __init__(self, runner): def _get_paged_kv_tensors( self, block_table: torch.Tensor, seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: - device = self.runner.device - block_size = self.runner.block_size - batch_size = seq_lens.shape[0] - - # Calculate total_blocks and number of blocks per sequence - total_blocks = seq_lens.sum().item() - max_blocks = ((seq_lens + block_size - 1) // block_size) - max_blocks_sum = max_blocks.sum().item() - - # Initialize tensors with precomputed sizes - paged_kv_indices = torch.zeros(max_blocks_sum, - dtype=torch.int, - device=device) - paged_kv_indptr = torch.zeros(batch_size + 1, - dtype=torch.int, - device=device) - paged_kv_last_page_lens = torch.zeros(batch_size, - dtype=torch.int, - device=device) - - current_index = 0 - for idx in range(batch_size): - seq_len = seq_lens[idx].item() - block_table_slice = block_table[idx].tolist() - block_table_slice.insert(0, 0) - - block_table_bound = (seq_len + block_size - 1) // block_size - - # Fill paged_kv_indices - paged_kv_indices[current_index:current_index + - block_table_bound] = torch.tensor( - block_table_slice[:block_table_bound], - dtype=torch.int, - device=device) - - # Fill paged_kv_indptr - paged_kv_indptr[idx + 1] = paged_kv_indptr[idx] + block_table_bound - - # Fill paged_kv_last_page_lens - last_page_len = seq_len % block_size - if last_page_len == 0: - last_page_len = block_size - paged_kv_last_page_lens[idx] = last_page_len - - current_index += block_table_bound - - if self.runner.use_cuda_graph: - cudagraph_batch_size = self.runner.cudagraph_batch_sizes[-1] - last_paged_kv_indptr = paged_kv_indptr[-1].item() - paged_kv_indptr = torch.cat([ - paged_kv_indptr, - paged_kv_indptr.new_full((cudagraph_batch_size, ), - last_paged_kv_indptr) - ]) - paged_kv_last_page_lens = torch.cat([ - paged_kv_last_page_lens, - paged_kv_last_page_lens.new_zeros((cudagraph_batch_size, )) - ]) - - pad_size = total_blocks - paged_kv_indices.size(0) - paged_kv_indices = torch.cat([ - paged_kv_indices, - torch.zeros(pad_size, dtype=torch.int, device=device) + page_size = self.runner.block_size + block_table_bounds = (seq_lens + page_size - 1) // page_size + + mask = (torch.arange(block_table.size(1), + dtype=block_table.dtype, + device=block_table.device).unsqueeze(0) + < block_table_bounds.unsqueeze(1)) + paged_kv_indices = block_table[mask] + + paged_kv_indptr = torch.cat([ + torch.zeros(1, + dtype=block_table_bounds.dtype, + device=block_table_bounds.device), + block_table_bounds.cumsum(dim=0, dtype=torch.int32) ]) + paged_kv_last_page_len = seq_lens % page_size + paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, + page_size, paged_kv_last_page_len) return ( paged_kv_indices, paged_kv_indptr, - paged_kv_last_page_lens, + paged_kv_last_page_len, ) def _build_decode(self, input_positions: torch.Tensor, @@ -150,16 +104,18 @@ def _build_decode(self, input_positions: torch.Tensor, ( paged_kv_indices, paged_kv_indptr, - paged_last_page_lens, + paged_last_page_len, ) = self._get_paged_kv_tensors(block_table, seq_lens) - return AiterMLADecodeMetadata( + attn_metadata = AiterMLADecodeMetadata( input_positions=input_positions, block_table=block_table, seq_lens=seq_lens, paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, - paged_kv_last_page_lens=paged_last_page_lens) + paged_kv_last_page_len=paged_last_page_len) + + return attn_metadata class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): @@ -238,6 +194,6 @@ def _forward_decode( aiter_mla_decode_forward(q, kv_buffer, o, self.scale, attn_metadata.decode.paged_kv_indptr, attn_metadata.decode.paged_kv_indices, - attn_metadata.decode.paged_kv_last_page_lens) + attn_metadata.decode.paged_kv_last_page_len) return self._v_up_proj_and_o_proj(o) From 20e769ee08710edbde1ced97c46767c96ca4b1f2 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 1 May 2025 07:20:54 +0000 Subject: [PATCH 22/39] update AITER commit Signed-off-by: vllmellm --- docker/Dockerfile.rocm_base | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index 1776b26d445c..05192eb69b54 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="1a7f4dfa" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="7e1ed08" +ARG AITER_BRANCH="5a77249" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base From 90daf6edb9fb239f1016b28fb5d9bb72573d0ecf Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 1 May 2025 08:05:33 +0000 Subject: [PATCH 23/39] update proper logging info in selected backend as well as updating attention selector backend Signed-off-by: vllmellm --- tests/kernels/attention/test_attention_selector.py | 5 ++++- vllm/platforms/rocm.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index b0414244c215..436cb430817e 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -102,7 +102,10 @@ def test_env( block_size, False, use_mla=use_mla) - assert backend.get_name() == name + if use_v1 and name != "TRITON_MLA": + assert backend.get_name() == f"{name}_VLLM_V1" + else: + assert backend.get_name() == name else: with pytest.raises(ValueError) as exc_info: get_attn_backend(16, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 08410cf50108..b8797a121b87 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -162,7 +162,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info("Using AITER MLA backend.") return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 else: - logger.info("Using AITER MLA V1 backend.") + logger.info("Using AITER MLA backend on V1 engine.") return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501 else: raise ValueError( From f68e926597fb5488acca613cf41e0befac770ee4 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 1 May 2025 08:50:58 +0000 Subject: [PATCH 24/39] fix wrong sync merge to main Signed-off-by: vllmellm --- vllm/attention/backends/mla/common.py | 41 +-------------------------- 1 file changed, 1 insertion(+), 40 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 05457b50d292..bfcda587727a 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1214,45 +1214,6 @@ def get_and_maybe_dequant_weights(layer: LinearBase): # Convert from (L, N, P) to (N, P, L) self.W_UK_T = W_UK.permute(1, 2, 0) - def _get_prefill_ctx_attn_output( - self, index: int, q: torch.Tensor, k: torch.Tensor, - v: torch.Tensor, - metadata: MLACommonMetadata) -> Tuple[torch.Tensor, ...]: - assert metadata.context_chunk_cu_seq_lens is not None - assert metadata.context_chunk_max_seq_lens is not None - - # For MLA the v head dim is smaller than qk head dim so we pad - # out v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) - if is_vllm_fa: - attn_output, attn_softmax_lse = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=metadata.query_start_loc, - cu_seqlens_k=metadata.context_chunk_cu_seq_lens[index], - max_seqlen_q=metadata.max_query_len, - max_seqlen_k=metadata.context_chunk_max_seq_lens[index], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_softmax_lse=True, - ) - else: - attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=metadata.query_start_loc, - cu_seqlens_k=metadata.context_chunk_cu_seq_lens[index], - max_seqlen_q=metadata.max_query_len, - max_seqlen_k=metadata.context_chunk_max_seq_lens[index], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_attn_probs=True, - ) - return attn_output, attn_softmax_lse - def _compute_prefill_context( self, q: torch.Tensor, @@ -1480,4 +1441,4 @@ def forward( output[num_prefill_tokens:] = self._forward_decode( decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) - return output + return output \ No newline at end of file From 821f475c8d320b125970ce8c68faa6c66dff5996 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 1 May 2025 08:55:22 +0000 Subject: [PATCH 25/39] fix pre-commit Signed-off-by: vllmellm --- vllm/v1/attention/backends/mla/common.py | 17 ++++++++--------- .../v1/attention/backends/mla/rocm_aiter_mla.py | 17 ++++++++--------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 05617d704611..b31a8024b5dc 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -187,7 +187,7 @@ import functools from abc import abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar import torch @@ -649,14 +649,13 @@ def __init__( self.vllm_flash_attn_version == 3 and current_platform.get_device_capability()[0] == 9) - def _flash_attn_varlen_diff_headdims( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - softmax_scale: float, - return_softmax_lse: bool = False, - **kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]: + def _flash_attn_varlen_diff_headdims(self, + q, + k, + v, + return_softmax_lse=False, + softmax_scale=None, + **kwargs): maybe_padded_v = v if self._pad_v: maybe_padded_v = torch.nn.functional.pad( diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 320db75eaa9f..fa96958548f4 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, Optional import torch @@ -151,14 +151,13 @@ def __init__( from aiter import flash_attn_varlen_func self.flash_attn_varlen_func = flash_attn_varlen_func - def _flash_attn_varlen_diff_headdims( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - softmax_scale: float, - return_softmax_lse: bool = False, - **kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]: + def _flash_attn_varlen_diff_headdims(self, + q, + k, + v, + return_softmax_lse=False, + softmax_scale=None, + **kwargs): output = self.flash_attn_varlen_func( q=q, k=k, From 7cc28d28517c993d131390a2e092a4eed1aa72ee Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 1 May 2025 09:09:40 +0000 Subject: [PATCH 26/39] add the missing line in common.py Signed-off-by: vllmellm --- vllm/attention/backends/mla/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index bfcda587727a..382a9a6d44d8 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1441,4 +1441,4 @@ def forward( output[num_prefill_tokens:] = self._forward_decode( decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) - return output \ No newline at end of file + return output From 29fc06002b897cd16cdb908128b323673b2c75a2 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 2 May 2025 16:00:51 +0000 Subject: [PATCH 27/39] fix wrong logger info message Signed-off-by: vllmellm --- vllm/platforms/rocm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index b8797a121b87..55563312a36a 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -159,10 +159,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, or selected_backend == _Backend.ROCM_AITER_MLA_VLLM_V1: if block_size == 1: if use_v1: - logger.info("Using AITER MLA backend.") + logger.info("Using AITER MLA backend on V1 engine.") return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 else: - logger.info("Using AITER MLA backend on V1 engine.") + logger.info("Using AITER MLA backend") return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501 else: raise ValueError( From 7f1ed77970dafcc52e6544f9a978371fa328c8ac Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 5 May 2025 09:12:44 +0000 Subject: [PATCH 28/39] clean code and fix AITER block scaled kernel fake impl in v1 engine Signed-off-by: vllmellm --- vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py | 2 +- vllm/v1/attention/backends/mla/rocm_aiter_mla.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index acaa93f5a23e..7d7bce9ec6ab 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -145,7 +145,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake( block_shape: List[int], smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor: - return torch.empty_like(a1, dtype=torch.bf16) + return torch.empty_like(a1, dtype=hidden_states_dtype) def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index fa96958548f4..3cc814b5a44c 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -67,8 +67,6 @@ def __init__(self, runner): assert self.runner.block_size == 1, "AITER MLA" \ "requires only block size 1." - self.block_size = self.runner.block_size - def _get_paged_kv_tensors( self, block_table: torch.Tensor, seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: From 11a89852bd39effe419f5301df7204b94c26d9da Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 5 May 2025 10:14:58 +0000 Subject: [PATCH 29/39] use env variable to adjsut timeout for model execution Signed-off-by: vllmellm --- vllm/envs.py | 5 +++++ vllm/v1/executor/multiproc_executor.py | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/envs.py b/vllm/envs.py index ea40bfff11b5..7a0402b0b67c 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -67,6 +67,7 @@ VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_RPC_TIMEOUT: int = 10000 # ms + VLLM_EXECUTE_MODEL_TIMEOUT: int = 40 #s VLLM_PLUGINS: Optional[list[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_USE_TRITON_AWQ: bool = False @@ -488,6 +489,10 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_RPC_TIMEOUT": lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")), + # Time in seconds for the model execution. + "VLLM_EXECUTE_MODEL_TIMEOUT": + lambda: int(os.getenv("VLLM_EXECUTE_MODEL_TIMEOUT", "40")), + # a list of plugin names to load, separated by commas. # if this is not set, it means all plugins will be loaded # if this is set to an empty string, no plugins will be loaded diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index cb125bf4bf17..248071a36863 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -19,6 +19,7 @@ import cloudpickle +import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel) @@ -38,7 +39,7 @@ POLLING_TIMEOUT_MS = 5000 POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 -EXECUTE_MODEL_TIMEOUT_S = 40 +EXECUTE_MODEL_TIMEOUT_S = envs.VLLM_EXECUTE_MODEL_TIMEOUT class MultiprocExecutor(Executor): From 825f387cf161426f1effe5570b53ac9352b4d9f1 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 5 May 2025 13:01:58 +0000 Subject: [PATCH 30/39] remove unnecessary backend check Signed-off-by: vllmellm --- vllm/engine/arg_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6ddc10c79b3d..a26019e5eb41 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1299,7 +1299,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: "FLASHINFER", "FLASHINFER_VLLM_V1", "ROCM_AITER_MLA", - "ROCM_AITER_MLA_VLLM_V1", ] if (envs.is_set("VLLM_ATTENTION_BACKEND") and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): From cbeb0df5833eec3f4a73ff39646de083bb2789b5 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 5 May 2025 14:00:54 +0000 Subject: [PATCH 31/39] make model execution timeout in envs variable rocm specific variable and include a warning message in rocm platform. Signed-off-by: vllmellm --- vllm/envs.py | 8 ++++---- vllm/platforms/rocm.py | 7 +++++++ vllm/v1/executor/multiproc_executor.py | 4 +++- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 7a0402b0b67c..d21ddafe1895 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -67,7 +67,6 @@ VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_RPC_TIMEOUT: int = 10000 # ms - VLLM_EXECUTE_MODEL_TIMEOUT: int = 40 #s VLLM_PLUGINS: Optional[list[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_USE_TRITON_AWQ: bool = False @@ -85,6 +84,7 @@ VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True + VLLM_ROCM_EXECUTE_MODEL_TIMEOUT: int = 250 #s VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False @@ -489,9 +489,9 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_RPC_TIMEOUT": lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")), - # Time in seconds for the model execution. - "VLLM_EXECUTE_MODEL_TIMEOUT": - lambda: int(os.getenv("VLLM_EXECUTE_MODEL_TIMEOUT", "40")), + # Time in seconds for the model execution in ROCm platforms. + "VLLM_ROCM_EXECUTE_MODEL_TIMEOUT": + lambda: int(os.getenv("VLLM_ROCM_EXECUTE_MODEL_TIMEOUT", "250")), # a list of plugin names to load, separated by commas. # if this is not set, it means all plugins will be loaded diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 55563312a36a..32a1c4be0bbc 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -160,6 +160,13 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, if block_size == 1: if use_v1: logger.info("Using AITER MLA backend on V1 engine.") + logger.warning( + "Increasing the model execution timeout" + "using the VLLM_ROCM_EXECUTE_MODEL_TIMEOUT" + "environment variable is recommended" + "if timeout error is encountered" + "when running %s" + "backend on V1 engine.", selected_backend) return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 else: logger.info("Using AITER MLA backend") diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 248071a36863..0643c7f50183 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -28,6 +28,7 @@ from vllm.executor.multiproc_worker_utils import ( _add_prefix, set_multiprocessing_worker_envs) from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.utils import (get_distributed_init_method, get_mp_context, get_open_port) from vllm.v1.executor.abstract import Executor, FailureCallback @@ -39,7 +40,8 @@ POLLING_TIMEOUT_MS = 5000 POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 -EXECUTE_MODEL_TIMEOUT_S = envs.VLLM_EXECUTE_MODEL_TIMEOUT +EXECUTE_MODEL_TIMEOUT_S = (envs.VLLM_ROCM_EXECUTE_MODEL_TIMEOUT + if current_platform.is_rocm() else 40) class MultiprocExecutor(Executor): From 2218bbcf58f644da5a2e7a67d551c78c170c5243 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 6 May 2025 03:30:11 +0000 Subject: [PATCH 32/39] fix unit-test Signed-off-by: vllmellm --- tests/kernels/attention/test_rocm_attention_selector.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/kernels/attention/test_rocm_attention_selector.py b/tests/kernels/attention/test_rocm_attention_selector.py index 4cf7bcb01d4d..6ffe27abf709 100644 --- a/tests/kernels/attention/test_rocm_attention_selector.py +++ b/tests/kernels/attention/test_rocm_attention_selector.py @@ -48,7 +48,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA") backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, False, True) - assert backend.get_name() == "ROCM_AITER_MLA" + assert (backend.get_name() == "ROCM_AITER_MLA" + or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1") # If attention backend is None # If use_mla is true @@ -58,4 +59,5 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): m.setenv("VLLM_ROCM_USE_AITER", "1") backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, False, True) - assert backend.get_name() == "ROCM_AITER_MLA" + assert (backend.get_name() == "ROCM_AITER_MLA" + or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1") From 44d813f2796c128d278fc8cb230165775317850d Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 6 May 2025 03:55:46 +0000 Subject: [PATCH 33/39] bugfix to update AITER MLA V1 decode forward after sync with main Signed-off-by: vllmellm --- vllm/v1/attention/backends/mla/rocm_aiter_mla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 3cc814b5a44c..f1f8adff4434 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -193,4 +193,4 @@ def _forward_decode( attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_last_page_len) - return self._v_up_proj_and_o_proj(o) + return self._v_up_proj(o) From 423c0befc60cbddf74f933444e9f7957ef912439 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 6 May 2025 16:47:33 +0800 Subject: [PATCH 34/39] Update vllm/platforms/rocm.py Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> --- vllm/platforms/rocm.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 895ad720f76c..30e31781c5f6 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -173,11 +173,11 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, if use_v1: logger.info("Using AITER MLA backend on V1 engine.") logger.warning( - "Increasing the model execution timeout" - "using the VLLM_ROCM_EXECUTE_MODEL_TIMEOUT" - "environment variable is recommended" - "if timeout error is encountered" - "when running %s" + "Increasing the model execution timeout " + "using the VLLM_ROCM_EXECUTE_MODEL_TIMEOUT " + "environment variable is recommended " + "if timeout error is encountered " + "when running %s " "backend on V1 engine.", selected_backend) return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 else: From 58d79bd5e51b6851a1e78933ce4e29e71b8c2369 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 6 May 2025 08:54:10 +0000 Subject: [PATCH 35/39] address PR comments Signed-off-by: vllmellm --- vllm/attention/backends/rocm_aiter_mla.py | 10 +++++----- vllm/attention/ops/rocm_aiter_mla.py | 2 +- vllm/v1/attention/backends/mla/rocm_aiter_mla.py | 10 +++++----- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index 0ed1a6e6f0be..2984bc1dad64 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -16,7 +16,7 @@ from vllm.attention.backends.utils import (compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) -from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_forward, +from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd, get_aiter_mla_metadata) if TYPE_CHECKING: @@ -404,9 +404,9 @@ def _forward_decode( kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) - aiter_mla_decode_forward(q, kv_buffer, o, self.scale, - attn_metadata.paged_kv_indptr, - attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_lens) + aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_lens) return self._v_up_proj(o) diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index e9306794e02d..3348d18804aa 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -23,7 +23,7 @@ def get_aiter_mla_metadata(max_batch_size: int, block_size: int, return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens -def aiter_mla_decode_forward( +def aiter_mla_decode_fwd( q: torch.Tensor, kv_buffer: torch.Tensor, o: torch.Tensor, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index f1f8adff4434..0c6e908afab3 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -6,7 +6,7 @@ import torch import vllm.envs as envs -from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_forward +from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd # yapf conflicts with isort for this docstring # yapf: disable from vllm.v1.attention.backends.mla.common import (MLACommonBackend, @@ -188,9 +188,9 @@ def _forward_decode( kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) - aiter_mla_decode_forward(q, kv_buffer, o, self.scale, - attn_metadata.decode.paged_kv_indptr, - attn_metadata.decode.paged_kv_indices, - attn_metadata.decode.paged_kv_last_page_len) + aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, + attn_metadata.decode.paged_kv_indptr, + attn_metadata.decode.paged_kv_indices, + attn_metadata.decode.paged_kv_last_page_len) return self._v_up_proj(o) From 95644ea818313d0577047a7d2f86e14e70446c43 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 7 May 2025 07:32:12 +0000 Subject: [PATCH 36/39] update assertion message Signed-off-by: vllmellm --- vllm/v1/attention/backends/mla/rocm_aiter_mla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 0c6e908afab3..909e737dcea8 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -65,7 +65,7 @@ def __init__(self, runner): assert max_model_len == 32768,\ "AITER MLA requires max_model_len=32768" assert self.runner.block_size == 1, "AITER MLA" \ - "requires only block size 1." + "only supports block size 1." def _get_paged_kv_tensors( self, block_table: torch.Tensor, From f41d616c3ffd241645b528455bc3b2065bdae253 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 8 May 2025 04:13:07 +0000 Subject: [PATCH 37/39] remove env variable for model execution timeout Signed-off-by: vllmellm --- vllm/envs.py | 5 ----- vllm/v1/executor/multiproc_executor.py | 5 +---- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index d21ddafe1895..ea40bfff11b5 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -84,7 +84,6 @@ VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True - VLLM_ROCM_EXECUTE_MODEL_TIMEOUT: int = 250 #s VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False @@ -489,10 +488,6 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_RPC_TIMEOUT": lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")), - # Time in seconds for the model execution in ROCm platforms. - "VLLM_ROCM_EXECUTE_MODEL_TIMEOUT": - lambda: int(os.getenv("VLLM_ROCM_EXECUTE_MODEL_TIMEOUT", "250")), - # a list of plugin names to load, separated by commas. # if this is not set, it means all plugins will be loaded # if this is set to an empty string, no plugins will be loaded diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 0643c7f50183..f925463d518b 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -19,7 +19,6 @@ import cloudpickle -import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel) @@ -28,7 +27,6 @@ from vllm.executor.multiproc_worker_utils import ( _add_prefix, set_multiprocessing_worker_envs) from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.utils import (get_distributed_init_method, get_mp_context, get_open_port) from vllm.v1.executor.abstract import Executor, FailureCallback @@ -40,8 +38,7 @@ POLLING_TIMEOUT_MS = 5000 POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 -EXECUTE_MODEL_TIMEOUT_S = (envs.VLLM_ROCM_EXECUTE_MODEL_TIMEOUT - if current_platform.is_rocm() else 40) +EXECUTE_MODEL_TIMEOUT_S = 250 class MultiprocExecutor(Executor): From 3ee787ecc3b1f1d7168edb7f9aec7d02f0317d32 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 8 May 2025 04:38:06 +0000 Subject: [PATCH 38/39] remove unnecessary warning Signed-off-by: vllmellm --- vllm/platforms/rocm.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 310b93f4853e..a0403e5325cb 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -172,13 +172,6 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, if block_size == 1: if use_v1: logger.info("Using AITER MLA backend on V1 engine.") - logger.warning( - "Increasing the model execution timeout " - "using the VLLM_ROCM_EXECUTE_MODEL_TIMEOUT " - "environment variable is recommended " - "if timeout error is encountered " - "when running %s " - "backend on V1 engine.", selected_backend) return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 else: logger.info("Using AITER MLA backend") From f68841875ed3e9ce0e73ba70775a1443f127329f Mon Sep 17 00:00:00 2001 From: vllmellm Date: Thu, 8 May 2025 05:38:09 +0000 Subject: [PATCH 39/39] keep model execution timeout as original value in main branch Signed-off-by: vllmellm --- vllm/v1/executor/multiproc_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 0222c598e265..74b226b45424 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -38,7 +38,7 @@ POLLING_TIMEOUT_MS = 5000 POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 -EXECUTE_MODEL_TIMEOUT_S = 250 +EXECUTE_MODEL_TIMEOUT_S = 40 class MultiprocExecutor(Executor):