diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 204ea33e50b..d7f53c46209 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -427,6 +427,7 @@ std::vector GetPaddingOffset( const paddle::Tensor& seq_len, const paddle::optional& draft_tokens, const paddle::optional& seq_lens_encoder, + const paddle::optional& seq_lens_decoder, const int64_t token_num_cpu); void SetValueByFlagsAndIdx(const paddle::Tensor& token_ids_all, diff --git a/custom_ops/gpu_ops/get_padding_offset.cu b/custom_ops/gpu_ops/get_padding_offset.cu index 998d1760515..010eb709fe5 100644 --- a/custom_ops/gpu_ops/get_padding_offset.cu +++ b/custom_ops/gpu_ops/get_padding_offset.cu @@ -28,6 +28,7 @@ __global__ void PrefixSumKernel(int64_t *ids_remove_padding, const int max_seq_len, const int64_t *draft_tokens, const int *seq_lens_encoder, + const int *seq_lens_decoder, const int max_draft_tokens_per_batch) { const int bi = blockIdx.x; const int tid = threadIdx.x; @@ -39,23 +40,37 @@ __global__ void PrefixSumKernel(int64_t *ids_remove_padding, const int lane_id = threadIdx.x % 32; #endif - int cum_seq_len = 0; + // Independent q/k prefix sums: + // cu_seqlens_q = Σ seq_lens[j] (new tokens only, for Q/varlen indexing) + // cu_seqlens_k = Σ (seq_lens[j] + seq_lens_decoder[j]) (cached + new, for + // K/FA) + // When seq_lens_decoder == nullptr, cu_seqlens_k degenerates to cu_seqlens_q. + int cum_seq_len_q = 0; + int cum_seq_len_k = 0; - // compute sum of seq_lens[0,1,2,...,bi] for (int i = lane_id; i < bi + 1; i += WARP_SIZE) { - cum_seq_len += seq_lens[i]; + const int q_inc = seq_lens[i]; + const int k_inc = + q_inc + (seq_lens_decoder != nullptr ? seq_lens_decoder[i] : 0); + cum_seq_len_q += q_inc; + cum_seq_len_k += k_inc; } for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { - const int tmp = __shfl_up_sync(0xffffffff, cum_seq_len, offset); - if (lane_id >= offset) cum_seq_len += tmp; + const int tmp_q = __shfl_up_sync(0xffffffff, cum_seq_len_q, offset); + const int tmp_k = __shfl_up_sync(0xffffffff, cum_seq_len_k, offset); + if (lane_id >= offset) { + cum_seq_len_q += tmp_q; + cum_seq_len_k += tmp_k; + } } - cum_seq_len = __shfl_sync(0xffffffff, cum_seq_len, WARP_SIZE - 1); + cum_seq_len_q = __shfl_sync(0xffffffff, cum_seq_len_q, WARP_SIZE - 1); + cum_seq_len_k = __shfl_sync(0xffffffff, cum_seq_len_k, WARP_SIZE - 1); if (tid == 0) { - cu_seqlens_q[bi + 1] = cum_seq_len; - cu_seqlens_k[bi + 1] = cum_seq_len; + cu_seqlens_q[bi + 1] = cum_seq_len_q; + cu_seqlens_k[bi + 1] = cum_seq_len_k; } if (bi == 0 && tid == 0) { @@ -63,8 +78,9 @@ __global__ void PrefixSumKernel(int64_t *ids_remove_padding, cu_seqlens_k[0] = 0; } + // Q-side token scatter uses cum_seq_len_q (new-tokens-only layout). for (int i = tid; i < seq_lens[bi]; i += blockDim.x) { - const int tgt_seq_id = cum_seq_len - seq_lens[bi] + i; + const int tgt_seq_id = cum_seq_len_q - seq_lens[bi] + i; if (max_draft_tokens_per_batch > 0 && seq_lens_encoder[bi] <= 0) { // speculative decoding const int src_seq_id = bi * max_draft_tokens_per_batch + i; @@ -84,6 +100,7 @@ std::vector GetPaddingOffset( const paddle::Tensor &seq_len, const paddle::optional &draft_tokens, const paddle::optional &seq_lens_encoder, + const paddle::optional &seq_lens_decoder, const int64_t cpu_token_num) { #ifdef PADDLE_WITH_CUSTOM_DEVICE auto dev_ctx = static_cast( @@ -131,6 +148,7 @@ std::vector GetPaddingOffset( max_seq_len, draft_tokens ? draft_tokens.get().data() : nullptr, seq_lens_encoder ? seq_lens_encoder.get().data() : nullptr, + seq_lens_decoder ? seq_lens_decoder.get().data() : nullptr, max_draft_tokens_per_batch); return {x_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k}; @@ -156,7 +174,8 @@ PD_BUILD_STATIC_OP(get_padding_offset) .Inputs({"input_ids", "seq_len", paddle::Optional("draft_tokens"), - paddle::Optional("seq_lens_encoder")}) + paddle::Optional("seq_lens_encoder"), + paddle::Optional("seq_lens_decoder")}) .Outputs({"x_remove_padding", "batch_id_per_token", "cu_seqlens_q", diff --git a/custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu b/custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu index 63bc77c9afc..1aeb85c4496 100644 --- a/custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu +++ b/custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu @@ -17,7 +17,8 @@ __global__ void GetPositionIdsAndMaskEncoderBatchKernel( const int* seq_lens_encoder, // [bsz] 每个批次的 encoder 长度 - const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度 + const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度 (对于 MLA + // prefix cache,这是 cached KV 长度) const int* seq_lens_this_time, int* position_ids, // 输出的一维 position_ids const int bsz) { // 批次大小 @@ -25,11 +26,16 @@ __global__ void GetPositionIdsAndMaskEncoderBatchKernel( int tid = threadIdx.x; if (tid >= bsz) return; - // 动态计算当前批次的偏移量 + // 动态计算当前批次的偏移量。 + // 每个 batch 只会贡献 encoder_len 或 seq_lens_this_time 中的一个, + // 而非两者之和(chunked prefill 时 encoder_len > 0 与 decoder_len > 0 + // 同时成立, + // 但该 batch 只有 encoder_len 个真实 token)。 int offset = 0; for (int i = 0; i < tid; i++) { - offset += seq_lens_encoder[i]; - if (seq_lens_decoder[i] > 0) { + if (seq_lens_encoder[i] > 0) { + offset += seq_lens_encoder[i]; + } else if (seq_lens_decoder[i] > 0) { offset += seq_lens_this_time[i]; } } @@ -39,16 +45,21 @@ __global__ void GetPositionIdsAndMaskEncoderBatchKernel( int decoder_len = seq_lens_decoder[tid]; int seq_len_this_time = seq_lens_this_time[tid]; - // 写入 encoder 的 position_ids - for (int i = 0; i < encoder_len; i++) { - position_ids[offset + i] = i; - } - offset += encoder_len; - - // 写入 decoder 的 position_ids - if (decoder_len > 0) { + // For MLA with prefix cache support: + // - Chunked prefill (encoder_len > 0 && decoder_len > 0): + // only writes encoder positions starting at decoder_len (cached length). + // - First-chunk prefill (encoder_len > 0 && decoder_len == 0): + // writes encoder positions starting at 0. + // - Pure decode (encoder_len == 0 && decoder_len > 0): + // writes seq_lens_this_time decode positions starting at decoder_len. + if (encoder_len > 0) { + int start_pos = (decoder_len > 0) ? decoder_len : 0; + for (int i = 0; i < encoder_len; i++) { + position_ids[offset + i] = start_pos + i; + } + } else if (decoder_len > 0) { for (int i = 0; i < seq_len_this_time; i++) { - position_ids[offset + i] = decoder_len + i; // 使用 decoder 长度本身 + position_ids[offset + i] = decoder_len + i; } } } diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 7e2e1066f0b..b57618793b7 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -67,6 +67,221 @@ ) from fastdeploy.spec_decode import SpecMethod +# ============================================================================ +# Fused Read-Cache + Interleave Kernel for Prefix Cache Support +# +# For each output token position t in [0, total_tokens): +# - if cached: load latent vector from paged latent_cache (physical block) +# - if new: load from new_compressed_kv / new_k_pe at the given index +# A single kernel produces full_compressed_kv / full_k_pe directly, avoiding +# an intermediate (cached_kv_c, cached_k_pe) allocation and an extra launch. +# ============================================================================ + + +def fused_read_cache_and_interleave_naive( + latent_cache: paddle.Tensor, + block_tables: paddle.Tensor, + new_compressed_kv: paddle.Tensor, + new_k_pe: paddle.Tensor, + cu_seqlens_k: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + kv_lora_rank: int, + qk_rope_head_dim: int, + block_size: int, +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Python/Paddle reference of the fused read+interleave path. + + Takes with-cache ``cu_seqlens_k`` (cumsum of ``cached + new`` + per batch) plus ``cu_seqlens_q`` (new-only); the cached length of each + batch is derived as ``k_len - new_len``. + """ + bsz = cu_seqlens_k.shape[0] - 1 + cu_total = cu_seqlens_k.tolist() + cu_new = cu_seqlens_q.tolist() + total_tokens = int(cu_total[bsz]) + + full_compressed_kv = paddle.empty([total_tokens, kv_lora_rank], dtype=new_compressed_kv.dtype) + full_k_pe = paddle.empty([total_tokens, qk_rope_head_dim], dtype=new_k_pe.dtype) + if total_tokens == 0: + return full_compressed_kv, full_k_pe + + out_pos = 0 + for b in range(bsz): + k_len = int(cu_total[b + 1]) - int(cu_total[b]) + nn = int(cu_new[b + 1]) - int(cu_new[b]) + nc = k_len - nn + # cached tokens first + for t in range(nc): + block_idx = t // block_size + block_offset = t % block_size + physical_block_id = block_tables[b, block_idx].item() + latent_vec = latent_cache[physical_block_id, 0, block_offset, :] + full_compressed_kv[out_pos] = latent_vec[:kv_lora_rank] + full_k_pe[out_pos] = latent_vec[kv_lora_rank:] + out_pos += 1 + # new tokens after cached + new_base = int(cu_new[b]) + for t in range(nn): + full_compressed_kv[out_pos] = new_compressed_kv[new_base + t] + full_k_pe[out_pos] = new_k_pe[new_base + t] + out_pos += 1 + + assert ( + out_pos == total_tokens + ), f"fused_read_cache_and_interleave_naive: out_pos={out_pos} != total_tokens={total_tokens}" + return full_compressed_kv, full_k_pe + + +@triton.jit() +def _fused_read_interleave_kernel( + latent_cache_ptr, # [num_blocks, 1, block_size, LATENT_DIM] + new_kv_c_ptr, # [total_new, kv_lora_rank] + new_k_pe_ptr, # [total_new, qk_rope_head_dim] + cu_total_ptr, # [bsz+1] int32 (= forward_meta.cu_seqlens_k) + cu_new_ptr, # [bsz+1] int32 (= cu_seqlens_q) + block_tables_ptr, # [bsz, max_blocks_per_seq] int32 + out_kv_c_ptr, # [total_tokens, kv_lora_rank] + out_k_pe_ptr, # [total_tokens, qk_rope_head_dim] + total_tokens, + bsz, + max_blocks_per_seq: tl.constexpr, + block_size: tl.constexpr, + kv_lora_rank: tl.constexpr, + qk_rope_head_dim: tl.constexpr, + LATENT_DIM: tl.constexpr, + LOG2_MAX_BSZ: tl.constexpr, + BLOCK_M: tl.constexpr, +): + """Tiled fused read+interleave kernel. + + Each program handles ``BLOCK_M`` contiguous output tokens. The owning + batch of each token is recovered via an in-kernel binary search on the + cumsum ``cu_total`` (O(log bsz), fully L1-resident because bsz+1 is + small), avoiding a host-side ``repeat_interleave`` of shape + ``[total_tokens]``. + + kv_c / k_pe are loaded as two separate vectors because ``kv_lora_rank`` + and ``qk_rope_head_dim`` are individually power-of-2 but their sum + (``LATENT_DIM``) is generally not (e.g. 512+64=576). + """ + pid = tl.program_id(axis=0) + kv_c_offs = tl.arange(0, kv_lora_rank) + k_pe_offs = tl.arange(0, qk_rope_head_dim) + + # Unrolled at compile time - BLOCK_M is small + for m in tl.static_range(0, BLOCK_M): + token_idx = pid * BLOCK_M + m + if token_idx < total_tokens: + # Binary search: find largest b such that cu_total[b] <= token_idx. + # Loop bound LOG2_MAX_BSZ is compile-time; cu_total fits in L1. + lo = 0 + hi = bsz - 1 + for _ in tl.static_range(0, LOG2_MAX_BSZ): + mid = (lo + hi + 1) // 2 + cu_mid = tl.load(cu_total_ptr + mid) + if cu_mid <= token_idx: + lo = mid + else: + hi = mid - 1 + b = lo + + ct_b = tl.load(cu_total_ptr + b) + ct_b1 = tl.load(cu_total_ptr + b + 1) + cn_b = tl.load(cu_new_ptr + b) + cn_b1 = tl.load(cu_new_ptr + b + 1) + + k_len = ct_b1 - ct_b + nn = cn_b1 - cn_b + nc = k_len - nn + local_t = token_idx - ct_b + + if local_t < nc: + # cached token: walk paged latent_cache via block_tables + block_idx = local_t // block_size + block_off = local_t % block_size + pbid = tl.load(block_tables_ptr + b * max_blocks_per_seq + block_idx) + base = latent_cache_ptr + (pbid * block_size + block_off) * LATENT_DIM + kv_c_val = tl.load(base + kv_c_offs) + k_pe_val = tl.load(base + kv_lora_rank + k_pe_offs) + else: + # new token: linear region in new_compressed_kv / new_k_pe + src = cn_b + (local_t - nc) + kv_c_val = tl.load(new_kv_c_ptr + src * kv_lora_rank + kv_c_offs) + k_pe_val = tl.load(new_k_pe_ptr + src * qk_rope_head_dim + k_pe_offs) + + tl.store(out_kv_c_ptr + token_idx * kv_lora_rank + kv_c_offs, kv_c_val) + tl.store(out_k_pe_ptr + token_idx * qk_rope_head_dim + k_pe_offs, k_pe_val) + + +def fused_read_cache_and_interleave_triton( + latent_cache: paddle.Tensor, + block_tables: paddle.Tensor, + new_compressed_kv: paddle.Tensor, + new_k_pe: paddle.Tensor, + cu_seqlens_k: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + kv_lora_rank: int, + qk_rope_head_dim: int, + block_size: int, +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Triton-accelerated fused read+interleave. + + Single kernel launch with zero Python metadata computation. All per-batch + geometry is derived in-kernel from ``cu_seqlens_k`` and ``cu_seqlens_q`` + via a compile-time-bounded binary search. Only one scalar D2H is needed: + the final cumsum entry, used to size the output tensors. + + ``BLOCK_M=4`` and ``num_warps=8`` are tuned defaults (autotune is avoided + because ``triton.testing.do_bench`` is incompatible with paddle worker + subprocesses during ``profile_run``). + """ + bsz = cu_seqlens_k.shape[0] - 1 + total_tokens = int(cu_seqlens_k[-1]) + + full_compressed_kv = paddle.empty([total_tokens, kv_lora_rank], dtype=new_compressed_kv.dtype) + full_k_pe = paddle.empty([total_tokens, qk_rope_head_dim], dtype=new_k_pe.dtype) + if total_tokens == 0: + return full_compressed_kv, full_k_pe + + max_blocks_per_seq = block_tables.shape[1] + # Compile-time binary-search loop bound. ceil(log2(bsz)), clamped >=1. + log2_max_bsz = max(1, (bsz - 1).bit_length()) if bsz > 1 else 1 + + # Default tuned config: BLOCK_M=4 is robust across decode / prefill / + # chunk-prefill; num_warps=8 matches the 512-wide kv_lora_rank vector. + BLOCK_M = 4 + num_warps = 8 + + grid = ((total_tokens + BLOCK_M - 1) // BLOCK_M,) + _fused_read_interleave_kernel[grid]( + latent_cache, + new_compressed_kv, + new_k_pe, + cu_seqlens_k, + cu_seqlens_q, + block_tables, + full_compressed_kv, + full_k_pe, + total_tokens, + bsz, + max_blocks_per_seq=max_blocks_per_seq, + block_size=block_size, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + LATENT_DIM=kv_lora_rank + qk_rope_head_dim, + LOG2_MAX_BSZ=log2_max_bsz, + BLOCK_M=BLOCK_M, + num_warps=num_warps, + ) + return full_compressed_kv, full_k_pe + + +def fused_read_cache_and_interleave(*args, **kwargs): + """Unified entry. ``FD_MLA_USE_NAIVE=1`` forces the Python reference path.""" + if os.environ.get("FD_MLA_USE_NAIVE", "0") == "1": + return fused_read_cache_and_interleave_naive(*args, **kwargs) + return fused_read_cache_and_interleave_triton(*args, **kwargs) + @enable_compat_on_triton_kernel @triton.jit() @@ -232,6 +447,11 @@ class MLAAttentionMetadata(AttentionMetadata): max_dec_len_this_time: Optional[paddle.Tensor] = None max_kv_len_this_time: Optional[paddle.Tensor] = None + # For prefix cache and chunked prefill support + # ``cu_seqlens_k`` semantics produced by ``get_padding_offset``. + # forward_meta.cu_seqlens_k: Optional[paddle.Tensor] = None + max_seqlen_k: int = 0 + class MLAAttentionBackend(AttentionBackend): """ @@ -312,12 +532,12 @@ def __init__( is_paddle_supported = any(num >= 90 for num in paddle.version.cuda_archs()) if is_current_sm_supported and is_paddle_supported: self.flash_attn_func = flash_attention_v3_varlen - print("The current platform supports Flash Attention V3.") + logger.info("The current platform supports Flash Attention V3.") self.flash_attn_kwargs = {"softmax_scale": self.attn_softmax_scale} else: self.flash_attn_func = flash_attn_unpadded self.flash_attn_kwargs = {"scale": self.attn_softmax_scale, "training": False} - print( + logger.info( "The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead." ) @@ -364,6 +584,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1] metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2] metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[5] + metadata.max_seqlen_k = max(metadata.max_kv_len_this_time.item(), metadata.max_enc_len_this_time.item()) # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers @@ -411,7 +632,13 @@ def forward_extend( forward_meta: ForwardMeta, ) -> paddle.Tensor: """ - Prefill阶段的前向传播 + Prefill阶段的前向传播,支持 chunk prefill /prefix cache + + 对于 MLA 模型的 chunk prefill /prefix cache 支持: + 1. 如果开启 chunk prefill /prefix cache + - k 和 v 应该已经包含了 cached KV 和 new KV 的拼接 + - cu_seqlens_k 应该已经调整为包含 cached tokens + 2. 如果不存在 prefix cache,行为与之前相同 """ metadata = self.attention_metadata @@ -423,7 +650,7 @@ def forward_extend( latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None - # 写入缓存 + # 写入新的 KV 到缓存 (只写入新 tokens,不写入 cached 部分) prefill_mla_write_cache( compressed_kv, k_pe, @@ -438,7 +665,6 @@ def forward_extend( getattr(forward_meta, "max_input_length", -1), ) - # Flash注意力计算 fmha_out = self.flash_attn_func( q, k, @@ -446,7 +672,7 @@ def forward_extend( forward_meta.cu_seqlens_q, forward_meta.cu_seqlens_k, metadata.max_enc_len_this_time, - metadata.max_enc_len_this_time, + metadata.max_seqlen_k, causal=self.causal, **self.flash_attn_kwargs, )[0] @@ -553,7 +779,11 @@ def forward_mixed( forward_meta: ForwardMeta, ) -> paddle.Tensor: """ - Mixed模式的前向传播 + Mixed模式的前向传播,支持 chunk prefill /prefix cache + + 对于 MLA 模型的 chunk prefill /prefix cache 支持: + 1. Prefill 分支:k 和 v 应该已包含 cached + new tokens + 2. Decode 分支:保持原有 latent attention 逻辑 """ metadata = self.attention_metadata speculate_decoder = self.speculative_method is not None @@ -567,6 +797,7 @@ def forward_mixed( latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None + # Prefill branch: k is not None if k is not None: prefill_mla_write_cache( compressed_kv, @@ -581,7 +812,8 @@ def forward_mixed( "none", self.max_seq_len, ) - # FA + + # FlashAttention for prefill fmha_out = self.flash_attn_func( q, k, @@ -589,14 +821,14 @@ def forward_mixed( forward_meta.cu_seqlens_q, forward_meta.cu_seqlens_k, metadata.max_enc_len_this_time, - metadata.max_enc_len_this_time, + metadata.max_seqlen_k, causal=self.causal, **self.flash_attn_kwargs, )[0] return fmha_out - # Decode + # Decode branch: k is None if k is None: decode_mla_write_cache( compressed_kv, diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 0b11b0b0905..0ac87fa6dfa 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -208,6 +208,11 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None super().__init__() self.fd_config = fd_config + self.layer_id = layer_id + self.block_size: int = fd_config.cache_config.block_size + self.enable_chunked_prefill = self.fd_config.cache_config.enable_chunked_prefill + self.enable_prefix_caching = self.fd_config.cache_config.enable_prefix_caching + self.use_gated_attn = getattr(self.fd_config.model_config, "use_gated_attn", False) self.use_bias = getattr(self.fd_config.model_config, "use_bias", False) self.tp_size = fd_config.parallel_config.tensor_parallel_size @@ -351,7 +356,11 @@ def forward( forward_meta: ForwardMeta, hidden_states: paddle.Tensor, ): - """ """ + """MLA attention forward with prefix cache support.""" + + from fastdeploy.model_executor.layers.attention.mla_attention_backend import ( + fused_read_cache_and_interleave, + ) attn_out = None if self.use_gated_attn: @@ -377,8 +386,27 @@ def forward( need_do_prefill = forward_meta.max_len_tensor_cpu[1] > 0 need_do_decode = forward_meta.max_len_tensor_cpu[2] > 0 - if need_do_prefill: # max_enc_len_this_time - key_value = self.kv_b_proj(compressed_kv) + if need_do_prefill: + # Handle prefix cache: read cached latent from paged cache and interleave + # with the new-token latent in a single fused kernel call. + full_compressed_kv = compressed_kv + full_k_pe = key_pe.squeeze(1) + if self.enable_chunked_prefill or self.enable_prefix_caching: + + full_compressed_kv, full_k_pe = fused_read_cache_and_interleave( + forward_meta.caches[self.layer_id], + forward_meta.block_tables, + compressed_kv, + key_pe.squeeze(1), + forward_meta.cu_seqlens_k, + forward_meta.cu_seqlens_q, + self.kv_lora_rank, + self.qk_rope_head_dim, + self.block_size, + ) + + # Project latent KV to full key and value + key_value = self.kv_b_proj(full_compressed_kv) key_value.reshape_( [ -1, @@ -389,9 +417,9 @@ def forward( key_nope, value = key_value.split([self.qk_nope_head_dim, self.v_head_dim], axis=-1) query[..., self.qk_nope_head_dim :] = query_pe - key = paddle.empty_like(query) + key = paddle.empty([full_k_pe.shape[0], self.num_attention_heads_tp, self.qk_head_dim], dtype=query.dtype) key[..., : self.qk_nope_head_dim] = key_nope - key[..., self.qk_nope_head_dim :] = key_pe + key[..., self.qk_nope_head_dim :] = full_k_pe.unsqueeze(1) value = paddle.nn.functional.pad(value, [0, self.qk_head_dim - self.v_head_dim], value=0) fmha_out = self.mla_attn( @@ -399,8 +427,8 @@ def forward( k=key, v=value, qkv=None, - compressed_kv=compressed_kv, - k_pe=key_pe, + compressed_kv=compressed_kv, # Pass original (new only) for cache writing + k_pe=key_pe, # Pass original (new only) for cache writing forward_meta=forward_meta, ) diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 3bf251009b3..50f764caa43 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -176,7 +176,7 @@ def pre_process( if specific_platform and not speculative_decoding: # Note(ZKK): This case's code is very simple! ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset( - input_ids, seq_lens_this_time, None, None, token_num_cpu + input_ids, seq_lens_this_time, None, None, seq_lens_decoder, token_num_cpu ) return ( ids_remove_padding, diff --git a/tests/layers/test_attention_layer.py b/tests/layers/test_attention_layer.py index 21d3deb5cff..56c6d6afc03 100644 --- a/tests/layers/test_attention_layer.py +++ b/tests/layers/test_attention_layer.py @@ -239,7 +239,7 @@ def create_forward_meta( input_ids = paddle.zeros([batch_size, max_model_len], dtype="int64") token_num = np.sum(seq_lens_this_time) ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset( - input_ids, seq_lens_this_time, None, None, token_num + input_ids, seq_lens_this_time, None, None, None, int(token_num) ) forward_meta = ForwardMeta( diff --git a/tests/model_executor/test_mla_fused_read_interleave.py b/tests/model_executor/test_mla_fused_read_interleave.py new file mode 100644 index 00000000000..d1610bd8a8c --- /dev/null +++ b/tests/model_executor/test_mla_fused_read_interleave.py @@ -0,0 +1,340 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Precision parity tests for MLA fused read+interleave. + +Verifies that the Triton single-kernel implementation +``fused_read_cache_and_interleave_triton`` produces bit-level identical +results to the Python/Paddle reference ``fused_read_cache_and_interleave_naive`` +across a variety of batch/sequence configurations. +""" + +import numpy as np +import paddle +import pytest + +from fastdeploy.model_executor.layers.attention.mla_attention_backend import ( + fused_read_cache_and_interleave_naive, + fused_read_cache_and_interleave_triton, +) + +# Typical DeepSeek V3 MLA geometry. +KV_LORA_RANK = 512 +QK_ROPE_HEAD_DIM = 64 +LATENT_DIM = KV_LORA_RANK + QK_ROPE_HEAD_DIM +BLOCK_SIZE = 64 + +pytestmark = pytest.mark.skipif( + not paddle.is_compiled_with_cuda() or paddle.device.cuda.device_count() == 0, + reason="fused_read_cache_and_interleave_triton requires CUDA", +) + + +def _build_inputs(cached_lens, new_lens, dtype="float16", seed=0): + """Construct all tensors needed by the fused read+interleave entry.""" + assert len(cached_lens) == len(new_lens) + bsz = len(cached_lens) + total_new = int(sum(new_lens)) + + # Paged latent cache with enough blocks for all batches. + max_blocks_per_seq = max(1, max((c + BLOCK_SIZE - 1) // BLOCK_SIZE for c in cached_lens) if bsz else 1) + # Give extra blocks so block ids are not all sequential. + num_blocks = max(bsz * max_blocks_per_seq + 7, 8) + + rng = np.random.default_rng(seed) + latent_np = rng.standard_normal((num_blocks, 1, BLOCK_SIZE, LATENT_DIM)).astype( + np.float16 if dtype == "float16" else np.float32 + ) + latent_cache = paddle.to_tensor(latent_np).cast(dtype) + + # block_tables: assign a distinct physical block id per (batch, block_idx). + bt_np = np.zeros((bsz, max_blocks_per_seq), dtype=np.int32) + free_ids = list(range(num_blocks)) + rng.shuffle(free_ids) + cursor = 0 + for b in range(bsz): + nb = (cached_lens[b] + BLOCK_SIZE - 1) // BLOCK_SIZE + for i in range(nb): + bt_np[b, i] = free_ids[cursor] + cursor += 1 + block_tables = paddle.to_tensor(bt_np) + + new_kv_c_np = rng.standard_normal((max(total_new, 1), KV_LORA_RANK)).astype(np.float32) + new_k_pe_np = rng.standard_normal((max(total_new, 1), QK_ROPE_HEAD_DIM)).astype(np.float32) + new_compressed_kv = paddle.to_tensor(new_kv_c_np[:total_new] if total_new > 0 else new_kv_c_np[:0]).cast(dtype) + new_k_pe = paddle.to_tensor(new_k_pe_np[:total_new] if total_new > 0 else new_k_pe_np[:0]).cast(dtype) + + cu_total = np.zeros(bsz + 1, dtype=np.int32) + cu_new = np.zeros(bsz + 1, dtype=np.int32) + for i in range(bsz): + cu_total[i + 1] = cu_total[i] + cached_lens[i] + new_lens[i] + cu_new[i + 1] = cu_new[i] + new_lens[i] + cu_seqlens_k = paddle.to_tensor(cu_total) + cu_seqlens_q = paddle.to_tensor(cu_new) + + return { + "latent_cache": latent_cache, + "block_tables": block_tables, + "new_compressed_kv": new_compressed_kv, + "new_k_pe": new_k_pe, + "cu_seqlens_k": cu_seqlens_k, + "cu_seqlens_q": cu_seqlens_q, + } + + +def _run_both(inputs): + common = dict( + latent_cache=inputs["latent_cache"], + block_tables=inputs["block_tables"], + new_compressed_kv=inputs["new_compressed_kv"], + new_k_pe=inputs["new_k_pe"], + cu_seqlens_k=inputs["cu_seqlens_k"], + cu_seqlens_q=inputs["cu_seqlens_q"], + kv_lora_rank=KV_LORA_RANK, + qk_rope_head_dim=QK_ROPE_HEAD_DIM, + block_size=BLOCK_SIZE, + ) + + ref_kv, ref_pe = fused_read_cache_and_interleave_naive(**common) + out_kv, out_pe = fused_read_cache_and_interleave_triton(**common) + return ref_kv, ref_pe, out_kv, out_pe + + +def _assert_bitwise_equal(ref_kv, ref_pe, out_kv, out_pe): + assert tuple(out_kv.shape) == tuple(ref_kv.shape) + assert tuple(out_pe.shape) == tuple(ref_pe.shape) + # Both paths do pure copies (no arithmetic), so they should agree bit-for-bit. + np.testing.assert_array_equal(out_kv.cast("float32").numpy(), ref_kv.cast("float32").numpy()) + np.testing.assert_array_equal(out_pe.cast("float32").numpy(), ref_pe.cast("float32").numpy()) + + +# ----------------------------------------------------------------------------- +# Test cases: (name, cached_lens, new_lens) +# ----------------------------------------------------------------------------- +_CASES = [ + # Single batch, purely new tokens (no prefix cache). + ("single_no_cache", [0], [17]), + # Single batch, purely cached (edge case; triton path should still be correct). + ("single_all_cache_short", [13], [1]), + # Single batch spanning multiple cache blocks. + ("single_multi_block", [BLOCK_SIZE * 3 + 5], [9]), + # Single batch, cached aligned to exact block boundary. + ("single_cache_aligned", [BLOCK_SIZE * 2], [7]), + # Multi batch, uniform small lengths. + ("multi_uniform_small", [8, 8, 8, 8], [4, 4, 4, 4]), + # Multi batch, mixed: some with cache, some without. + ("multi_mixed_cache", [0, 33, 0, BLOCK_SIZE + 1], [5, 7, 11, 3]), + # Multi batch, long + short mixed. + ("multi_long_short", [BLOCK_SIZE * 5 + 3, 2, BLOCK_SIZE * 2, 0], [16, 1, 8, 12]), + # Larger batch with varied lengths. + ( + "multi_large", + [0, 17, BLOCK_SIZE, BLOCK_SIZE * 2 + 1, 5, 0, 64, BLOCK_SIZE * 4 + 30], + [8, 3, 16, 1, 12, 4, 9, 7], + ), + # Every cached_len > 0 (no batch without prefix cache). + ("multi_all_have_cache", [BLOCK_SIZE + 2, 7, BLOCK_SIZE * 2 + 17, 40], [6, 6, 6, 6]), + # bsz=1, minimal. + ("minimal", [1], [1]), +] + + +@pytest.mark.parametrize("name,cached_lens,new_lens", _CASES, ids=[c[0] for c in _CASES]) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) +def test_triton_matches_naive(name, cached_lens, new_lens, dtype): + inputs = _build_inputs(cached_lens, new_lens, dtype=dtype, seed=hash(name) & 0xFFFF) + ref_kv, ref_pe, out_kv, out_pe = _run_both(inputs) + _assert_bitwise_equal(ref_kv, ref_pe, out_kv, out_pe) + + +def test_total_tokens_zero(): + """All batches empty -> early return with empty tensors from both paths.""" + inputs = _build_inputs([0, 0], [0, 0], dtype="float16", seed=1) + ref_kv, ref_pe, out_kv, out_pe = _run_both(inputs) + assert ref_kv.shape[0] == 0 and out_kv.shape[0] == 0 + assert ref_pe.shape[0] == 0 and out_pe.shape[0] == 0 + + +# --------------------------------------------------------------------------- +# Hand-crafted baseline test: small tensors with deterministic integer values so +# every expected output entry can be written out and compared directly. This +# gives a ground-truth anchor independent of the naive reference, covering the +# real chunk-prefill / prefix-cache layout end-to-end. +# +# Scenario: +# bsz = 2 +# batch 0: cached=3, new=2 (spans 2 paged blocks: pbid 5 then pbid 2) +# batch 1: cached=0, new=4 +# block_size=2, kv_lora_rank=4, qk_rope_head_dim=2 +# +# Interleaved output layout (per-batch [cached, new]): +# t0..t2 : batch0 cached 0,1,2 (block_tables[0]=[5,2]) +# t3..t4 : batch0 new 0,1 +# t5..t8 : batch1 new 2,3,4,5 +# --------------------------------------------------------------------------- + + +def _fill_latent_cache(num_blocks, block_size, lora, rope): + """latent_cache[pbid, 0, off, d] = pbid * 1000 + off * 100 + d.""" + latent_dim = lora + rope + arr = np.zeros((num_blocks, 1, block_size, latent_dim), dtype=np.float32) + for p in range(num_blocks): + for o in range(block_size): + for d in range(latent_dim): + arr[p, 0, o, d] = p * 1000 + o * 100 + d + return arr + + +def test_manual_baseline(): + lora, rope = 4, 2 + block_size = 2 + num_blocks = 8 + + latent_np = _fill_latent_cache(num_blocks, block_size, lora, rope) + latent_cache = paddle.to_tensor(latent_np).cast("float32") + + # batch 0 cached tokens live in physical blocks (pbid=5 then pbid=2), + # batch 1 has no cached tokens so its row is unused. + block_tables = paddle.to_tensor(np.array([[5, 2], [0, 0]], dtype=np.int32)) + + # New tokens: 2 (batch 0) + 4 (batch 1) = 6 rows, deterministic values. + total_new = 6 + new_kv = np.zeros((total_new, lora), dtype=np.float32) + new_pe = np.zeros((total_new, rope), dtype=np.float32) + for i in range(total_new): + for d in range(lora): + new_kv[i, d] = -(i * 10 + d + 1) # negatives so they never clash with cache pattern + for d in range(rope): + new_pe[i, d] = -(i * 10 + d + 101) + new_compressed_kv = paddle.to_tensor(new_kv) + new_k_pe = paddle.to_tensor(new_pe) + + # cu_seqlens_k = cumsum(cached + new) per batch = [0, 5, 9] + # cu_seqlens_q = cumsum(new) per batch = [0, 2, 6] + cu_k = paddle.to_tensor(np.array([0, 5, 9], dtype=np.int32)) + cu_q = paddle.to_tensor(np.array([0, 2, 6], dtype=np.int32)) + + out_kv, out_pe = fused_read_cache_and_interleave_triton( + latent_cache, + block_tables, + new_compressed_kv, + new_k_pe, + cu_k, + cu_q, + lora, + rope, + block_size, + ) + + # Hand-built ground truth: total_tokens=9. + expected_kv = np.zeros((9, lora), dtype=np.float32) + expected_pe = np.zeros((9, rope), dtype=np.float32) + + # t0: batch0 cached 0 -> pbid=5, off=0 -> 5000 + 0 + d + # t1: batch0 cached 1 -> pbid=5, off=1 -> 5000 + 100 + d + # t2: batch0 cached 2 -> pbid=2, off=0 -> 2000 + 0 + d + cached_refs = [(5, 0), (5, 1), (2, 0)] + for t, (pbid, off) in enumerate(cached_refs): + for d in range(lora): + expected_kv[t, d] = pbid * 1000 + off * 100 + d + for d in range(rope): + expected_pe[t, d] = pbid * 1000 + off * 100 + lora + d + + # t3,t4: batch0 new tokens (src=0,1) + # t5..t8: batch1 new tokens (src=2,3,4,5) + new_sources = [0, 1, 2, 3, 4, 5] + for slot, src in enumerate(new_sources): + expected_kv[3 + slot] = new_kv[src] + expected_pe[3 + slot] = new_pe[src] + + np.testing.assert_array_equal(out_kv.numpy(), expected_kv) + np.testing.assert_array_equal(out_pe.numpy(), expected_pe) + + +def test_manual_baseline_no_cache(): + """All batches pure prefill (no prefix cache) -> kernel must still copy + new tokens in order with no off-by-one.""" + lora, rope = 4, 2 + block_size = 2 + # Any latent_cache / block_tables; kernel should never read from them. + latent_cache = paddle.zeros([4, 1, block_size, lora + rope], dtype="float32") + block_tables = paddle.zeros([3, 1], dtype="int32") + + new_kv = np.arange(5 * lora, dtype=np.float32).reshape(5, lora) + new_pe = np.arange(5 * rope, dtype=np.float32).reshape(5, rope) + 1000 + new_compressed_kv = paddle.to_tensor(new_kv) + new_k_pe = paddle.to_tensor(new_pe) + + # 3 batches, new=[2,1,2], no cache + cu_k = paddle.to_tensor(np.array([0, 2, 3, 5], dtype=np.int32)) + cu_q = paddle.to_tensor(np.array([0, 2, 3, 5], dtype=np.int32)) + + out_kv, out_pe = fused_read_cache_and_interleave_triton( + latent_cache, + block_tables, + new_compressed_kv, + new_k_pe, + cu_k, + cu_q, + lora, + rope, + block_size, + ) + np.testing.assert_array_equal(out_kv.numpy(), new_kv) + np.testing.assert_array_equal(out_pe.numpy(), new_pe) + + +def test_manual_baseline_all_cached(): + """Single batch, all cached (decode-like K build with zero new tokens).""" + lora, rope = 4, 2 + block_size = 2 + num_blocks = 4 + latent_np = _fill_latent_cache(num_blocks, block_size, lora, rope) + latent_cache = paddle.to_tensor(latent_np) + + # batch 0 cached=4 spans two blocks; pick pbid=3 then pbid=1. + block_tables = paddle.to_tensor(np.array([[3, 1]], dtype=np.int32)) + + # No new tokens; still need a valid (0-row) tensor. + new_compressed_kv = paddle.zeros([0, lora], dtype="float32") + new_k_pe = paddle.zeros([0, rope], dtype="float32") + + cu_k = paddle.to_tensor(np.array([0, 4], dtype=np.int32)) + cu_q = paddle.to_tensor(np.array([0, 0], dtype=np.int32)) + + out_kv, out_pe = fused_read_cache_and_interleave_triton( + latent_cache, + block_tables, + new_compressed_kv, + new_k_pe, + cu_k, + cu_q, + lora, + rope, + block_size, + ) + + # expected: pbid=3 off=0,1 then pbid=1 off=0,1 + expected_kv = np.zeros((4, lora), dtype=np.float32) + expected_pe = np.zeros((4, rope), dtype=np.float32) + for t, (pbid, off) in enumerate([(3, 0), (3, 1), (1, 0), (1, 1)]): + for d in range(lora): + expected_kv[t, d] = pbid * 1000 + off * 100 + d + for d in range(rope): + expected_pe[t, d] = pbid * 1000 + off * 100 + lora + d + np.testing.assert_array_equal(out_kv.numpy(), expected_kv) + np.testing.assert_array_equal(out_pe.numpy(), expected_pe) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/operators/test_get_padding_offset.py b/tests/operators/test_get_padding_offset.py index 13c86c422d2..8d21660e89a 100644 --- a/tests/operators/test_get_padding_offset.py +++ b/tests/operators/test_get_padding_offset.py @@ -32,7 +32,9 @@ def test_get_padding_offset(self): batch_id_per_token, cu_seqlens_q, cu_seqlens_k, - ) = get_padding_offset(paddle.to_tensor(input_ids), paddle.to_tensor(seq_lens), None, None, token_num_cpu) + ) = get_padding_offset( + paddle.to_tensor(input_ids), paddle.to_tensor(seq_lens), None, None, None, int(token_num_cpu) + ) ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6], "int64") ref_batch_id_per_token = np.array([0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2], "int32") diff --git a/tests/operators/test_get_position_ids_and_mask_encoder_batch.py b/tests/operators/test_get_position_ids_and_mask_encoder_batch.py index 2d1dd8e2f7c..cb7e38633fb 100644 --- a/tests/operators/test_get_position_ids_and_mask_encoder_batch.py +++ b/tests/operators/test_get_position_ids_and_mask_encoder_batch.py @@ -37,7 +37,7 @@ def test_basic_functionality(self): # Call the custom operator get_position_ids_and_mask_encoder_batch(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids) - expected_position_ids = np.array([0, 1, 2, 1, 0, 1, 2, 3], dtype=np.int32) + expected_position_ids = np.array([1, 2, 3, 2, 3, 0, 0, 0], dtype=np.int32) # Convert to numpy for comparison position_ids_np = position_ids.numpy() diff --git a/tests/operators/test_reasoning_phase_token_constraint.py b/tests/operators/test_reasoning_phase_token_constraint.py index 2c46b5d9c29..07fbf6e23ea 100644 --- a/tests/operators/test_reasoning_phase_token_constraint.py +++ b/tests/operators/test_reasoning_phase_token_constraint.py @@ -87,7 +87,8 @@ def setUp(self): seq_lens_output, None, None, - output_token_num.item(), + None, + int(output_token_num.item()), ) # self.output_padding_offset = paddle.zeros([self.token_num], dtype="int32") @@ -467,7 +468,8 @@ def test_perf_bsz128_vocab100k_status2(self): seq_lens_output, None, None, - output_token_num.item(), + None, + int(output_token_num.item()), ) # ------------------------