diff --git a/custom_ops/gpu_ops/append_attention/attention_func.cuh b/custom_ops/gpu_ops/append_attention/attention_func.cuh new file mode 100644 index 00000000000..ee74570e5d8 --- /dev/null +++ b/custom_ops/gpu_ops/append_attention/attention_func.cuh @@ -0,0 +1,1231 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "mma_tensor_op.cuh" +#include "utils.cuh" + +template +__device__ __forceinline__ void init_states(float (*o_frag)[num_frags_y][8], + float (*m)[2], + float (*d)[2]) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_frag[fx][fy][reg_id] = 0.f; + } + } + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + if constexpr (std::is_same::value) { + m[fx][j] = -5e4f; + } else if constexpr (std::is_same::value) { + m[fx][j] = -3.0e+30f; + } + d[fx][j] = 1.f; + } + } +} + +template +__device__ __forceinline__ void load_block_table_per_chunk( + const int32_t* block_table_chunk_start, + int32_t* block_table_smem, + uint32_t chunk_start, + uint32_t chunk_end, + uint32_t tid, + uint32_t wid) { + uint32_t len = chunk_end / BLOCK_SIZE - chunk_start / BLOCK_SIZE; + for (uint32_t i = 0; i < div_up(len, 128); i++) { + uint32_t offset = wid * kWarpSize + tid + i * 128; + if (offset < len) { + block_table_smem[offset] = block_table_chunk_start[offset]; + } + } +} + +// load q from global memory to shared memory +template +__device__ __forceinline__ void load_q_global_smem_multi_warps( + T* q_ptr_base, + smem_t* q_smem, + uint32_t q_idx_base, + const uint32_t qo_upper_bound, + const uint32_t qo_n_stride, + const uint32_t qo_h_stride) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t q_smem_offset_w = // [NUM_WARP_Q, num_frags_x, 16, head_dim] + smem_t::get_permuted_offset(ty * 4 + tx / 8, + tx % 8); // 4 * 64 + + const uint32_t tx_offset = tx / 8; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + const uint32_t base_offset = q_idx_base + fx * 16 + tx_offset; +#pragma unroll + const int j = ty; + const uint32_t offset_now = base_offset + j * 4; + const uint32_t n_offset = offset_now / group_size; + const uint32_t h_offset = offset_now % group_size; + T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; +#pragma unroll + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { + q_smem->load_128b_async( + q_smem_offset_w, q_ptr, n_offset < qo_upper_bound); + q_smem_offset_w = + q_smem->advance_offset_by_column<8>(q_smem_offset_w, fyo); + q_ptr += 8 * num_elems_per_128b(); + } + q_smem_offset_w = + q_smem->advance_offset_by_row<16, num_vecs_per_head>(q_smem_offset_w) - + 2 * num_frags_y; + } +} + +template +__device__ __forceinline__ void q_smem_inplace_multiply_sm_scale_multi_warps( + smem_t* q_smem, // [num_frags_x * 16, num_frags_y * 16] + const float sm_scale) { + constexpr int vec_size = 16 / sizeof(T); + using LoadT = AlignedVector; + LoadT tmp_vec; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + +#pragma unroll + for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024; ++i) { + const int offset = i * 1024 + ty * 256 + tx * 8; + Load(reinterpret_cast(q_smem->base) + offset, &tmp_vec); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + tmp_vec[reg_id] *= sm_scale; + } + Store(tmp_vec, reinterpret_cast(q_smem->base) + offset); + } +} + +template +__device__ __forceinline__ void produce_k_blockwise_c8( + smem_t smem, + uint32_t* smem_offset, + CacheT* cache_k, + const int* block_table_now, + const uint32_t kv_head_idx, + const uint32_t kv_n_stride, + const uint32_t kv_h_stride, + const uint32_t kv_b_stride, + const uint32_t kv_idx_base, + const uint32_t kv_len, + const uint32_t const_k_offset) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = + head_dim / num_elems_per_128b(); // 8 + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; +#pragma unroll + for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV / 2; ++kv_i) { + int block_id = __ldg(&block_table_now[kv_idx / block_size]); + if (block_id < 0) block_id = 0; + CacheT* cache_k_now = cache_k + block_id * kv_n_stride + const_k_offset; +#pragma unroll + for (uint32_t i = 0; i < 2 * num_frags_z * 4 / num_warps; + ++i) { // m num_frags_z * 16 / (num_warps * 4) +#pragma unroll + for (uint32_t j = 0; j < num_frags_y / 8; ++j) { + smem.load_128b_async(*smem_offset, cache_k_now, true); + *smem_offset = smem.advance_offset_by_column<8, num_vecs_per_head>( + *smem_offset, j); + cache_k_now += 8 * num_elems_per_128b(); + } + kv_idx += num_warps * 4; + *smem_offset = + smem.advance_offset_by_row( + *smem_offset) - + num_frags_y; // num_frags_y / 4 * 4 + cache_k_now += num_warps * 4 * kv_b_stride - + num_frags_y * num_elems_per_128b(); + } + } + *smem_offset -= NUM_WARP_KV * num_frags_z * 16 * num_vecs_per_head; +} + +template +__device__ __forceinline__ void produce_v_blockwise_c8( + smem_t smem, + uint32_t* smem_offset, + CacheT* cache_v, + const int* block_table_now, + const uint32_t kv_head_idx, + const uint32_t kv_n_stride, + const uint32_t kv_h_stride, + const uint32_t kv_d_stride, + const uint32_t kv_idx_base, + const uint32_t kv_len, + const uint32_t const_v_offset) { + constexpr uint32_t num_vecs_per_blocksize = + block_size / num_elems_per_128b(); // 8 + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t kv_idx = kv_idx_base + tx % 4 * num_elems_per_128b(); + +#pragma unroll + for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV / 2; ++kv_i) { + int block_id = __ldg(&block_table_now[kv_idx / block_size]); + if (block_id < 0) block_id = 0; + CacheT* cache_v_now = cache_v + block_id * kv_n_stride + const_v_offset; + +#pragma unroll + for (uint32_t i = 0; i < num_frags_y * 2 / num_warps; + ++i) { // m (num_frags_y * 16 / (num_warps * 8)) +#pragma unroll + for (uint32_t j = 0; j < 2 * num_frags_z / 4; ++j) { + smem.load_128b_async(*smem_offset, cache_v_now, true); + *smem_offset = smem.advance_offset_by_column<4, num_vecs_per_blocksize>( + *smem_offset, j); + cache_v_now += 4 * num_elems_per_128b(); + kv_idx += 4 * num_elems_per_128b(); + } + kv_idx -= 2 * num_frags_z * num_elems_per_128b(); + *smem_offset = + smem.advance_offset_by_row( + *smem_offset) - + 2 * num_frags_z; // num_frags_z / 4 * 4 + cache_v_now += num_warps * 8 * kv_d_stride - + 2 * num_frags_z * num_elems_per_128b(); + } + kv_idx += block_size; + } + *smem_offset -= NUM_WARP_KV / 2 * num_frags_y * 16 * num_vecs_per_blocksize; +} + +template +__device__ __forceinline__ void produce_kv_dynamic_scale_gmem2smem_async( + smem_t kv_scale_smem, + const int* block_table_now, + const T* cache_kv_scale, + const uint32_t kv_idx, + const uint32_t kv_num_heads, + const uint32_t kv_head_idx, + const uint32_t chunk_end) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + const uint32_t tid = ty * 32 + tx; + // 1 warp 32 tokens + if (tid < block_size / 8 * 2) { + const uint32_t kv_idx_now = kv_idx + block_size * tid / 8; + int block_id = __ldg(&block_table_now[kv_idx_now / block_size]); + if (block_id < 0) block_id = 0; + const int kv_idx_this_thread = kv_idx + tid * 8; + const T* cache_k_scale_now = cache_kv_scale + + block_id * kv_num_heads * block_size + + kv_head_idx * block_size + tid % 8 * 8; + kv_scale_smem.load_128b_async( + tid, cache_k_scale_now, kv_idx_this_thread < chunk_end); + } +} + +template +__device__ __forceinline__ void produce_k_dynamic_scale_smem2reg( + T* k_smem_scale, T* cache_k_reg) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + // 1 warp 32 tokens + const uint32_t row_id = tx / 4; + for (uint32_t fz = 0; fz < num_frags_z; fz++) { + const uint32_t scale_idx = ty * 32 + fz * 16 + row_id; + cache_k_reg[fz * 2] = k_smem_scale[scale_idx]; + cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8]; + } +} + +template +__device__ __forceinline__ void produce_v_dynamic_scale_smem2reg( + T* v_smem_scale, T* cache_v_reg) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + + // 1 warp 32 tokens + const uint32_t row_id = tx % 4 * 2; + for (uint32_t fz = 0; fz < num_frags_z; fz++) { + const uint32_t scale_idx = ty * 32 + fz * 16 + row_id; + cache_v_reg[fz * 4] = v_smem_scale[scale_idx]; + cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1]; + cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8]; + cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9]; + } +} + +template +__device__ __forceinline__ void compute_qk_c8(smem_t* q_smem, + uint32_t* q_smem_offset_r, + smem_t* k_smem, + uint32_t* k_smem_offset_r, + const T* cache_k_scale, + float (*s_frag)[num_frags_z][8]) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head_q = head_dim / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_head_k = + head_dim / num_elems_per_128b(); + + uint32_t a_frag[num_frags_x][2][4], b_frag[4], b_frag_dq[4]; + +#pragma unroll + for (uint32_t ky = 0; ky < num_frags_y / 2; ++ky) { // k + // load q +#pragma unroll + for (uint32_t fy = 0; fy < 2; ++fy) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[fx][fy]); + + *q_smem_offset_r = + q_smem->advance_offset_by_row<16, num_vecs_per_head_q>( + *q_smem_offset_r); + } + *q_smem_offset_r = + q_smem->advance_offset_by_column<2>(*q_smem_offset_r, ky * 2 + fy) - + num_frags_x * 16 * num_vecs_per_head_q; + } + +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + // load + k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); + *k_smem_offset_r = k_smem->advance_offset_by_row<16, num_vecs_per_head_k>( + *k_smem_offset_r); +#pragma unroll + for (uint32_t fy = 0; fy < 2; ++fy) { + T* b_frag_dq_T = reinterpret_cast(b_frag_dq); + convert_c8(b_frag_dq_T, b_frag[fy * 2]); + convert_c8(b_frag_dq_T + 4, b_frag[fy * 2 + 1]); + // scale zp + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { + const int scale_col = (ky * 2 + fy) * 4; + b_frag_dq_T[0] *= cache_k_scale[scale_col]; + b_frag_dq_T[1] *= cache_k_scale[scale_col + 1]; + b_frag_dq_T[2] *= cache_k_scale[scale_col + 2]; + b_frag_dq_T[3] *= cache_k_scale[scale_col + 3]; + b_frag_dq_T[4] *= cache_k_scale[scale_col]; + b_frag_dq_T[5] *= cache_k_scale[scale_col + 1]; + b_frag_dq_T[6] *= cache_k_scale[scale_col + 2]; + b_frag_dq_T[7] *= cache_k_scale[scale_col + 3]; + } else { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_k_scale[0]; + } + } + } else { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4]; + } + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + if (ky == 0 && fy == 0) { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx][fy], b_frag_dq); + } else { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx][fy], b_frag_dq); + } + } + } + } + *k_smem_offset_r = k_smem->advance_offset_by_column<2, num_vecs_per_head_k>( + *k_smem_offset_r, ky) - + num_frags_z * 16 * num_vecs_per_head_k; + } + *q_smem_offset_r -= num_frags_y * 2; + *k_smem_offset_r -= num_frags_y / 2 * 2; +} + +template +__device__ __forceinline__ void mask_s(const bool* attn_mask, + const uint32_t qo_idx_base, + const uint32_t kv_idx_base, + const uint32_t qo_len, + const uint32_t kv_len, + const uint32_t chunk_end, + const uint32_t attn_mask_len, + float (*s_frag)[num_frags_z][8], + const int* mask_offset = nullptr, + const int sliding_window = 0) { + const uint32_t tx = threadIdx.x; +#pragma unroll 1 + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + const uint32_t q_idx = (qo_idx_base + fx * 16 + tx / 4 + + 8 * ((reg_id % 4) / 2)) / + group_size, + kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) + + 8 * (reg_id / 4) + reg_id % 2; + bool out_of_boundary; + if (mask_offset) { + const int2 mo = reinterpret_cast(mask_offset)[q_idx]; + out_of_boundary = + q_idx < qo_len ? (kv_idx >= mo.y || kv_idx < mo.x) : true; + } else if (sliding_window > 0) { + bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx - + (int)qo_len - sliding_window; + out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len || + out_of_window || (kv_idx >= chunk_end)) + : kv_idx >= chunk_end); + } else { + out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len || + (kv_idx >= chunk_end)) + : kv_idx >= chunk_end); + if (attn_mask != nullptr && kv_idx > kv_len - qo_len && + kv_idx < chunk_end && q_idx < attn_mask_len) { + const int32_t mask_idx = + q_idx * attn_mask_len + kv_idx - kv_len + qo_len; + bool mask = attn_mask[mask_idx]; + out_of_boundary |= mask; + } + } + + if constexpr (std::is_same::value) { + s_frag[fx][fz][reg_id] = + out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id]; + } else if constexpr (std::is_same::value) { + s_frag[fx][fz][reg_id] = + out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id]; + } + } + } + } +} + +template +__device__ __forceinline__ void update_mdo_states( + float (*s_frag)[num_frags_z][8], + float (*o_frag)[num_frags_y][8], + float (*m)[2], + float (*d)[2]) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + uint32_t j_id = j * 2; + float m_prev = m[fx][j]; +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + float* s_frag_tmp = s_frag[fx][fz] + j_id; + float m_local = max(max(s_frag_tmp[0], s_frag_tmp[1]), + max(s_frag_tmp[4], s_frag_tmp[5])); + m[fx][j] = max(m[fx][j], m_local); + } + m[fx][j] = max(m[fx][j], __shfl_xor_sync(-1, m[fx][j], 0x2, 32)); + m[fx][j] = max(m[fx][j], __shfl_xor_sync(-1, m[fx][j], 0x1, 32)); + float o_scale = expf(m_prev - m[fx][j]); + d[fx][j] *= o_scale; + float2 fp2_scale = make_float2(o_scale, o_scale); +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + float2* o_frag_ptr = reinterpret_cast(o_frag[fx][fy] + j_id); + o_frag_ptr[0] = fast_float2_mul(o_frag_ptr[0], fp2_scale); + o_frag_ptr[2] = fast_float2_mul(o_frag_ptr[2], fp2_scale); + } + float tmp_m = m[fx][j]; +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + float* s_frag_ptr = s_frag[fx][fz] + j_id; + s_frag_ptr[0] = __expf(s_frag_ptr[0] - tmp_m); + s_frag_ptr[1] = __expf(s_frag_ptr[1] - tmp_m); + s_frag_ptr[4] = __expf(s_frag_ptr[4] - tmp_m); + s_frag_ptr[5] = __expf(s_frag_ptr[5] - tmp_m); + } + } + } +} + +template +__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec( + smem_t* v_smem, + uint32_t* v_smem_offset_r, + float (*s_frag)[num_frags_z][8], + float (*o_frag)[num_frags_y][8], + float (*d)[2], + T* cache_v_scale) { + constexpr uint32_t num_vecs_per_blocksize = + block_size / num_elems_per_128b(); + + T s_frag_f16[num_frags_x][num_frags_z][8]; + uint32_t b_frag[4], b_frag_dq[4]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + vec_cast(s_frag_f16[fx][fz], s_frag[fx][fz]); + } + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + rowsum_f16f16f32(d[fx], s_frag_f16[fx][fz]); + } + } + +#pragma unroll + for (uint32_t kz = 0; kz < num_frags_z / 2; ++kz) { // k +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + v_smem->ldmatrix_m8n8x4(*v_smem_offset_r, b_frag); + *v_smem_offset_r = + v_smem->advance_offset_by_row<16, num_vecs_per_blocksize>( + *v_smem_offset_r); +#pragma unroll + for (uint32_t fz = 0; fz < 2; ++fz) { + // dequant b_frag -> b_frag_dq + T* b_frag_dq_T = reinterpret_cast(b_frag_dq); + convert_c8(b_frag_dq_T, b_frag[fz * 2]); + convert_c8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); + // scale zp + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2]; + } + } else { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[0]; + } + } + } else { + const int scale_col = (kz * 2 + fz) * 4; + b_frag_dq_T[0] *= cache_v_scale[scale_col]; + b_frag_dq_T[1] *= cache_v_scale[scale_col + 1]; + b_frag_dq_T[2] *= cache_v_scale[scale_col + 2]; + b_frag_dq_T[3] *= cache_v_scale[scale_col + 3]; + b_frag_dq_T[4] *= cache_v_scale[scale_col]; + b_frag_dq_T[5] *= cache_v_scale[scale_col + 1]; + b_frag_dq_T[6] *= cache_v_scale[scale_col + 2]; + b_frag_dq_T[7] *= cache_v_scale[scale_col + 3]; + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16 + mma_sync_m16n16k16_row_col_f16f16f32( + o_frag[fx][fy], + (uint32_t*)(s_frag_f16[fx][kz * 2 + fz]), + b_frag_dq); + } + } + } + *v_smem_offset_r -= num_frags_y * 16 * num_vecs_per_blocksize; + } +} + +template +__device__ __forceinline__ void merge_block_res(float (*o_frag)[num_frags_y][8], + float* md_smem, + float (*m)[2], + float (*d)[2], + const uint32_t wid, + const uint32_t tid, + const bool normalize = false) { + // Padded row stride (33 instead of 32) to avoid cross-row bank conflicts. + constexpr uint32_t kRowStride = 33; + // o_smem row stride in floats: kRowStride * 8 = 264 + constexpr uint32_t kORowStride = kRowStride * 8; + // md_smem base offset: after all o_smem data + // NUM_WARPS(4) * num_frags_x * num_frags_y * kORowStride floats + constexpr uint32_t kOMemFloats = 4 * num_frags_x * num_frags_y * kORowStride; + float2* smem_md = reinterpret_cast(md_smem + kOMemFloats); + + // Phase 1: Write m/d to smem only (2KB, no o data yet) +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + smem_md[((wid * num_frags_x + fx) * 2 + j) * kRowStride + tid] = + make_float2(m[fx][j], d[fx][j]); + } + } + __syncthreads(); + + // Phase 2: Compute global m/d and scale own o_frag in registers + float scale_j[2]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + float m_new; + float d_new = 1.f; + if constexpr (std::is_same::value) { + m_new = -5e4f; + } else { + m_new = -3.0e+30f; + } +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + float2 md = + smem_md[((i * num_frags_x + fx) * 2 + j) * kRowStride + tid]; + float m_prev = m_new, d_prev = d_new; + m_new = max(m_new, md.x); + d_new = fmaf(d_prev, expf(m_prev - m_new), md.y * expf(md.x - m_new)); + } + float own_scale = expf(m[fx][j] - m_new); + m[fx][j] = m_new; + d[fx][j] = d_new; + float d_rcp = normalize ? (1.f / d_new) : 1.f; + scale_j[j] = own_scale * d_rcp; + } + // Apply scale to o_frag using WGMMA fragment layout: + // regs 0,1→j=0, 2,3→j=1, 4,5→j=0, 6,7→j=1 + // i.e., float2 index k → j = k % 2 +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t k = 0; k < 4; ++k) { + float s = scale_j[k % 2]; + o_frag[fx][fy][2 * k + 0] *= s; + o_frag[fx][fy][2 * k + 1] *= s; + } + } + } + + // Phase 3: Write pre-scaled o_frag to smem with padded stride +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + float2* o_smem_start = + (float2*)(md_smem + + ((wid * num_frags_x + fx) * num_frags_y + fy) * + kORowStride + + tid * 2); +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + o_smem_start[i * kRowStride] = ((float2*)(&o_frag[fx][fy][0]))[i]; + } + } + } + __syncthreads(); + + // Phase 4: Accumulate all warps' scaled o_frag +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + float2* o_new_fp2 = reinterpret_cast(&o_frag[fx][fy][0]); +#pragma unroll + for (uint32_t o_id = 0; o_id < 4; ++o_id) { + o_new_fp2[o_id] = make_float2(0.f, 0.f); + } +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + AlignedVector oi_fp2; + float2* o_smem_start = + (float2*)(md_smem + + ((i * num_frags_x + fx) * num_frags_y + fy) * + kORowStride + + tid * 2); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 4; ++reg_id) { + oi_fp2[reg_id] = o_smem_start[reg_id * kRowStride]; + } +#pragma unroll + for (uint32_t reg_fp2_id = 0; reg_fp2_id < 4; ++reg_fp2_id) { + o_new_fp2[reg_fp2_id].x += oi_fp2[reg_fp2_id].x; + o_new_fp2[reg_fp2_id].y += oi_fp2[reg_fp2_id].y; + } + } + } + } +} + +template +__device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8], + float (*d)[2]) { + float d_rcp[num_frags_x][2]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + d_rcp[fx][j] = 1.f / d[fx][j]; + } + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_frag[fx][fy][reg_id] = + o_frag[fx][fy][reg_id] * d_rcp[fx][(reg_id % 4) / 2]; + } + } + } +} + +template +__device__ __forceinline__ void write_o_reg_gmem_multi_warps( + float (*o_frag)[num_frags_y][8], + smem_t* o_smem, + OutT* o_ptr_base, + uint32_t o_idx_base, + const uint32_t q_head_idx_base, + const uint32_t qo_upper_bound, + const uint32_t qo_n_stride, + const uint32_t qo_h_stride) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr int VEC_SIZE = 16 / sizeof(T); + // [num_warps * num_frags_x * 16, num_frags_y * 16] + if (ty == 0) { + // [num_frags_x * 16, num_frags_y * 16] +#pragma unroll 1 + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + uint32_t o_frag_f16[4]; + vec_cast((T*)o_frag_f16, o_frag[fx][fy]); + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(fx * 16 + tx / 4, + fy * 2); + ((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0]; + ((uint32_t*)(o_smem->base + o_smem_offset_w + + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[tx % 4] = + o_frag_f16[2]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) + + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[3]; + } + } + } + __syncthreads(); + + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(ty * 4 + tx / 8, tx % 8); + + const uint32_t tx_offset = tx / 8; +#pragma unroll 1 + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + const uint32_t base_offset = o_idx_base + fx * 16 + tx_offset; +#pragma unroll + const int j = ty; + const uint32_t offset_now = base_offset + j * 4; + const uint32_t n_offset = offset_now / group_size; + const uint32_t h_offset = offset_now % group_size; + + OutT* o_ptr = o_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; +#pragma unroll + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { + if (n_offset < qo_upper_bound) { + o_smem->store_128b(o_smem_offset_w, o_ptr); + } + o_ptr += 8 * num_elems_per_128b(); + o_smem_offset_w = + o_smem->advance_offset_by_column<8>(o_smem_offset_w, fyo); + } + o_smem_offset_w = + o_smem->advance_offset_by_row<16, num_vecs_per_head>(o_smem_offset_w) - + 2 * num_frags_y; + } +} + +template +struct prefill_softmax_state_t { + AlignedVector o; + float m; + float d; + + __device__ __forceinline__ void init() { + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((half2*)(&o) + i) = make_half2(0, 0); + } + } else if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0); + } + } + d = 1.f; + if constexpr (std::is_same::value) { + m = -5e4f; + } else if constexpr (std::is_same::value) { + m = -3.38953e38f; + } + } + + __device__ __forceinline__ void merge( + const AlignedVector& other_o, float other_m, float other_d) { + float m_prev = m, d_prev = d; + m = m_prev > other_m ? m_prev : other_m; + const float scale1 = __expf(m_prev - m), scale2 = __expf(other_m - m); + const T scale1_T = static_cast(scale1), + scale2_T = static_cast(scale2); + d = d_prev * scale1 + other_d * scale2; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] = o[i] * scale1_T + other_o[i] * scale2_T; + } + } + + __device__ __forceinline__ void normalize() { + const T d_t = static_cast(d); +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] /= d_t; + } + } + + __device__ __forceinline__ void normalize(float current_sink) { + const T d_t = static_cast(d + __expf(current_sink - m)); +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] /= d_t; + } + } +}; + +// C16 (fp16/bf16 KV cache) helper functions + +template +__device__ __forceinline__ void produce_kv_blockwise(smem_t smem, + uint32_t* smem_offset, + T** gptr, + const uint32_t kv_b_stride, + const uint32_t kv_idx_base, + const uint32_t kv_len) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; +#pragma unroll + for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 4 / num_warps; ++i) { +#pragma unroll + for (uint32_t j = 0; j < num_frags_y / 4; ++j) { + smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); + *smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j); + *gptr += 8 * num_elems_per_128b(); + } + kv_idx += num_warps * 4; + *smem_offset = smem.advance_offset_by_row( + *smem_offset) - + 2 * num_frags_y; + *gptr += + num_warps * 4 * kv_b_stride - 2 * num_frags_y * num_elems_per_128b(); + } + *gptr -= NUM_WARP_KV * num_frags_z * 16 * kv_b_stride; + *smem_offset -= NUM_WARP_KV * num_frags_z * 16 * num_vecs_per_head; +} + +template +__device__ __forceinline__ void compute_qk(smem_t* q_smem, + uint32_t* q_smem_offset_r, + smem_t* k_smem, + uint32_t* k_smem_offset_r, + float (*s_frag)[num_frags_z][8]) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + uint32_t a_frag[num_frags_x][4], b_frag[4]; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[fx]); + *q_smem_offset_r = q_smem->advance_offset_by_row<16, num_vecs_per_head>( + *q_smem_offset_r); + } + + *q_smem_offset_r = + q_smem->advance_offset_by_column<2>(*q_smem_offset_r, fy) - + num_frags_x * 16 * num_vecs_per_head; + +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); + *k_smem_offset_r = k_smem->advance_offset_by_row<16, num_vecs_per_head>( + *k_smem_offset_r); +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + if (fy == 0) { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx], b_frag); + } else { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx], b_frag); + } + } + } + *k_smem_offset_r = + k_smem->advance_offset_by_column<2>(*k_smem_offset_r, fy) - + num_frags_z * 16 * num_vecs_per_head; + } + *q_smem_offset_r -= num_frags_y * 2; + *k_smem_offset_r -= num_frags_y * 2; +} + +template +__device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, + uint32_t* v_smem_offset_r, + float (*s_frag)[num_frags_z][8], + float (*o_frag)[num_frags_y][8], + float (*d)[2]) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + + T s_frag_f16[num_frags_x][num_frags_z][8]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + vec_cast(s_frag_f16[fx][fz], s_frag[fx][fz]); + } + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + rowsum_f16f16f32(d[fx], s_frag_f16[fx][fz]); + } + } + +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + uint32_t b_frag[4]; + v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + mma_sync_m16n16k16_row_col_f16f16f32( + o_frag[fx][fy], (uint32_t*)(s_frag_f16[fx][fz]), b_frag); + } + *v_smem_offset_r = + v_smem->advance_offset_by_column<2>(*v_smem_offset_r, fy); + } + *v_smem_offset_r = + v_smem->advance_offset_by_row<16, num_vecs_per_head>(*v_smem_offset_r) - + 2 * num_frags_y; + } + *v_smem_offset_r -= 16 * num_frags_z * num_vecs_per_head; +} + +template +__global__ void merge_chunks_kernel( + const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads, + // head_dim] + const float* __restrict__ multi_m, // [token_num, num_chunks, num_heads] + const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads] + const int* __restrict__ seq_lens_q, + const int* __restrict__ seq_lens_kv, + const int* __restrict__ seq_lens_encoder, + const int* __restrict__ batch_id_per_token, + const int* __restrict__ cu_seqlens_q, + const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const T* __restrict__ sinks, // [q_num_heads] + const int* __restrict__ chunk_size_ptr, + T* __restrict__ out, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_seq_len, + const int num_chunks, + const int num_heads, + const int head_dim, + const int token_num, + const int max_tokens_per_batch = 5) { + const int vid = threadIdx.x, ty = threadIdx.y; + const int hid = blockIdx.y; + // After intra-warp reduction, only bdy/2 results need smem storage + __shared__ T smem[(bdy / 2) * HEAD_DIM]; + __shared__ float md_smem[(bdy / 2) * 2]; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + // Phase 1: Fast path — all ty participate independently (no smem, no + // syncthreads) Each ty handles a different qid with stride gridDim.x * bdy + using LoadT = AlignedVector; + for (int qid = blockIdx.x + ty * gridDim.x; qid < token_num; + qid += gridDim.x * bdy) { + const uint32_t bid = batch_id_per_token[qid]; + if (bid == (uint32_t)-1) continue; + if (seq_lens_encoder[bid] > 0) continue; // skip prefill batches + const uint32_t local_seq_id = qid - cu_seqlens_q[bid]; + const int seq_len_q = seq_lens_q[bid]; + if (seq_len_q == 0) continue; + int seq_len_kv = seq_lens_kv[bid]; + if (seq_len_kv == 0) continue; + seq_len_kv += seq_len_q; + const int num_chunks_this_seq = div_up(seq_len_kv, *chunk_size_ptr); + if (num_chunks_this_seq != 1) continue; // handled in Phase 2 + + LoadT load_vec; + uint32_t offset = + ((bid * max_tokens_per_batch + local_seq_id) * num_chunks * num_heads + + hid) * + head_dim + + vid * vec_size; + Load(&multi_out[offset], &load_vec); + Store( + load_vec, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]); + } + + // Phase 2: Slow path — merge multi-chunk results + // Optimization: use warp-shuffle reduction within each warp, then cross-warp + // via smem. This eliminates the large smem[bdy * HEAD_DIM] buffer and reduces + // syncthreads from 2 per qid to 1 per qid. + // Block layout: (blockx=16, bdy=8) => 4 warps, each warp has 2 ty values + // Warp 0: ty=0,1 Warp 1: ty=2,3 Warp 2: ty=4,5 Warp 3: ty=6,7 + // Lane layout within warp: lanes 0-15 = (ty_low, vid), lanes 16-31 = + // (ty_high, vid) + const int lane_id = (ty * blockDim.x + vid) % 32; + + for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) { + const uint32_t bid = batch_id_per_token[qid]; + if (bid == (uint32_t)-1) continue; // uniform skip — no syncthreads needed + if (seq_lens_encoder[bid] > 0) continue; + const uint32_t local_seq_id = qid - cu_seqlens_q[bid]; + const int seq_len_q = seq_lens_q[bid]; + if (seq_len_q == 0) continue; + int seq_len_kv = seq_lens_kv[bid]; + if (seq_len_kv == 0) continue; + seq_len_kv += seq_len_q; + const int num_chunks_this_seq = div_up(seq_len_kv, *chunk_size_ptr); + if (num_chunks_this_seq == 1) continue; // handled in Phase 1 + + LoadT load_vec; + LoadT res_vec; + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((half2*)(&res_vec) + i) = make_half2(0, 0); + } + } else { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0); + } + } + float m; + float d = 1.f; + if constexpr (std::is_same::value) { + m = -5e4f; + } else if constexpr (std::is_same::value) { + m = -3.0e+30f; + } + + // Step 1: Each ty iterates over its chunk subset and does local online + // softmax merge +#pragma unroll 2 + for (int i = ty; i < num_chunks_this_seq; i += bdy) { + uint32_t offset; + + offset = ((bid * max_tokens_per_batch + local_seq_id) * num_chunks + i) * + num_heads + + hid; + float m_prev = m; + float d_prev = d; + const float m_now = multi_m[offset]; + const float d_now = multi_d[offset]; + m = max(m_prev, m_now); + + offset = ((bid * max_tokens_per_batch + local_seq_id) * num_chunks * + num_heads + + i * num_heads + hid) * + head_dim + + vid * vec_size; + Load(&multi_out[offset], &load_vec); + const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m); + const T scale1_T = static_cast(scale1), + scale2_T = static_cast(scale2); + d = d * scale1 + d_now * scale2; +#pragma unroll + for (int j = 0; j < vec_size; j++) { + res_vec[j] = res_vec[j] * scale1_T + load_vec[j] * scale2_T; + } + } + + // Step 2: Intra-warp reduction via warp shuffle + // Each warp has 2 ty values: ty_low at lanes 0-15, ty_high at lanes 16-31 + // Merge ty_high into ty_low using shuffle + const int partner_lane = lane_id ^ 16; // flip bit 4 to swap low/high ty + const float m_partner = __shfl_sync(0xffffffff, m, partner_lane); + const float d_partner = __shfl_sync(0xffffffff, d, partner_lane); + // Pack adjacent 16-bit pairs into 32-bit for efficient shuffle. + // AlignedVector alignment >= 4 bytes, so uint32 reinterpret is safe + // — no OOB read, no type confusion. This halves shuffle count vs + // per-element memcpy for bf16/fp16. + constexpr int PACKED_SIZE = vec_size * sizeof(T) / sizeof(unsigned); + const unsigned* packed_res = reinterpret_cast(&res_vec); + unsigned packed_partner[PACKED_SIZE]; +#pragma unroll + for (int j = 0; j < PACKED_SIZE; j++) { + packed_partner[j] = __shfl_sync(0xffffffff, packed_res[j], partner_lane); + } + LoadT partner_vec; + memcpy(&partner_vec, packed_partner, sizeof(partner_vec)); + + // Merge partner into self (only the "low ty" keeps the result) + float m_new = max(m, m_partner); + const float scale1 = __expf(m - m_new); + const float scale2 = __expf(m_partner - m_new); + float d_new = d * scale1 + d_partner * scale2; + if ((ty & 1) == 0) { // low ty keeps merged result + m = m_new; + d = d_new; + const T scale1_T = static_cast(scale1); + const T scale2_T = static_cast(scale2); +#pragma unroll + for (int j = 0; j < vec_size; j++) { + res_vec[j] = res_vec[j] * scale1_T + partner_vec[j] * scale2_T; + } + } + + // Cross-warp: only even ty (0,2,4,6) write to smem + if ((ty & 1) == 0) { + Store(res_vec, &smem[(ty / 2) * head_dim + vid * vec_size]); + md_smem[ty] = m; + md_smem[ty + 1] = d; + } + __syncthreads(); + + if (ty == 0) { + prefill_softmax_state_t st; + st.init(); +#pragma unroll + for (int i = 0; i < bdy / 2; i++) { + Load(&smem[i * head_dim + vid * vec_size], &load_vec); + const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1]; + st.merge(load_vec, m_tmp, d_tmp); + } + + if (sinks) { + float current_sink = static_cast(sinks[hid]); + st.normalize(current_sink); + } else { + st.normalize(); + } + + const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size; + AlignedVector shift_bias_vec; + AlignedVector smooth_weight_vec; + AlignedVector out_vec; + if (shift_bias) { + Load(shift_bias + shift_smooth_offset, &shift_bias_vec); + Load(smooth_weight + shift_smooth_offset, + &smooth_weight_vec); + } + +#pragma unroll + for (int i = 0; i < vec_size; ++i) { + StoreFunc()(st.o, + shift_bias_vec, + smooth_weight_vec, + out_vec, + quant_max_bound, + quant_min_bound, + in_scale, + i); + } + Store( + out_vec, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]); + } + __syncthreads(); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} diff --git a/custom_ops/gpu_ops/append_attention/config_for_attention.cu b/custom_ops/gpu_ops/append_attention/config_for_attention.cu new file mode 100644 index 00000000000..5e753aa93d9 --- /dev/null +++ b/custom_ops/gpu_ops/append_attention/config_for_attention.cu @@ -0,0 +1,406 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "cute/tensor.hpp" +#include "helper.h" +#include "paddle/extension.h" +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU +#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" +#include "paddle/phi/core/memory/memcpy.h" +#endif +#include "utils.cuh" + +template +__global__ void GetMaxLenKernel(const int *seq_lens_decoder, + const int *seq_lens_this_time, + const int *seq_lens_encoder, + int *max_lens, + const int batch_size) { + const int tid = threadIdx.x; + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + int max_len_this_time_this_thread = 0; + int max_len_encoder_this_thread = 0; + int max_len_decoder_this_thread = 0; + int max_len_this_thread = 0; + int max_just_dec_len_this_thread = 0; + int max_len_kv_this_thread = 0; + for (int i = tid; i < batch_size; i += blockDim.x) { + const int seq_len_this_time = seq_lens_this_time[i]; + const int seq_len_decoder = seq_lens_decoder[i]; + max_len_this_time_this_thread = + max(seq_len_this_time, max_len_this_time_this_thread); + max_len_encoder_this_thread = + max(seq_lens_encoder[i], max_len_encoder_this_thread); + max_len_decoder_this_thread = + max(seq_len_decoder, max_len_decoder_this_thread); + if (seq_len_this_time <= 0) continue; + const int max_just_dec_len_now = + seq_lens_encoder[i] > 0 ? 0 : seq_len_decoder; + max_len_this_thread = + max(seq_len_decoder + seq_len_this_time, max_len_this_thread); + max_just_dec_len_this_thread = + max(max_just_dec_len_this_thread, max_just_dec_len_now); + + if (seq_len_decoder == 0) continue; + max_len_kv_this_thread = + max(seq_len_this_time + seq_len_decoder, max_len_kv_this_thread); + } + int total_max_len_this_time = + BlockReduce(temp_storage) + .Reduce(max_len_this_time_this_thread, MaxOp()); + int total_max_len_encoder = + BlockReduce(temp_storage) + .Reduce(max_len_encoder_this_thread, MaxOp()); + int total_max_len_decoder = + BlockReduce(temp_storage) + .Reduce(max_len_decoder_this_thread, MaxOp()); + int total = + BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp()); + int total_just_dec = BlockReduce(temp_storage) + .Reduce(max_just_dec_len_this_thread, MaxOp()); + int total_max_len_kv = + BlockReduce(temp_storage).Reduce(max_len_kv_this_thread, MaxOp()); + if (tid == 0) { + max_lens[0] = total_max_len_this_time; + max_lens[1] = total_max_len_encoder; + max_lens[2] = total_max_len_decoder; + max_lens[3] = total; + max_lens[4] = total_just_dec; + max_lens[5] = total_max_len_kv; + } +} + +template +__global__ void config_decode_attn(const int *__restrict__ seq_lens_this_time, + const int *__restrict__ seq_lens_encoder, + const int *__restrict__ seq_lens_decoder, + int *__restrict__ block_indices, + int *__restrict__ num_blocks, + int *__restrict__ chunk_size, + const int bsz, + const int group_size, + const int kv_num_heads, + const int q_tile_size, + const int max_tokens_per_batch, + const int config_gridx) { + // one block one warp + const int tid = threadIdx.x, wid = threadIdx.y; + const uint32_t warp_size = blockDim.x; + __shared__ int num_block_all_shared[block_size]; + __shared__ int chunk_size_res[1]; + __shared__ int use_scheme_e_res[1]; + + const int lane_id = tid + wid * warp_size; + + // Step 1: compute num_block_all WITHOUT chunk splitting (Scheme E) + int num_block_no_chunk = 0; + for (int bid = 0; bid < bsz; bid++) { + if (seq_lens_this_time[bid] <= 0 || seq_lens_encoder[bid] > 0) { + continue; + } + int token_num_cur_batch = seq_lens_this_time[bid]; + int q_tile_num = div_up(token_num_cur_batch * group_size, q_tile_size); + num_block_no_chunk += q_tile_num * kv_num_heads; + } + + // Step 2: decide mode — Scheme E if enough blocks, else split-kv + // Adaptive strategy: prefer Scheme E (zero merge overhead) when blocks + // already fill all SMs. When splitting is needed, use the LARGEST + // chunk_size that still creates enough blocks to fill SMs, minimizing + // merge count while ensuring SM utilization. + // Target: at least sm_count*4 blocks to ensure 4+ waves for GPU utilization. + // Too few waves (e.g. 2 waves with target=sm_count*2) leaves SMs idle between + // waves; 4 waves is a balanced tradeoff between utilization and merge + // overhead. + const int target_blocks = config_gridx / 4; // sm_count * 4 + const bool use_scheme_e = (num_block_no_chunk >= target_blocks); + + if (use_scheme_e) { + // Scheme E: no chunk splitting, chunk_size = INT_MAX + if (tid == 0 && wid == 0) { + num_blocks[0] = num_block_no_chunk; + chunk_size[0] = INT_MAX; + chunk_size_res[0] = INT_MAX; + use_scheme_e_res[0] = 1; + } + } else { + // Split-kv: find the LARGEST chunk_size whose total blocks >= target_blocks + // This minimizes merge count while ensuring SM utilization. + int cur_chunk_size = min_chunk_size * (lane_id + 1); + int num_block_all = 0; + for (int bid = 0; bid < bsz; bid++) { + if (seq_lens_this_time[bid] <= 0 || seq_lens_encoder[bid] > 0) { + continue; + } + int token_num_cur_batch = seq_lens_this_time[bid]; + int kv_len_cur_batch = seq_lens_decoder[bid] + token_num_cur_batch; + int q_tile_num = div_up(token_num_cur_batch * group_size, q_tile_size); + int kv_chunk_num = div_up(kv_len_cur_batch, cur_chunk_size); + num_block_all += q_tile_num * kv_chunk_num * kv_num_heads; + } + num_block_all_shared[lane_id] = num_block_all; + __syncthreads(); + + int chunk_size_best; + int num_block_all_best; + if (tid == 0 && wid == 0) { + // Search from largest chunk_size to smallest: + // pick the first (largest) chunk_size with enough blocks + chunk_size_best = min_chunk_size; // fallback: smallest chunk + num_block_all_best = num_block_all_shared[0]; + for (int i = block_size - 1; i >= 0; i--) { + if (num_block_all_shared[i] >= target_blocks) { + chunk_size_best = min_chunk_size * (i + 1); + num_block_all_best = num_block_all_shared[i]; + break; + } + } + // If even the smallest chunk doesn't reach target_blocks, + // use the smallest chunk to maximize parallelism + if (num_block_all_best < target_blocks) { + chunk_size_best = min_chunk_size; + num_block_all_best = num_block_all_shared[0]; + } + num_blocks[0] = num_block_all_best; + chunk_size[0] = chunk_size_best; + chunk_size_res[0] = chunk_size_best; + use_scheme_e_res[0] = 0; + } + } + + __syncthreads(); + if (wid == 0) { + const bool use_scheme_e_local = use_scheme_e_res[0]; + const int chunk_size_best = chunk_size_res[0]; + + // one block one warp + int prev_offset = 0; + // loop on warp tile:[base, base+32) + for (int base = 0; base < bsz; base += warp_size) { + const int bid = base + tid; + int q_tile_num = 0; + int kv_chunk_num = 0; + + // calculate loop_times for bid + int num_block_all = 0; + if (bid < bsz) { + int token_num_cur_batch = seq_lens_this_time[bid]; + if (seq_lens_encoder && seq_lens_encoder[bid] > 0) { + token_num_cur_batch = 0; + } + q_tile_num = div_up(token_num_cur_batch * group_size, q_tile_size); + if (use_scheme_e_local) { + num_block_all += q_tile_num * kv_num_heads; + } else { + int kv_len_cur_batch = seq_lens_decoder[bid] + token_num_cur_batch; + kv_chunk_num = div_up(kv_len_cur_batch, chunk_size_best); + num_block_all += q_tile_num * kv_chunk_num * kv_num_heads; + } + } + + // prefix sum for each lane, get the start offset in this tile + // inclusive scan + int x = num_block_all; + for (int offset = 1; offset < warp_size; offset <<= 1) { + int y = __shfl_up_sync(0xffffffff, x, offset); + if (tid >= offset) x += y; + } + // exclusive prefix sum + int bid_offset = x - num_block_all; + int tile_sum = __shfl_sync(0xffffffff, x, warp_size - 1); + + // write batch_ids and tile_ids_per_batch + if (bid < bsz && num_block_all > 0) { + int write_base = prev_offset + bid_offset; + if (use_scheme_e_local) { + for (int kv_head_id = 0; kv_head_id < kv_num_heads; kv_head_id++) { + for (int q_tile_id = 0; q_tile_id < q_tile_num; q_tile_id++) { + int idx = + write_base * 4 + (kv_head_id * q_tile_num + q_tile_id) * 4; + block_indices[idx] = bid; + block_indices[idx + 1] = kv_head_id; + block_indices[idx + 2] = 0; + block_indices[idx + 3] = q_tile_id; + } + } + } else { + for (int kv_head_id = 0; kv_head_id < kv_num_heads; kv_head_id++) { + for (int kv_chunk_id = 0; kv_chunk_id < kv_chunk_num; + kv_chunk_id++) { + for (int q_tile_id = 0; q_tile_id < q_tile_num; q_tile_id++) { + int idx = + write_base * 4 + + ((kv_head_id * kv_chunk_num + kv_chunk_id) * q_tile_num + + q_tile_id) * + 4; + block_indices[idx] = bid; + block_indices[idx + 1] = kv_head_id; + block_indices[idx + 2] = kv_chunk_id; + block_indices[idx + 3] = q_tile_id; + } + } + } + } + } + // for next warp tile + prev_offset += tile_sum; + } + } +} + +void ConfigForAttention( + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &seq_lens_this_time, + paddle::Tensor &block_indices, // Inplace, shape:[block_num,4], block's + // indices with 4 dimension[batch_idx, + // kv_head_idx, kv_chunk_idx, q_tile_idx] + paddle::Tensor &num_blocks, // Inplace + paddle::Tensor &chunk_size, // Inplace + paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU + const std::string cache_quant_type, + const int group_size, + const int kv_num_heads, + const int max_tokens_per_batch) { + auto stream = seq_lens_encoder.stream(); + int bsz = seq_lens_this_time.shape()[0]; + + paddle::Tensor max_len_tensor_gpu = + GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, + paddle::DataType::INT32, + seq_lens_this_time.place()); + + GetMaxLenKernel<1024><<<1, 1024, 0, stream>>>(seq_lens_decoder.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + max_len_tensor_gpu.data(), + bsz); + // Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU data + // is only for branching in attention. +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if (!phi::backends::gpu::IsCUDAGraphCapturing()) +#endif + max_len_tensor_cpu.copy_( + max_len_tensor_gpu, max_len_tensor_cpu.place(), false); + auto max_len_cpu_ptr = max_len_tensor_cpu.data(); + int max_just_dec_len_this_time = max_len_cpu_ptr[4]; + + const uint32_t block_indices_ele_num = block_indices.size(); + + // decoder + if (max_just_dec_len_this_time > 0) { + CUDA_CHECK(cudaMemsetAsync(block_indices.data(), + 0, + block_indices_ele_num * sizeof(int32_t), + stream)); + CUDA_CHECK( + cudaMemsetAsync(num_blocks.data(), 0, sizeof(int32_t), stream)); + CUDA_CHECK( + cudaMemsetAsync(chunk_size.data(), 0, sizeof(int32_t), stream)); + + int device; + CUDA_CHECK(cudaGetDevice(&device)); + int sm_cout; + CUDA_CHECK(cudaDeviceGetAttribute( + &sm_cout, cudaDevAttrMultiProcessorCount, device)); + const int config_gridx = sm_cout * 8; + + // 选择最优的q_tile_size + int q_tile_size = 32; + if (group_size * max_tokens_per_batch <= 16) { + q_tile_size = 16; + } + dim3 blocks(32, 4); + if (cache_quant_type == "cache_int4_zp") { + config_decode_attn<256, 128> + <<<1, blocks, 0, stream>>>(seq_lens_this_time.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + block_indices.data(), + num_blocks.data(), + chunk_size.data(), + bsz, + group_size, + kv_num_heads, + q_tile_size, + max_tokens_per_batch, + config_gridx); + } else { + config_decode_attn<128, 128> + <<<1, blocks, 0, stream>>>(seq_lens_this_time.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + block_indices.data(), + num_blocks.data(), + chunk_size.data(), + bsz, + group_size, + kv_num_heads, + q_tile_size, + max_tokens_per_batch, + config_gridx); + } + } +} + +std::vector> ConfigForAttentionInferShape( + const std::vector &seq_lens_encoder_shape, + const std::vector &seq_lens_decoder_shape, + const std::vector &seq_lens_this_time_shape, + const std::vector &num_blocks_shape, + const std::vector &chunk_size_shape, + const std::vector &max_len_tensor_cpu_shape, + const std::string cache_quant_type, + const int group_size, + const int kv_num_heads, + const int max_tokens_per_batch) { + return {}; +} + +std::vector ConfigForAttentionInferDtype( + const paddle::DataType &seq_lens_encoder_dtype, + const paddle::DataType &seq_lens_decoder_dtype, + const paddle::DataType &seq_lens_this_time_dtype, + const paddle::DataType &num_blocks_dtype, + const paddle::DataType &chunk_size_dtype, + const paddle::DataType &max_len_tensor_cpu_dtype, + const std::string cache_quant_type, + const int group_size, + const int kv_num_heads, + const int max_tokens_per_batch) { + return {}; +} + +PD_BUILD_STATIC_OP(config_for_attention) + .Inputs({ + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "block_indices", + "num_blocks", + "chunk_size", + "max_len_tensor_cpu", + }) + .Outputs({ + + }) + .Attrs({"cache_quant_type: std::string", + "group_size: int", + "kv_num_heads: int", + "max_tokens_per_batch: int"}) + .SetKernelFn(PD_KERNEL(ConfigForAttention)) + .SetInferShapeFn(PD_INFER_SHAPE(ConfigForAttentionInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(ConfigForAttentionInferDtype)); diff --git a/custom_ops/gpu_ops/append_attention/cu_tensor_map.cuh b/custom_ops/gpu_ops/append_attention/cu_tensor_map.cuh new file mode 100644 index 00000000000..ff84e1cd3f6 --- /dev/null +++ b/custom_ops/gpu_ops/append_attention/cu_tensor_map.cuh @@ -0,0 +1,124 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include +#include +#include + +using barrier = cuda::barrier; +namespace cde = cuda::device::experimental; + +template +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; +}; + +template <> +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; +}; + +template <> +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; +}; + +template <> +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; +}; + +template <> +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; +}; + +template +CUtensorMap makeTensorMapForKVCache(T const* addr, + uint32_t block_num, + uint32_t kv_num_head, + uint32_t second_size, + uint32_t last_size) { + CUtensorMap tensorMap{}; + + uint32_t elem_bytes = sizeof(T); + + uint32_t const last_size_bytes = elem_bytes * last_size; + // VLLM Layout + CUtensorMapDataType data_dtype = cu_tensor_map_type_traits::type; + constexpr uint32_t rank = 4; + uint64_t global_dims[] = {last_size, second_size, kv_num_head, block_num}; + uint64_t global_strides[] = {last_size_bytes, + second_size * last_size_bytes, + kv_num_head * second_size * last_size_bytes}; + + uint32_t box_dims[] = {last_size, second_size, 1, 1}; + uint32_t elem_strides[] = {1, 1, 1, 1}; + + auto const swizzle = [&] { + switch (last_size_bytes) { + case 128: + return CU_TENSOR_MAP_SWIZZLE_128B; + case 64: + return CU_TENSOR_MAP_SWIZZLE_64B; + default: + throw std::runtime_error("unsupported cache last_size"); + } + }(); + CUresult res = cuTensorMapEncodeTiled( + &tensorMap, + data_dtype, + rank, + reinterpret_cast(const_cast(addr)), + global_dims, + global_strides, + box_dims, + elem_strides, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + switch (res) { + case CUDA_SUCCESS: + printf("CUDA_SUCCESS!\n"); + break; + case CUDA_ERROR_INVALID_VALUE: + printf("CUDA_ERROR_INVALID_VALUE\n"); + break; + case CUDA_ERROR_OUT_OF_MEMORY: + printf("CUDA_ERROR_OUT_OF_MEMORY\n"); + break; + case CUDA_ERROR_NOT_INITIALIZED: + printf("CUDA_ERROR_NOT_INITIALIZED\n"); + break; + case CUDA_ERROR_DEINITIALIZED: + printf("CUDA_ERROR_DEINITIALIZED\n"); + break; + case CUDA_ERROR_PROFILER_DISABLED: + printf("CUDA_ERROR_PROFILER_DISABLED\n"); + break; + default: + throw std::runtime_error("unsupported res!"); + } + + return tensorMap; +} diff --git a/custom_ops/gpu_ops/append_attention/decode_append_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attention/decode_append_attention_c16_impl.cuh new file mode 100644 index 00000000000..ddf2c194c06 --- /dev/null +++ b/custom_ops/gpu_ops/append_attention/decode_append_attention_c16_impl.cuh @@ -0,0 +1,492 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "utils.cuh" +#include "attention_func.cuh" + +template +__global__ void decode_append_attention_c16_kernel( + AttentionParams params) { + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + + // Cache loop-invariant params fields into registers. + // Pass-by-value (no __grid_constant__) allows the compiler to cache + // struct fields, and explicit local variables guarantee no constant + // cache pressure in the grid-stride loop. + // Only cache frequently-used fields; rarely-used ones are accessed + // via params.xxx to reduce register pressure (Scheme I-A.2). + const auto qkv = params.qkv; + const auto cache_k = params.cache_k; + const auto cache_v = params.cache_v; + const auto seq_lens_q = params.seq_lens_q; + const auto seq_lens_kv = params.seq_lens_kv; + const auto block_table = params.block_table; + const auto cu_seqlens_q = params.cu_seqlens_q; + const auto block_indices = params.block_indices; + const auto mask_offset = params.mask_offset; + const auto attn_mask = params.attn_mask; + const auto tmp_o = params.tmp_o; + const auto tmp_m = params.tmp_m; + const auto tmp_d = params.tmp_d; + const float softmax_scale = params.softmax_scale; + const int q_num_heads = params.q_num_heads; + const int kv_num_heads = params.kv_num_heads; + + extern __shared__ __align__(128) uint8_t smem[]; + smem_t qo_smem(smem); + smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + (num_frags_x * 16 + BLOCK_SIZE) * HEAD_DIM * sizeof(T)); + + int total_block = params.num_blocks_ptr[0]; + int chunk_size = params.chunk_size_ptr[0]; + + for (int lane_idx = blockIdx.x; lane_idx < total_block; + lane_idx += gridDim.x) { + int4 indices = reinterpret_cast(block_indices)[lane_idx]; + int batch_idx = indices.x; + int kv_head_idx = indices.y; + int chunk_idx = indices.z; + int tile_idx = indices.w; + int q_head_idx = kv_head_idx * GROUP_SIZE; + + const uint32_t q_len = seq_lens_q[batch_idx]; + const int *block_table_now = + block_table + batch_idx * params.max_blocks_per_seq; + + constexpr uint32_t num_rows_per_block = num_frags_x * 16; + const uint32_t q_end = + min(q_len, div_up((tile_idx + 1) * num_rows_per_block, GROUP_SIZE)); + const uint32_t kv_len = seq_lens_kv[batch_idx] + q_len; + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_b_stride = HEAD_DIM; + + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + + const uint32_t chunk_start = chunk_idx * chunk_size; + const uint32_t chunk_end = min(kv_len, chunk_start + chunk_size); + const uint32_t chunk_len = chunk_end - chunk_start; + + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_start_seq_id = cu_seqlens_q[batch_idx]; + const uint32_t q_base_seq_id_this_block = tile_idx * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T *q_base_ptr = qkv + q_offset; + + T *o_base_ptr_T = tmp_o + + batch_idx * params.max_tokens_per_batch * + params.max_num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const int *mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + const bool *attn_mask_this_seq = + attn_mask ? attn_mask + + batch_idx * params.attn_mask_len * params.attn_mask_len + : nullptr; + + uint32_t q_smem_offset_r = + smem_t::get_permuted_offset(tid % 16, tid / 16); + + load_q_global_smem_multi_warps(q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale_multi_warps( + &qo_smem, softmax_scale); + + const uint32_t num_iterations = + div_up(CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_idx + 1) * num_rows_per_block, + GROUP_SIZE), + chunk_start))) + : chunk_len, + BLOCK_SIZE); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero(kv_len - q_len, chunk_start))) + : mask_offset ? 0 + : chunk_len) / + (BLOCK_SIZE); + + uint32_t k_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + tid % 16, tid / 16); + uint32_t kv_smem_offset_w = smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); + + uint32_t kv_idx = chunk_start; + int block_table_idx = kv_idx / BLOCK_SIZE; + int block_id = __ldg(&block_table_now[block_table_idx]); + int block_id_next = __ldg(&block_table_now[block_table_idx + 1]); + if (block_id_next < 0) { + block_id_next = 0; + } + const uint32_t const_offset = kv_head_idx * kv_h_stride + + (wid * 4 + tid / 8) * kv_b_stride + + tid % 8 * num_elems_per_128b(); + T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + + produce_kv_blockwise(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_b_stride, + kv_idx, + chunk_end); + commit_group(); + + produce_kv_blockwise(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_b_stride, + kv_idx, + chunk_end); + commit_group(); +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + if (iter + 1 < num_iterations) { + block_id_next = __ldg(&block_table_now[block_table_idx + 1]); + if (block_id_next < 0) { + block_id_next = 0; + } + } + + wait_group<1>(); + __syncthreads(); + + compute_qk( + &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + + if (iter >= mask_check_iteration || params.sliding_window > 0) { + mask_s(attn_mask_this_seq, + q_base_seq_id_this_block, + kv_idx + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + params.attn_mask_len, + s_frag, + mask_offset_this_seq, + params.sliding_window); + } + + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + kv_idx += BLOCK_SIZE; + block_table_idx++; + + block_id = block_id_next; + cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + produce_kv_blockwise(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_b_stride, + kv_idx, + chunk_end); + commit_group(); + wait_group<1>(); + __syncthreads(); + + compute_sfm_v( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag); + __syncthreads(); + + cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + produce_kv_blockwise(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_b_stride, + kv_idx, + chunk_end); + commit_group(); + } + wait_group<0>(); + __syncthreads(); + const bool do_normalize = (num_chunks_this_seq <= 1); + merge_block_res( + o_frag, + reinterpret_cast(smem), + m_frag, + d_frag, + wid, + tid, + do_normalize); + + write_o_reg_gmem_multi_warps( + o_frag, + &qo_smem, + o_base_ptr_T, + q_base_seq_id_this_block, + q_head_idx, + q_len, + q_n_stride * params.max_num_chunks, + HEAD_DIM); + + if (num_chunks_this_seq > 1) { + if (wid == 0) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + offset = ((batch_idx * params.max_tokens_per_batch + + qo_idx_now / GROUP_SIZE) * + params.max_num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } + } + } +} + +template +void DecodeAppendC16Attention(const AppendAttnMetaData &meta_data, + const paddle::Tensor &qkv, + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::Tensor &tmp_workspace, + const paddle::Tensor &tmp_m, + const paddle::Tensor &tmp_d, + const paddle::optional &attn_mask, + const paddle::optional &sinks, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + const paddle::Tensor &block_indices, + const paddle::Tensor &num_blocks, + const paddle::Tensor &chunk_size, + const int max_seq_len, + const int max_dec_len, + const int max_tokens_per_batch, + cudaStream_t &stream, + paddle::Tensor *out, + const int sliding_window) { + using NV_TYPE = typename type_traits::nv_type; + + auto num_heads = meta_data.q_num_heads; + auto kv_num_heads = meta_data.kv_num_heads; + auto token_num = meta_data.token_num; + auto bsz = meta_data.batch_size; + auto max_blocks_per_seq = meta_data.max_blocks_per_seq; + + constexpr uint32_t NUM_WARP_Q = 1; + constexpr uint32_t NUM_WARP_KV = NUM_WARPS_PER_BLOCK / NUM_WARP_Q; + constexpr uint32_t num_frags_x = Q_TILE_SIZE / (16 * NUM_WARP_Q); + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV; + constexpr uint32_t smem_size_0 = + (num_frags_x + NUM_WARP_KV * num_frags_z * 2) * 16 * HEAD_DIM * + sizeof(NV_TYPE); + constexpr uint32_t smem_size_1 = + NUM_WARPS_PER_BLOCK * num_frags_x * num_frags_y * 33 * 8 * sizeof(float) + + NUM_WARPS_PER_BLOCK * num_frags_x * 2 * 33 * 8; + constexpr uint32_t smem_size = + smem_size_0 > smem_size_1 ? smem_size_0 : smem_size_1; + + auto split_kv_kernel = decode_append_attention_c16_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + + const int max_num_chunks = div_up(max_seq_len, 128); + uint32_t attn_mask_len; + if (attn_mask) { + attn_mask_len = attn_mask.get().shape()[1]; + } else { + attn_mask_len = -1; + } + + AttentionParams params; + memset(¶ms, 0, sizeof(AttentionParams)); + + params.qkv = reinterpret_cast(const_cast(qkv.data())); + params.cache_k = + reinterpret_cast(const_cast(cache_k.data())); + params.cache_v = + reinterpret_cast(const_cast(cache_v.data())); + params.seq_lens_q = const_cast(seq_lens_q.data()); + params.seq_lens_kv = const_cast(seq_lens_kv.data()); + params.block_indices = const_cast(block_indices.data()); + params.num_blocks_ptr = const_cast(num_blocks.data()); + params.chunk_size_ptr = const_cast(chunk_size.data()); + params.cu_seqlens_q = const_cast(cu_seqlens_q.data()); + params.block_table = const_cast(block_table.data()); + params.mask_offset = const_cast(meta_data.mask_offset); + params.attn_mask = + attn_mask ? const_cast(attn_mask.get().data()) : nullptr; + params.max_model_len = max_dec_len; + params.max_kv_len = max_dec_len; + params.max_blocks_per_seq = max_blocks_per_seq; + params.softmax_scale = 1.f / sqrt(HEAD_DIM); + params.tmp_o = + reinterpret_cast(const_cast(tmp_workspace.data())); + params.tmp_m = const_cast(tmp_m.data()); + params.tmp_d = const_cast(tmp_d.data()); + params.max_tokens_per_batch = max_tokens_per_batch; + params.attn_mask_len = + attn_mask ? attn_mask_len = attn_mask.get().shape()[1] : -1; + params.sliding_window = sliding_window; + params.q_num_heads = num_heads; + params.kv_num_heads = kv_num_heads; + params.max_num_chunks = max_num_chunks; + params.batch_size = meta_data.batch_size; + + int device; + CUDA_CHECK(cudaGetDevice(&device)); + int sm_cout; + CUDA_CHECK( + cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device)); + + dim3 grids(sm_cout * 8); + dim3 blocks(32, NUM_WARPS_PER_BLOCK); + + launchWithPdlWhenEnabled( + split_kv_kernel, grids, blocks, smem_size, stream, params); + + constexpr int vec_size = num_elems_per_128b(); + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); + dim3 blocks_merge(blockx, blocky); + launchWithPdlWhenEnabled( + merge_chunks_kernel, + grids_merge, + blocks_merge, + 0, + stream, + params.tmp_o, + params.tmp_m, + params.tmp_d, + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + (NV_TYPE *)nullptr, + (NV_TYPE *)nullptr, + sinks + ? reinterpret_cast(const_cast(sinks.get().data())) + : nullptr, + chunk_size.data(), + reinterpret_cast(out->data()), + 0.f, + 0.f, + -1, + max_seq_len, + max_num_chunks, + num_heads, + HEAD_DIM, + token_num, + max_tokens_per_batch); +} diff --git a/custom_ops/gpu_ops/append_attention/decode_append_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attention/decode_append_attention_c8_impl.cuh new file mode 100644 index 00000000000..6029f58b8f7 --- /dev/null +++ b/custom_ops/gpu_ops/append_attention/decode_append_attention_c8_impl.cuh @@ -0,0 +1,708 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "utils.cuh" +// #include "cu_tensor_map.cuh" +#include "attention_func.cuh" + +template +void print_params(AttentionParams const params) { + printf("max_model_len: %d\n", params.max_model_len); + printf("max_kv_len: %d\n", params.max_kv_len); + printf("max_blocks_per_seq: %d\n", params.max_blocks_per_seq); + printf("softmax_scale: %f\n", params.softmax_scale); + printf("quant_max_bound: %f\n", params.quant_max_bound); + printf("quant_min_bound: %f\n", params.quant_min_bound); + printf("max_tokens_per_batch: %d\n", params.max_tokens_per_batch); + printf("attn_mask_len: %d\n", params.attn_mask_len); + printf("sliding_window: %d\n", params.sliding_window); + printf("q_num_heads: %d\n", params.q_num_heads); + printf("kv_num_heads: %d\n", params.kv_num_heads); + printf("max_num_chunks: %d\n", params.max_num_chunks); + printf("max_tile_q: %d\n", params.max_tile_q); + printf("batch_size: %d\n", params.batch_size); +} + +template +__global__ void decode_append_attention_c8_kernel( + AttentionParams params) { + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + + // Cache loop-invariant params fields into registers. + // Pass-by-value (no __grid_constant__) allows the compiler to cache + // struct fields, and explicit local variables guarantee no constant + // cache pressure in the grid-stride loop. + // Only cache frequently-used fields; rarely-used ones are accessed + // via params.xxx to reduce register pressure (Scheme I-A.2). + const auto qkv = params.qkv; + const auto cache_k = params.cache_k; + const auto cache_v = params.cache_v; + const auto cache_k_scale = params.cache_k_scale; + const auto cache_v_scale = params.cache_v_scale; + const auto seq_lens_q = params.seq_lens_q; + const auto seq_lens_kv = params.seq_lens_kv; + const auto block_table = params.block_table; + const auto cu_seqlens_q = params.cu_seqlens_q; + const auto block_indices = params.block_indices; + const auto mask_offset = params.mask_offset; + const auto attn_mask = params.attn_mask; + const auto tmp_o = params.tmp_o; + const auto tmp_m = params.tmp_m; + const auto tmp_d = params.tmp_d; + const float softmax_scale = params.softmax_scale; + const int q_num_heads = params.q_num_heads; + const int kv_num_heads = params.kv_num_heads; + + extern __shared__ __align__(128) uint8_t smem[]; + smem_t qo_smem(smem); + smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); + smem_t k_scale_smem; + smem_t v_scale_smem; + T *k_smem_scale_ptr = nullptr; + T *v_smem_scale_ptr = nullptr; + + int total_block = params.num_blocks_ptr[0]; + int chunk_size = params.chunk_size_ptr[0]; + + for (int lane_idx = blockIdx.x; lane_idx < total_block; + lane_idx += gridDim.x) { + int4 indices = reinterpret_cast(block_indices)[lane_idx]; + int batch_idx = indices.x; + int kv_head_idx = indices.y; + int chunk_idx = indices.z; + int tile_idx = indices.w; + int q_head_idx = kv_head_idx * GROUP_SIZE; + + const uint32_t q_len = seq_lens_q[batch_idx]; + const int *block_table_now = + block_table + batch_idx * params.max_blocks_per_seq; + + T cache_k_scale_reg[IsDynamicC8 + ? num_frags_z * 2 + : (is_scale_channel_wise ? num_frags_y * 4 : 1)]; + T cache_v_scale_reg[IsDynamicC8 + ? num_frags_z * 4 + : (is_scale_channel_wise ? num_frags_y * 2 : 1)]; + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { + int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM; + const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx]; + cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1]; + cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8]; + cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9]; + } + scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM; + const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx]; + cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8]; + } + } else { + cache_k_scale_reg[0] = cache_k_scale[kv_head_idx]; + cache_v_scale_reg[0] = cache_v_scale[kv_head_idx]; + } + } + constexpr uint32_t num_rows_per_block = num_frags_x * 16; + const uint32_t q_end = + min(q_len, div_up((tile_idx + 1) * num_rows_per_block, GROUP_SIZE)); + const uint32_t kv_len = seq_lens_kv[batch_idx] + q_len; + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_head_k = + HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_blocksize = + BLOCK_SIZE / num_elems_per_128b(); + constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; + constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_b_stride = HEAD_DIM; + const uint32_t kv_d_stride = BLOCK_SIZE; + + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + + T *o_base_ptr_T = nullptr; + + const uint32_t chunk_start = chunk_idx * chunk_size; + const uint32_t chunk_end = min(kv_len, chunk_start + chunk_size); + const uint32_t chunk_len = chunk_end - chunk_start; + + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_start_seq_id = cu_seqlens_q[batch_idx]; + const uint32_t q_base_seq_id_this_block = tile_idx * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T *q_base_ptr = qkv + q_offset; + + o_base_ptr_T = tmp_o + + batch_idx * params.max_tokens_per_batch * + params.max_num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const int *mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + const bool *attn_mask_this_seq = + attn_mask ? attn_mask + + batch_idx * params.attn_mask_len * params.attn_mask_len + : nullptr; + + uint32_t q_smem_offset_r = + smem_t::get_permuted_offset(tid % 16, tid / 16); + load_q_global_smem_multi_warps(q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale_multi_warps( + &qo_smem, softmax_scale); + + if constexpr (IsDynamicC8) { + k_smem_scale_ptr = reinterpret_cast( + smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2); + v_smem_scale_ptr = k_smem_scale_ptr + NUM_WARP_KV * num_frags_z * 16; + k_scale_smem.base = reinterpret_cast(k_smem_scale_ptr); + v_scale_smem.base = reinterpret_cast(v_smem_scale_ptr); + } + + const uint32_t num_iterations = + div_up(CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_idx + 1) * num_rows_per_block, + GROUP_SIZE), + chunk_start))) + : chunk_len, + NUM_WARP_KV * num_frags_z * 16); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + tile_idx * num_rows_per_block / GROUP_SIZE, + chunk_start))) + : mask_offset ? 0 + : chunk_len) / + (NUM_WARP_KV * num_frags_z * 16); + + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + (wid / 2) * num_frags_y * 16 + 8 * (tid / 16) + tid % 8, + (wid % 2) * num_frags_z + (tid % 16) / 8); + + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 8 + tid / 4, tid % 4); + + uint32_t kv_idx_base = chunk_start; + const uint32_t const_k_offset = kv_head_idx * kv_h_stride + + (wid * 4 + tid / 8) * kv_b_stride + + tid % 8 * num_elems_per_128b(); + const uint32_t const_v_offset = kv_head_idx * kv_h_stride + + (wid * 8 + tid / 4) * kv_d_stride + + tid % 4 * num_elems_per_128b(); + + produce_k_blockwise_c8(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(k_scale_smem, + block_table_now, + cache_k_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); + + produce_v_blockwise_c8(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(v_scale_smem, + block_table_now, + cache_v_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + wait_group<1>(); + __syncthreads(); + + if constexpr (IsDynamicC8) { + produce_k_dynamic_scale_smem2reg(k_smem_scale_ptr, + cache_k_scale_reg); + } + + compute_qk_c8(&qo_smem, + &q_smem_offset_r, + &k_smem, + &k_smem_offset_r, + cache_k_scale_reg, + s_frag); + + if (iter >= mask_check_iteration || params.sliding_window > 0) { + mask_s(attn_mask_this_seq, + q_base_seq_id_this_block, + kv_idx_base + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + params.attn_mask_len, + s_frag, + mask_offset_this_seq, + params.sliding_window); + } + + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + kv_idx_base += NUM_WARP_KV * num_frags_z * 16; + produce_k_blockwise_c8(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(k_scale_smem, + block_table_now, + cache_k_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); + wait_group<1>(); + __syncthreads(); + + if constexpr (IsDynamicC8) { + produce_v_dynamic_scale_smem2reg(v_smem_scale_ptr, + cache_v_scale_reg); + } + + compute_sfm_v_c8_iter_sq_bvec( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg); + __syncthreads(); + + produce_v_blockwise_c8(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(v_scale_smem, + block_table_now, + cache_v_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); + } + wait_group<0>(); + __syncthreads(); + const bool do_normalize = (num_chunks_this_seq <= 1); + merge_block_res( + o_frag, + reinterpret_cast(smem), + m_frag, + d_frag, + wid, + tid, + do_normalize); + + write_o_reg_gmem_multi_warps( + o_frag, + &qo_smem, + o_base_ptr_T, + q_base_seq_id_this_block, + q_head_idx, + q_len, + q_n_stride * params.max_num_chunks, + HEAD_DIM); + + if (num_chunks_this_seq > 1) { + if (wid == 0) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + offset = ((batch_idx * params.max_tokens_per_batch + + qo_idx_now / GROUP_SIZE) * + params.max_num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } + } + } +} + +template +void DecodeAppendC8Attention(const AppendAttnMetaData &meta_data, + const paddle::Tensor &qkv, + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::Tensor &tmp_workspace, + const paddle::Tensor &tmp_m, + const paddle::Tensor &tmp_d, + const paddle::optional &attn_mask, + const paddle::Tensor &cache_k_scale, + const paddle::Tensor &cache_v_scale, + const paddle::optional &sinks, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + const paddle::Tensor &block_indices, + const paddle::Tensor &num_blocks, + const paddle::Tensor &chunk_size, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + cudaStream_t &stream, + paddle::Tensor *out, + const int sliding_window) { + using NV_TYPE = typename type_traits::nv_type; + + auto num_heads = meta_data.q_num_heads; + auto kv_num_heads = meta_data.kv_num_heads; + auto token_num = meta_data.token_num; + auto bsz = meta_data.batch_size; + auto max_blocks_per_seq = meta_data.max_blocks_per_seq; + + constexpr uint32_t NUM_WARP_Q = 1; + constexpr uint32_t NUM_WARP_KV = NUM_WARPS_PER_BLOCK / NUM_WARP_Q; + constexpr uint32_t num_frags_x = Q_TILE_SIZE / (16 * NUM_WARP_Q); + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + + auto *allocator = paddle::GetAllocator(qkv.place()); + + bool is_scale_channel_wise = false; + if (cache_k_scale.dims()[0] == HEAD_DIM * kv_num_heads) { + is_scale_channel_wise = true; + } + + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2; + constexpr uint32_t smem_size_0 = + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 + + NUM_WARP_KV * num_frags_z * 16 * sizeof(T) * 2; + constexpr uint32_t smem_size_1 = + NUM_WARPS_PER_BLOCK * num_frags_x * num_frags_y * 33 * 8 * sizeof(float) + + NUM_WARPS_PER_BLOCK * num_frags_x * 2 * 33 * 8; + constexpr uint32_t smem_size = + smem_size_0 > smem_size_1 ? smem_size_0 : smem_size_1; + + auto split_kv_kernel = decode_append_attention_c8_kernel; + if (is_scale_channel_wise) { + split_kv_kernel = decode_append_attention_c8_kernel; + } + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + + const int max_num_chunks = div_up(max_seq_len, 128); + uint32_t attn_mask_len; + if (attn_mask) { + attn_mask_len = attn_mask.get().shape()[1]; + } else { + attn_mask_len = -1; + } + + AttentionParams params; + memset(¶ms, 0, sizeof(AttentionParams)); + + params.qkv = reinterpret_cast(const_cast(qkv.data())); + params.cache_k = const_cast(cache_k.data()); + params.cache_v = const_cast(cache_v.data()); + params.cache_k_scale = + reinterpret_cast(const_cast(cache_k_scale.data())); + params.cache_v_scale = + reinterpret_cast(const_cast(cache_v_scale.data())); + params.seq_lens_q = const_cast(seq_lens_q.data()); + params.seq_lens_kv = const_cast(seq_lens_kv.data()); + params.block_indices = const_cast(block_indices.data()); + params.num_blocks_ptr = const_cast(num_blocks.data()); + params.chunk_size_ptr = const_cast(chunk_size.data()); + params.cu_seqlens_q = const_cast(cu_seqlens_q.data()); + params.block_table = const_cast(block_table.data()); + params.mask_offset = const_cast(meta_data.mask_offset); + params.attn_mask = + attn_mask ? const_cast(attn_mask.get().data()) : nullptr; + params.max_model_len = max_dec_len; + params.max_kv_len = max_dec_len; + params.max_blocks_per_seq = max_blocks_per_seq; + params.softmax_scale = 1.f / sqrt(HEAD_DIM); + params.quant_max_bound = quant_max_bound; + params.quant_min_bound = quant_min_bound; + params.tmp_o = + reinterpret_cast(const_cast(tmp_workspace.data())); + params.tmp_m = const_cast(tmp_m.data()); + params.tmp_d = const_cast(tmp_d.data()); + params.max_tokens_per_batch = max_tokens_per_batch; + params.attn_mask_len = + attn_mask ? attn_mask_len = attn_mask.get().shape()[1] : -1; + params.sliding_window = sliding_window; + params.q_num_heads = num_heads; + params.kv_num_heads = kv_num_heads; + params.max_num_chunks = max_num_chunks; + params.batch_size = meta_data.batch_size; + + int device; + CUDA_CHECK(cudaGetDevice(&device)); + int sm_cout; + CUDA_CHECK( + cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device)); + + dim3 grids(sm_cout * 8); + dim3 blocks(32, NUM_WARPS_PER_BLOCK); + + launchWithPdlWhenEnabled( + split_kv_kernel, grids, blocks, smem_size, stream, params); + + constexpr int vec_size = num_elems_per_128b(); + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); + dim3 blocks_merge(blockx, blocky); + launchWithPdlWhenEnabled( + merge_chunks_kernel, + grids_merge, + blocks_merge, + 0, + stream, + params.tmp_o, + params.tmp_m, + params.tmp_d, + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + (NV_TYPE *)nullptr, + (NV_TYPE *)nullptr, + sinks + ? reinterpret_cast(const_cast(sinks.get().data())) + : nullptr, + chunk_size.data(), + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + -1, + max_seq_len, + max_num_chunks, + num_heads, + HEAD_DIM, + token_num, + max_tokens_per_batch); +} diff --git a/custom_ops/gpu_ops/append_attention/mem_util.cuh b/custom_ops/gpu_ops/append_attention/mem_util.cuh new file mode 100644 index 00000000000..18788858923 --- /dev/null +++ b/custom_ops/gpu_ops/append_attention/mem_util.cuh @@ -0,0 +1,389 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include + +enum class SharedMemFillMode { kFillZero, kNoFill }; + +enum class PrefetchMode { kNoPrefetch, kPrefetch }; + +template +__device__ __forceinline__ void ldmatrix_m8n8x4_impl(uint32_t* R, T* smem_ptr) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +} + +template +__device__ __forceinline__ void ldmatrix_m8n8x4_trans_impl(uint32_t* R, + T* smem_ptr) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +} + +__device__ __forceinline__ void commit_group() { +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + {} +#else + asm volatile("cp.async.commit_group;\n" ::); +#endif +} + +template +__device__ __forceinline__ void wait_group() { +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + cooperative_groups::wait(cooperative_groups::this_thread_block()); +#else + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +#endif +} + +template +__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } else { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } +#else + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile( + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"( + smem_int_ptr), + "l"(gmem_ptr), + "n"(16), + "r"(16)); + } else { + asm volatile( + "cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(16), + "r"(16)); + } +#endif +} + +template +__device__ __forceinline__ void pred_load_128b(T* smem_ptr, + const T* gmem_ptr, + bool predicate) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 16 : 0; + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), + (void*)gmem_ptr, + src_in_bytes); + } else { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), + (void*)gmem_ptr, + src_in_bytes); + } + } else { + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } + } else { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } + } + } +#else + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 16 : 0; + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile( + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"( + smem_int_ptr), + "l"(gmem_ptr), + "n"(16), + "r"(src_in_bytes)); + } else { + asm volatile( + "cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(16), + "r"(src_in_bytes)); + } + } else { + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(16)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(16)); + } + } +#endif +} + +template +__device__ __forceinline__ void pred_load_64b(T* smem_ptr, + const T* gmem_ptr, + bool predicate) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 8 : 0; + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 8); + memcpy( + __cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, src_in_bytes); + } else { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 8); + } + } +#else + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 8 : 0; + asm volatile( + "cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(8), + "r"(src_in_bytes)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(8)); + } +#endif +} + +template +__device__ __forceinline__ void pred_load_32b(T* smem_ptr, + const T* gmem_ptr, + bool predicate) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 4 : 0; + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 4); + memcpy( + __cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, src_in_bytes); + } else { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 4); + } + } +#else + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 4 : 0; + asm volatile( + "cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(4), + "r"(src_in_bytes)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(4)); + } +#endif +} + +template +__device__ __forceinline__ void load(T* smem_ptr, const T* gmem_ptr) { + static_assert(num_bits == 128, "num_bits must be 128"); + load_128b(smem_ptr, gmem_ptr); +} + +template +__device__ __forceinline__ void pred_load(T* smem_ptr, + const T* gmem_ptr, + bool predicate) { + static_assert(num_bits == 128 || num_bits == 64 || num_bits == 32, + "num_bits must be 128, 64 or 32."); + if constexpr (num_bits == 128) { + pred_load_128b(smem_ptr, gmem_ptr, predicate); + } else if constexpr (num_bits == 64) { + pred_load_64b(smem_ptr, gmem_ptr, predicate); + } else if constexpr (num_bits == 32) { + pred_load_32b(smem_ptr, gmem_ptr, predicate); + } +} + +using b32_t = uint32_t; +using b64_t = uint2; +using b128_t = uint4; + +template +constexpr __host__ __device__ __forceinline__ uint32_t num_elems_per_128b() { + return sizeof(b128_t) / sizeof(T); +} + +struct smem_t { + // The base pointer. + b128_t* base; + __device__ __forceinline__ smem_t() : base(nullptr) {} + template + __device__ __forceinline__ smem_t(T* base) : base((b128_t*)base) {} + + template + static __device__ __forceinline__ uint32_t get_permuted_offset(uint32_t i, + uint32_t j) { + if constexpr (inv_stride <= 1) { + return i * stride + (j ^ (i % 8)); + } else { + return i / inv_stride * 8 + ((j + (i % inv_stride) * stride)) ^ + ((i / inv_stride) % 8); + } + } + + template + static __device__ __forceinline__ uint32_t + advance_offset_by_column(uint32_t offset, uint32_t step_idx) { + if constexpr (row_stride == 2) { + static_assert(step_size == 2, "Unsupported step size"); + return offset + step_size; + } else if constexpr (row_stride == 4) { + static_assert(step_size == 2 || step_size == 4, "Unsupported step size"); + if constexpr (step_size == 2) { + return (offset ^ 0x2) + (step_idx % 2 == 1) * 4; + } else { + return offset + step_size; + } + } else { + static_assert(step_size == 2 || step_size == 4 || step_size % 8 == 0, + "Unsupported step size"); + if constexpr (step_size == 2) { + return (offset ^ (0x2 + (0x4 * (step_idx % 2 == 1)))) + + (step_idx % 4 == 3) * 8; + } else if constexpr (step_size == 4) { + return (offset ^ 0x4) + (step_idx % 2 == 1) * 8; + } else { + // step_size % 8 == 0 + return offset + step_size; + } + } + } + + template + static __device__ __forceinline__ uint32_t + advance_offset_by_row(uint32_t offset) { + if constexpr (row_stride == 2) { + static_assert(step_size == 16 || step_size % 32 == 0, + "Unsupported step size"); + if constexpr (step_size == 16) { + return (offset ^ 0x4) + step_size * row_stride; + } else { + // step_size % 32 == 0 + return offset + step_size * row_stride; + } + } else if constexpr (row_stride == 4) { + static_assert(step_size == 8 || step_size % 16 == 0, + "Unsupported step size"); + if constexpr (step_size == 8) { + return (offset ^ 0x4) + step_size * row_stride; + } else { + // step_size % 16 == 0 + return offset + step_size * row_stride; + } + } else { + static_assert(step_size == 4 || step_size % 8 == 0, + "Unsupported step size"); + if constexpr (step_size == 4) { + return (offset ^ 0x4) + step_size * row_stride; + } else { + // step_size % 8 == 0 + return offset + step_size * row_stride; + } + } + } + + __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t offset, + uint32_t* R) { + b128_t* smem_ptr = base + offset; + ldmatrix_m8n8x4_impl(R, smem_ptr); + } + + __device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t offset, + uint32_t* R) { + b128_t* smem_ptr = base + offset; + ldmatrix_m8n8x4_trans_impl(R, smem_ptr); + } + + template + __device__ __forceinline__ void load_128b_async(uint32_t offset, + const T* gptr, + bool predicate) { + b128_t* smem_ptr = base + offset; + pred_load_128b( + smem_ptr, reinterpret_cast(gptr), predicate); + } + + template + __device__ __forceinline__ void load_128b_async(uint32_t offset, + const T* gptr) { + b128_t* smem_ptr = base + offset; + load_128b(smem_ptr, + reinterpret_cast(gptr)); + } + + template + __device__ __forceinline__ void store_128b(uint32_t offset, T* gptr) { + *reinterpret_cast(gptr) = *(base + offset); + } +}; diff --git a/custom_ops/gpu_ops/append_attention/mma_tensor_op.cuh b/custom_ops/gpu_ops/append_attention/mma_tensor_op.cuh new file mode 100644 index 00000000000..8662ee298d2 --- /dev/null +++ b/custom_ops/gpu_ops/append_attention/mma_tensor_op.cuh @@ -0,0 +1,296 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include + +enum class MMAMode { + kInit = 0U, + kInplaceUpdate = 1U, +}; + +template +__device__ __forceinline__ void mma_sync_m16n16k32_row_col_i8i8i32( + int* C, // 8 + uint32_t* A, // 4 + uint32_t* B) { // 4 + if constexpr (mma_mode == MMAMode::kInit) { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "r"(0), + "r"(0), + "r"(0), + "r"(0)); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "r"(0), + "r"(0), + "r"(0), + "r"(0)); + } else { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "r"(C[0]), + "r"(C[1]), + "r"(C[2]), + "r"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "r"(C[4]), + "r"(C[5]), + "r"(C[6]), + "r"(C[7])); + } +} + +template +__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32( + float* C, uint32_t* A, uint32_t* B) { + if constexpr (mma_mode == MMAMode::kInit) { + if constexpr (std::is_same::value) { // fp16 + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "f"(0.f), + "f"(0.f), + "f"(0.f), + "f"(0.f)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "f"(0.f), + "f"(0.f), + "f"(0.f), + "f"(0.f)); + } else { // bf16 + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "f"(0.f), + "f"(0.f), + "f"(0.f), + "f"(0.f)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "f"(0.f), + "f"(0.f), + "f"(0.f), + "f"(0.f)); + } + } else { + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7])); + } else { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7])); + } + } +} + +template +__device__ __forceinline__ void rowsum_f16f16f32(float* d, DType* s) { + static_assert(sizeof(DType) == 2, "DType must be 16bit floating data type"); + uint32_t* s_u32 = (uint32_t*)(s); + if constexpr (std::is_same::value) { + asm volatile( + "{\n" + ".reg .f32 ph;\n" + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, ph, %1, ph}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[0]), + "r"(s_u32[1]), + "r"(s_u32[2]), + "r"(s_u32[3]), + "r"(1006648320), + "r"(1006648320), + "f"(d[0]), + "f"(d[1])); + } else { + asm volatile( + "{\n" + ".reg .f32 ph;\n" + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, ph, %1, ph}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[0]), + "r"(s_u32[1]), + "r"(s_u32[2]), + "r"(s_u32[3]), + "r"(1065369472), + "r"(1065369472), + "f"(d[0]), + "f"(d[1])); + } +} diff --git a/custom_ops/gpu_ops/append_attention/template_config.json b/custom_ops/gpu_ops/append_attention/template_config.json new file mode 100644 index 00000000000..044e0f149bb --- /dev/null +++ b/custom_ops/gpu_ops/append_attention/template_config.json @@ -0,0 +1,78 @@ +{ + "multiquery_attention_c8": { + "name": "decode_append_attention_c8_kernel", + "function_name": "decode_append_attention_c8_kernel", + "impl_file": "decode_append_attention_c8_impl.cuh", + "template_params": [ + "T", + "CacheT", + "GROUP_SIZE", + "CAUSAL", + "NUM_WARPS", + "NUM_WARP_Q", + "NUM_WARP_KV", + "HEAD_DIM", + "BLOCK_SIZE", + "num_frags_x", + "num_frags_y", + "num_frags_z", + "is_scale_channel_wise", + "IsFP8", + "IsDynamicC8" + ], + "dispatch_params": { + "T": ["half", "__nv_bfloat16"], + "CacheT": ["uint8_t"], + "GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16], + "CAUSAL": [0, 1], + "NUM_WARPS": [4], + "NUM_WARP_Q": [1], + "NUM_WARP_KV": [4], + "HEAD_DIM": [128], + "BLOCK_SIZE": [64], + "num_frags_x": [1, 2], + "num_frags_y": [8], + "num_frags_z": [1], + "is_scale_channel_wise": [0, 1], + "IsFP8": [0, 1], + "IsDynamicC8": [0, 1] + }, + "max_instances_per_file": 80, + "file_prefix": "decode_append_attention_c8", + "function_signature": "template __global__ void {function_name}{template_args}(AttentionParams{params_template_args} params);\n\n" + }, + "multiquery_attention_c16": { + "name": "decode_append_attention_c16_kernel", + "function_name": "decode_append_attention_c16_kernel", + "impl_file": "decode_append_attention_c16_impl.cuh", + "template_params": [ + "T", + "GROUP_SIZE", + "CAUSAL", + "NUM_WARPS", + "NUM_WARP_Q", + "NUM_WARP_KV", + "HEAD_DIM", + "BLOCK_SIZE", + "num_frags_x", + "num_frags_z", + "num_frags_y" + ], + "dispatch_params": { + "T": ["half", "__nv_bfloat16"], + "GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16], + "CAUSAL": [0, 1], + "NUM_WARPS": [4], + "NUM_WARP_Q": [1], + "NUM_WARP_KV": [4], + "HEAD_DIM": [128], + "BLOCK_SIZE": [64], + "num_frags_x": [1, 2], + "num_frags_z": [1], + "num_frags_y": [8] + }, + "max_instances_per_file": 80, + "file_prefix": "decode_append_attention_c16", + "function_signature": "template __global__ void {function_name}{template_args}(AttentionParams{params_template_args} params);\n\n" + } +} diff --git a/custom_ops/gpu_ops/append_attention/utils.cuh b/custom_ops/gpu_ops/append_attention/utils.cuh new file mode 100644 index 00000000000..536867eee16 --- /dev/null +++ b/custom_ops/gpu_ops/append_attention/utils.cuh @@ -0,0 +1,710 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include +#include "helper.h" +#include "mem_util.cuh" + +#define NUM_WARPS_PER_BLOCK 4 +#define NUM_THREADS_PER_BLOCK 128 +#define kWarpSize 32 + +#define HOSTDEVICE __host__ __device__ + +/*-------------------------------------traits-----------------------------------------*/ +template +struct type_traits { + using paddle_type = T; + using phi_type = T; + using nv_type = T; + using nv2_type = T; +}; + +// template <> +// struct type_traits { +// using paddle_type = paddle::DataType::FLOAT16; +// using phi_type = phi::dtype::float16; +// using nv_type = half; +// using nv2_type = half2; +// }; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::float16; + using nv_type = half; + using nv2_type = half2; +}; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::float16; + using nv_type = half; + using nv2_type = half2; +}; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::float16; + using nv_type = half; + using nv2_type = half2; +}; + +// template <> +// struct type_traits { +// using paddle_type = paddle::DataType::FLOAT16; +// using phi_type = phi::dtype::bfloat16; +// using nv_type = __nv_bfloat16; +// using nv2_type = __nv_bfloat162; +// }; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::bfloat16; + using nv_type = __nv_bfloat16; + using nv2_type = __nv_bfloat162; +}; + +template <> +struct type_traits<__nv_bfloat16> { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::bfloat16; + using nv_type = __nv_bfloat16; + using nv2_type = __nv_bfloat162; +}; + +template <> +struct type_traits<__nv_bfloat162> { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::bfloat16; + using nv_type = __nv_bfloat16; + using nv2_type = __nv_bfloat162; +}; + +// template <> +// struct type_traits { +// using paddle_type = paddle::DataType::FLOAT8_E4M3FN; +// using phi_type = phi::dtype::float8_e4m3fn; +// using nv_type = __nv_fp8_e4m3; +// using nv2_type = __nv_fp8x2_e4m3; +// }; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT8_E4M3FN; + using phi_type = phi::dtype::float8_e4m3fn; + using nv_type = __nv_fp8_e4m3; + using nv2_type = __nv_fp8x2_e4m3; +}; + +template <> +struct type_traits<__nv_fp8_e4m3> { + // using paddle_type = paddle::DataType::FLOAT8_E4M3FN; + using phi_type = phi::dtype::float8_e4m3fn; + using nv_type = __nv_fp8_e4m3; + using nv2_type = __nv_fp8x2_e4m3; +}; + +template <> +struct type_traits<__nv_fp8x2_e4m3> { + // using paddle_type = paddle::DataType::FLOAT8_E4M3FN; + using phi_type = phi::dtype::float8_e4m3fn; + using nv_type = __nv_fp8_e4m3; + using nv2_type = __nv_fp8x2_e4m3; +}; +/*---------------------------------1. type + * traits--------------------------------------*/ + +/*---------------------------------2. fast + * convert--------------------------------------*/ +inline __device__ static void convert_fp8(half* result, + const uint32_t& source) { + printf("Do not support fp8 to half although it's very easy.\n"); +} + +inline __device__ static void convert_fp8(__nv_bfloat16* result, + const uint32_t& source) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) + uint32_t dest0; + uint32_t dest1; + asm volatile( + "{\n" + ".reg .b16 lo, hi;\n" + "mov.b32 {lo, hi}, %2;\n" + "cvt.rn.f16x2.e4m3x2 %0, lo;\n" + "cvt.rn.f16x2.e4m3x2 %1, hi;\n" + "}\n" + : "=r"(dest0), "=r"(dest1) + : "r"(source)); + + ((nv_bfloat162*)(result))[0] = + __float22bfloat162_rn(__half22float2(((half2*)(&dest0))[0])); + ((nv_bfloat162*)(result))[1] = + __float22bfloat162_rn(__half22float2(((half2*)(&dest1))[0])); +#else + printf("Do not support fp8 in arch < 890\n"); + asm("trap;"); +#endif +} + +inline __device__ static void convert_int8( + half* result, const uint32_t& source) { // 4 int8 each time + uint32_t* fp16_result_ptr = reinterpret_cast(result); + uint32_t const i8s = reinterpret_cast(source); + static constexpr uint32_t mask_for_elt_01 = 0x5150; + static constexpr uint32_t mask_for_elt_23 = 0x5352; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(fp16_result_ptr[0]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(fp16_result_ptr[1]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(fp16_result_ptr[0]) + : "r"(fp16_result_ptr[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(fp16_result_ptr[1]) + : "r"(fp16_result_ptr[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); +} + +inline __device__ static void convert_int8( + __nv_bfloat16* result, const uint32_t& source) { // 4 int8 each time + uint32_t* bf16_result_ptr = reinterpret_cast(result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + +#pragma unroll + for (int ii = 0; ii < 4; ++ii) { + fp32_intermediates[ii] -= 8388736.f; // (8388608.f + 128.f); + } + +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], + fp32_intermediates_casted[2 * ii + 1], + 0x7632); + } +} +/*---------------------------------2. fast + * convert--------------------------------------*/ + +/*---------------------------------3. vector + * cast--------------------------------------*/ +template +__forceinline__ HOSTDEVICE void vec_cast(dst_t* dst, const src_t* src) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = src[i]; + } +} + +template +__forceinline__ HOSTDEVICE void vec_cast(float* dst, + const half* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __half22float2(((half2*)src)[i]); + } +} + +template +__forceinline__ HOSTDEVICE void vec_cast(half* dst, + const float* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]); + } +} + +template +__forceinline__ HOSTDEVICE void vec_cast( + float* dst, const nv_bfloat16* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __bfloat1622float2(((nv_bfloat162*)src)[i]); + } +} + +template +__forceinline__ HOSTDEVICE void vec_cast(nv_bfloat16* dst, + const float* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((nv_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]); + } +} +/*---------------------------------3. vector + * cast--------------------------------------*/ + +/*-------------------------------------4. + * func-----------------------------------------*/ +__forceinline__ HOSTDEVICE int div_up(int a, int b) { return (a + b - 1) / b; } + +template +__inline__ __device__ T Rsqrt(T x); + +template <> +__inline__ __device__ float Rsqrt(float x) { + return rsqrt(x); +} + +template <> +__inline__ __device__ double Rsqrt(double x) { + return rsqrt(x); +} + +__device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x, + uint32_t y) { + return (x > y) ? x - y : 0U; +} + +template +inline HOSTDEVICE T roundWithTiesToEven(T x) { + T xLower = floor(x); + T xUpper = ceil(x); + // x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to + // even. + T dLower = x - xLower; + T dUpper = xUpper - x; + return static_cast( + (dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper) + ? xLower + : xUpper); +} + +template +HOSTDEVICE __forceinline__ uint8_t QuantToC8(const T scale, + const T value, + const float max_bound, + const float min_bound) { + uint8_t eight_bits; + float quant_value; + if constexpr (is_need_kv_quant) { + quant_value = static_cast(scale * value); + } else { + quant_value = static_cast(value); + } + if constexpr (RoundType == 0) { + quant_value = roundWithTiesToEven(quant_value); + } else { + quant_value = round(quant_value); + } + + if constexpr (IsFP8) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) + quant_value = quant_value > 448.0f ? 448.0f : quant_value; + quant_value = quant_value < -448.0f ? -448.0f : quant_value; + auto tmp = static_cast<__nv_fp8_e4m3>(quant_value); + eight_bits = *(reinterpret_cast(&tmp)); +#else + printf("Do not support fp8 in arch < 890\n"); + asm("trap;"); +#endif + } else { + quant_value = quant_value > 127.0f ? 127.0f : quant_value; + quant_value = quant_value < -127.0f ? -127.0f : quant_value; + eight_bits = static_cast(quant_value + 128.0f); + } + return eight_bits; +} + +template +inline __device__ static void convert_c8(T* result, const uint32_t& source) { + if constexpr (IsFP8) { + convert_fp8(result, source); + } else { + convert_int8(result, source); + } +} + +template +inline __device__ void WelfordCombine1(T b_m2, T* m2) { + *m2 += b_m2; +} + +template +__inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) { + *m2 = thread_m2; + for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) { + T b_m2 = __shfl_xor_sync(0xffffffff, *m2, mask); + WelfordCombine1(b_m2, m2); + } +} + +template +__inline__ __device__ void WelfordWarpAllReduce(T thread_m2, T* m2) { + WelfordWarpReduce(thread_m2, m2); +} + +#define CHECK_CUDA_CALL(func, ...) \ + { \ + cudaError_t e = (func); \ + if (e != cudaSuccess) { \ + std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \ + << ") " << __FILE__ << ": line " << __LINE__ \ + << " at function " << STR(func) << std::endl; \ + return e; \ + } \ + } + +__device__ __forceinline__ float2 fast_float2_mul(const float2& a, + const float2& b) { + float2 res; + // 使用向量化PTX指令同时处理x/y分量 + asm volatile( + "{\n" + " fma.rn.f32 %0, %2, %4, 0.0;\n" // res.x = a.x * b.x + " fma.rn.f32 %1, %3, %5, 0.0;\n" // res.y = a.y * b.y + "}" + : "=f"(res.x), "=f"(res.y) // 输出操作数 + : "f"(a.x), "f"(a.y), "f"(b.x), "f"(b.y) // 输入操作数 + ); + return res; +} + +__device__ __forceinline__ float2 fast_float2_fma(float2& a, + const float2& b, + const float2& c) { + float2 res; + // 使用向量化PTX指令同时处理x/y分量 + asm volatile( + "{\n" + " fma.rn.f32 %0, %2, %4, %6;\n" // res.x = a.x * b.x + " fma.rn.f32 %1, %3, %5, %7;\n" // res.y = a.y * b.y + "}" + : "=f"(res.x), "=f"(res.y) // 输出操作数 + : "f"(a.x), + "f"(a.y), + "f"(b.x), + "f"(b.y), + "f"(c.x), + "f"(c.y) // 输入操作数 + ); + return res; +} + +// __device__ __forceinline__ float2 fast_bfloat162_fma(__nv_bfloat162& a_bf162, +// const __nv_bfloat162& b_bf162, const __nv_bfloat162& c_bf162) { +// // 使用向量化PTX指令同时处理x/y分量 +// asm volatile ( +// "{\n" +// " fma.rn.b16 %0, %2, %4, %0;\n" // res.x = a.x * b.x +// " fma.rn.b16 %1, %3, %5, %1;\n" // res.y = a.y * b.y +// "}" +// : "=r"(a_bf162.x), "=r"(a_bf162.y) // 输出操作数 +// : "r"(b_bf162.x), "r"(b_bf162.y), +// "r"(c_bf162.x), "r"(c_bf162.y) // 输入操作数 +// ); +// float2 res = __bfloat1622float2_rn(a_bf162); +// return res; +// } + +__device__ __forceinline__ float2 fast_float2_sub_expf(const float2& a, + const float2& b) { + float2 res; + // 使用向量化减法指令(PTX sub.rn.f32) + asm volatile( + "{\n" + " sub.f32 %0, %2, %4;\n" // res.x = a.x - b.x + " sub.f32 %1, %3, %5;\n" // res.y = a.y - b.y + "}" + : "=f"(res.x), "=f"(res.y) // 输出操作数 + : "f"(a.x), "f"(a.y), "f"(b.x), "f"(b.y) // 输入操作数 + ); + res.x = expf(res.x); + res.y = expf(res.y); + return res; +} + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector& out_vec, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int i) { + out_vec[i] = static_cast(ori_out_vec[i]); + printf("Fatal! Unimplemented StoreFunc for cascade append attention\n"); + } +}; + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector& out_vec, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int i) { + float quant_value = + 127.0f * + static_cast((ori_out_vec[i] + shift_bias_vec[i]) * + smooth_weight_vec[i]) * + in_scale; + quant_value = rintf(quant_value); + quant_value = quant_value > 127.0f ? 127.0f : quant_value; + quant_value = quant_value < -127.0f ? -127.0f : quant_value; + out_vec[i] = static_cast(quant_value); + } +}; + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector<__nv_fp8_e4m3, VEC_SIZE>& out_vec, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int i) { + float quant_value = + quant_max_bound * static_cast(ori_out_vec[i]) * in_scale; + quant_value = quant_value > quant_max_bound ? quant_max_bound : quant_value; + quant_value = quant_value < quant_min_bound ? quant_min_bound : quant_value; + out_vec[i] = static_cast<__nv_fp8_e4m3>(quant_value); + } +}; + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector& out_vec, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int i) { + out_vec[i] = ori_out_vec[i]; + } +}; +/*-------------------------------------4. + * func-----------------------------------------*/ + +/*-----------------------------------5. + * dispatch---------------------------------------*/ +#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ + switch (head_dim) { \ + case 128: { \ + constexpr size_t HEAD_DIM = 128; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + PD_THROW("not support the head_dim"); \ + } \ + } + +#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ + if (group_size == 1) { \ + constexpr size_t GROUP_SIZE = 1; \ + __VA_ARGS__ \ + } else if (group_size == 2) { \ + constexpr size_t GROUP_SIZE = 2; \ + __VA_ARGS__ \ + } else if (group_size == 3) { \ + constexpr size_t GROUP_SIZE = 3; \ + __VA_ARGS__ \ + } else if (group_size == 4) { \ + constexpr size_t GROUP_SIZE = 4; \ + __VA_ARGS__ \ + } else if (group_size == 5) { \ + constexpr size_t GROUP_SIZE = 5; \ + __VA_ARGS__ \ + } else if (group_size == 6) { \ + constexpr size_t GROUP_SIZE = 6; \ + __VA_ARGS__ \ + } else if (group_size == 7) { \ + constexpr size_t GROUP_SIZE = 7; \ + __VA_ARGS__ \ + } else if (group_size == 8) { \ + constexpr size_t GROUP_SIZE = 8; \ + __VA_ARGS__ \ + } else if (group_size == 12) { \ + constexpr size_t GROUP_SIZE = 12; \ + __VA_ARGS__ \ + } else if (group_size == 14) { \ + constexpr size_t GROUP_SIZE = 14; \ + __VA_ARGS__ \ + } else if (group_size == 16) { \ + constexpr size_t GROUP_SIZE = 16; \ + __VA_ARGS__ \ + } else { \ + PD_THROW("not support the group_size", group_size); \ + } + +#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ + if (group_size == 1) { \ + constexpr size_t GROUP_SIZE = 1; \ + __VA_ARGS__ \ + } else if (group_size == 8) { \ + constexpr size_t GROUP_SIZE = 8; \ + __VA_ARGS__ \ + } else if (group_size == 12) { \ + constexpr size_t GROUP_SIZE = 12; \ + __VA_ARGS__ \ + } else if (group_size == 14) { \ + constexpr size_t GROUP_SIZE = 14; \ + __VA_ARGS__ \ + } else if (group_size == 16) { \ + constexpr size_t GROUP_SIZE = 16; \ + __VA_ARGS__ \ + } else { \ + PD_THROW("not support the group_size", group_size); \ + } + +#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \ + if (block_shape_q <= 16) { \ + constexpr size_t BLOCK_SHAPE_Q = 16; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } else if (block_shape_q <= 32) { \ + constexpr size_t BLOCK_SHAPE_Q = 32; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } + +#define DISPATCH_Q_TILE_SIZE( \ + group_size, max_tokens_per_batch, Q_TILE_SIZE, ...) \ + if (group_size * max_tokens_per_batch <= 16) { \ + constexpr size_t Q_TILE_SIZE = 16; \ + __VA_ARGS__ \ + } else { \ + constexpr size_t Q_TILE_SIZE = 32; \ + __VA_ARGS__ \ + } + +#define DISPATCH_CAUSAL(causal, CAUSAL, ...) \ + if (causal) { \ + constexpr bool CAUSAL = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool CAUSAL = false; \ + __VA_ARGS__ \ + } + +#define DISPATCH_BLOCKSHAPE_Q_SYSTEM( \ + block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \ + if (block_shape_q <= 16) { \ + constexpr size_t BLOCK_SHAPE_Q = 16; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } else if (block_shape_q <= 32) { \ + constexpr size_t BLOCK_SHAPE_Q = 32; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } + +#define DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, ...) \ + if (block_size == 64) { \ + constexpr size_t BLOCK_SIZE = 64; \ + __VA_ARGS__ \ + } + +#define DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, ...) \ + if (is_dynamic_cfp8) { \ + constexpr bool IsDynamicC8 = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool IsDynamicC8 = false; \ + __VA_ARGS__ \ + } + +#define DISPATCH_IS_FP8(is_fp8, IS_FP8, ...) \ + if (is_fp8) { \ + constexpr bool IS_FP8 = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool IS_FP8 = false; \ + __VA_ARGS__ \ + } + +struct AppendAttnMetaData { + int batch_size; + int block_size; + int q_num_heads; + int kv_num_heads; + int token_num; + int head_dims; + int head_dims_v; + int max_blocks_per_seq; + const int* mask_offset = nullptr; +}; + +template +struct AttentionParams { + T* __restrict__ qkv; + CacheT* __restrict__ cache_k; + CacheT* __restrict__ cache_v; + T* __restrict__ cache_k_scale; + T* __restrict__ cache_v_scale; + int* __restrict__ seq_lens_q; + int* __restrict__ seq_lens_kv; + int* __restrict__ block_indices; + int* __restrict__ num_blocks_ptr; + int* __restrict__ chunk_size_ptr; + int* __restrict__ cu_seqlens_q; + int* __restrict__ block_table; + int* __restrict__ mask_offset; + bool* __restrict__ attn_mask; + T* __restrict__ tmp_o; + float* __restrict__ tmp_m; + float* __restrict__ tmp_d; + int max_model_len; + int max_kv_len; + int max_blocks_per_seq; + float softmax_scale; + float quant_max_bound; + float quant_min_bound; + int num_blocks_x; + int attn_mask_len; + bool sliding_window; + int q_num_heads; + int kv_num_heads; + int max_num_chunks; + int max_tile_q; + int batch_size; + int token_num; + int head_dims; + int max_tokens_per_batch; +}; diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index bc6f7e0783a..30103c06f09 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -189,6 +189,84 @@ std::vector AppendAttentionWithOutput( const int sliding_window, const int sink_size); +std::vector DecoderWriteCacheWithRoPE( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& set_max_lengths, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_bias, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& kv_signal_data, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const bool speculate_decoder); + +std::vector DecodeAppendAttention( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& tmp_workspace, + const paddle::Tensor& tmp_m, + const paddle::Tensor& tmp_d, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& block_indices, + const paddle::Tensor& num_blocks, + const paddle::Tensor& chunk_size, + const paddle::Tensor& set_max_lengths, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& mask_offset, + const paddle::optional& sinks, + paddle::Tensor& fmha_out, + const std::string& cache_quant_type, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + const bool causal, + const int sliding_window); + +void ConfigForAttention(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + paddle::Tensor& block_indices, // Inplace + paddle::Tensor& num_blocks, // Inplace + paddle::Tensor& chunk_size, // Inplace + paddle::Tensor& max_len_tensor_cpu, // Inplace, CPU + const std::string cache_quant_type, + const int group_size, + const int kv_num_heads, + const int max_tokens_per_batch); + std::vector GQARopeWriteCacheKernel( const paddle::Tensor& qkv, const paddle::Tensor& key_cache, @@ -1942,4 +2020,28 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("per_token_group_fp8_quant", &PerTokenGroupQuantFp8, "per_token_group_quant_fp8"); + + /** + * decoder_write_cache_with_rope.cu + * decoder_write_cache_with_rope + */ + m.def("decoder_write_cache_with_rope", + &DecoderWriteCacheWithRoPE, + "decoder write cache with RoPE function"); + + /** + * decode_append_attention.cu + * decode_append_attention + */ + m.def("decode_append_attention", + &DecodeAppendAttention, + "decoder append attention function"); + + /** + * config_for_attention.cu + * config_for_attention + */ + m.def("config_for_attention", + &ConfigForAttention, + "config for attention function"); } diff --git a/custom_ops/gpu_ops/decode_append_attention.cu b/custom_ops/gpu_ops/decode_append_attention.cu new file mode 100644 index 00000000000..fb4c8c0793b --- /dev/null +++ b/custom_ops/gpu_ops/decode_append_attention.cu @@ -0,0 +1,428 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "append_attention/decode_append_attention_c8_impl.cuh" +#include "append_attention/decode_append_attention_c16_impl.cuh" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +template +class type2value; + +template <> +class type2value { + public: + static constexpr paddle::DataType value = paddle::DataType::BFLOAT16; +}; + +template <> +class type2value { + public: + static constexpr paddle::DataType value = paddle::DataType::FLOAT16; +}; + +std::vector DecodeAppendAttention( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& tmp_workspace, + const paddle::Tensor& tmp_m, + const paddle::Tensor& tmp_d, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& block_indices, + const paddle::Tensor& num_blocks, + const paddle::Tensor& chunk_size, + const paddle::Tensor& set_max_lengths, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& mask_offset, + const paddle::optional& sinks, + paddle::Tensor& fmha_out, + const std::string& cache_quant_type, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + const bool causal, + const int sliding_window) { + AppendAttnMetaData meta_data; + + const auto& qkv_dims = qkv.dims(); + const auto& key_cache_dims = key_cache.dims(); + meta_data.token_num = qkv_dims[0]; + meta_data.kv_num_heads = key_cache_dims[1]; + meta_data.head_dims = key_cache_dims[3]; + // TODO: trick method support c4, add attr head_dims in the future + if (cache_quant_type == "cache_int4_zp") { + meta_data.head_dims *= 2; + } + const int total_num_head = + qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims; + meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads; + const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads; + + meta_data.max_blocks_per_seq = block_tables.dims()[1]; + meta_data.block_size = key_cache.dims()[2]; + meta_data.batch_size = seq_lens_this_time.dims()[0]; + + if (mask_offset) { + meta_data.mask_offset = mask_offset.get().data(); + } + + const int max_just_dec_len_this_time = set_max_lengths.data()[4]; + const int max_kv_len_this_time = set_max_lengths.data()[5]; + + auto stream = qkv.stream(); + bool is_fp8 = + cache_quant_type == "cache_fp8" || cache_quant_type == "block_wise_fp8"; + bool is_dynamic_cfp8 = cache_quant_type == "block_wise_fp8"; + bool is_c16 = cache_quant_type == "none"; + + if (max_just_dec_len_this_time > 0) { + if (is_c16) { + DISPATCH_CAUSAL( + causal, + CAUSAL, + {DISPATCH_GQA_GROUP_SIZE( + group_size, + GROUP_SIZE, + {DISPATCH_HEAD_DIM( + meta_data.head_dims, + HEAD_DIM, + {DISPATCH_BLOCK_SIZE( + meta_data.block_size, + BLOCK_SIZE, + {DISPATCH_Q_TILE_SIZE( + group_size, max_tokens_per_batch, Q_TILE_SIZE, { + switch (qkv.dtype()) { + case paddle::DataType::BFLOAT16: { + DecodeAppendC16Attention( + meta_data, + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + sinks, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + max_input_length, + max_kv_len_this_time, + max_tokens_per_batch, + stream, + &fmha_out, + sliding_window); + break; + } + case paddle::DataType::FLOAT16: { + DecodeAppendC16Attention( + meta_data, + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + sinks, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + max_input_length, + max_kv_len_this_time, + max_tokens_per_batch, + stream, + &fmha_out, + sliding_window); + break; + } + default: + PD_THROW( + "NOT supported data type. " + "Only bfloat16 and float16 are " + "supported. "); + } + })})})})}) + } else { + DISPATCH_CAUSAL( + causal, + CAUSAL, + {DISPATCH_GQA_GROUP_SIZE( + group_size, + GROUP_SIZE, + {DISPATCH_HEAD_DIM( + meta_data.head_dims, + HEAD_DIM, + {DISPATCH_BLOCK_SIZE( + meta_data.block_size, + BLOCK_SIZE, + {DISPATCH_Q_TILE_SIZE( + group_size, + max_tokens_per_batch, + Q_TILE_SIZE, + {DISPATCH_DyCfp8( + is_dynamic_cfp8, + IsDynamicC8, + {DISPATCH_IS_FP8(is_fp8, IsFP8, { + switch (qkv.dtype()) { + case paddle::DataType::BFLOAT16: { + DecodeAppendC8Attention( + meta_data, + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + cache_quant_type == "block_wise_fp8" + ? cache_k_quant_scales.get() + : cache_k_dequant_scales.get(), + cache_quant_type == "block_wise_fp8" + ? cache_v_quant_scales.get() + : cache_v_dequant_scales.get(), + sinks, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + max_input_length, + max_kv_len_this_time, + quant_max_bound, + quant_min_bound, + max_tokens_per_batch, + stream, + &fmha_out, + sliding_window); + break; + } + case paddle::DataType::FLOAT16: { + DecodeAppendC8Attention( + meta_data, + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + cache_quant_type == "block_wise_fp8" + ? cache_k_quant_scales.get() + : cache_v_dequant_scales.get(), + cache_quant_type == "block_wise_fp8" + ? cache_v_quant_scales.get() + : cache_v_dequant_scales.get(), + sinks, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + max_input_length, + max_kv_len_this_time, + quant_max_bound, + quant_min_bound, + max_tokens_per_batch, + stream, + &fmha_out, + sliding_window); + break; + } + default: + PD_THROW( + "NOT supported data type. " + "Only bfloat16 and float16 are " + "supported. "); + } + })})})})})})}) + } + } + return {fmha_out}; +} + +std::vector> DecodeAppendAttentionInferShape( + const std::vector& qkv_shape, + const std::vector& key_cache_shape, + const std::vector& value_cache_shape, + const std::vector& tmp_workspace_shape, + const std::vector& tmp_m_shape, + const std::vector& tmp_d_shape, + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& seq_lens_this_time_shape, + const std::vector& batch_id_per_token_shape, + const std::vector& cu_seqlens_q_shape, + const std::vector& block_tables_shape, + const std::vector& block_indices_shape, + const std::vector& num_blocks_shape, + const std::vector& chunk_size_shape, + const std::vector& set_max_lengths_shape, + const paddle::optional>& attn_mask_shape, + const paddle::optional>& cache_k_quant_scales_shape, + const paddle::optional>& cache_v_quant_scales_shape, + const paddle::optional>& cache_k_dequant_scales_shape, + const paddle::optional>& cache_v_dequant_scales_shape, + const paddle::optional>& cache_k_zp_shape, + const paddle::optional>& cache_v_zp_shape, + const paddle::optional>& mask_offset_shape, + const paddle::optional>& sinks_shape, + const std::vector& fmha_out_shape, + const std::string& cache_quant_type, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + const bool causal, + const int sliding_window) { + return {fmha_out_shape}; +} + +std::vector DecodeAppendAttentionInferDtype( + const paddle::DataType& qkv_dtype, + const paddle::DataType& key_cache_dtype, + const paddle::DataType& value_cache_dtype, + const paddle::DataType& tmp_workspace_dtype, + const paddle::DataType& tmp_m_dtype, + const paddle::DataType& tmp_d_dtype, + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& batch_id_per_token_dtype, + const paddle::DataType& cu_seqlens_q_dtype, + const paddle::DataType& block_tables_dtype, + const paddle::DataType& block_indices_dtype, + const paddle::DataType& num_blocks_dtype, + const paddle::DataType& chunk_size_dtype, + const paddle::DataType& set_max_lengths_dtype, + const paddle::optional& attn_mask_dtype, + const paddle::optional& cache_k_quant_scales_dtype, + const paddle::optional& cache_v_quant_scales_dtype, + const paddle::optional& cache_k_dequant_scales_dtype, + const paddle::optional& cache_v_dequant_scales_dtype, + const paddle::optional& cache_k_zp_dtype, + const paddle::optional& cache_v_zp_dtype, + const paddle::optional& mask_offset_dtype, + const paddle::optional& sinks_dtype, + const paddle::DataType& fmha_out_dtype, + const std::string& cache_quant_type, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + const bool causal, + const int sliding_window) { + return {fmha_out_dtype}; +} + +PD_BUILD_STATIC_OP(decode_append_attention) + .Inputs({"qkv", + "key_cache", + "value_cache", + "tmp_workspace", + "tmp_m", + "tmp_d", + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "batch_id_per_token", + "cu_seqlens_q", + "block_tables", + "block_indices", + "num_blocks", + "chunk_size", + "set_max_lengths", + paddle::Optional("attn_mask"), + paddle::Optional("cache_k_quant_scales"), + paddle::Optional("cache_v_quant_scales"), + paddle::Optional("cache_k_dequant_scales"), + paddle::Optional("cache_v_dequant_scales"), + paddle::Optional("cache_k_zp"), + paddle::Optional("cache_v_zp"), + paddle::Optional("mask_offset"), + paddle::Optional("sinks"), + "fmha_out"}) + .Outputs({"fmha_out_out"}) + .SetInplaceMap({{"fmha_out", "fmha_out_out"}}) + .Attrs({ + "cache_quant_type: std::string", + "max_input_length: int", + "quant_max_bound: float", + "quant_min_bound: float", + "max_tokens_per_batch: int", + "causal: bool", + "sliding_window: int", + }) + .SetKernelFn(PD_KERNEL(DecodeAppendAttention)) + .SetInferShapeFn(PD_INFER_SHAPE(DecodeAppendAttentionInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(DecodeAppendAttentionInferDtype)); diff --git a/custom_ops/gpu_ops/decoder_write_cache_with_rope.cu b/custom_ops/gpu_ops/decoder_write_cache_with_rope.cu new file mode 100644 index 00000000000..7878e9926c5 --- /dev/null +++ b/custom_ops/gpu_ops/decoder_write_cache_with_rope.cu @@ -0,0 +1,326 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "append_attn/decoder_write_cache_with_rope_kernel.h" +#include "append_attn/speculate_write_cache_with_rope_kernel.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +template +class type2value; + +template <> +class type2value { + public: + static constexpr paddle::DataType value = paddle::DataType::BFLOAT16; +}; + +template <> +class type2value { + public: + static constexpr paddle::DataType value = paddle::DataType::FLOAT16; +}; + +std::vector DecoderWriteCacheWithRoPE( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& set_max_lengths, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_bias, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& kv_signal_data, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const bool speculate_decoder) { + auto stream = qkv.stream(); + + AppendAttnMetaData meta_data; + + const auto& qkv_dims = qkv.dims(); + const auto& key_cache_dims = key_cache.dims(); + meta_data.token_nums = qkv_dims[0]; + meta_data.kv_num_heads = key_cache_dims[1]; + meta_data.head_dims = key_cache_dims[3]; + // TODO: trick method support c4, add attr head_dims in the future + if (cache_quant_type_str == "cache_int4_zp") { + meta_data.head_dims *= 2; + } + const int total_num_head = + qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims; + meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads; + + meta_data.max_blocks_per_seq = block_tables.dims()[1]; + meta_data.block_size = key_cache.dims()[2]; + meta_data.batch_size = seq_lens_this_time.dims()[0]; + + const int max_just_dec_len_this_time = set_max_lengths.data()[4]; + + if (max_just_dec_len_this_time > 0) { + if (speculate_decoder) { + switch (qkv.dtype()) { + case paddle::DataType::BFLOAT16: { + SpeculateWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + rotary_embs, + NULL, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + stream, + const_cast(&qkv), + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + break; + } + case paddle::DataType::FLOAT16: { + SpeculateWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + rotary_embs, + NULL, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + stream, + const_cast(&qkv), + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + break; + } + default: + PD_THROW( + "NOT supported data type. " + "Only bfloat16 and float16 are supported. "); + } + } else { + switch (qkv.dtype()) { + case paddle::DataType::BFLOAT16: { + DecoderWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + cu_seqlens_q, + block_tables, + rotary_embs, + NULL, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + stream, + const_cast(&qkv), + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + break; + } + case paddle::DataType::FLOAT16: { + DecoderWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + cu_seqlens_q, + block_tables, + rotary_embs, + NULL, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + stream, + const_cast(&qkv), + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + break; + } + default: + PD_THROW( + "NOT supported data type. " + "Only bfloat16 and float16 are supported. "); + } + } + } + return {qkv}; +} + +std::vector> DecoderWriteCacheWithRoPEInferShape( + const std::vector& qkv_shape, + const std::vector& key_cache_shape, + const std::vector& value_cache_shape, + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& seq_lens_this_time_shape, + const std::vector& batch_id_per_token_shape, + const std::vector& cu_seqlens_q_shape, + const std::vector& block_tables_shape, + const std::vector& set_max_lengths_shape, + const paddle::optional>& rotary_embs_shape, + const paddle::optional>& qkv_bias_shape, + const paddle::optional>& cache_k_quant_scales_shape, + const paddle::optional>& cache_v_quant_scales_shape, + const paddle::optional>& cache_k_dequant_scales_shape, + const paddle::optional>& cache_v_dequant_scales_shape, + const paddle::optional>& cache_k_zp_shape, + const paddle::optional>& cache_v_zp_shape, + const paddle::optional>& kv_signal_data_shape, + const paddle::optional>& q_norm_weight_shape, + const paddle::optional>& k_norm_weight_shape, + const float rms_norm_eps, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const bool speculate_decoder) { + return {qkv_shape}; +} + +std::vector DecoderWriteCacheWithRoPEInferDtype( + const paddle::DataType& qkv_dtype, + const paddle::DataType& key_cache_dtype, + const paddle::DataType& value_cache_dtype, + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& batch_id_per_token_dtype, + const paddle::DataType& cu_seqlens_q_dtype, + const paddle::DataType& block_tables_dtype, + const paddle::DataType& set_max_lengths_dtype, + const paddle::optional& rotary_embs_dtype, + const paddle::optional& qkv_bias_dtype, + const paddle::optional& cache_k_quant_scales_dtype, + const paddle::optional& cache_v_quant_scales_dtype, + const paddle::optional& cache_k_dequant_scales_dtype, + const paddle::optional& cache_v_dequant_scales_dtype, + const paddle::optional& cache_k_zp_dtype, + const paddle::optional& cache_v_zp_dtype, + const paddle::optional& kv_signal_data_dtype, + const paddle::optional& q_norm_weight_dtype, + const paddle::optional& k_norm_weight_dtype, + const float rms_norm_eps, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const bool speculate_decoder) { + return {qkv_dtype}; +} + +PD_BUILD_STATIC_OP(decoder_write_cache_with_rope) + .Inputs({"qkv", + "key_cache", + "value_cache", + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "batch_id_per_token", + "cu_seqlens_q", + "block_tables", + "set_max_lengths", + paddle::Optional("rotary_embs"), + paddle::Optional("qkv_bias"), + paddle::Optional("cache_k_quant_scales"), + paddle::Optional("cache_v_quant_scales"), + paddle::Optional("cache_k_dequant_scales"), + paddle::Optional("cache_v_dequant_scales"), + paddle::Optional("cache_k_zp"), + paddle::Optional("cache_v_zp"), + paddle::Optional("kv_signal_data"), + paddle::Optional("q_norm_weight"), + paddle::Optional("k_norm_weight")}) + .Outputs({"qkv_out"}) + .SetInplaceMap({{"qkv", "qkv_out"}}) + .Attrs({ + "rms_norm_eps: float", + "cache_quant_type: std::string", + "use_neox_rotary_style: bool", + "rope_3d: bool", + "max_input_length: int", + "quant_max_bound: float", + "quant_min_bound: float", + "speculate_decoder: bool", + }) + .SetKernelFn(PD_KERNEL(DecoderWriteCacheWithRoPE)) + .SetInferShapeFn(PD_INFER_SHAPE(DecoderWriteCacheWithRoPEInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(DecoderWriteCacheWithRoPEInferDtype)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index b9a2fe90dbc..2bd73a5df70 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -543,6 +543,13 @@ def find_end_files(directory, end_str): sources += find_end_files(fp8_auto_gen_directory, ".cu") if cc >= 90 and nvcc_version >= 12.0: + # decode attention + os.system( + "python utils/auto_gen_template_attention.py --config gpu_ops/append_attention/template_config.json --output gpu_ops/append_attention/template_instantiation/autogen" + ) + sources += ["gpu_ops/decode_append_attention.cu"] + sources += ["gpu_ops/decoder_write_cache_with_rope.cu"] + sources += find_end_files("gpu_ops/append_attention", ".cu") # Hopper optimized mla sources += find_end_files("gpu_ops/mla_attn", ".cu") sources += ["gpu_ops/flash_mask_attn/flash_mask_attn.cu"] diff --git a/custom_ops/utils/auto_gen_template_attention.py b/custom_ops/utils/auto_gen_template_attention.py new file mode 100644 index 00000000000..5658f6645e7 --- /dev/null +++ b/custom_ops/utils/auto_gen_template_attention.py @@ -0,0 +1,227 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Universal template instantiation generator - fully based on configuration file template instantiation generation.""" + +import argparse +import json +import shutil +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +@dataclass +class TemplateConfig: + """Template configuration class.""" + + name: str # Function name + function_name: str # Actual function name + impl_file: str # Implementation file path + template_params: List[str] # Template parameter list (in order) + dispatch_params: Dict[str, List[Any]] # Dispatch parameters + data_types: Optional[List[Tuple[str, str, str]]] = None # Data type combinations (input_type, output_type, suffix) + max_instances_per_file: int = 60 # Maximum instances per file + file_prefix: str = "" # File prefix + function_signature: str = "" # Function signature template + + +class UniversalTemplateInstantiator: + """Universal template instantiator - fully based on configuration file.""" + + def __init__(self, config_file: str): + """Initialize the instantiator.""" + self.config_file = config_file + self.configs = self._load_configs() + + def _load_configs(self) -> Dict[str, TemplateConfig]: + """Load configuration file.""" + with open(self.config_file, "r", encoding="utf-8") as f: + config_data = json.load(f) + + configs = {} + for name, config_dict in config_data.items(): + config = TemplateConfig(**config_dict) + self._validate_config(config) + configs[name] = config + return configs + + def _validate_config(self, config: TemplateConfig): + """Validate configuration completeness.""" + for param_name in config.template_params: + if param_name not in config.dispatch_params: + raise ValueError(f"Template parameter '{param_name}' in '{config.name}' not found in dispatch_params") + + def _build_template_args(self, config: TemplateConfig, params: Dict[str, Any]) -> str: + """Build template arguments.""" + template_args_parts = [] + + for param_name in config.template_params: + if param_name in params: + template_args_parts.append(str(params[param_name])) + + else: + raise ValueError(f"Template parameter '{param_name}' not found in dispatch_params") + + return f"<{', '.join(template_args_parts)}>" + + def _build_params_template_args(self, params: Dict[str, Any]) -> str: + """Build template arguments for AttentionParams.""" + params_template_args = [] + if "T" in params: + params_template_args.append(str(params["T"])) + else: + raise ValueError("Template parameter 'T' not found in dispatch_params") + + if "CacheT" in params: + params_template_args.append(str(params["CacheT"])) + else: + # C16 kernels use AttentionParams - T is repeated for both args + params_template_args.append(str(params["T"])) + + return f"<{', '.join(params_template_args)}>" + + def _generate_function_signature( + self, config: TemplateConfig, template_args: str, params_template_args: str + ) -> str: + """Generate function signature.""" + if config.function_signature: + signature = config.function_signature.format( + function_name=config.function_name, + template_args=template_args, + params_template_args=params_template_args, + ) + + return signature + else: + raise ValueError(f"Function signature not found for {config.name}") + + def _generate_file_header(self, config: TemplateConfig) -> str: + """Generate file header.""" + return f"""// Generated by autogen_template_instantiation.py - Do not edit. + +#pragma once + +#include "../../{config.impl_file}" +""" + + def _generate_template_instantiation(self, config: TemplateConfig, params: Dict[str, Any]) -> str: + """Generate template instantiation.""" + template_args = self._build_template_args(config, params) + params_template_args = self._build_params_template_args(params) + return self._generate_function_signature(config, template_args, params_template_args) + + def _clean_output_directory(self, output_dir: str): + """Clean output directory before generating new files.""" + output_path = Path(output_dir) + if output_path.exists(): + shutil.rmtree(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + def generate_combinations_for_type(self, config: TemplateConfig) -> List[Dict[str, Any]]: + """Generate parameter combinations for specific type.""" + combinations = [] + + def _generate_recursive( + params_dict: Dict[str, List[Any]], current_params: Dict[str, Any], param_names: List[str] + ): + if not param_names: + combinations.append(current_params.copy()) + return + + param_name = param_names[0] + for value in params_dict[param_name]: + current_params[param_name] = value + _generate_recursive(params_dict, current_params, param_names[1:]) + + _generate_recursive(config.dispatch_params, {}, list(config.dispatch_params.keys())) + + return combinations + + def split_combinations(self, combinations: List[Dict[str, Any]], max_per_file: int) -> List[List[Dict[str, Any]]]: + """Split combinations into multiple files.""" + chunks = [] + for i in range(0, len(combinations), max_per_file): + chunk = combinations[i : i + max_per_file] + chunks.append(chunk) + return chunks + + def generate_file_content( + self, + config: TemplateConfig, + file_index: int, + combinations: List[Dict[str, Any]], + ) -> str: + """Generate file content.""" + content = self._generate_file_header(config) + + for params in combinations: + content += self._generate_template_instantiation(config, params) + + return content + + def generate_for_function_type(self, function_name: str, output_dir: str): + """Generate template instantiation files for specific function type.""" + if function_name not in self.configs: + raise ValueError(f"Function type '{function_name}' not found in config") + + config = self.configs[function_name] + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + combinations = self.generate_combinations_for_type(config) + if combinations: + chunks = self.split_combinations(combinations, config.max_instances_per_file) + for i, chunk in enumerate(chunks): + filename = f"{config.file_prefix}_part_{i:02d}.cu" + filepath = output_path / filename + content = self.generate_file_content(config, i, chunk) + with open(filepath, "w", encoding="utf-8") as f: + f.write(content) + + def generate_all(self, output_dir: str): + """Generate all configured function types.""" + self._clean_output_directory(output_dir) + for function_name in self.configs.keys(): + print(f"Generating template instantiations for {function_name}...") + self.generate_for_function_type(function_name, output_dir) + print(f"Completed generating {function_name} template instantiations.") + + +def main(): + """Main function.""" + parser = argparse.ArgumentParser(description="Universal template instantiation generator") + parser.add_argument( + "--config", + "-c", + type=str, + help="Configuration file path (JSON format)", + ) + parser.add_argument( + "--output", + "-o", + type=str, + help="Output directory", + ) + + args = parser.parse_args() + + try: + instantiator = UniversalTemplateInstantiator(args.config) + instantiator.generate_all(args.output) + except Exception as e: + print(f"Error: {e}") + + +if __name__ == "__main__": + main() diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 7e0f809d5d3..a2136d2075b 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -277,6 +277,10 @@ def _validate_split_kv_size(value: int) -> int: "FD_SiluAndMul_USE_PHI_SWIGLU": lambda: bool(int(os.getenv("FD_SiluAndMul_USE_PHI_SWIGLU", "0"))), # Whether to enable FP8 quantization with pow2scale. "FD_FP8_QUANT_WITH_POW2SCALE": lambda: bool(int(os.getenv("FD_FP8_QUANT_WITH_POW2SCALE", "0"))), + # enable kv cache manager v1 + "ENABLE_V1_KVCACHE_MANAGER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_MANAGER", "0")), + # enable decode attention + "USE_DECODE_ATTENTION": lambda: bool(int(os.getenv("USE_DECODE_ATTENTION", "0"))), } diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 15b657c249d..bd42e467afb 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -73,6 +73,8 @@ def allocate_launch_related_buffer( num_heads, kv_num_heads, block_size, + head_dim=128, + dtype="bfloat16", ): # Initialize AttentionBackend buffers assert num_heads % kv_num_heads == 0 @@ -107,6 +109,28 @@ def allocate_launch_related_buffer( res["kv_batch_ids"] = paddle.full([kv_max_tile_size], 0, dtype="int32") res["kv_tile_ids_per_batch"] = paddle.full([kv_max_tile_size], 0, dtype="int32") res["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + + # Decode attention split ops buffers + if envs.USE_DECODE_ATTENTION: + min_chunk_size = 128 + max_num_chunk = (max_model_len + min_chunk_size - 1) // min_chunk_size + q_tile_size = 16 if decoder_step_token_num * group_size <= 16 else 32 + q_tile_num = (decoder_step_token_num * group_size + q_tile_size - 1) // q_tile_size + res["decode_block_indices"] = paddle.full( + [max_batch_size * kv_num_heads * max_num_chunk * q_tile_num, 4], 0, dtype="int32" + ) + res["decode_num_blocks"] = paddle.full([1], 0, dtype="int32") + res["decode_chunk_size"] = paddle.full([1], 0, dtype="int32") + res["decode_tmp_workspace"] = paddle.full( + [max_batch_size * decoder_step_token_num, max_num_chunk, num_heads * head_dim], 0, dtype=dtype + ) + res["decode_tmp_m"] = paddle.full( + [max_batch_size * decoder_step_token_num, max_num_chunk, num_heads], 0, dtype="float32" + ) + res["decode_tmp_d"] = paddle.full( + [max_batch_size * decoder_step_token_num, max_num_chunk, num_heads], 0, dtype="float32" + ) + return res diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 2549f9f5d87..771a7fc1104 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -43,6 +43,9 @@ ) from fastdeploy.model_executor.layers.attention.ops import ( append_attention, + config_for_attention, + decode_append_attention, + decoder_write_cache_with_rope, get_attn_mask_q, get_block_shape_and_split_kv_block, gqa_rope_write_cache, @@ -272,8 +275,10 @@ def __init__( self.rope_3d = False # Note(ZKK): here must be consistent with append_attn_backend.py self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 1024)) + self.max_tokens_per_batch: int = self.speculate_max_draft_token_num + 1 if FLASH_ATTN_VERSION is None: init_flash_attn_version() + print(f"num_heads: {self.num_heads}, kv_num_heads: {self.kv_num_heads}") def get_attention_meta(self): """get_attention_meta""" @@ -414,6 +419,20 @@ def forward_mixed( ) else: forward_meta.attn_mask_q = None + if envs.USE_DECODE_ATTENTION: + config_for_attention( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.decode_block_indices, + forward_meta.decode_num_blocks, + forward_meta.decode_chunk_size, + forward_meta.max_len_tensor_cpu, + getattr(layer, "cache_quant_type_str", "none"), + self.group_size, + self.kv_num_heads, + self.max_tokens_per_batch, + ) use_fa_do_prefill = forward_meta.max_len_tensor_cpu[1].item() > 0 @@ -468,73 +487,148 @@ def forward_mixed( head_dim=self.head_dim, )[0].reshape([-1, self.attn_outputsize_tp]) - res_decoder = append_attention( - qkv, - cache_k, - cache_v, - forward_meta.seq_lens_encoder, - forward_meta.seq_lens_decoder, - forward_meta.seq_lens_this_time, - forward_meta.batch_id_per_token, - forward_meta.cu_seqlens_q, - forward_meta.block_tables, - forward_meta.encoder_batch_ids, - forward_meta.encoder_tile_ids_per_batch, - forward_meta.encoder_num_blocks_x_cpu, - forward_meta.kv_batch_ids, - forward_meta.kv_tile_ids_per_batch, - forward_meta.kv_num_blocks_x_cpu, - forward_meta.decoder_batch_ids, - forward_meta.decoder_tile_ids_per_batch, - forward_meta.decoder_num_blocks_cpu, - forward_meta.max_len_tensor_cpu_decoder if use_fa_do_prefill else forward_meta.max_len_tensor_cpu, - forward_meta.rotary_embs, - forward_meta.attn_mask, - layer.qkv_bias, - layer.qkv_scale, - cache_k_scales, - cache_v_scales, - getattr(layer, "cache_k_out_scale", None), - getattr(layer, "cache_v_out_scale", None), - getattr(layer, "cache_k_zp", None), - getattr(layer, "cache_v_zp", None), - layer.linear_shift, - layer.linear_smooth, - forward_meta.attn_mask_offsets, - metadata.kv_signal_data_list[layer.layer_id], - q_norm_weight, - k_norm_weight, - getattr(layer, "sinks", None), - getattr(layer, "rms_norm_eps", 1e-6), - metadata._fuse_kernel_compute_dtype, - getattr(layer, "cache_quant_type_str", "none"), - layer.use_neox_rotary_style, - self.rope_3d, - self.max_seq_len, - getattr(layer, "quant_max_bound", 0.0), - getattr(layer, "quant_min_bound", 0.0), - getattr(layer, "out_scale", -1.0), - self.encoder_block_shape_q, - self.decoder_block_shape_q, - self.max_partition_size, - self.max_seq_len, - self.speculate_max_draft_token_num + 1, - self.causal, - self.speculative_method is not None, - ) - - if use_fa_do_prefill: - merge_prefill_decode_output( - res_encoder, - res_decoder, + if envs.USE_DECODE_ATTENTION: + qkv_out = decoder_write_cache_with_rope( + qkv, + cache_k, + cache_v, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.batch_id_per_token, + forward_meta.cu_seqlens_q, + forward_meta.block_tables, + forward_meta.max_len_tensor_cpu, + forward_meta.rotary_embs, + layer.qkv_bias, + cache_k_scales, + cache_v_scales, + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + metadata.kv_signal_data_list[layer.layer_id], + q_norm_weight, + k_norm_weight, + getattr(layer, "rms_norm_eps", 1e-6), + getattr(layer, "cache_quant_type_str", "none"), + layer.use_neox_rotary_style, + self.rope_3d, + self.max_seq_len, + getattr(layer, "quant_max_bound", 0.0), + getattr(layer, "quant_min_bound", 0.0), + self.speculative_method is not None, + ) + if use_fa_do_prefill: + res_decoder = res_encoder + else: + res_decoder = paddle.empty( + [qkv.shape[0], self.num_heads * self.head_dim], + dtype=qkv.dtype, + ) + decode_append_attention( + qkv_out, + cache_k, + cache_v, + forward_meta.decode_tmp_workspace, + forward_meta.decode_tmp_m, + forward_meta.decode_tmp_d, forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, + forward_meta.batch_id_per_token, forward_meta.cu_seqlens_q, - self.num_heads, - self.head_dim, + forward_meta.block_tables, + forward_meta.decode_block_indices, + forward_meta.decode_num_blocks, + forward_meta.decode_chunk_size, + forward_meta.max_len_tensor_cpu, + forward_meta.attn_mask, + cache_k_scales, + cache_v_scales, + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + forward_meta.attn_mask_offsets, + getattr(layer, "sinks", None), + res_decoder, + getattr(layer, "cache_quant_type_str", "none"), + self.max_seq_len, + getattr(layer, "quant_max_bound", 0.0), + getattr(layer, "quant_min_bound", 0.0), self.speculate_max_draft_token_num + 1, + self.causal, ) - return res_encoder - else: return res_decoder + else: + res_decoder = append_attention( + qkv, + cache_k, + cache_v, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.batch_id_per_token, + forward_meta.cu_seqlens_q, + forward_meta.block_tables, + forward_meta.encoder_batch_ids, + forward_meta.encoder_tile_ids_per_batch, + forward_meta.encoder_num_blocks_x_cpu, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, + forward_meta.decoder_batch_ids, + forward_meta.decoder_tile_ids_per_batch, + forward_meta.decoder_num_blocks_cpu, + forward_meta.max_len_tensor_cpu_decoder if use_fa_do_prefill else forward_meta.max_len_tensor_cpu, + forward_meta.rotary_embs, + forward_meta.attn_mask, + layer.qkv_bias, + layer.qkv_scale, + cache_k_scales, + cache_v_scales, + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + layer.linear_shift, + layer.linear_smooth, + forward_meta.attn_mask_offsets, + metadata.kv_signal_data_list[layer.layer_id], + q_norm_weight, + k_norm_weight, + getattr(layer, "sinks", None), + getattr(layer, "rms_norm_eps", 1e-6), + metadata._fuse_kernel_compute_dtype, + getattr(layer, "cache_quant_type_str", "none"), + layer.use_neox_rotary_style, + self.rope_3d, + self.max_seq_len, + getattr(layer, "quant_max_bound", 0.0), + getattr(layer, "quant_min_bound", 0.0), + getattr(layer, "out_scale", -1.0), + self.encoder_block_shape_q, + self.decoder_block_shape_q, + self.max_partition_size, + self.max_seq_len, + self.speculate_max_draft_token_num + 1, + self.causal, + self.speculative_method is not None, + ) + + if use_fa_do_prefill: + merge_prefill_decode_output( + res_encoder, + res_decoder, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.cu_seqlens_q, + self.num_heads, + self.head_dim, + self.speculate_max_draft_token_num + 1, + ) + return res_encoder + else: + return res_decoder diff --git a/fastdeploy/model_executor/layers/attention/ops/__init__.py b/fastdeploy/model_executor/layers/attention/ops/__init__.py index e0175573fa3..ad0e38bc35c 100644 --- a/fastdeploy/model_executor/layers/attention/ops/__init__.py +++ b/fastdeploy/model_executor/layers/attention/ops/__init__.py @@ -15,6 +15,9 @@ """ from .append_attention import append_attention, append_attention_with_output +from .config_for_attention import config_for_attention +from .decode_append_attention import decode_append_attention +from .decoder_write_cache_with_rope import decoder_write_cache_with_rope from .flash_attn_v4 import flash_attn_v4 from .flash_mask_attention import flash_mask_attention from .get_attn_mask_q import get_attn_mask_q @@ -37,4 +40,7 @@ "flash_attn_v4", "flash_mask_attention", "get_attn_mask_q", + "config_for_attention", + "decoder_write_cache_with_rope", + "decode_append_attention", ] diff --git a/fastdeploy/model_executor/layers/attention/ops/config_for_attention.py b/fastdeploy/model_executor/layers/attention/ops/config_for_attention.py new file mode 100644 index 00000000000..d8226aad4b1 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/ops/config_for_attention.py @@ -0,0 +1,58 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle + +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import ( + config_for_attention as config_for_attention_cuda, + ) + + +def config_for_attention( + seq_lens_encoder: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + block_indices: paddle.Tensor, + num_blocks: paddle.Tensor, + chunk_size: paddle.Tensor, + max_len_tensor_cpu: paddle.Tensor, + cache_quant_type: str = "none", + group_size: int = 1, + kv_num_heads: int = 1, + max_tokens_per_batch: int = 1, +): + """ + append_attention + """ + if current_platform.is_cuda(): + config_for_attention_cuda( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + block_indices, + num_blocks, + chunk_size, + max_len_tensor_cpu, + cache_quant_type, + group_size, + kv_num_heads, + max_tokens_per_batch, + ) + else: + raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/attention/ops/decode_append_attention.py b/fastdeploy/model_executor/layers/attention/ops/decode_append_attention.py new file mode 100644 index 00000000000..ef92cc7383f --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/ops/decode_append_attention.py @@ -0,0 +1,105 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Optional + +import paddle + +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import ( + decode_append_attention as decode_append_attention_cuda, + ) + + +def decode_append_attention( + qkv: paddle.Tensor, + key_cache: paddle.Tensor, + value_cache: paddle.Tensor, + tmp_workspace: paddle.Tensor, + tmp_m: paddle.Tensor, + tmp_d: paddle.Tensor, + seq_lens_encoder: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + batch_id_per_token: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + block_tables: paddle.Tensor, + block_indices: paddle.Tensor, + num_blocks: paddle.Tensor, + chunk_size: paddle.Tensor, + set_max_lengths: paddle.Tensor, + attn_mask: Optional[paddle.Tensor] = None, + k_quant_scale: Optional[paddle.Tensor] = None, + v_quant_scale: Optional[paddle.Tensor] = None, + k_dequant_scale: Optional[paddle.Tensor] = None, + v_dequant_scale: Optional[paddle.Tensor] = None, + cache_k_zp: Optional[paddle.Tensor] = None, + cache_v_zp: Optional[paddle.Tensor] = None, + mask_offset: Optional[paddle.Tensor] = None, + sinks: Optional[paddle.Tensor] = None, + fmha_out: Optional[paddle.Tensor] = None, + cache_quant_type: str = "none", + max_input_length: int = 0, + quant_max_bound: float = 0.0, + quant_min_bound: float = 0.0, + max_tokens_per_batch: int = 1, + causal: bool = True, + sliding_window: int = 0, +) -> paddle.Tensor: + """ + append_attention + """ + if current_platform.is_cuda(): + out = decode_append_attention_cuda( + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + set_max_lengths, + attn_mask, + k_quant_scale, + v_quant_scale, + k_dequant_scale, + v_dequant_scale, + cache_k_zp, + cache_v_zp, + mask_offset, + sinks, + fmha_out, + cache_quant_type, + max_input_length, + quant_max_bound, + quant_min_bound, + max_tokens_per_batch, + causal, + sliding_window, + ) + return out + else: + raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/attention/ops/decoder_write_cache_with_rope.py b/fastdeploy/model_executor/layers/attention/ops/decoder_write_cache_with_rope.py new file mode 100644 index 00000000000..b10f6cd1bf6 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/ops/decoder_write_cache_with_rope.py @@ -0,0 +1,97 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Optional + +import paddle + +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import ( + decoder_write_cache_with_rope as decoder_write_cache_with_rope_cuda, + ) + + +def decoder_write_cache_with_rope( + qkv: paddle.Tensor, + key_cache: paddle.Tensor, + value_cache: paddle.Tensor, + seq_lens_encoder: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + batch_id_per_token: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + block_tables: paddle.Tensor, + set_max_lengths: paddle.Tensor, + rotary_embs: Optional[paddle.Tensor] = None, + qkv_bias: Optional[paddle.Tensor] = None, + k_quant_scale: Optional[paddle.Tensor] = None, + v_quant_scale: Optional[paddle.Tensor] = None, + k_dequant_scale: Optional[paddle.Tensor] = None, + v_dequant_scale: Optional[paddle.Tensor] = None, + cache_k_zp: Optional[paddle.Tensor] = None, + cache_v_zp: Optional[paddle.Tensor] = None, + kv_signal_data: Optional[paddle.Tensor] = None, + q_norm_weight: Optional[paddle.Tensor] = None, + k_norm_weight: Optional[paddle.Tensor] = None, + rms_norm_eps: float = 1e-6, + cache_quant_type: str = "none", + use_neox_rotary_style: bool = False, + rope_3d: bool = False, + max_input_length: int = 0, + quant_max_bound: float = 0.0, + quant_min_bound: float = 0.0, + speculate_decoder: bool = False, +) -> paddle.Tensor: + """ + append_attention + """ + if current_platform.is_cuda(): + qkv_out = decoder_write_cache_with_rope_cuda( + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + set_max_lengths, + rotary_embs, + qkv_bias, + k_quant_scale, + v_quant_scale, + k_dequant_scale, + v_dequant_scale, + cache_k_zp, + cache_v_zp, + kv_signal_data, + q_norm_weight, + k_norm_weight, + rms_norm_eps, + cache_quant_type, + use_neox_rotary_style, + rope_3d, + max_input_length, + quant_max_bound, + quant_min_bound, + speculate_decoder, + ) + return qkv_out + else: + raise NotImplementedError diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 2d8d310a469..6e4b081c394 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -404,6 +404,22 @@ def _initialize_attn_backend( self.target_model_inputs["kv_num_blocks_x_cpu"] ).cpu() + # Decode attention split ops buffers + if ( + "decode_block_indices" in self.target_model_inputs + and self.target_model_inputs["decode_block_indices"] is not None + ): + self.model_inputs["decode_block_indices"] = paddle.zeros_like( + self.target_model_inputs["decode_block_indices"] + ) + self.model_inputs["decode_num_blocks"] = paddle.zeros_like(self.target_model_inputs["decode_num_blocks"]) + self.model_inputs["decode_chunk_size"] = paddle.zeros_like(self.target_model_inputs["decode_chunk_size"]) + self.model_inputs["decode_tmp_workspace"] = paddle.zeros_like( + self.target_model_inputs["decode_tmp_workspace"] + ) + self.model_inputs["decode_tmp_m"] = paddle.zeros_like(self.target_model_inputs["decode_tmp_m"]) + self.model_inputs["decode_tmp_d"] = paddle.zeros_like(self.target_model_inputs["decode_tmp_d"]) + # Get the attention backend attn_cls = get_attention_backend() attn_backend = attn_cls( @@ -673,6 +689,15 @@ def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_ru attn_mask_offsets=self.model_inputs["attn_mask_offsets"] if self.use_attn_mask_offset else None, ) + # Decode attention split ops buffers (assigned after construction due to ForwardMeta __getattr__) + if "decode_block_indices" in self.model_inputs: + self.forward_meta.decode_block_indices = self.model_inputs["decode_block_indices"] + self.forward_meta.decode_num_blocks = self.model_inputs["decode_num_blocks"] + self.forward_meta.decode_chunk_size = self.model_inputs["decode_chunk_size"] + self.forward_meta.decode_tmp_workspace = self.model_inputs["decode_tmp_workspace"] + self.forward_meta.decode_tmp_m = self.model_inputs["decode_tmp_m"] + self.forward_meta.decode_tmp_d = self.model_inputs["decode_tmp_d"] + # Initialzie attention meta data for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index bca07c82170..8523d02776f 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1386,6 +1386,15 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): gpu_routing_buffer=gpu_routing_buffer, ) + # Decode attention split ops buffers (assigned after construction due to ForwardMeta __getattr__) + if "decode_block_indices" in self.share_inputs: + self.forward_meta.decode_block_indices = self.share_inputs["decode_block_indices"] + self.forward_meta.decode_num_blocks = self.share_inputs["decode_num_blocks"] + self.forward_meta.decode_chunk_size = self.share_inputs["decode_chunk_size"] + self.forward_meta.decode_tmp_workspace = self.share_inputs["decode_tmp_workspace"] + self.forward_meta.decode_tmp_m = self.share_inputs["decode_tmp_m"] + self.forward_meta.decode_tmp_d = self.share_inputs["decode_tmp_d"] + dist_status = self.collect_distributed_status() if_only_decode = dist_status.if_only_decode @@ -1614,6 +1623,8 @@ def _initialize_attn_backend(self) -> None: num_heads=num_heads, kv_num_heads=self.model_config.kv_num_heads, block_size=self.fd_config.cache_config.block_size, + head_dim=head_dim, + dtype=self.model_config.dtype, ) self.share_inputs.update(res_buffer) diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index f47c7bccc6d..db31d242046 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -206,6 +206,13 @@ def init_share_inputs(self): self.kv_batch_ids = None self.kv_tile_ids_per_batch = None self.kv_num_blocks_x_cpu = None # CPU + # Decode attention split ops buffers (initialized by _initialize_attn_backend) + self.decode_block_indices = None + self.decode_num_blocks = None + self.decode_chunk_size = None + self.decode_tmp_workspace = None + self.decode_tmp_m = None + self.decode_tmp_d = None # Initialize thinking related buffers self.enable_thinking = paddle.full(shape=[max_num_seqs, 1], fill_value=True, dtype="bool") @@ -810,6 +817,13 @@ def init_share_inputs(self): self.kv_batch_ids = None self.kv_tile_ids_per_batch = None self.kv_num_blocks_x_cpu = None # CPU + # Decode attention split ops buffers + self.decode_block_indices = None + self.decode_num_blocks = None + self.decode_chunk_size = None + self.decode_tmp_workspace = None + self.decode_tmp_m = None + self.decode_tmp_d = None # Input tokens self.draft_tokens = paddle.full( diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index 7e721107f1e..5e391ee9f1b 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -1452,6 +1452,8 @@ def _initialize_attn_backend(self) -> None: num_heads=num_heads, kv_num_heads=self.model_config.kv_num_heads, block_size=self.fd_config.cache_config.block_size, + head_dim=head_dim, + dtype=self.model_config.dtype, ) self.share_inputs.update(res_buffer) diff --git a/tests/operators/attention/benchmark_decode_attention.py b/tests/operators/attention/benchmark_decode_attention.py new file mode 100644 index 00000000000..3afbff3b177 --- /dev/null +++ b/tests/operators/attention/benchmark_decode_attention.py @@ -0,0 +1,853 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmark script comparing append_attention vs decode_append_attention (C16) performance. + +Each case runs append_attention once and decode_attention once, prints elapsed time. +Supports --op flag to run only one op (for ncu profiling). + +Usage: + python benchmark_decode_attention.py # run all cases, both ops + python benchmark_decode_attention.py --op append # run only append_attention + python benchmark_decode_attention.py --op decode # run only decode_attention + python benchmark_decode_attention.py --case 0 # run only case index 0 +""" + +import argparse +import copy +import time + +import numpy as np +import paddle + +from fastdeploy.model_executor.layers.attention.ops import ( + append_attention as append_attention_op, +) +from fastdeploy.model_executor.layers.attention.ops import ( + config_for_attention, + decode_append_attention, + decoder_write_cache_with_rope, + get_block_shape_and_split_kv_block, +) + +seed = 1000 +np.random.seed(seed) +paddle.seed(seed) + + +class RopeEmbedding: + def __init__(self, use_neox_rotary_style=False): + self.use_neox_rotary_style = use_neox_rotary_style + self.base = 10000 + + def get_rotary_position_embedding(self, position_ids, head_dim): + bsz, max_seq_len = position_ids.shape[:2] + rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim // 2), dtype="float32") + inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim) + freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, head_dim // 2)) + emb = paddle.unsqueeze(emb, 2) + rot_emb[0] = paddle.cos(emb) + rot_emb[1] = paddle.sin(emb) + return rot_emb + + +def get_padding_offset(bsz, seq_lens_this_time): + token_num = int(paddle.sum(seq_lens_this_time).item()) + seq_lens_list = seq_lens_this_time.numpy().tolist() + cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32") + cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32") + batch_id_per_token = np.zeros(token_num, dtype="int32") + offset = 0 + for i in range(bsz): + sl = int(seq_lens_list[i]) + batch_id_per_token[offset : offset + sl] = i + offset += sl + cu_seqlens_q[i + 1] = offset + cu_seqlens_k[i + 1] = offset + return paddle.to_tensor(batch_id_per_token, dtype="int32"), cu_seqlens_q, cu_seqlens_k + + +def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, head_dim, place, dtype): + query = np.random.random([bs, q_num_head, seq_len, head_dim]) + q = paddle.to_tensor(query, place=place, dtype=dtype, stop_gradient=False) - 0.5 + key = np.random.random([bs, kv_num_head, seq_len, head_dim]) + k = paddle.to_tensor(key, place=place, dtype=dtype, stop_gradient=False) - 0.5 + value = np.random.random([bs, kv_num_head, seq_len, head_dim]) + v = paddle.to_tensor(value, place=place, dtype=dtype, stop_gradient=False) - 0.5 + token_num = bs * seq_len + qkv = paddle.concat( + [ + q.transpose([0, 2, 1, 3]).reshape([token_num, q_num_head * head_dim]), + k.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * head_dim]), + v.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * head_dim]), + ], + axis=1, + ).reshape([token_num, -1]) + return q, k, v, qkv + + +def build_block_tables(batch_size, max_model_len, block_size): + block_num_per_seq = (max_model_len + block_size - 1) // block_size + max_block_num = block_num_per_seq * batch_size + # Assign each batch a contiguous range of blocks (descending order to match test) + block_ids = np.arange(max_block_num - 1, -1, -1, dtype="int32").reshape(batch_size, block_num_per_seq) + block_tables = paddle.to_tensor(block_ids, dtype="int32") + return block_tables, max_block_num + + +def build_append_attention_buffers(batch_size, max_model_len, group_size, block_size): + max_num_block_dec = batch_size * (max_model_len * group_size + 16 - 1) // 16 + decoder_batch_ids = paddle.full([max_num_block_dec], 0, dtype="int32") + decoder_tile_ids_per_batch = paddle.full([max_num_block_dec], 0, dtype="int32") + decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") + decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") + max_num_block = batch_size * (max_model_len * group_size + 64 - 1) // 64 + encoder_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + encoder_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + encoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + kv_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + kv_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + max_len_tensor_cpu = paddle.full([6], 0, dtype="int32").cpu() + return { + "decoder_batch_ids": decoder_batch_ids, + "decoder_tile_ids_per_batch": decoder_tile_ids_per_batch, + "decoder_num_blocks_cpu": decoder_num_blocks_cpu, + "decoder_num_blocks_device": decoder_num_blocks_device, + "decoder_chunk_size_device": decoder_chunk_size_device, + "encoder_batch_ids": encoder_batch_ids, + "encoder_tile_ids_per_batch": encoder_tile_ids_per_batch, + "encoder_num_blocks_cpu": encoder_num_blocks_cpu, + "kv_batch_ids": kv_batch_ids, + "kv_tile_ids_per_batch": kv_tile_ids_per_batch, + "kv_num_blocks_x_cpu": kv_num_blocks_x_cpu, + "max_len_tensor_cpu": max_len_tensor_cpu, + } + + +def build_decode_attention_buffers( + batch_size, max_model_len, kv_num_head, q_num_head, head_dim, max_tokens_per_batch, group_size, dtype +): + buffer = {} + min_chunk_size = 128 + max_num_chunk = (max_model_len + min_chunk_size - 1) // min_chunk_size + q_tile_size = 16 if max_tokens_per_batch * group_size <= 16 else 32 + q_tile_num = (max_tokens_per_batch * group_size + q_tile_size - 1) // q_tile_size + buffer["max_len_tensor_cpu"] = paddle.full([6], 0, dtype="int32").cpu() + buffer["block_indices"] = paddle.full([batch_size * kv_num_head * max_num_chunk * q_tile_num, 4], 0, dtype="int32") + buffer["num_blocks"] = paddle.full([1], 0, dtype="int32") + buffer["chunk_size"] = paddle.full([1], 0, dtype="int32") + buffer["tmp_workspace"] = paddle.full( + [batch_size * max_tokens_per_batch, max_num_chunk, q_num_head * head_dim], + 0, + dtype=dtype, + ) + buffer["tmp_m"] = paddle.full([batch_size * max_tokens_per_batch, max_num_chunk, q_num_head], 0, dtype="float32") + buffer["tmp_d"] = paddle.full([batch_size * max_tokens_per_batch, max_num_chunk, q_num_head], 0, dtype="float32") + return buffer + + +class BenchmarkCase: + def __init__( + self, + name, + batch_size, + q_num_head, + kv_num_head, + head_dim, + seq_len, + max_model_len, + dtype="bfloat16", + max_tokens_per_batch=1, + block_size=64, + causal=True, + ): + self.name = name + self.batch_size = batch_size + self.q_num_head = q_num_head + self.kv_num_head = kv_num_head + self.head_dim = head_dim + self.seq_len = seq_len + self.max_model_len = max_model_len + self.dtype = dtype + self.max_tokens_per_batch = max_tokens_per_batch + self.block_size = block_size + self.causal = causal + self.group_size = q_num_head // kv_num_head + self.cache_quant_type = "none" + + def short_name(self): + return self.name + + +CASES = [ + BenchmarkCase( + "bs1_seq64", + batch_size=1, + q_num_head=12, + kv_num_head=1, + head_dim=128, + seq_len=64, + max_model_len=22528, + max_tokens_per_batch=1, + ), + BenchmarkCase( + "bs1_seq512", + batch_size=1, + q_num_head=12, + kv_num_head=1, + head_dim=128, + seq_len=512, + max_model_len=22528, + max_tokens_per_batch=1, + ), + BenchmarkCase( + "bs1_seq2048", + batch_size=1, + q_num_head=12, + kv_num_head=1, + head_dim=128, + seq_len=2048, + max_model_len=22528, + max_tokens_per_batch=1, + ), + BenchmarkCase( + "bs1_seq4096", + batch_size=1, + q_num_head=12, + kv_num_head=1, + head_dim=128, + seq_len=4096, + max_model_len=22528, + max_tokens_per_batch=1, + ), + BenchmarkCase( + "bs1_seq8192", + batch_size=1, + q_num_head=12, + kv_num_head=1, + head_dim=128, + seq_len=8192, + max_model_len=22528, + max_tokens_per_batch=1, + ), + BenchmarkCase( + "bs16_seq64", + batch_size=16, + q_num_head=12, + kv_num_head=1, + head_dim=128, + seq_len=64, + max_model_len=22528, + max_tokens_per_batch=1, + ), + BenchmarkCase( + "bs16_seq512", + batch_size=16, + q_num_head=12, + kv_num_head=1, + head_dim=128, + seq_len=512, + max_model_len=22528, + max_tokens_per_batch=1, + ), + BenchmarkCase( + "bs16_seq2048", + batch_size=16, + q_num_head=12, + kv_num_head=1, + head_dim=128, + seq_len=2048, + max_model_len=22528, + max_tokens_per_batch=1, + ), + BenchmarkCase( + "bs16_seq4096", + batch_size=16, + q_num_head=12, + kv_num_head=1, + head_dim=128, + seq_len=4096, + max_model_len=22528, + max_tokens_per_batch=1, + ), + BenchmarkCase( + "bs16_seq8192", + batch_size=16, + q_num_head=12, + kv_num_head=1, + head_dim=128, + seq_len=8192, + max_model_len=22528, + max_tokens_per_batch=1, + ), + BenchmarkCase( + "bs128_seq64", + batch_size=128, + q_num_head=12, + kv_num_head=1, + head_dim=128, + seq_len=64, + max_model_len=22528, + max_tokens_per_batch=1, + ), + BenchmarkCase( + "bs128_seq512", + batch_size=128, + q_num_head=12, + kv_num_head=1, + head_dim=128, + seq_len=512, + max_model_len=22528, + max_tokens_per_batch=1, + ), + BenchmarkCase( + "bs128_seq2048", + batch_size=128, + q_num_head=12, + kv_num_head=1, + head_dim=128, + seq_len=2048, + max_model_len=22528, + max_tokens_per_batch=1, + ), + BenchmarkCase( + "bs128_seq4096", + batch_size=128, + q_num_head=12, + kv_num_head=1, + head_dim=128, + seq_len=4096, + max_model_len=22528, + max_tokens_per_batch=1, + ), + BenchmarkCase( + "bs128_seq8192", + batch_size=128, + q_num_head=12, + kv_num_head=1, + head_dim=128, + seq_len=8192, + max_model_len=22528, + max_tokens_per_batch=1, + ), + BenchmarkCase( + "bs256_seq64", + batch_size=256, + q_num_head=12, + kv_num_head=1, + head_dim=128, + seq_len=64, + max_model_len=22528, + max_tokens_per_batch=1, + ), + BenchmarkCase( + "bs256_seq512", + batch_size=256, + q_num_head=12, + kv_num_head=1, + head_dim=128, + seq_len=512, + max_model_len=22528, + max_tokens_per_batch=1, + ), + BenchmarkCase( + "bs256_seq2048", + batch_size=256, + q_num_head=12, + kv_num_head=1, + head_dim=128, + seq_len=2048, + max_model_len=22528, + max_tokens_per_batch=1, + ), + BenchmarkCase( + "bs256_seq4096", + batch_size=256, + q_num_head=12, + kv_num_head=1, + head_dim=128, + seq_len=4096, + max_model_len=22528, + max_tokens_per_batch=1, + ), + BenchmarkCase( + "bs256_seq8192", + batch_size=256, + q_num_head=12, + kv_num_head=1, + head_dim=128, + seq_len=8192, + max_model_len=22528, + max_tokens_per_batch=1, + ), + # BenchmarkCase("bs1_seq64_spec", batch_size=1, q_num_head=12, kv_num_head=1, head_dim=128, seq_len=64, max_model_len=22528, max_tokens_per_batch=2), + # BenchmarkCase("bs1_seq512_spec", batch_size=1, q_num_head=12, kv_num_head=1, head_dim=128, seq_len=512, max_model_len=22528, max_tokens_per_batch=2), + # BenchmarkCase("bs1_seq2048_spec", batch_size=1, q_num_head=12, kv_num_head=1, head_dim=128, seq_len=2048, max_model_len=22528, max_tokens_per_batch=2), + # BenchmarkCase("bs1_seq4096_spec", batch_size=1, q_num_head=12, kv_num_head=1, head_dim=128, seq_len=4096, max_model_len=22528, max_tokens_per_batch=2), + # BenchmarkCase("bs1_seq8192_spec", batch_size=1, q_num_head=12, kv_num_head=1, head_dim=128, seq_len=8192, max_model_len=22528, max_tokens_per_batch=2), + # BenchmarkCase("bs16_seq64_spec", batch_size=16, q_num_head=12, kv_num_head=1, head_dim=128, seq_len=64, max_model_len=22528, max_tokens_per_batch=2), + # BenchmarkCase("bs16_seq512_spec", batch_size=16, q_num_head=12, kv_num_head=1, head_dim=128, seq_len=512, max_model_len=22528, max_tokens_per_batch=2), + # BenchmarkCase("bs16_seq2048_spec", batch_size=16, q_num_head=12, kv_num_head=1, head_dim=128, seq_len=2048, max_model_len=22528, max_tokens_per_batch=2), + # BenchmarkCase("bs16_seq4096_spec", batch_size=16, q_num_head=12, kv_num_head=1, head_dim=128, seq_len=4096, max_model_len=22528, max_tokens_per_batch=2), + # BenchmarkCase("bs16_seq8192_spec", batch_size=16, q_num_head=12, kv_num_head=1, head_dim=128, seq_len=8192, max_model_len=22528, max_tokens_per_batch=2), + # BenchmarkCase("bs128_seq64_spec", batch_size=128, q_num_head=12, kv_num_head=1, head_dim=128, seq_len=64, max_model_len=22528, max_tokens_per_batch=2), + # BenchmarkCase("bs128_seq512_spec", batch_size=128, q_num_head=12, kv_num_head=1, head_dim=128, seq_len=512, max_model_len=22528, max_tokens_per_batch=2), + # BenchmarkCase("bs128_seq2048_spec", batch_size=128, q_num_head=12, kv_num_head=1, head_dim=128, seq_len=2048, max_model_len=22528, max_tokens_per_batch=2), + # BenchmarkCase("bs128_seq4096_spec", batch_size=128, q_num_head=12, kv_num_head=1, head_dim=128, seq_len=4096, max_model_len=22528, max_tokens_per_batch=2), + # BenchmarkCase("bs128_seq8192_spec", batch_size=128, q_num_head=12, kv_num_head=1, head_dim=128, seq_len=8192, max_model_len=22528, max_tokens_per_batch=2), + # BenchmarkCase("bs256_seq64_spec", batch_size=256, q_num_head=12, kv_num_head=1, head_dim=128, seq_len=64, max_model_len=22528, max_tokens_per_batch=2), + # BenchmarkCase("bs256_seq512_spec", batch_size=256, q_num_head=12, kv_num_head=1, head_dim=128, seq_len=512, max_model_len=22528, max_tokens_per_batch=2), + # BenchmarkCase("bs256_seq2048_spec", batch_size=256, q_num_head=12, kv_num_head=1, head_dim=128, seq_len=2048, max_model_len=22528, max_tokens_per_batch=2), + # BenchmarkCase("bs256_seq4096_spec", batch_size=256, q_num_head=12, kv_num_head=1, head_dim=128, seq_len=4096, max_model_len=22528, max_tokens_per_batch=2), + # BenchmarkCase("bs256_seq8192_spec", batch_size=256, q_num_head=12, kv_num_head=1, head_dim=128, seq_len=8192, max_model_len=22528, max_tokens_per_batch=2), +] + + +def do_prefill(case, block_tables, rotary_embs, place): + """Run prefill and return cache_k, cache_v after prefill.""" + max_block_num = block_tables.shape[0] * block_tables.shape[1] + cache_shape = (max_block_num, case.kv_num_head, case.block_size, case.head_dim) + cache_k = paddle.zeros(shape=cache_shape, dtype=case.dtype) + cache_v = paddle.zeros(shape=cache_shape, dtype=case.dtype) + + _, _, _, enc_qkv = get_qkv_and_qkv_concat_tensor( + case.batch_size, + case.q_num_head, + case.kv_num_head, + case.seq_len, + case.head_dim, + place, + case.dtype, + ) + + enc_seq_lens_encoder = paddle.to_tensor([case.seq_len] * case.batch_size, "int32") + enc_seq_lens_decoder = paddle.to_tensor([0] * case.batch_size, "int32") + enc_seq_lens_this_time = copy.deepcopy(enc_seq_lens_encoder) + enc_batch_id_per_token, enc_cu_seqlens_q, _ = get_padding_offset(case.batch_size, enc_seq_lens_this_time) + + buffers = build_append_attention_buffers(case.batch_size, case.max_model_len, case.group_size, case.block_size) + get_block_shape_and_split_kv_block( + enc_seq_lens_encoder, + enc_seq_lens_decoder, + enc_seq_lens_this_time, + buffers["decoder_batch_ids"], + buffers["decoder_tile_ids_per_batch"], + buffers["decoder_num_blocks_cpu"], + buffers["decoder_num_blocks_device"], + buffers["decoder_chunk_size_device"], + buffers["max_len_tensor_cpu"], + buffers["encoder_batch_ids"], + buffers["encoder_tile_ids_per_batch"], + buffers["encoder_num_blocks_cpu"], + buffers["kv_batch_ids"], + buffers["kv_tile_ids_per_batch"], + buffers["kv_num_blocks_x_cpu"], + 64, + 16, + case.group_size, + case.block_size, + ) + + append_attention_op( + enc_qkv, + cache_k, + cache_v, + enc_seq_lens_encoder, + enc_seq_lens_decoder, + enc_seq_lens_this_time, + enc_batch_id_per_token, + enc_cu_seqlens_q, + block_tables, + buffers["encoder_batch_ids"], + buffers["encoder_tile_ids_per_batch"], + buffers["encoder_num_blocks_cpu"], + buffers["kv_batch_ids"], + buffers["kv_tile_ids_per_batch"], + buffers["kv_num_blocks_x_cpu"], + buffers["decoder_batch_ids"], + buffers["decoder_tile_ids_per_batch"], + buffers["decoder_num_blocks_cpu"], + buffers["max_len_tensor_cpu"], + rotary_embs, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + 1e-6, + "bf16", + case.cache_quant_type, + False, + False, + case.max_model_len, + 0.0, + 0.0, + -1, + 64, + 16, + 1024, + 22528, + case.max_tokens_per_batch, + case.causal, + case.max_tokens_per_batch > 1, + ) + return cache_k, cache_v + + +def get_decode_inputs(case, place): + """Return decode qkv and seq_lens tensors.""" + _, _, _, dec_qkv = get_qkv_and_qkv_concat_tensor( + case.batch_size, + case.q_num_head, + case.kv_num_head, + case.max_tokens_per_batch, + case.head_dim, + place, + case.dtype, + ) + dec_seq_lens_encoder = paddle.to_tensor([0] * case.batch_size, "int32") + dec_seq_lens_decoder = paddle.to_tensor([case.seq_len] * case.batch_size, "int32") + dec_seq_lens_this_time = paddle.to_tensor([case.max_tokens_per_batch] * case.batch_size, "int32") + dec_batch_id_per_token, dec_cu_seqlens_q, _ = get_padding_offset(case.batch_size, dec_seq_lens_this_time) + return ( + dec_qkv, + dec_seq_lens_encoder, + dec_seq_lens_decoder, + dec_seq_lens_this_time, + dec_batch_id_per_token, + dec_cu_seqlens_q, + ) + + +def run_append_attention( + case, + cache_k, + cache_v, + dec_qkv, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + rotary_embs, +): + buffers = build_append_attention_buffers(case.batch_size, case.max_model_len, case.group_size, case.block_size) + get_block_shape_and_split_kv_block( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + buffers["decoder_batch_ids"], + buffers["decoder_tile_ids_per_batch"], + buffers["decoder_num_blocks_cpu"], + buffers["decoder_num_blocks_device"], + buffers["decoder_chunk_size_device"], + buffers["max_len_tensor_cpu"], + buffers["encoder_batch_ids"], + buffers["encoder_tile_ids_per_batch"], + buffers["encoder_num_blocks_cpu"], + buffers["kv_batch_ids"], + buffers["kv_tile_ids_per_batch"], + buffers["kv_num_blocks_x_cpu"], + 64, + 16, + case.group_size, + case.block_size, + ) + qkv_copy = copy.deepcopy(dec_qkv) + cache_k_copy = copy.deepcopy(cache_k) + cache_v_copy = copy.deepcopy(cache_v) + out = append_attention_op( + qkv_copy, + cache_k_copy, + cache_v_copy, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + buffers["encoder_batch_ids"], + buffers["encoder_tile_ids_per_batch"], + buffers["encoder_num_blocks_cpu"], + buffers["kv_batch_ids"], + buffers["kv_tile_ids_per_batch"], + buffers["kv_num_blocks_x_cpu"], + buffers["decoder_batch_ids"], + buffers["decoder_tile_ids_per_batch"], + buffers["decoder_num_blocks_cpu"], + buffers["max_len_tensor_cpu"], + rotary_embs, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + 1e-6, + "bf16", + case.cache_quant_type, + False, + False, + case.max_model_len, + 0.0, + 0.0, + -1, + 64, + 16, + 1024, + 22528, + case.max_tokens_per_batch, + case.causal, + case.max_tokens_per_batch > 1, + ) + return out + + +def run_decode_attention( + case, + cache_k, + cache_v, + dec_qkv, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + rotary_embs, +): + buffer = build_decode_attention_buffers( + case.batch_size, + case.max_model_len, + case.kv_num_head, + case.q_num_head, + case.head_dim, + case.max_tokens_per_batch, + case.group_size, + case.dtype, + ) + config_for_attention( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + buffer["block_indices"], + buffer["num_blocks"], + buffer["chunk_size"], + buffer["max_len_tensor_cpu"], + case.cache_quant_type, + case.group_size, + case.kv_num_head, + case.max_tokens_per_batch, + ) + dec_cache_k = copy.deepcopy(cache_k) + dec_cache_v = copy.deepcopy(cache_v) + dec_qkv_copy = copy.deepcopy(dec_qkv) + decoder_write_cache_with_rope( + dec_qkv_copy, + dec_cache_k, + dec_cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + buffer["max_len_tensor_cpu"], + rotary_embs, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + 1e-6, + case.cache_quant_type, + False, + False, + case.max_model_len, + 0.0, + 0.0, + case.max_tokens_per_batch > 1, + ) + out = decode_append_attention( + dec_qkv_copy, + dec_cache_k, + dec_cache_v, + buffer["tmp_workspace"], + buffer["tmp_m"], + buffer["tmp_d"], + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + buffer["block_indices"], + buffer["num_blocks"], + buffer["chunk_size"], + buffer["max_len_tensor_cpu"], + None, + None, + None, + None, + None, + None, + None, + None, + None, + paddle.empty([dec_qkv_copy.shape[0], case.q_num_head * case.head_dim], dtype=dec_qkv_copy.dtype), # fmha_out + case.cache_quant_type, + case.max_model_len, + 0.0, + 0.0, + case.max_tokens_per_batch, + case.causal, + ) + return out + + +# Cache rope embeddings keyed by max_model_len to avoid recomputation +_rope_cache = {} + + +def _get_rotary_embs(max_model_len, head_dim): + key = (max_model_len, head_dim) + if key not in _rope_cache: + rope = RopeEmbedding() + tmp_position_ids = paddle.arange(max_model_len).reshape((1, -1)) + _rope_cache[key] = rope.get_rotary_position_embedding(tmp_position_ids, head_dim) + return _rope_cache[key] + + +def benchmark_case(case, op="both"): + """Run a single case: prefill once, then run decode op(s) once each.""" + paddle.disable_static() + place = paddle.CUDAPlace(0) + + rotary_embs = _get_rotary_embs(case.max_model_len, case.head_dim) + + block_tables, max_block_num = build_block_tables(case.batch_size, case.max_model_len, case.block_size) + + # Prefill + cache_k, cache_v = do_prefill(case, block_tables, rotary_embs, place) + + # Decode inputs + dec_qkv, dec_sle, dec_sld, dec_slt, dec_bid, dec_csq = get_decode_inputs(case, place) + + results = {} + + if op in ("both", "append"): + paddle.device.cuda.synchronize() + t0 = time.perf_counter() + run_append_attention( + case, + cache_k, + cache_v, + dec_qkv, + dec_sle, + dec_sld, + dec_slt, + dec_bid, + dec_csq, + block_tables, + rotary_embs, + ) + paddle.device.cuda.synchronize() + results["append"] = (time.perf_counter() - t0) * 1000 + + if op in ("both", "decode"): + paddle.device.cuda.synchronize() + t0 = time.perf_counter() + run_decode_attention( + case, + cache_k, + cache_v, + dec_qkv, + dec_sle, + dec_sld, + dec_slt, + dec_bid, + dec_csq, + block_tables, + rotary_embs, + ) + paddle.device.cuda.synchronize() + results["decode"] = (time.perf_counter() - t0) * 1000 + + return results + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark append_attention vs decode_append_attention") + parser.add_argument( + "--op", choices=["both", "append", "decode"], default="both", help="Which op to run (default: both)" + ) + parser.add_argument( + "--case", type=int, default=-1, help="Run only case by index (0-based). Default: run all cases." + ) + args = parser.parse_args() + + cases = CASES if args.case < 0 else [CASES[args.case]] + + if args.op == "both": + print(f"{'Case':<25} {'append_attn (ms)':>18} {'decode_attn (ms)':>18} {'Ratio':>10}") + elif args.op == "append": + print(f"{'Case':<25} {'append_attn (ms)':>18}") + else: + print(f"{'Case':<25} {'decode_attn (ms)':>18}") + print("-" * 75) + + for case in cases: + results = benchmark_case(case, op=args.op) + if args.op == "both": + a, d = results["append"], results["decode"] + ratio = a / d if d > 0 else float("inf") + print(f"{case.short_name():<25} {a:>18.3f} {d:>18.3f} {ratio:>10.2f}x") + elif args.op == "append": + print(f"{case.short_name():<25} {results['append']:>18.3f}") + else: + print(f"{case.short_name():<25} {results['decode']:>18.3f}") + + +if __name__ == "__main__": + main() diff --git a/tests/operators/attention/ncu.sh b/tests/operators/attention/ncu.sh new file mode 100644 index 00000000000..7c42a9ba31d --- /dev/null +++ b/tests/operators/attention/ncu.sh @@ -0,0 +1,5 @@ + + +ncu --clock-control=reset + +ncu --target-processes all --set full -f -o attn_v18_all -k regex:"multi_query_append_attention_warp1_4_kernel|decode_append_attention_c16_kernel" python tests/operators/attention/benchmark_decode_attention.py --op both diff --git a/tests/operators/attention/test_decode_append_attention.py b/tests/operators/attention/test_decode_append_attention.py new file mode 100644 index 00000000000..07619de2aa2 --- /dev/null +++ b/tests/operators/attention/test_decode_append_attention.py @@ -0,0 +1,979 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import random +import unittest + +import numpy as np +import paddle +from paddle.incubate.nn.functional import fused_rms_norm + +from fastdeploy.model_executor.layers.attention.ops import ( + append_attention, + config_for_attention, + decode_append_attention, + decoder_write_cache_with_rope, + get_block_shape_and_split_kv_block, + gqa_rope_write_cache, + pre_cache_len_concat, +) + +seed = 1000 + +random.seed(seed) +np.random.seed(seed) +paddle.seed(seed) + + +class RopeEmbedding: + def __init__(self, use_neox_rotary_style=False): + self.use_neox_rotary_style = use_neox_rotary_style + self.base = 10000 + + def get_neox_style_position_embedding(self, position_ids, head_dim): + bsz, max_seq_len = position_ids.shape[:2] + rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim), dtype="float32") + inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim) + + # shape: [B, S, D/2] + freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + # shape: [B, S, 1, D] + emb = paddle.concat([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, 1, head_dim)) + + rot_emb[0] = paddle.cos(emb) + rot_emb[1] = paddle.sin(emb) + return rot_emb + + def get_rotary_position_embedding(self, position_ids, head_dim): + bsz, max_seq_len = position_ids.shape[:2] + rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim // 2), dtype="float32") + inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim) + + # shape: [B, S, D/2] + freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + # shape: [B, S, D/2] + emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, head_dim // 2)) + # shape: [B, S, 1, D/2] + emb = paddle.unsqueeze(emb, 2) + + rot_emb[0] = paddle.cos(emb) + rot_emb[1] = paddle.sin(emb) + return rot_emb + + def _apply_rope(self, rotary_emb, q, k, cache_len): + # sin [sequence_length, embed_size_per_head//2] + # cos [sequence_length, embed_size_per_head//2] + # sin, cos = paddle.chunk(rp, 2, axis=-1) + seq, head_dim = q.shape[2], q.shape[3] + cos, sin = paddle.chunk(rotary_emb, 2, axis=0) + cos = cos[:, :, cache_len : cache_len + seq, ...] + sin = sin[:, :, cache_len : cache_len + seq, ...] + cos = paddle.squeeze(cos, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :] + sin = paddle.squeeze(sin, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :] + # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + + if self.use_neox_rotary_style: + sin_pos = sin + cos_pos = cos + # NeoX Stype:前后半部分分块旋转 + rotate_half_q = paddle.reshape( + paddle.concat( + [ + -q[:, :, :, q.shape[-1] // 2 :], + q[:, :, :, : q.shape[-1] // 2], + ], + axis=-1, + ), + paddle.shape(q), + ) + rotate_half_k = paddle.reshape( + paddle.concat( + [ + -k[:, :, :, k.shape[-1] // 2 :], + k[:, :, :, : k.shape[-1] // 2], + ], + axis=-1, + ), + paddle.shape(k), + ) + else: + sin_pos = paddle.reshape(paddle.stack([sin, sin], axis=-1), [1, 1, seq, head_dim]) + # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] + cos_pos = paddle.reshape(paddle.stack([cos, cos], axis=-1), [1, 1, seq, head_dim]) + # GPT Stype:奇偶位置分块旋转 + rotate_half_q = paddle.reshape( + paddle.stack([-q[:, :, :, 1::2], q[:, :, :, 0::2]], axis=-1), + paddle.shape(q), + ) + rotate_half_k = paddle.reshape( + paddle.stack([-k[:, :, :, 1::2], k[:, :, :, 0::2]], axis=-1), + paddle.shape(k), + ) + + query = paddle.add(paddle.multiply(q, cos_pos), paddle.multiply(rotate_half_q, sin_pos)) + + key = paddle.add(paddle.multiply(k, cos_pos), paddle.multiply(rotate_half_k, sin_pos)) + + return paddle.cast(query, q.dtype), paddle.cast(key, k.dtype) + + +def create_attn_mask(mask_type, batch_size, seq_lens, pre_cache_length=0, sliding_window=0): + max_seq_len = max(seq_lens) + mask = paddle.zeros( + # [batch_size, 1, max_seq_len, max_seq_len + pre_cache_length], + [batch_size, 1, max_seq_len, max_seq_len], + dtype=mask_type, + ) + mask[:, :, :, :pre_cache_length] = 1 + for i in range(batch_size): + seq_len = seq_lens[i] + ones_tensor = paddle.ones(shape=(seq_len, seq_len), dtype=mask_type) + if sliding_window <= 0: + mask[i, 0, :seq_len, :seq_len] = (paddle.tril(ones_tensor) - 1) * 1e4 + else: + tmp_triu = paddle.triu(ones_tensor, -(sliding_window - 1)) + mask[i, 0, :seq_len, :seq_len] = (paddle.tril(ones_tensor) * tmp_triu - 1) * 1e4 + return mask + + +def naive_attention_impl( + query, + key, + value, + pre_key=None, + pre_value=None, + mask=None, + scale=1.0, + cache_k_dequant_scales=None, + cache_v_dequant_scales=None, + use_cachekv_int8="None", + q_norm_weight=None, + k_norm_weight=None, + sinks=None, +): + batch = query.shape[0] + heads = query.shape[1] + seq_len = query.shape[2] + head_dim = query.shape[3] + kv_head = key.shape[1] + + key = key.reshape([batch, kv_head, 1, seq_len, head_dim]) + key = paddle.tile(key, [1, 1, heads // kv_head, 1, 1]) + key = key.reshape([batch, heads, seq_len, head_dim]) + + if pre_key is not None: + pre_key = pre_key.reshape([batch, kv_head, 1, -1, head_dim]) + pre_key = paddle.tile(pre_key, [1, 1, heads // kv_head, 1, 1]) + pre_key = pre_key.reshape([batch, heads, -1, head_dim]) + key = paddle.concat([pre_key, key], axis=2) + + value = value.reshape([batch, kv_head, 1, seq_len, head_dim]) + value = paddle.tile(value, [1, 1, heads // kv_head, 1, 1]) + value = value.reshape([batch, heads, seq_len, head_dim]) + + if pre_value is not None: + pre_value = pre_value.reshape([batch, kv_head, 1, -1, head_dim]) + pre_value = paddle.tile(pre_value, [1, 1, heads // kv_head, 1, 1]) + pre_value = pre_value.reshape([batch, heads, -1, head_dim]) + value = paddle.concat([pre_value, value], axis=2) + + qk_res = paddle.matmul(query, key, transpose_y=True) + attention = qk_res * scale + if mask is not None: + attention = attention + mask + + if sinks is not None: + kv_len = attention.shape[-1] + sinks_tiled = sinks.unsqueeze([0, 2, 3]).expand([batch, heads, seq_len, 1]) + attention = paddle.concat([attention, sinks_tiled], axis=-1) + softmax_result = paddle.nn.functional.softmax(attention, -1)[:, :, :, :kv_len] + else: + softmax_result = paddle.nn.functional.softmax(attention, -1) + result = paddle.matmul(paddle.cast(softmax_result, dtype=value.dtype), value) + return result + + +def get_padding_offset(bsz, seq_lens_this_time): + token_num = paddle.sum(seq_lens_this_time) + batch_id_per_token = paddle.zeros(shape=(token_num), dtype="int32") + cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32") + cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32") + index = 0 + for i in range(bsz): + seq_len_now = seq_lens_this_time[i].item() + for j in range(seq_len_now): + batch_id_per_token[index] = i + index += 1 + cu_seqlens_q[i + 1] = index + cu_seqlens_k[i + 1] = index + return batch_id_per_token, cu_seqlens_q, cu_seqlens_k + + +def remove_padding(seq_lens, cu_seq_lens, inputs, token_num): + bsz, num_head, seq_len, head_dim = inputs.shape + output = paddle.zeros(shape=[token_num, num_head * head_dim], dtype=inputs.dtype) + inputs = inputs.transpose([0, 2, 1, 3]).reshape([bsz, seq_len, -1]) + for i in range(bsz): + seq_len_now = seq_lens[i] + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + output[start_idx:end_idx, :] = inputs[i, :seq_len_now, :] + return output + + +def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, head_dim, place, dtype): + query = np.random.random([bs, q_num_head, seq_len, head_dim]) + q = paddle.to_tensor(query, place=place, dtype=dtype, stop_gradient=False) - 0.5 + key = np.random.random([bs, kv_num_head, seq_len, head_dim]) + k = paddle.to_tensor(key, place=place, dtype=dtype, stop_gradient=False) - 0.5 + value = np.random.random([bs, kv_num_head, seq_len, head_dim]) + v = paddle.to_tensor(value, place=place, dtype=dtype, stop_gradient=False) - 0.5 + token_num = bs * seq_len + + qkv = paddle.concat( + [ + q.transpose([0, 2, 1, 3]).reshape([token_num, q_num_head * head_dim]), + k.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * head_dim]), + v.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * head_dim]), + ], + axis=1, + ).reshape([token_num, -1]) + return q, k, v, qkv + + +class TestDecodeAppendAttention(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeAppendAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 1 + self.max_tokens_per_batch = 1 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.mask_matrix = False + self.use_sinks = False + self.causal = False + self.use_dynamic_quant = False + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + def init_tensor(self): + # seq_lens + if self.seq_len_dec is None: + self.seq_lens_dec = [ + self.cache_len, + ] * self.batch_size + else: + self.batch_size = len(self.seq_lens_dec) + self.seq_lens_decoder = paddle.to_tensor( + self.seq_lens_dec, + "int32", + ) + if self.seq_lens_this_time is None: + self.seq_lens_this_time = [ + self.max_tokens_per_batch, + ] * self.batch_size + self.token_num = sum(self.seq_lens_this_time) + self.seq_lens_this_time = paddle.to_tensor(self.seq_lens_this_time, "int32") + + self.seq_lens_enc = [0] * self.batch_size + + self.seq_lens_encoder = paddle.to_tensor( + self.seq_lens_enc, + "int32", + ) + + # self.qkv = paddle.rand([self.token_num, (self.q_num_head + 2 * self.kv_num_head) * self.head_dim], dtype=self.dtype) + self.q, self.k, self.v, self.qkv = get_qkv_and_qkv_concat_tensor( + self.batch_size, + self.q_num_head, + self.kv_num_head, + self.max_tokens_per_batch, + self.head_dim, + self.place, + self.dtype, + ) + self.qkv = paddle.to_tensor(self.qkv, dtype=self.dtype) + + # qk_norm + self.q_norm_weight = None + self.k_norm_weight = None + if self.use_qk_norm: + q_norm_weight_np = np.random.random([self.head_dim]) / 10 + k_norm_weight_np = np.random.random([self.head_dim]) / 10 + self.q_norm_weight = paddle.to_tensor(q_norm_weight_np, dtype="float32") + self.k_norm_weight = paddle.to_tensor(k_norm_weight_np, dtype="float32") + + # rotary embedding + self.rope = RopeEmbedding(False) + tmp_position_ids = paddle.arange(self.max_model_len).reshape((1, -1)) + self.rotary_embs = self.rope.get_rotary_position_embedding(tmp_position_ids, self.head_dim) + + # block_table + self.block_num_per_seq = (self.max_model_len + self.block_size - 1) // self.block_size + self.max_block_num = self.block_num_per_seq * self.batch_size + self.free_list = list(range(self.max_block_num - 1, -1, -1)) + self.block_tables = paddle.zeros(shape=(self.batch_size, self.block_num_per_seq), dtype="int32") + for i in range(self.batch_size): + need_block_num = (self.max_model_len + self.block_size - 1) // self.block_size + for j in range(need_block_num): + self.block_tables[i, j] = self.free_list.pop() + + # cache_kv && scale + self.cache_shape = ( + self.max_block_num, + self.kv_num_head, + self.block_size, + self.head_dim, + ) + + if self.use_dynamic_quant: + self.cache_scale_shape = ( + self.max_block_num, + self.kv_num_head, + self.block_size, + ) + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype="uint8") + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype="uint8") + self.cache_k_T = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.cache_v_T = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.cache_k_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype) + self.cache_v_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype) + self.cache_k_out_scale = None + self.cache_k_out_scale = None + else: + self.cache_k_scale = self.quant_max_bound / self.k.transpose([1, 0, 2, 3]).reshape( + [self.kv_num_head, -1] + ).abs().max(axis=1) + self.cache_v_scale = self.quant_max_bound / self.v.transpose([1, 0, 2, 3]).reshape( + [self.kv_num_head, -1] + ).abs().max(axis=1) + + self.cache_k_out_scale = ( + self.k.transpose([1, 0, 2, 3]).reshape([self.kv_num_head, -1]).max(axis=1) / self.quant_max_bound + ) + self.cache_v_out_scale = ( + self.v.transpose([1, 0, 2, 3]).reshape([self.kv_num_head, -1]).max(axis=1) / self.quant_max_bound + ) + + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype="uint8") + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype="uint8") + + ( + self.batch_id_per_token, + self.cu_seqlens_q, + self.cu_seqlens_k, + ) = get_padding_offset(self.batch_size, self.seq_lens_this_time) + + # mask + if self.mask_matrix: + self.attn_mask = create_attn_mask( + self.dtype, + self.batch_size, + [ + self.max_tokens_per_batch, + ] + * self.batch_size, + sliding_window=self.sliding_window, + ) + else: + self.attn_mask = None + + # mask offset + self.mask_offset = None + if self.use_mask_offset: + self.mask_offset = paddle.full(self.batch_size * 2, 0, "int32") + for i in range(self.batch_size): + self.mask_offset[i * 2] = 0 + self.mask_offset[i * 2 + 1] = self.seq_lens_dec[i] + 1 + + if self.use_sinks: + self.sinks = paddle.to_tensor( + np.random.random([self.q_num_head]), place=self.place, dtype=self.dtype, stop_gradient=False + ) + else: + self.sinks = None + + # buffer + self.buffer = {} + min_chunk_size = 128 + max_num_chunk = (self.max_model_len + min_chunk_size - 1) // min_chunk_size + self.group_size = self.q_num_head // self.kv_num_head + q_tile_size = 16 if self.max_tokens_per_batch * self.group_size <= 16 else 32 + q_tile_num = (self.max_tokens_per_batch * self.group_size + q_tile_size - 1) // q_tile_size + self.buffer["max_len_tensor_cpu"] = paddle.full([6], 0, dtype="int32").cpu() + # block_indices: Launched block's indices with 4 dimensions [batch_idx, kv_head_idx, chunk_idx, q_tile_idx] in decode append attention backend + self.buffer["block_indices"] = paddle.full( + [self.batch_size * self.kv_num_head * max_num_chunk * q_tile_num, 4], 0, dtype="int32" + ) + # num_blocks: Number of Launched blocks in decode append attention backend, researched by config_for_attention op + self.buffer["num_blocks"] = paddle.full([1], 0, dtype="int32") + # chunk_size: Chunk size for split kv cache in decode append attention backend, researched by config_for_attention op + self.buffer["chunk_size"] = paddle.full([1], 0, dtype="int32") + # tmp_workspace: Workspace tensor for temporary store the result before merging in decode append attention backend + self.buffer["tmp_workspace"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head * self.head_dim], + 0, + dtype=self.dtype, + ) + # tmp_m: Tmp_m tensor for temporary store the max value before merging in decode append attention backend + self.buffer["tmp_m"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) + # tmp_d: Tmp_d tensor for temporary store the exponential sum before merging in decode append attention backend + self.buffer["tmp_d"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) + + def apply_qk_norm(self, head_dim, dtype, q, k): + bs, q_num_head, seq_len, head_dim = q.shape + _, kv_num_head, _, _ = k.shape + + q = q.reshape([-1, head_dim]) + k = k.reshape([-1, head_dim]) + q = fused_rms_norm(q.astype("float32"), self.q_norm_weight, None, self.rms_norm_eps)[0].astype(dtype) + k = fused_rms_norm(k.astype("float32"), self.k_norm_weight, None, self.rms_norm_eps)[0].astype(dtype) + q = q.reshape([-1, q_num_head, seq_len, head_dim]) + k = k.reshape([-1, kv_num_head, seq_len, head_dim]) + return q, k + + def naive_attention(self, pre_k, pre_v): + q, k = self.rope._apply_rope(self.rotary_embs, self.q, self.k, self.cache_len) + if self.use_qk_norm: + q, k = self.apply_qk_norm(self.head_dim, self.dtype, q, k) + + out_ref = naive_attention_impl( + q, + k, + self.v, + pre_k, + pre_v, + self.attn_mask, + self.softmax_scale, + sinks=self.sinks, + ) + out_ref = remove_padding(self.seq_lens_this_time, self.cu_seqlens_q, out_ref, self.token_num) + return q, k, self.v, out_ref + + def append_attention(self): + # buffer + max_num_block_dec = self.batch_size * (self.max_model_len * self.group_size + 16 - 1) // 16 + decoder_batch_ids = paddle.full([max_num_block_dec], 0, dtype="int32") + decoder_tile_ids_per_batch = paddle.full([max_num_block_dec], 0, dtype="int32") + decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") + decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") + + max_num_block = self.batch_size * (self.max_model_len * self.group_size + 64 - 1) // 64 + encoder_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + encoder_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + encoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + + kv_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + kv_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + max_len_tensor_cpu = paddle.full([6], 0, dtype="int32").cpu() + + get_block_shape_and_split_kv_block( + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_cpu, + decoder_num_blocks_device, + decoder_chunk_size_device, + max_len_tensor_cpu, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_cpu, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + 64, + 16, + self.group_size, + self.block_size, + ) + qkv = copy.deepcopy(self.qkv) + cache_k = copy.deepcopy(self.cache_k) + cache_v = copy.deepcopy(self.cache_v) + _ = append_attention( + qkv, + cache_k, + cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.batch_id_per_token, + self.cu_seqlens_q, + self.block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_cpu, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_cpu, + max_len_tensor_cpu, + self.rotary_embs, + None, # attn_mask + None, # qkv_bias + None, # qkv_out_scales + self.cache_k_scale, # cache_k_quant_scales + self.cache_v_scale, # cache_v_quant_scales + self.cache_k_out_scale, # cache_k_dequant_scales + self.cache_v_out_scale, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # linear_shift + None, # linear_smooth + self.mask_offset, + None, # kv_signal_data + self.q_norm_weight, + self.k_norm_weight, + self.sinks, + self.rms_norm_eps, + "bf16", + self.cache_quant_type, + False, # use_neox_rotary_style + self.rope_3d, + self.max_model_len, + self.quant_max_bound, # quant_max_bound + self.quant_min_bound, # quant_min_bound + -1, + 64, + 16, + 32768, + 1024, + self.max_tokens_per_batch, + self.causal, + self.max_tokens_per_batch > 1, + self.sliding_window, + ) + + def decode_attention(self): + paddle.disable_static() + + config_for_attention( + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.buffer["block_indices"], + self.buffer["num_blocks"], + self.buffer["chunk_size"], + self.buffer["max_len_tensor_cpu"], + self.cache_quant_type, + self.group_size, + self.kv_num_head, + self.max_tokens_per_batch, + ) + # print(f"num_blocks: {self.buffer['num_blocks']}") + decoder_write_cache_with_rope( + self.qkv, + self.cache_k, + self.cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.batch_id_per_token, + self.cu_seqlens_q, + self.block_tables, + self.buffer["max_len_tensor_cpu"], + self.rotary_embs, # rotary_embs + None, # qkv_bias + self.cache_k_scale, # cache_k_quant_scales + self.cache_v_scale, # cache_v_quant_scales + self.cache_k_out_scale, # cache_k_dequant_scales + self.cache_v_out_scale, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # kv_signal_data + self.q_norm_weight, # q_norm_weight + self.k_norm_weight, # k_norm_weight + self.rms_norm_eps, + self.cache_quant_type, + False, # use_neox_rotary_style + self.rope_3d, + self.max_model_len, + self.quant_max_bound, # quant_max_bound + self.quant_min_bound, # quant_min_bound + self.max_tokens_per_batch > 1, # speculate_decoder + ) + + out = decode_append_attention( + self.qkv, + self.cache_k, + self.cache_v, + self.buffer["tmp_workspace"], + self.buffer["tmp_m"], + self.buffer["tmp_d"], + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.batch_id_per_token, + self.cu_seqlens_q, + self.block_tables, + self.buffer["block_indices"], + self.buffer["num_blocks"], + self.buffer["chunk_size"], + self.buffer["max_len_tensor_cpu"], # rope_emb + None, # attn_mask + self.cache_k_scale, # cache_k_quant_scales + self.cache_v_scale, # cache_v_quant_scales + self.cache_k_out_scale, # cache_k_dequant_scales + self.cache_v_out_scale, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + self.mask_offset, # mask_offset + self.sinks, # sinks + self.cache_quant_type, + self.max_model_len, + self.quant_max_bound, # quant_max_bound + self.quant_min_bound, # quant_min_bound + self.max_tokens_per_batch, # speculate_max_draft_token_num + self.causal, # causal + self.sliding_window, + ) + return self.qkv, out + + def prefill(self): + # init seq_len + seq_lens_encoder = copy.deepcopy(self.seq_lens_decoder) + seq_lens_decoder = paddle.zeros([self.batch_size], dtype="int32") + seq_lens_this_time = seq_lens_encoder + token_num = seq_lens_this_time.sum().item() + qkv_np = np.random.random([token_num, (self.q_num_head + 2 * self.kv_num_head) * self.head_dim]) - 0.5 + qkv = paddle.to_tensor(qkv_np, dtype=self.dtype) + + ( + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k, + ) = get_padding_offset(self.batch_size, seq_lens_this_time) + # buffer + decode_max_tile_size = self.batch_size * (self.max_model_len * self.group_size + 16 - 1) // 16 + decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") + decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") + max_num_block = self.batch_size * (self.max_model_len * self.group_size + 64 - 1) // 64 + encoder_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + encoder_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + encoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + + kv_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + kv_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + max_len_tensor_cpu = paddle.full([6], 0, dtype="int32").cpu() + get_block_shape_and_split_kv_block( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_cpu, + decoder_num_blocks_device, + decoder_chunk_size_device, + max_len_tensor_cpu, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_cpu, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + 64, + 16, + self.group_size, + self.block_size, + ) + ( + cu_seqlens_k, + pre_cache_batch_ids, + pre_cache_tile_ids_per_batch, + pre_cache_num_blocks_cpu, + kv_token_num_cpu, + ) = pre_cache_len_concat( + seq_lens_decoder, + seq_lens_this_time, + max_len_tensor_cpu[2], + self.block_size, + ) + q, k, v, _ = gqa_rope_write_cache( + qkv, + self.cache_k, + self.cache_v, + cu_seqlens_q, + cu_seqlens_k, + self.rotary_embs, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + batch_id_per_token, + self.block_tables, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + pre_cache_batch_ids, + pre_cache_tile_ids_per_batch, + pre_cache_num_blocks_cpu, + self.q_norm_weight, + self.k_norm_weight, + self.cache_k_scale, # cache_k_quant_scales + self.cache_v_scale, # cache_v_quant_scales + self.cache_k_out_scale, # cache_k_dequant_scales + self.cache_v_out_scale, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # kv_signal_data + kv_token_num_cpu[0].item(), + self.max_model_len, + self.rms_norm_eps, + False, # use_neox_rotary_style + self.cache_quant_type, + self.rope_3d, + ) + + k = k.reshape([self.batch_size, -1, self.kv_num_head, self.head_dim]).transpose([0, 2, 1, 3]) + v = v.reshape([self.batch_size, -1, self.kv_num_head, self.head_dim]).transpose([0, 2, 1, 3]) + return k, v + + def test_all(self): + pre_k, pre_v = self.prefill() + + q_ref, k_ref, v_ref, out_ref = self.naive_attention(pre_k, pre_v) + qkv_out, out = self.decode_attention() + + np.testing.assert_allclose( + out.astype("float32").numpy(), + out_ref.astype("float32").numpy(), + rtol=1e-03, + atol=2e-03, + ) + + # profiler + def profile(self): + pre_k, pre_v = self.prefill() + paddle.device.synchronize() + self.append_attention() + paddle.device.synchronize() + qkv_out, out = self.decode_attention() + paddle.device.synchronize() + + +class TestDecodeAppendAttentionMultiBatch(TestDecodeAppendAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeAppendAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 60 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.mask_matrix = False + self.use_sinks = False + self.causal = False + self.use_dynamic_quant = False + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeAppendAttentionSpeculate(TestDecodeAppendAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeAppendAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.mask_matrix = False + self.use_sinks = False + self.causal = False + self.use_dynamic_quant = False + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeAppendAttentionMultiHead(TestDecodeAppendAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeAppendAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 16 + self.kv_num_head = 2 + self.batch_size = 6 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.mask_matrix = False + self.use_sinks = False + self.causal = False + self.use_dynamic_quant = False + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeAppendAttentionMultiSpeculate(TestDecodeAppendAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeAppendAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 4 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.mask_matrix = False + self.use_sinks = False + self.causal = False + self.use_dynamic_quant = False + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeAppendAttentionQKNorm(TestDecodeAppendAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeAppendAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = True + self.use_mask_offset = False + self.mask_matrix = False + self.use_sinks = False + self.causal = False + self.use_dynamic_quant = False + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/operators/attention/test_decode_append_attention_c16.py b/tests/operators/attention/test_decode_append_attention_c16.py new file mode 100644 index 00000000000..5cb869bf709 --- /dev/null +++ b/tests/operators/attention/test_decode_append_attention_c16.py @@ -0,0 +1,868 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import random +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.layers.attention.ops import ( + append_attention as append_attention_op, +) +from fastdeploy.model_executor.layers.attention.ops import ( + config_for_attention, + decode_append_attention, + decoder_write_cache_with_rope, + get_block_shape_and_split_kv_block, +) + +seed = 1000 + +random.seed(seed) +np.random.seed(seed) +paddle.seed(seed) + + +class RopeEmbedding: + def __init__(self, use_neox_rotary_style=False): + self.use_neox_rotary_style = use_neox_rotary_style + self.base = 10000 + + def get_rotary_position_embedding(self, position_ids, head_dim): + bsz, max_seq_len = position_ids.shape[:2] + rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim // 2), dtype="float32") + inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim) + freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, head_dim // 2)) + emb = paddle.unsqueeze(emb, 2) + rot_emb[0] = paddle.cos(emb) + rot_emb[1] = paddle.sin(emb) + return rot_emb + + def _apply_rope(self, rotary_emb, q, k, start_pos=0): + seq, head_dim = q.shape[2], q.shape[3] + cos, sin = paddle.chunk(rotary_emb, 2, axis=0) + cos = cos[:, :, start_pos : start_pos + seq, ...] + sin = sin[:, :, start_pos : start_pos + seq, ...] + cos = paddle.squeeze(cos, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :] + sin = paddle.squeeze(sin, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :] + + sin_pos = paddle.reshape(paddle.stack([sin, sin], axis=-1), [1, 1, seq, head_dim]) + cos_pos = paddle.reshape(paddle.stack([cos, cos], axis=-1), [1, 1, seq, head_dim]) + rotate_half_q = paddle.reshape( + paddle.stack([-q[:, :, :, 1::2], q[:, :, :, 0::2]], axis=-1), + paddle.shape(q), + ) + rotate_half_k = paddle.reshape( + paddle.stack([-k[:, :, :, 1::2], k[:, :, :, 0::2]], axis=-1), + paddle.shape(k), + ) + + query = paddle.add(paddle.multiply(q, cos_pos), paddle.multiply(rotate_half_q, sin_pos)) + key = paddle.add(paddle.multiply(k, cos_pos), paddle.multiply(rotate_half_k, sin_pos)) + return paddle.cast(query, q.dtype), paddle.cast(key, k.dtype) + + +def naive_attention_impl(query, key, value, cache_k=None, cache_v=None, mask=None, scale=1.0): + batch = query.shape[0] + heads = query.shape[1] + seq_len = query.shape[2] + head_dim = query.shape[3] + kv_head = key.shape[1] + + key = key.reshape([batch, kv_head, 1, seq_len, head_dim]) + key = paddle.tile(key, [1, 1, heads // kv_head, 1, 1]) + key = key.reshape([batch, heads, seq_len, head_dim]) + + if cache_k is not None: + cache_k = cache_k.reshape([batch, kv_head, 1, -1, head_dim]) + cache_k = paddle.tile(cache_k, [1, 1, heads // kv_head, 1, 1]) + cache_k = cache_k.reshape([batch, heads, -1, head_dim]) + key = paddle.concat([cache_k, key], axis=2) + + value = value.reshape([batch, kv_head, 1, seq_len, head_dim]) + value = paddle.tile(value, [1, 1, heads // kv_head, 1, 1]) + value = value.reshape([batch, heads, seq_len, head_dim]) + + if cache_v is not None: + cache_v = cache_v.reshape([batch, kv_head, 1, -1, head_dim]) + cache_v = paddle.tile(cache_v, [1, 1, heads // kv_head, 1, 1]) + cache_v = cache_v.reshape([batch, heads, -1, head_dim]) + value = paddle.concat([cache_v, value], axis=2) + + qk_res = paddle.matmul(query, key, transpose_y=True) + attention = qk_res * scale + if mask is not None: + attention = attention + mask + softmax_result = paddle.nn.functional.softmax(attention, -1) + result = paddle.matmul(paddle.cast(softmax_result, dtype=value.dtype), value) + return result + + +def block_cache_to_naive_cache(cache_k, cache_v, bsz, block_tables, cache_seq_len): + """Read K/V from paged cache and return as [batch, num_head, seq_len, dim_head].""" + _, num_head, blocksize, dim_head = cache_k.shape + out_cache_k = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_k.dtype) + out_cache_v = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_v.dtype) + for i in range(bsz): + for j in range(cache_seq_len): + out_cache_k[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :] + out_cache_v[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :] + return out_cache_k, out_cache_v + + +def get_padding_offset(bsz, seq_lens_this_time): + token_num = paddle.sum(seq_lens_this_time) + batch_id_per_token = paddle.zeros(shape=(token_num), dtype="int32") + cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32") + cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32") + index = 0 + for i in range(bsz): + seq_len_now = seq_lens_this_time[i].item() + for j in range(seq_len_now): + batch_id_per_token[index] = i + index += 1 + cu_seqlens_q[i + 1] = index + cu_seqlens_k[i + 1] = index + return batch_id_per_token, cu_seqlens_q, cu_seqlens_k + + +def remove_padding(seq_lens, cu_seq_lens, inputs, token_num): + bsz, num_head, seq_len, head_dim = inputs.shape + output = paddle.zeros(shape=[token_num, num_head * head_dim], dtype=inputs.dtype) + inputs = inputs.transpose([0, 2, 1, 3]).reshape([bsz, seq_len, -1]) + for i in range(bsz): + seq_len_now = seq_lens[i] + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + output[start_idx:end_idx, :] = inputs[i, :seq_len_now, :] + return output + + +def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, head_dim, place, dtype): + query = np.random.random([bs, q_num_head, seq_len, head_dim]) + q = paddle.to_tensor(query, place=place, dtype=dtype, stop_gradient=False) - 0.5 + key = np.random.random([bs, kv_num_head, seq_len, head_dim]) + k = paddle.to_tensor(key, place=place, dtype=dtype, stop_gradient=False) - 0.5 + value = np.random.random([bs, kv_num_head, seq_len, head_dim]) + v = paddle.to_tensor(value, place=place, dtype=dtype, stop_gradient=False) - 0.5 + token_num = bs * seq_len + + qkv = paddle.concat( + [ + q.transpose([0, 2, 1, 3]).reshape([token_num, q_num_head * head_dim]), + k.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * head_dim]), + v.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * head_dim]), + ], + axis=1, + ).reshape([token_num, -1]) + return q, k, v, qkv + + +class TestDecodeAppendAttentionC16(unittest.TestCase): + """Base test class for decode append attention with cache_quant_type='none' (fp16/bf16 KV cache). + + Uses append_attention for prefill (verified correct by test_append_attention_c16.py) + and then tests decode_attention (new split ops) against the same naive reference. + + Subclasses override setUp to vary batch_size, max_tokens_per_batch, dtype, etc. + """ + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 1 + self.max_tokens_per_batch = 1 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + + # Use small seq_len for fast testing; can increase later + self.seq_len = 64 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + def init_tensor(self): + self.rope = RopeEmbedding(self.use_neox_rotary_style) + tmp_position_ids = paddle.arange(self.max_model_len).reshape((1, -1)) + self.rotary_embs = self.rope.get_rotary_position_embedding(tmp_position_ids, self.head_dim) + + # block_table + self.block_num_per_seq = (self.max_model_len + self.block_size - 1) // self.block_size + self.max_block_num = self.block_num_per_seq * self.batch_size + self.free_list = list(range(self.max_block_num - 1, -1, -1)) + self.block_tables = paddle.zeros(shape=(self.batch_size, self.block_num_per_seq), dtype="int32") + for i in range(self.batch_size): + need_block_num = (self.max_model_len + self.block_size - 1) // self.block_size + for j in range(need_block_num): + self.block_tables[i, j] = self.free_list.pop() + + # cache + self.cache_shape = ( + self.max_block_num, + self.kv_num_head, + self.block_size, + self.head_dim, + ) + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + + # Encoder phase: prefill with seq_len tokens + self.enc_q, self.enc_k, self.enc_v, self.enc_qkv = get_qkv_and_qkv_concat_tensor( + self.batch_size, + self.q_num_head, + self.kv_num_head, + self.seq_len, + self.head_dim, + self.place, + self.dtype, + ) + + # Decoder phase: max_tokens_per_batch decode tokens + self.dec_q, self.dec_k, self.dec_v, self.dec_qkv = get_qkv_and_qkv_concat_tensor( + self.batch_size, + self.q_num_head, + self.kv_num_head, + self.max_tokens_per_batch, + self.head_dim, + self.place, + self.dtype, + ) + + def _get_block_shape_buffers(self, seq_lens_encoder, seq_lens_decoder, seq_lens_this_time): + max_num_block_dec = self.batch_size * (self.max_model_len * self.group_size + 16 - 1) // 16 + decoder_batch_ids = paddle.full([max_num_block_dec], 0, dtype="int32") + decoder_tile_ids_per_batch = paddle.full([max_num_block_dec], 0, dtype="int32") + decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") + decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") + + max_num_block = self.batch_size * (self.max_model_len * self.group_size + 64 - 1) // 64 + encoder_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + encoder_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + encoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + + kv_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + kv_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + max_len_tensor_cpu = paddle.full([6], 0, dtype="int32").cpu() + + get_block_shape_and_split_kv_block( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_cpu, + decoder_num_blocks_device, + decoder_chunk_size_device, + max_len_tensor_cpu, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_cpu, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + 64, + 16, + self.group_size, + self.block_size, + ) + return { + "decoder_batch_ids": decoder_batch_ids, + "decoder_tile_ids_per_batch": decoder_tile_ids_per_batch, + "decoder_num_blocks_cpu": decoder_num_blocks_cpu, + "encoder_batch_ids": encoder_batch_ids, + "encoder_tile_ids_per_batch": encoder_tile_ids_per_batch, + "encoder_num_blocks_cpu": encoder_num_blocks_cpu, + "kv_batch_ids": kv_batch_ids, + "kv_tile_ids_per_batch": kv_tile_ids_per_batch, + "kv_num_blocks_x_cpu": kv_num_blocks_x_cpu, + "max_len_tensor_cpu": max_len_tensor_cpu, + } + + def run_append_attention( + self, + qkv, + cache_k, + cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + ): + """Run append_attention op.""" + buffers = self._get_block_shape_buffers(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time) + + qkv_copy = copy.deepcopy(qkv) + cache_k_copy = copy.deepcopy(cache_k) + cache_v_copy = copy.deepcopy(cache_v) + + out = append_attention_op( + qkv_copy, + cache_k_copy, + cache_v_copy, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + self.block_tables, + buffers["encoder_batch_ids"], + buffers["encoder_tile_ids_per_batch"], + buffers["encoder_num_blocks_cpu"], + buffers["kv_batch_ids"], + buffers["kv_tile_ids_per_batch"], + buffers["kv_num_blocks_x_cpu"], + buffers["decoder_batch_ids"], + buffers["decoder_tile_ids_per_batch"], + buffers["decoder_num_blocks_cpu"], + buffers["max_len_tensor_cpu"], + self.rotary_embs, + None, # attn_mask + None, # qkv_bias + None, # qkv_out_scales + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # linear_shift + None, # linear_smooth + None, # mask_offset + None, # kv_signal_data + None, # q_norm_weight + None, # k_norm_weight + None, # sinks + self.rms_norm_eps, + "bf16", + self.cache_quant_type, + self.use_neox_rotary_style, + self.rope_3d, + self.max_model_len, + 0.0, # quant_max_bound + 0.0, # quant_min_bound + -1, + 64, + 16, + 32768, + 32768, + self.max_tokens_per_batch, # speculate_max_draft_token_num + self.causal, + self.max_tokens_per_batch > 1, # speculate_decoder + ) + return out, cache_k_copy, cache_v_copy + + def _build_decode_buffer(self): + """Build buffer for new split decode ops.""" + buffer = {} + min_chunk_size = 128 + max_num_chunk = (self.max_model_len + min_chunk_size - 1) // min_chunk_size + q_tile_size = 16 if self.max_tokens_per_batch * self.group_size <= 16 else 32 + q_tile_num = (self.max_tokens_per_batch * self.group_size + q_tile_size - 1) // q_tile_size + buffer["max_len_tensor_cpu"] = paddle.full([6], 0, dtype="int32").cpu() + buffer["block_indices"] = paddle.full( + [self.batch_size * self.kv_num_head * max_num_chunk * q_tile_num, 4], 0, dtype="int32" + ) + buffer["num_blocks"] = paddle.full([1], 0, dtype="int32") + buffer["chunk_size"] = paddle.full([1], 0, dtype="int32") + buffer["tmp_workspace"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head * self.head_dim], + 0, + dtype=self.dtype, + ) + buffer["tmp_m"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) + buffer["tmp_d"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) + return buffer + + def _run_decode_attention( + self, + cache_k, + cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + ): + """Run config_for_attention + decoder_write_cache_with_rope + decode_append_attention.""" + buffer = self._build_decode_buffer() + + config_for_attention( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + buffer["block_indices"], + buffer["num_blocks"], + buffer["chunk_size"], + buffer["max_len_tensor_cpu"], + self.cache_quant_type, + self.group_size, + self.kv_num_head, + self.max_tokens_per_batch, + ) + + dec_cache_k = copy.deepcopy(cache_k) + dec_cache_v = copy.deepcopy(cache_v) + dec_qkv = copy.deepcopy(self.dec_qkv) + + decoder_write_cache_with_rope( + dec_qkv, + dec_cache_k, + dec_cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + self.block_tables, + buffer["max_len_tensor_cpu"], + self.rotary_embs, + None, # qkv_bias + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # kv_signal_data + None, # q_norm_weight + None, # k_norm_weight + self.rms_norm_eps, + self.cache_quant_type, + self.use_neox_rotary_style, + self.rope_3d, + self.max_model_len, + 0.0, # quant_max_bound + 0.0, # quant_min_bound + self.max_tokens_per_batch > 1, # speculate_decoder + ) + + out = decode_append_attention( + dec_qkv, + dec_cache_k, + dec_cache_v, + buffer["tmp_workspace"], + buffer["tmp_m"], + buffer["tmp_d"], + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + self.block_tables, + buffer["block_indices"], + buffer["num_blocks"], + buffer["chunk_size"], + buffer["max_len_tensor_cpu"], + None, # attn_mask + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # mask_offset + None, # sinks + paddle.empty([dec_qkv.shape[0], self.q_num_head * self.head_dim], dtype=dec_qkv.dtype), # fmha_out + self.cache_quant_type, + self.max_model_len, + 0.0, # quant_max_bound + 0.0, # quant_min_bound + self.max_tokens_per_batch, # speculate_max_draft_token_num + self.causal, # causal + ) + return out, dec_cache_k, dec_cache_v + + def do_prefill_with_append_attention(self): + """Prefill using append_attention. Returns cache_k, cache_v after prefill.""" + seq_lens_encoder = paddle.to_tensor([self.seq_len] * self.batch_size, "int32") + seq_lens_decoder = paddle.to_tensor([0] * self.batch_size, "int32") + seq_lens_this_time = copy.deepcopy(seq_lens_encoder) + + batch_id_per_token, cu_seqlens_q, _ = get_padding_offset(self.batch_size, seq_lens_this_time) + + _, cache_k, cache_v = self.run_append_attention( + self.enc_qkv, + self.cache_k, + self.cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + ) + return cache_k, cache_v + + def compute_naive_decode_ref(self, cache_k, cache_v): + """Compute naive reference for decode step using cache from paged cache.""" + # Read K/V from paged cache + naive_cache_k, naive_cache_v = block_cache_to_naive_cache( + cache_k, cache_v, self.batch_size, self.block_tables, self.seq_len + ) + + # Only use the first decode token (seq_lens_this_time=1 per batch) + dec_q = self.dec_q[:, :, :1, :] + dec_k = self.dec_k[:, :, :1, :] + dec_v = self.dec_v[:, :, :1, :] + + # Apply RoPE to decode Q/K at position seq_len + dec_q_rope, dec_k_rope = self.rope._apply_rope(self.rotary_embs, dec_q, dec_k, start_pos=self.seq_len) + + # Compute naive attention + out_ref = naive_attention_impl( + dec_q_rope, + dec_k_rope, + dec_v, + cache_k=naive_cache_k, + cache_v=naive_cache_v, + scale=self.softmax_scale, + ) + + dec_seq_lens_this_time = paddle.to_tensor([1] * self.batch_size, "int32") + dec_token_num = self.batch_size + _, dec_cu_seqlens_q, _ = get_padding_offset(self.batch_size, dec_seq_lens_this_time) + out_ref = remove_padding(dec_seq_lens_this_time, dec_cu_seqlens_q, out_ref, dec_token_num) + return out_ref + + def test_naive_vs_append_attention_decode(self): + """Test: prefill with append_attention, then decode with append_attention. Compare to naive.""" + # Step 1: Prefill + cache_k, cache_v = self.do_prefill_with_append_attention() + + # Step 2: Naive reference for decode + out_ref = self.compute_naive_decode_ref(cache_k, cache_v) + + # Step 3: Decode with append_attention + # seq_lens_this_time must match qkv rows: batch_size * max_tokens_per_batch + dec_seq_lens_encoder = paddle.to_tensor([0] * self.batch_size, "int32") + dec_seq_lens_decoder = paddle.to_tensor([self.seq_len] * self.batch_size, "int32") + dec_seq_lens_this_time = paddle.to_tensor([self.max_tokens_per_batch] * self.batch_size, "int32") + + dec_batch_id_per_token, dec_cu_seqlens_q, _ = get_padding_offset(self.batch_size, dec_seq_lens_this_time) + + out_dec, _, _ = self.run_append_attention( + self.dec_qkv, + cache_k, + cache_v, + dec_seq_lens_encoder, + dec_seq_lens_decoder, + dec_seq_lens_this_time, + dec_batch_id_per_token, + dec_cu_seqlens_q, + ) + + out_ref_f = out_ref.astype("float32").numpy() + out_dec_f = out_dec.astype("float32").numpy() + + # Truncate to actual token count (output may be padded to max_tokens_per_batch) + dec_token_num = self.batch_size + out_dec_f = out_dec_f[:dec_token_num] + + np.testing.assert_allclose( + out_dec_f, + out_ref_f, + rtol=1e-02, + atol=1e-02, + err_msg="append_attention decode output doesn't match naive reference", + ) + + def test_naive_vs_decode_attention(self): + """Test: prefill with append_attention, then decode with new split decode ops.""" + # Step 1: Prefill + cache_k, cache_v = self.do_prefill_with_append_attention() + + # Step 2: Naive reference for decode + out_ref = self.compute_naive_decode_ref(cache_k, cache_v) + + # Step 3: Decode with new split ops + # seq_lens_this_time must match qkv rows: batch_size * max_tokens_per_batch + dec_seq_lens_encoder = paddle.to_tensor([0] * self.batch_size, "int32") + dec_seq_lens_decoder = paddle.to_tensor([self.seq_len] * self.batch_size, "int32") + dec_seq_lens_this_time = paddle.to_tensor([self.max_tokens_per_batch] * self.batch_size, "int32") + + dec_batch_id_per_token, dec_cu_seqlens_q, _ = get_padding_offset(self.batch_size, dec_seq_lens_this_time) + + out, _, _ = self._run_decode_attention( + cache_k, + cache_v, + dec_seq_lens_encoder, + dec_seq_lens_decoder, + dec_seq_lens_this_time, + dec_batch_id_per_token, + dec_cu_seqlens_q, + ) + + out_ref_f = out_ref.astype("float32").numpy() + out_decode_f = out.astype("float32").numpy() + + # Truncate to actual token count (output may be padded to max_tokens_per_batch) + dec_token_num = self.batch_size + out_decode_f = out_decode_f[:dec_token_num] + + np.testing.assert_allclose( + out_decode_f, + out_ref_f, + rtol=1e-02, + atol=1e-02, + err_msg="decode_append_attention output doesn't match naive reference", + ) + + def test_append_vs_decode_attention(self): + """Test: append_attention decode vs new split decode ops should produce same result.""" + # Step 1: Prefill + cache_k, cache_v = self.do_prefill_with_append_attention() + + # Step 2: Decode with append_attention + # seq_lens_this_time must match qkv rows: batch_size * max_tokens_per_batch + dec_seq_lens_encoder = paddle.to_tensor([0] * self.batch_size, "int32") + dec_seq_lens_decoder = paddle.to_tensor([self.seq_len] * self.batch_size, "int32") + dec_seq_lens_this_time = paddle.to_tensor([self.max_tokens_per_batch] * self.batch_size, "int32") + dec_batch_id_per_token, dec_cu_seqlens_q, _ = get_padding_offset(self.batch_size, dec_seq_lens_this_time) + + out_append, _, _ = self.run_append_attention( + self.dec_qkv, + copy.deepcopy(cache_k), + copy.deepcopy(cache_v), + dec_seq_lens_encoder, + dec_seq_lens_decoder, + dec_seq_lens_this_time, + dec_batch_id_per_token, + dec_cu_seqlens_q, + ) + + # Step 3: Decode with new split ops + out_decode, _, _ = self._run_decode_attention( + cache_k, + cache_v, + dec_seq_lens_encoder, + dec_seq_lens_decoder, + dec_seq_lens_this_time, + dec_batch_id_per_token, + dec_cu_seqlens_q, + ) + + out_append_f = out_append.astype("float32").numpy() + out_decode_f = out_decode.astype("float32").numpy() + + # Truncate to actual token count (output may be padded to max_tokens_per_batch) + dec_token_num = self.batch_size + out_append_f = out_append_f[:dec_token_num] + out_decode_f = out_decode_f[:dec_token_num] + + np.testing.assert_allclose( + out_decode_f, + out_append_f, + rtol=1e-02, + atol=1e-02, + err_msg="decode_append_attention doesn't match append_attention decode", + ) + + +class TestDecodeAppendAttentionC16Speculate(TestDecodeAppendAttentionC16): + """Test with speculate decode: max_tokens_per_batch=2. + + When max_tokens_per_batch > 1, naive ref only computes 1 token while ops + compute multiple tokens. So naive comparison tests are skipped; only + append_attention vs decode_append_attention comparison is kept. + """ + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 1 + self.max_tokens_per_batch = 2 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 64 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + def test_naive_vs_append_attention_decode(self): + """Skip: naive ref only computes 1 token, but ops compute max_tokens_per_batch tokens.""" + pass + + def test_naive_vs_decode_attention(self): + """Skip: naive ref only computes 1 token, but ops compute max_tokens_per_batch tokens.""" + pass + + +class TestDecodeAppendAttentionC16MultiBatch(TestDecodeAppendAttentionC16): + """Test with multiple batches.""" + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 4 + self.max_tokens_per_batch = 1 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 64 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + +class TestDecodeAppendAttentionC16MultiHead(TestDecodeAppendAttentionC16): + """Test with multiple KV heads (GQA).""" + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 16 + self.kv_num_head = 2 + self.batch_size = 2 + self.max_tokens_per_batch = 1 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 64 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + +class TestDecodeAppendAttentionC16FP16(TestDecodeAppendAttentionC16): + """Test with float16 dtype.""" + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 1 + self.max_tokens_per_batch = 1 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "float16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 64 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + +class TestDecodeAppendAttentionC16NoCausal(TestDecodeAppendAttentionC16): + """Test with causal=False.""" + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 1 + self.max_tokens_per_batch = 1 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = False + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 64 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + +class TestDecodeAppendAttentionC16MultiBatchSpeculate(TestDecodeAppendAttentionC16): + """Test with multi-batch + speculate decode. + + When max_tokens_per_batch > 1, the naive reference only computes 1 token + while ops compute multiple tokens. So we only compare append_attention vs + decode_append_attention (both should produce same result), and skip the + naive comparison tests. + """ + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 4 + self.max_tokens_per_batch = 2 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 64 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + def test_naive_vs_append_attention_decode(self): + """Skip: naive ref only computes 1 token, but ops compute max_tokens_per_batch tokens.""" + pass + + def test_naive_vs_decode_attention(self): + """Skip: naive ref only computes 1 token, but ops compute max_tokens_per_batch tokens.""" + pass + + +if __name__ == "__main__": + unittest.main()