@@ -215,6 +215,41 @@ __global__ void reshape_and_cache_kernel(
215
215
}
216
216
}
217
217
218
+ template <typename scalar_t >
219
+ __global__ void reshape_and_cache_flash_kernel (
220
+ const scalar_t * __restrict__ key, // [num_tokens, num_heads, head_size]
221
+ const scalar_t * __restrict__ value, // [num_tokens, num_heads, head_size]
222
+ scalar_t * __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size]
223
+ scalar_t * __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size]
224
+ const int64_t * __restrict__ slot_mapping, // [num_tokens]
225
+ const int block_stride,
226
+ const int key_stride,
227
+ const int value_stride,
228
+ const int num_heads,
229
+ const int head_size,
230
+ const int block_size) {
231
+ const int64_t token_idx = blockIdx .x ;
232
+ const int64_t slot_idx = slot_mapping[token_idx];
233
+ // NOTE: slot_idx can be -1 if the token is padded
234
+ if (slot_idx < 0 ) {
235
+ return ;
236
+ }
237
+ const int64_t block_idx = slot_idx / block_size;
238
+ const int64_t block_offset = slot_idx % block_size;
239
+ const int n = num_heads * head_size;
240
+ for (int i = threadIdx .x ; i < n; i += blockDim .x ) {
241
+ const int64_t src_key_idx = token_idx * key_stride + i;
242
+ const int64_t src_value_idx = token_idx * value_stride + i;
243
+ const int head_idx = i / head_size;
244
+ const int head_offset = i % head_size;
245
+ const int64_t tgt_value_idx = block_idx * block_stride
246
+ + block_offset * num_heads * head_size
247
+ + head_idx * head_size
248
+ + head_offset;
249
+ k_cache[tgt_value_idx] = key[src_key_idx];
250
+ v_cache[tgt_value_idx] = value[src_value_idx];
251
+ }
252
+ }
218
253
} // namespace vllm
219
254
220
255
#define CALL_RESHAPE_AND_CACHE (KV_T, CACHE_T, IS_FP8_KV_CACHE ) \
@@ -275,6 +310,51 @@ void reshape_and_cache(
275
310
}
276
311
}
277
312
313
+ void reshape_and_cache_flash (
314
+ torch::Tensor& key, // [num_tokens, num_heads, head_size]
315
+ torch::Tensor& value, // [num_tokens, num_heads, head_size]
316
+ torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
317
+ torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
318
+ torch::Tensor& slot_mapping, // [num_tokens]
319
+ const std::string& kv_cache_dtype)
320
+ {
321
+ // FIXME: only support auto datatype, does not support fp8
322
+ if (kv_cache_dtype != " auto" ) {
323
+ TORCH_CHECK (false , " Unsupported data type of kv cache: " , kv_cache_dtype);
324
+ }
325
+ int num_tokens = key.size (0 );
326
+ int num_heads = key.size (1 );
327
+ int head_size = key.size (2 );
328
+ int block_size = k_cache.size (1 );
329
+
330
+ int key_stride = key.stride (0 );
331
+ int value_stride = value.stride (0 );
332
+ int block_stride = k_cache.stride (0 );
333
+ TORCH_CHECK (k_cache.stride (0 ) == v_cache.stride (0 ));
334
+
335
+ dim3 grid (num_tokens);
336
+ dim3 block (std::min (num_heads * head_size, 512 ));
337
+ const at::cuda::OptionalCUDAGuard device_guard (device_of (key));
338
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
339
+ VLLM_DISPATCH_FLOATING_TYPES (
340
+ key.scalar_type (),
341
+ " reshape_and_cache_flash" ,
342
+ [&] {
343
+ vllm::reshape_and_cache_flash_kernel<scalar_t ><<<grid, block, 0 , stream>>> (
344
+ key.data_ptr <scalar_t >(),
345
+ value.data_ptr <scalar_t >(),
346
+ k_cache.data_ptr <scalar_t >(),
347
+ v_cache.data_ptr <scalar_t >(),
348
+ slot_mapping.data_ptr <int64_t >(),
349
+ block_stride,
350
+ key_stride,
351
+ value_stride,
352
+ num_heads,
353
+ head_size,
354
+ block_size);
355
+ });
356
+ }
357
+
278
358
namespace vllm {
279
359
280
360
template <typename Tout, typename Tin>
0 commit comments