diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index d5cfa5de298..1614e58ebca 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -2107,7 +2107,6 @@ runtime::CudaEvent TrtGptModelInflightBatching::decoderStepAsync(ScheduledReques for (SizeType32 vid = 0; vid < getNumVocabs(); vid++) { auto& decoderInputBuffers = mDecoderInputBuffers.at(getFusedBufferId()); - auto& decoderOutputBuffers = mDecoderOutputBuffers[vid].at(getFusedBufferId()); auto& decoderState = mDecoderStates[vid]; auto const contextBufferId = mCtxGenFusion ? getFusedBufferId() : getContextBufferId(); @@ -2433,7 +2432,8 @@ void TrtGptModelInflightBatching::updateRequests(ScheduledRequests const& schedu // Terminate if request has finished or if it is speculative decoding target model if (decoderFinishedSumPtr[seqSlot] == reqBeamWidth - || (mModelConfig.getSpeculativeDecodingMode().isDraftTokensExternal() && llmReq->hasDraftTokens())) + || (mModelConfig.getSpeculativeDecodingMode().isDraftTokensExternal() && llmReq->hasDraftTokens()) + || (mModelConfig.useAttentionPrior() && llmReq->isAttentionPriorFinished())) { postProcessRequest(*llmReq, numDroppedTokens); diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h index 21b9112b9fe..3d9714d4e92 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h @@ -1528,6 +1528,18 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske float const kv_scale_quant_orig_f = (ENABLE_8BITS_KV_CACHE ? params.kv_scale_quant_orig[0] : 1.0f); convert_from_float(&k_scale_quant_orig, k_scale_quant_orig_f); convert_from_float(&kv_scale_orig_quant, (ENABLE_8BITS_KV_CACHE ? params.kv_scale_orig_quant[0] : 1.0f)); + // parameters related to attention prior + int focus; + if (params.attention_prior_focus != nullptr) + { + focus = params.attention_prior_focus[batch_beam_idx]; + } + bool const store_scores = params.attention_prior_scores != nullptr; + float* scores_ptr = nullptr; + if (store_scores) + { + scores_ptr = ¶ms.attention_prior_scores[batch_beam_idx * params.attention_prior_lookahead]; + } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); @@ -1849,7 +1861,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske { relative_attention_bias = convert_to_float(relative_attention_bias_ptr[tlength]); } - if (has_attention_mask && tidx == 0) + if (has_attention_mask && tidx == 0 && !DO_CROSS_ATTENTION) { // Note: reuse the relative_attention_bias variable. // attention_mask = 1.0 means that the position is not masked. @@ -2055,7 +2067,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske { relative_attention_bias = convert_to_float(relative_attention_bias_ptr[local_time_now]); } - if (is_active && has_attention_mask) + if (is_active && has_attention_mask && !DO_CROSS_ATTENTION) { // Note: reuse the relative_attention_bias variable. // attention_mask = 1.0 means that the position is not masked. @@ -2268,13 +2280,30 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske float inv_sum = __fdividef(logit_scale, sum + 1.e-6f); int const normlization_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length; + float sum_rescale = 0.0f; for (int ti = tidx; ti <= normlization_loop_end; ti += THREADS_PER_BLOCK) { int const time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti; if (!MULTI_BLOCK_FLAG) { - convert_from_float(&logits_smem[ti], qk_smem[ti] * inv_sum); + float prob = qk_smem[ti] * inv_sum; + if (DO_CROSS_ATTENTION && params.attention_prior_focus != nullptr) + { + // do the masking to the prob + if (ti < (focus - params.attention_prior_window_left) + || ti > (focus + params.attention_prior_window_right)) + { + prob *= 0.1f; + } + // store back + qk_smem[ti] = prob; + sum_rescale += prob; + } + else + { + convert_from_float(&logits_smem[ti], prob); + } } else { @@ -2290,6 +2319,27 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske } } + // for the case when we apply prior, we need to perform additional normalization, + // dividing by the sum of the modified probs. + __syncthreads(); + if (!MULTI_BLOCK_FLAG && DO_CROSS_ATTENTION && params.attention_prior_focus != nullptr) + { + sum_rescale = block_sum(&red_smem[WARPS_PER_BLOCK], sum_rescale); + + // finally loop to compute probability, store probability to buffer if needed + float inv_sum_rescale = __fdividef(1.0f, sum_rescale + 1.e-6f); + for (int ti = tidx; ti <= kv_loop_length; ti += THREADS_PER_BLOCK) + { + float prob = qk_smem[ti] * inv_sum_rescale; + if (store_scores && ti >= focus && ti < focus + params.attention_prior_lookahead) + { + scores_ptr[ti - focus] = prob; + } + convert_from_float(&logits_smem[ti], prob); + } + __syncthreads(); + } + // Put Values part below so we leverage __syncthreads // from the previous step