-
-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[FEAT][ROCm]: Support AITER MLA on V1 Engine #17523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 33 commits
f782c66
42d5c62
565a3fd
645f400
22c8726
5dc1348
12f8023
da8c69f
1ea5718
681d777
9ada055
1ceb3b9
20a3f07
194a42a
a9a02d5
02a4fb3
95213e2
6e48433
0265f20
693c870
a5a1a54
38c67c7
74c9cb3
643d07f
6171e50
905cec9
20e769e
455bbf2
90daf6e
f68e926
821f475
1a6ba99
7cc28d2
29fc060
7f1ed77
11a8985
825f387
cbeb0df
2218bbc
cb98504
44d813f
423c0be
58d79bd
95644ea
f41d616
56d2254
3ee787e
f688418
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,9 @@ | |
|
|
||
| import torch | ||
|
|
||
| from vllm.platforms import current_platform | ||
| from vllm.utils import direct_register_custom_op | ||
|
|
||
|
|
||
| def get_aiter_mla_metadata(max_batch_size: int, block_size: int, | ||
| max_block_per_batch: int, | ||
|
|
@@ -20,7 +23,7 @@ def get_aiter_mla_metadata(max_batch_size: int, block_size: int, | |
| return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens | ||
|
|
||
|
|
||
| def aiter_mla_decode_fwd( | ||
| def aiter_mla_decode_forward( | ||
|
||
| q: torch.Tensor, | ||
| kv_buffer: torch.Tensor, | ||
| o: torch.Tensor, | ||
|
|
@@ -30,6 +33,28 @@ def aiter_mla_decode_fwd( | |
| kv_last_page_lens: Optional[torch.Tensor] = None, | ||
| logit_cap: float = 0.0, | ||
| ): | ||
|
|
||
| torch.ops.vllm.rocm_aiter_mla_decode_fwd(q, | ||
| kv_buffer.view( | ||
| -1, 1, 1, q.shape[-1]), | ||
| o, | ||
| kv_indptr, | ||
| kv_indices, | ||
| kv_last_page_lens, | ||
| sm_scale=sm_scale, | ||
| logit_cap=logit_cap) | ||
|
|
||
|
|
||
| def mla_decode_fwd_impl( | ||
| q: torch.Tensor, | ||
| kv_buffer: torch.Tensor, | ||
| o: torch.Tensor, | ||
| kv_indptr: Optional[torch.Tensor] = None, | ||
| kv_indices: Optional[torch.Tensor] = None, | ||
| kv_last_page_lens: Optional[torch.Tensor] = None, | ||
| sm_scale: float = 1.0, | ||
| logit_cap: float = 0.0, | ||
| ) -> None: | ||
| from aiter.mla import mla_decode_fwd | ||
|
|
||
| mla_decode_fwd(q, | ||
|
|
@@ -40,3 +65,24 @@ def aiter_mla_decode_fwd( | |
| kv_last_page_lens, | ||
| sm_scale=sm_scale, | ||
| logit_cap=logit_cap) | ||
|
|
||
|
|
||
| def mla_decode_fwd_fake( | ||
| q: torch.Tensor, | ||
| kv_buffer: torch.Tensor, | ||
| o: torch.Tensor, | ||
| kv_indptr: Optional[torch.Tensor] = None, | ||
| kv_indices: Optional[torch.Tensor] = None, | ||
| kv_last_page_lens: Optional[torch.Tensor] = None, | ||
| sm_scale: float = 1.0, | ||
| logit_cap: float = 0.0, | ||
| ) -> None: | ||
| pass | ||
|
|
||
|
|
||
| if current_platform.is_rocm(): | ||
| direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd", | ||
| op_func=mla_decode_fwd_impl, | ||
| mutates_args=["o"], | ||
| fake_impl=mla_decode_fwd_fake, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious is the fake_impl necessary here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @houseroad Yes without it there is error from torch dynamo while building cuda graphs. |
||
| tags=[torch.Tag.needs_fixed_stride_order]) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1298,6 +1298,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: | |
| "FLASHMLA", | ||
| "FLASHINFER", | ||
| "FLASHINFER_VLLM_V1", | ||
| "ROCM_AITER_MLA", | ||
| "ROCM_AITER_MLA_VLLM_V1", | ||
|
||
| ] | ||
| if (envs.is_set("VLLM_ATTENTION_BACKEND") | ||
| and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -494,11 +494,12 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, | |||||||||||
| max_context_chunk = (self.chunked_prefill_workspace_size // | ||||||||||||
| num_prefills_with_context_cpu) | ||||||||||||
|
|
||||||||||||
| # align max_context_chunk to page_size by rounding down, | ||||||||||||
| # currently the `gather_cache` kernel cannot handle | ||||||||||||
| # `context_chunk_starts` that are not aligned to page_size | ||||||||||||
| max_context_chunk = round_down(max_context_chunk, | ||||||||||||
| self.page_size) | ||||||||||||
| if self.aot_schedule: | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you explain this a bit? Why was this change necessary?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @SageMoore the vllm/vllm/v1/attention/backends/mla/common.py Lines 355 to 359 in ba7703e
You may want to ask the author about this as these line changes were added in this PR. anyways if |
||||||||||||
| # align max_context_chunk to page_size by rounding down, | ||||||||||||
| # currently the `gather_cache` kernel cannot handle | ||||||||||||
| # `context_chunk_starts` that are not aligned to page_size | ||||||||||||
| max_context_chunk = round_down(max_context_chunk, | ||||||||||||
| self.page_size) | ||||||||||||
|
|
||||||||||||
| assert max_context_chunk > 0 | ||||||||||||
| num_chunks = cdiv(max_context_len_cpu, max_context_chunk) | ||||||||||||
|
|
||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,198 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import Any, Optional | ||
|
|
||
| import torch | ||
|
|
||
| import vllm.envs as envs | ||
| from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_forward | ||
| # yapf conflicts with isort for this docstring | ||
| # yapf: disable | ||
| from vllm.v1.attention.backends.mla.common import (MLACommonBackend, | ||
| MLACommonDecodeMetadata, | ||
| MLACommonImpl, | ||
| MLACommonMetadata, | ||
| MLACommonMetadataBuilder) | ||
|
|
||
| # yapf: enable | ||
|
|
||
|
|
||
| def is_aiter_mla_enabled() -> bool: | ||
| return envs.VLLM_ROCM_USE_AITER \ | ||
| and envs.VLLM_ROCM_USE_AITER_MLA | ||
|
|
||
|
|
||
| class AiterMLABackend(MLACommonBackend): | ||
|
|
||
| @staticmethod | ||
| def get_name() -> str: | ||
| return "ROCM_AITER_MLA_VLLM_V1" | ||
|
|
||
| @staticmethod | ||
| def get_impl_cls() -> type["AiterMLAImpl"]: | ||
| return AiterMLAImpl | ||
|
|
||
| @staticmethod | ||
| def get_metadata_cls() -> type["AiterMLAMetadata"]: | ||
| return AiterMLAMetadata | ||
|
|
||
| @staticmethod | ||
| def get_builder_cls() -> type["AiterMLAMetadataBuilder"]: | ||
| return AiterMLAMetadataBuilder | ||
|
|
||
|
|
||
| @dataclass | ||
| class AiterMLADecodeMetadata(MLACommonDecodeMetadata): | ||
| # The indptr of the paged kv cache, shape: [batch_size + 1] | ||
| paged_kv_indptr: Optional[torch.Tensor] = None | ||
| # The page indices of the paged kv cache | ||
| paged_kv_indices: Optional[torch.Tensor] = None | ||
| # The number of entries in the last page of each request in | ||
| # the paged kv cache, shape: [batch_size] | ||
| paged_kv_last_page_len: Optional[torch.Tensor] = None | ||
|
|
||
|
|
||
| class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): | ||
| pass | ||
|
|
||
|
|
||
| class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): | ||
|
|
||
| def __init__(self, runner): | ||
| super().__init__(runner) | ||
| max_model_len = self.runner.model_config.max_model_len | ||
| assert max_model_len == 32768,\ | ||
| "AITER MLA requires max_model_len=32768" | ||
| assert self.runner.block_size == 1, "AITER MLA" \ | ||
| "requires only block size 1." | ||
|
||
|
|
||
| self.block_size = self.runner.block_size | ||
|
|
||
| def _get_paged_kv_tensors( | ||
| self, block_table: torch.Tensor, | ||
| seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: | ||
| page_size = self.runner.block_size | ||
| block_table_bounds = (seq_lens + page_size - 1) // page_size | ||
|
|
||
| mask = (torch.arange(block_table.size(1), | ||
| dtype=block_table.dtype, | ||
| device=block_table.device).unsqueeze(0) | ||
| < block_table_bounds.unsqueeze(1)) | ||
| paged_kv_indices = block_table[mask] | ||
|
|
||
| paged_kv_indptr = torch.cat([ | ||
| torch.zeros(1, | ||
| dtype=block_table_bounds.dtype, | ||
| device=block_table_bounds.device), | ||
| block_table_bounds.cumsum(dim=0, dtype=torch.int32) | ||
| ]) | ||
|
|
||
| paged_kv_last_page_len = seq_lens % page_size | ||
| paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, | ||
| page_size, paged_kv_last_page_len) | ||
| return ( | ||
| paged_kv_indices, | ||
| paged_kv_indptr, | ||
| paged_kv_last_page_len, | ||
| ) | ||
|
|
||
| def _build_decode(self, input_positions: torch.Tensor, | ||
| block_table: torch.Tensor, | ||
| seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: | ||
|
|
||
| ( | ||
| paged_kv_indices, | ||
| paged_kv_indptr, | ||
| paged_last_page_len, | ||
| ) = self._get_paged_kv_tensors(block_table, seq_lens) | ||
|
|
||
| attn_metadata = AiterMLADecodeMetadata( | ||
| input_positions=input_positions, | ||
| block_table=block_table, | ||
| seq_lens=seq_lens, | ||
| paged_kv_indptr=paged_kv_indptr, | ||
| paged_kv_indices=paged_kv_indices, | ||
| paged_kv_last_page_len=paged_last_page_len) | ||
|
|
||
| return attn_metadata | ||
|
|
||
|
|
||
| class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): | ||
|
|
||
| def __init__( | ||
| self, | ||
| num_heads: int, | ||
| head_size: int, | ||
| scale: float, | ||
| num_kv_heads: int, | ||
| alibi_slopes: Optional[list[float]], | ||
| sliding_window: Optional[int], | ||
| kv_cache_dtype: str, | ||
| blocksparse_params: Optional[dict[str, Any]], | ||
| logits_soft_cap: Optional[float], | ||
| attn_type: str, | ||
| # MLA Specific Arguments | ||
| **mla_args) -> None: | ||
| super().__init__(num_heads, head_size, scale, num_kv_heads, | ||
| alibi_slopes, sliding_window, kv_cache_dtype, | ||
| blocksparse_params, logits_soft_cap, attn_type, | ||
| **mla_args) | ||
|
|
||
| unsupported_features = [ | ||
| alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap | ||
| ] | ||
| if any(unsupported_features): | ||
| raise NotImplementedError( | ||
| "Aiter MLA does not support one of the following: " | ||
| "alibi_slopes, sliding_window, blocksparse_params, " | ||
| "logits_soft_cap") | ||
|
|
||
| from aiter import flash_attn_varlen_func | ||
| self.flash_attn_varlen_func = flash_attn_varlen_func | ||
|
|
||
| def _flash_attn_varlen_diff_headdims(self, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is this used?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @SageMoore you may want to check
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see now. I must have mistyped the string when I searched for it :). |
||
| q, | ||
| k, | ||
| v, | ||
| return_softmax_lse=False, | ||
| softmax_scale=None, | ||
| **kwargs): | ||
| output = self.flash_attn_varlen_func( | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| softmax_scale=softmax_scale, | ||
| return_lse=return_softmax_lse, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| return output | ||
|
|
||
| def _forward_decode( | ||
| self, | ||
| q_nope: torch.Tensor, | ||
| q_pe: torch.Tensor, | ||
| kv_c_and_k_pe_cache: torch.Tensor, | ||
| attn_metadata: AiterMLAMetadata, | ||
| ) -> torch.Tensor: | ||
| assert kv_c_and_k_pe_cache.numel() > 0 | ||
| assert attn_metadata.decode is not None | ||
|
|
||
| B = q_nope.shape[0] | ||
|
|
||
| q = torch.cat([q_nope, q_pe], dim=-1) | ||
| o = torch.zeros(B, | ||
| self.num_heads, | ||
| self.kv_lora_rank, | ||
| dtype=q.dtype, | ||
| device=q.device) | ||
|
|
||
| kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) | ||
|
|
||
| aiter_mla_decode_forward(q, kv_buffer, o, self.scale, | ||
| attn_metadata.decode.paged_kv_indptr, | ||
| attn_metadata.decode.paged_kv_indices, | ||
| attn_metadata.decode.paged_kv_last_page_len) | ||
|
|
||
| return self._v_up_proj_and_o_proj(o) | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this change from fwd -> forward seems not necessary, in order to minimize the number of files changed in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it has been addressed in the latest commit.