-
Notifications
You must be signed in to change notification settings - Fork 743
[Cherry-Pick][Feature] support decode attention for mix(#7688) #7729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
lizhenyun01
wants to merge
13
commits into
PaddlePaddle:release/2.6
Choose a base branch
from
lizhenyun01:dec_attn_2.6
base: release/2.6
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
2265dc9
support c8 decode attention
lizhenyun01 4c922bc
support c16 attention && backend
lizhenyun01 de6450d
opt kernel
lizhenyun01 111230a
fix
lizhenyun01 03263a0
opt larger batch
lizhenyun01 cb64cb3
inplace out
lizhenyun01 b1acb37
fix input_batch && remove fast_math
lizhenyun01 a5e394f
fix xpu
lizhenyun01 6a5b3c6
fix bug
lizhenyun01 307e5a8
fix ci
lizhenyun01 3f29b01
opt and fix mtp
lizhenyun01 35a876a
fix merge
lizhenyun01 9ab4e13
clean code
lizhenyun01 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
1,228 changes: 1,228 additions & 0 deletions
1,228
custom_ops/gpu_ops/append_attention/attention_func.cuh
Large diffs are not rendered by default.
Oops, something went wrong.
406 changes: 406 additions & 0 deletions
406
custom_ops/gpu_ops/append_attention/config_for_attention.cu
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
| #pragma once | ||
| #include <cuda.h> | ||
| #include <cuda_bf16.h> | ||
| #include <cuda_runtime.h> | ||
| #include <cuda_runtime_api.h> | ||
| #include <cuda/barrier> | ||
| #include <stdexcept> | ||
|
|
||
| using barrier = cuda::barrier<cuda::thread_scope_block>; | ||
| namespace cde = cuda::device::experimental; | ||
|
|
||
| template <typename T> | ||
| struct cu_tensor_map_type_traits { | ||
| static const CUtensorMapDataType type = | ||
| CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; | ||
| }; | ||
|
|
||
| template <> | ||
| struct cu_tensor_map_type_traits<phi::dtype::bfloat16> { | ||
| static const CUtensorMapDataType type = | ||
| CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; | ||
| }; | ||
|
|
||
| template <> | ||
| struct cu_tensor_map_type_traits<phi::dtype::float16> { | ||
| static const CUtensorMapDataType type = | ||
| CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; | ||
| }; | ||
|
|
||
| template <> | ||
| struct cu_tensor_map_type_traits<uint8_t> { | ||
| static const CUtensorMapDataType type = | ||
| CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; | ||
| }; | ||
|
|
||
| template <> | ||
| struct cu_tensor_map_type_traits<phi::dtype::float8_e4m3fn> { | ||
| static const CUtensorMapDataType type = | ||
| CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; | ||
| }; | ||
|
|
||
| template <typename T> | ||
| CUtensorMap makeTensorMapForKVCache(T const* addr, | ||
| uint32_t block_num, | ||
| uint32_t kv_num_head, | ||
| uint32_t second_size, | ||
| uint32_t last_size) { | ||
| CUtensorMap tensorMap{}; | ||
|
|
||
| uint32_t elem_bytes = sizeof(T); | ||
|
|
||
| uint32_t const last_size_bytes = elem_bytes * last_size; | ||
| // VLLM Layout | ||
| CUtensorMapDataType data_dtype = cu_tensor_map_type_traits<T>::type; | ||
| constexpr uint32_t rank = 4; | ||
| uint64_t global_dims[] = {last_size, second_size, kv_num_head, block_num}; | ||
| uint64_t global_strides[] = {last_size_bytes, | ||
| second_size * last_size_bytes, | ||
| kv_num_head * second_size * last_size_bytes}; | ||
|
|
||
| uint32_t box_dims[] = {last_size, second_size, 1, 1}; | ||
| uint32_t elem_strides[] = {1, 1, 1, 1}; | ||
|
|
||
| auto const swizzle = [&] { | ||
| switch (last_size_bytes) { | ||
| case 128: | ||
| return CU_TENSOR_MAP_SWIZZLE_128B; | ||
| case 64: | ||
| return CU_TENSOR_MAP_SWIZZLE_64B; | ||
| default: | ||
| throw std::runtime_error("unsupported cache last_size"); | ||
| } | ||
| }(); | ||
| CUresult res = cuTensorMapEncodeTiled( | ||
| &tensorMap, | ||
| data_dtype, | ||
| rank, | ||
| reinterpret_cast<void*>(const_cast<T*>(addr)), | ||
| global_dims, | ||
| global_strides, | ||
| box_dims, | ||
| elem_strides, | ||
| CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, | ||
| swizzle, | ||
| CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B, | ||
| CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); | ||
| switch (res) { | ||
| case CUDA_SUCCESS: | ||
| printf("CUDA_SUCCESS!\n"); | ||
| break; | ||
| case CUDA_ERROR_INVALID_VALUE: | ||
| printf("CUDA_ERROR_INVALID_VALUE\n"); | ||
| break; | ||
| case CUDA_ERROR_OUT_OF_MEMORY: | ||
| printf("CUDA_ERROR_OUT_OF_MEMORY\n"); | ||
| break; | ||
| case CUDA_ERROR_NOT_INITIALIZED: | ||
| printf("CUDA_ERROR_NOT_INITIALIZED\n"); | ||
| break; | ||
| case CUDA_ERROR_DEINITIALIZED: | ||
| printf("CUDA_ERROR_DEINITIALIZED\n"); | ||
| break; | ||
| case CUDA_ERROR_PROFILER_DISABLED: | ||
| printf("CUDA_ERROR_PROFILER_DISABLED\n"); | ||
| break; | ||
| default: | ||
| throw std::runtime_error("unsupported res!"); | ||
| } | ||
|
|
||
| return tensorMap; | ||
| } | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❓ 疑问
cu_tensor_map.cuh已加入仓库,但当前没有被任何 kernel 引用——decode_append_attention_c8_impl.cuh第15行中对应的#include "cu_tensor_map.cuh"已注释掉,decode_append_attention_c16_impl.cuh也未包含此头文件。该文件内的
cuda::device::experimental命名空间(Hopper TMA API)和CUtensorMapDataType均为 SM90+ 专属能力。请确认:TODO: SM90+;