@@ -70,6 +70,22 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
70
70
self .cg_buf_tile_scheduler_metadata = None
71
71
self .cg_buf_num_splits = None
72
72
73
+ device_properties = torch .cuda .get_device_properties (self .device )
74
+ num_sms = device_properties .multi_processor_count
75
+
76
+ if self .compilation_config .full_cuda_graph :
77
+ self .cg_buf_tile_scheduler_metadata = torch .zeros (
78
+ # Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
79
+ # TileSchedulerMetaDataSize = 8
80
+ (num_sms , 8 ),
81
+ device = self .device ,
82
+ dtype = torch .int32 ,
83
+ )
84
+ self .cg_buf_num_splits = torch .empty (
85
+ (vllm_config .scheduler_config .max_num_seqs + 1 ),
86
+ device = self .device ,
87
+ dtype = torch .int32 )
88
+
73
89
def _build_decode (self , block_table_tensor : torch .Tensor ,
74
90
seq_lens : torch .Tensor ) -> FlashMLADecodeMetadata :
75
91
tile_scheduler_metadata , num_splits = \
@@ -80,28 +96,28 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
80
96
)
81
97
82
98
if self .compilation_config .full_cuda_graph :
83
- # First time around (CUDAGraph capture), allocate the static buffer
84
- if self .cg_buf_tile_scheduler_metadata is None :
85
- self . cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
86
- self . cg_buf_num_splits = num_splits
87
- else :
88
- assert self .cg_buf_num_splits is not None
89
-
90
- # Metadata per-SM, fixed size (#SMs, TileMetadataSize)
91
- assert ( self . cg_buf_tile_scheduler_metadata . size () ==
92
- tile_scheduler_metadata . size ())
93
- self . cg_buf_tile_scheduler_metadata .\
94
- copy_ ( tile_scheduler_metadata )
95
- tile_scheduler_metadata = self . cg_buf_tile_scheduler_metadata
96
-
97
- # Num splits is per-batch, varying size (batch_size, )
98
- n = num_splits . size ( 0 )
99
- # make sure static buffer is large enough
100
- assert n <= self . cg_buf_num_splits . size ( 0 )
101
- num_splits_view = self . cg_buf_num_splits [: n ]
102
- num_splits_view . copy_ ( num_splits )
103
- self .cg_buf_num_splits [n :].fill_ (0 ) # fill the rest with 0s
104
- num_splits = num_splits_view
99
+ assert self . cg_buf_tile_scheduler_metadata is not None
100
+ assert self .cg_buf_num_splits is not None
101
+
102
+ sm_parts = tile_scheduler_metadata . size ( 0 )
103
+ # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
104
+ assert sm_parts <= self .cg_buf_tile_scheduler_metadata . size ( 0 )
105
+ tile_scheduler_metadata_view = \
106
+ self . cg_buf_tile_scheduler_metadata [: sm_parts ]
107
+ tile_scheduler_metadata_view . copy_ ( tile_scheduler_metadata )
108
+ tile_scheduler_metadata = tile_scheduler_metadata_view
109
+
110
+ # Num splits is per-batch, varying size (batch_size, )
111
+ n = num_splits . size ( 0 )
112
+ # make sure static buffer is large enough
113
+ assert n <= self . cg_buf_num_splits . size ( 0 )
114
+ num_splits_view = self . cg_buf_num_splits [: n ]
115
+ num_splits_view . copy_ ( num_splits )
116
+ # Num splits needs to monotonically increasing
117
+ # (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
118
+ # it needs to monotonically increasing by 1 )
119
+ self .cg_buf_num_splits [n :].fill_ (num_splits [ - 1 ])
120
+ num_splits = num_splits_view
105
121
106
122
return FlashMLADecodeMetadata (
107
123
block_table = block_table_tensor ,
0 commit comments