Skip to content

Commit cd9b9de

Browse files
[BugFix] Fix IMA FlashMLA full cuda-graph and DP + Update FlashMLA (#21691)
Signed-off-by: Lucas Wilkinson <[email protected]> Co-authored-by: yewentao256 <[email protected]> Co-authored-by: Wentao Ye <[email protected]>
1 parent fe6d825 commit cd9b9de

File tree

3 files changed

+42
-27
lines changed

3 files changed

+42
-27
lines changed

cmake/external_projects/flashmla.cmake

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ else()
1919
FetchContent_Declare(
2020
flashmla
2121
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
22-
GIT_TAG 575f7724b9762f265bbee5889df9c7d630801845
22+
GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1
2323
GIT_PROGRESS TRUE
2424
CONFIGURE_COMMAND ""
2525
BUILD_COMMAND ""
@@ -37,9 +37,9 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
3737
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
3838
set(FlashMLA_SOURCES
3939
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
40-
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu
41-
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_fp16_sm90.cu
42-
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_metadata.cu)
40+
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
41+
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
42+
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu)
4343

4444
set(FlashMLA_INCLUDES
4545
${flashmla_SOURCE_DIR}/csrc/cutlass/include

vllm/attention/ops/flashmla.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def flash_mla_with_kvcache(
9191
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
9292
q,
9393
k_cache,
94-
None,
9594
head_dim_v,
9695
cache_seqlens,
9796
block_table,

vllm/v1/attention/backends/mla/flashmla.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,22 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
7070
self.cg_buf_tile_scheduler_metadata = None
7171
self.cg_buf_num_splits = None
7272

73+
device_properties = torch.cuda.get_device_properties(self.device)
74+
num_sms = device_properties.multi_processor_count
75+
76+
if self.compilation_config.full_cuda_graph:
77+
self.cg_buf_tile_scheduler_metadata = torch.zeros(
78+
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
79+
# TileSchedulerMetaDataSize = 8
80+
(num_sms, 8),
81+
device=self.device,
82+
dtype=torch.int32,
83+
)
84+
self.cg_buf_num_splits = torch.empty(
85+
(vllm_config.scheduler_config.max_num_seqs + 1),
86+
device=self.device,
87+
dtype=torch.int32)
88+
7389
def _build_decode(self, block_table_tensor: torch.Tensor,
7490
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
7591
tile_scheduler_metadata, num_splits = \
@@ -80,28 +96,28 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
8096
)
8197

8298
if self.compilation_config.full_cuda_graph:
83-
# First time around (CUDAGraph capture), allocate the static buffer
84-
if self.cg_buf_tile_scheduler_metadata is None:
85-
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
86-
self.cg_buf_num_splits = num_splits
87-
else:
88-
assert self.cg_buf_num_splits is not None
89-
90-
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
91-
assert (self.cg_buf_tile_scheduler_metadata.size() ==
92-
tile_scheduler_metadata.size())
93-
self.cg_buf_tile_scheduler_metadata.\
94-
copy_(tile_scheduler_metadata)
95-
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata
96-
97-
# Num splits is per-batch, varying size (batch_size,)
98-
n = num_splits.size(0)
99-
# make sure static buffer is large enough
100-
assert n <= self.cg_buf_num_splits.size(0)
101-
num_splits_view = self.cg_buf_num_splits[:n]
102-
num_splits_view.copy_(num_splits)
103-
self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s
104-
num_splits = num_splits_view
99+
assert self.cg_buf_tile_scheduler_metadata is not None
100+
assert self.cg_buf_num_splits is not None
101+
102+
sm_parts = tile_scheduler_metadata.size(0)
103+
# Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
104+
assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
105+
tile_scheduler_metadata_view = \
106+
self.cg_buf_tile_scheduler_metadata[:sm_parts]
107+
tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
108+
tile_scheduler_metadata = tile_scheduler_metadata_view
109+
110+
# Num splits is per-batch, varying size (batch_size,)
111+
n = num_splits.size(0)
112+
# make sure static buffer is large enough
113+
assert n <= self.cg_buf_num_splits.size(0)
114+
num_splits_view = self.cg_buf_num_splits[:n]
115+
num_splits_view.copy_(num_splits)
116+
# Num splits needs to monotonically increasing
117+
# (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
118+
# it needs to monotonically increasing by 1)
119+
self.cg_buf_num_splits[n:].fill_(num_splits[-1])
120+
num_splits = num_splits_view
105121

106122
return FlashMLADecodeMetadata(
107123
block_table=block_table_tensor,

0 commit comments

Comments
 (0)