-
-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[FEAT][ROCm]: Support AITER MLA on V1 Engine #17523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 38 commits
f782c66
42d5c62
565a3fd
645f400
22c8726
5dc1348
12f8023
da8c69f
1ea5718
681d777
9ada055
1ceb3b9
20a3f07
194a42a
a9a02d5
02a4fb3
95213e2
6e48433
0265f20
693c870
a5a1a54
38c67c7
74c9cb3
643d07f
6171e50
905cec9
20e769e
455bbf2
90daf6e
f68e926
821f475
1a6ba99
7cc28d2
29fc060
7f1ed77
11a8985
825f387
cbeb0df
2218bbc
cb98504
44d813f
423c0be
58d79bd
95644ea
f41d616
56d2254
3ee787e
f688418
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,9 @@ | |
|
|
||
| import torch | ||
|
|
||
| from vllm.platforms import current_platform | ||
| from vllm.utils import direct_register_custom_op | ||
|
|
||
|
|
||
| def get_aiter_mla_metadata(max_batch_size: int, block_size: int, | ||
| max_block_per_batch: int, | ||
|
|
@@ -20,7 +23,7 @@ def get_aiter_mla_metadata(max_batch_size: int, block_size: int, | |
| return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens | ||
|
|
||
|
|
||
| def aiter_mla_decode_fwd( | ||
| def aiter_mla_decode_forward( | ||
|
||
| q: torch.Tensor, | ||
| kv_buffer: torch.Tensor, | ||
| o: torch.Tensor, | ||
|
|
@@ -30,6 +33,28 @@ def aiter_mla_decode_fwd( | |
| kv_last_page_lens: Optional[torch.Tensor] = None, | ||
| logit_cap: float = 0.0, | ||
| ): | ||
|
|
||
| torch.ops.vllm.rocm_aiter_mla_decode_fwd(q, | ||
| kv_buffer.view( | ||
| -1, 1, 1, q.shape[-1]), | ||
| o, | ||
| kv_indptr, | ||
| kv_indices, | ||
| kv_last_page_lens, | ||
| sm_scale=sm_scale, | ||
| logit_cap=logit_cap) | ||
|
|
||
|
|
||
| def mla_decode_fwd_impl( | ||
| q: torch.Tensor, | ||
| kv_buffer: torch.Tensor, | ||
| o: torch.Tensor, | ||
| kv_indptr: Optional[torch.Tensor] = None, | ||
| kv_indices: Optional[torch.Tensor] = None, | ||
| kv_last_page_lens: Optional[torch.Tensor] = None, | ||
| sm_scale: float = 1.0, | ||
| logit_cap: float = 0.0, | ||
| ) -> None: | ||
| from aiter.mla import mla_decode_fwd | ||
|
|
||
| mla_decode_fwd(q, | ||
|
|
@@ -40,3 +65,24 @@ def aiter_mla_decode_fwd( | |
| kv_last_page_lens, | ||
| sm_scale=sm_scale, | ||
| logit_cap=logit_cap) | ||
|
|
||
|
|
||
| def mla_decode_fwd_fake( | ||
| q: torch.Tensor, | ||
| kv_buffer: torch.Tensor, | ||
| o: torch.Tensor, | ||
| kv_indptr: Optional[torch.Tensor] = None, | ||
| kv_indices: Optional[torch.Tensor] = None, | ||
| kv_last_page_lens: Optional[torch.Tensor] = None, | ||
| sm_scale: float = 1.0, | ||
| logit_cap: float = 0.0, | ||
| ) -> None: | ||
| pass | ||
|
|
||
|
|
||
| if current_platform.is_rocm(): | ||
| direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd", | ||
| op_func=mla_decode_fwd_impl, | ||
| mutates_args=["o"], | ||
| fake_impl=mla_decode_fwd_fake, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious is the fake_impl necessary here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @houseroad Yes without it there is error from torch dynamo while building cuda graphs. |
||
| tags=[torch.Tag.needs_fixed_stride_order]) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -84,6 +84,7 @@ | |
| VLLM_ROCM_FP8_PADDING: bool = True | ||
| VLLM_ROCM_MOE_PADDING: bool = True | ||
| VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True | ||
| VLLM_ROCM_EXECUTE_MODEL_TIMEOUT: int = 250 #s | ||
| VLLM_ENABLE_V1_MULTIPROCESSING: bool = True | ||
| VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 | ||
| VLLM_DISABLE_COMPILE_CACHE: bool = False | ||
|
|
@@ -488,6 +489,10 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: | |
| "VLLM_RPC_TIMEOUT": | ||
| lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")), | ||
|
|
||
| # Time in seconds for the model execution in ROCm platforms. | ||
| "VLLM_ROCM_EXECUTE_MODEL_TIMEOUT": | ||
|
||
| lambda: int(os.getenv("VLLM_ROCM_EXECUTE_MODEL_TIMEOUT", "250")), | ||
|
|
||
| # a list of plugin names to load, separated by commas. | ||
| # if this is not set, it means all plugins will be loaded | ||
| # if this is set to an empty string, no plugins will be loaded | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -494,11 +494,12 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, | |||||||||||
| max_context_chunk = (self.chunked_prefill_workspace_size // | ||||||||||||
| num_prefills_with_context_cpu) | ||||||||||||
|
|
||||||||||||
| # align max_context_chunk to page_size by rounding down, | ||||||||||||
| # currently the `gather_cache` kernel cannot handle | ||||||||||||
| # `context_chunk_starts` that are not aligned to page_size | ||||||||||||
| max_context_chunk = round_down(max_context_chunk, | ||||||||||||
| self.page_size) | ||||||||||||
| if self.aot_schedule: | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you explain this a bit? Why was this change necessary?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @SageMoore the vllm/vllm/v1/attention/backends/mla/common.py Lines 355 to 359 in ba7703e
You may want to ask the author about this as these line changes were added in this PR. anyways if |
||||||||||||
| # align max_context_chunk to page_size by rounding down, | ||||||||||||
| # currently the `gather_cache` kernel cannot handle | ||||||||||||
| # `context_chunk_starts` that are not aligned to page_size | ||||||||||||
| max_context_chunk = round_down(max_context_chunk, | ||||||||||||
| self.page_size) | ||||||||||||
|
|
||||||||||||
| assert max_context_chunk > 0 | ||||||||||||
| num_chunks = cdiv(max_context_len_cpu, max_context_chunk) | ||||||||||||
|
|
||||||||||||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this change from fwd -> forward seems not necessary, in order to minimize the number of files changed in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it has been addressed in the latest commit.