forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathallocateKvCache.cpp
More file actions
89 lines (76 loc) · 3.45 KB
/
allocateKvCache.cpp
File metadata and controls
89 lines (76 loc) · 3.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/batch_manager/allocateKvCache.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/nvtxUtils.h"
void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager& kvCacheManager,
RequestVector& contextRequests, RequestVector const& generationRequests, runtime::ModelConfig const& modelConfig,
OptionalRef<BaseKVCacheManager> crossKvCacheManager) const
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(allocateKvCache);
kvCacheManager.syncTransferManagerWithBufferManager();
for (auto const& llmReq : contextRequests)
{
if (llmReq->isFirstContextChunk())
{
auto const requestId = llmReq->mRequestId;
auto const promptLen = llmReq->mPromptLen;
auto const reqBeamWidth = llmReq->mSamplingConfig.beamWidth;
auto draftLength = llmReq->getNumDraftTokens();
// Allocate/Reuse KV cache
kvCacheManager.addSequence(requestId, promptLen, reqBeamWidth, llmReq);
// EagleNet will increment kv cache up to maxPathLen to account for accepted tokens.
// Then up to maxDecodingDraftTokens will be used to generate next draft tokens.
if (modelConfig.getSpeculativeDecodingMode().isEagle())
{
draftLength = modelConfig.getSpeculativeDecodingModule().getMaxPathLen()
+ modelConfig.getSpeculativeDecodingModule().getMaxDecodingTokens();
}
// Allocate more KV cache for speculative decoding
if (draftLength > 0)
{
for (SizeType32 di = 0; di < draftLength; ++di)
{
kvCacheManager.addToken(requestId);
}
}
if (crossKvCacheManager)
{
crossKvCacheManager->addSequence(requestId, llmReq->getEncoderOutputLen(), reqBeamWidth, llmReq);
}
}
}
for (auto const& llmReq : generationRequests)
{
auto const requestId = llmReq->mRequestId;
auto decodingTokens = llmReq->getNumDraftTokens() + 1;
// EagleNet will increment kv cache up to maxPathLen to account for accepted tokens.
// Then up to maxDecodingDraftTokens will be used to generate next draft tokens.
if (modelConfig.getSpeculativeDecodingMode().isEagle())
{
decodingTokens = modelConfig.getSpeculativeDecodingModule().getMaxPathLen()
+ modelConfig.getSpeculativeDecodingModule().getMaxDecodingTokens();
}
for (SizeType32 di = 0; di < decodingTokens; ++di)
{
kvCacheManager.addToken(requestId);
}
}
kvCacheManager.refreshBlocks();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}