forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkvCacheTransferManager.cpp
More file actions
346 lines (296 loc) · 15.1 KB
/
kvCacheTransferManager.cpp
File metadata and controls
346 lines (296 loc) · 15.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstdint>
#include "tensorrt_llm/batch_manager/kvCacheTransferManager.h"
#include "tensorrt_llm/batch_manager/kvCacheEventManager.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/kernels/kvCachePartialCopy.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/cudaStream.h"
namespace tr = tensorrt_llm::runtime;
namespace tk = tensorrt_llm::kernels;
namespace kvc = tensorrt_llm::executor::kv_cache;
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
static bool gpuToFilePosix(tr::ITensor::SharedPtr const& srcPtr, std::string const& filename)
{
int fd = ::open(filename.c_str(), O_CREAT | O_WRONLY, 0664);
TLLM_CHECK_WITH_INFO(fd >= 0, "Failed to open '%s' for writing (POSIX fallback)", filename.c_str());
ssize_t numBytes = static_cast<ssize_t>(srcPtr->getSizeInBytes());
std::vector<uint8_t> hostBuffer(numBytes);
cudaError_t cpyErr = cudaMemcpy(hostBuffer.data(), srcPtr->data(), numBytes, cudaMemcpyDeviceToHost);
TLLM_CHECK_WITH_INFO(cpyErr == cudaSuccess, "cudaMemcpy to host failed, error=%d", cpyErr);
ssize_t written = ::write(fd, hostBuffer.data(), numBytes);
TLLM_CHECK_WITH_INFO(written >= 0, "POSIX write error=%zd", written);
TLLM_LOG_DEBUG("Wrote %zd bytes to %s (POSIX fallback)", written, filename.c_str());
::close(fd);
return true;
}
static bool fileToGpuPosix(tr::ITensor::SharedPtr const& dstPtr, std::string const& filename)
{
int fd = ::open(filename.c_str(), O_RDONLY);
TLLM_CHECK_WITH_INFO(fd >= 0, "Failed to open '%s' for reading (POSIX fallback)", filename.c_str());
ssize_t numBytes = static_cast<ssize_t>(dstPtr->getSizeInBytes());
std::vector<uint8_t> hostBuffer(numBytes);
ssize_t bytesRead = ::read(fd, hostBuffer.data(), numBytes);
TLLM_CHECK_WITH_INFO(bytesRead >= 0, "POSIX read error=%zd", bytesRead);
TLLM_LOG_DEBUG("Read %zd bytes from %s (POSIX fallback)", bytesRead, filename.c_str());
cudaError_t cpyErr = cudaMemcpy(dstPtr->data(), hostBuffer.data(), numBytes, cudaMemcpyHostToDevice);
TLLM_CHECK_WITH_INFO(cpyErr == cudaSuccess, "cudaMemcpy to device failed, error=%d", cpyErr);
::close(fd);
return true;
}
KVCacheTransferManager::KVCacheTransferManager(
tr::BufferManager const& bufferManager, std::shared_ptr<kvc::BaseLoopbackAgent> loopbackAgent)
: mBufferManager{bufferManager}
, mOnboardManager(std::make_shared<tr::CudaStream>())
, mOffloadManager(std::make_shared<tr::CudaStream>())
, mLoopbackAgent{loopbackAgent}
{
TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId));
TLLM_CHECK(mDeviceId != -1);
}
tr::ITensor::SharedPtr KVCacheTransferManager::computeBlockPointer(
BlockPtr const& block, std::vector<KVCacheBlockPool> const& pools, size_t poolIdx)
{
TLLM_CHECK_WITH_INFO(!pools.empty(), "Pool index %lu is out of bounds", poolIdx);
auto const& pool = pools.at(poolIdx);
auto ptr = block->isPrimary() ? pool.primaryPtr : pool.secondaryPtr;
auto const blockOffset = block->getMemoryPoolBlockIndex();
tr::ITensor::SharedPtr blockTensor{tr::ITensor::slice(ptr, blockOffset, 1)};
return blockTensor;
}
void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst,
std::vector<KVCacheBlockPool> const& pools, bool isOffload, int numTokensToCopy, executor::KvCacheTransferMode mode,
std::string const& directory)
{
TLLM_LOG_DEBUG("copyBlock entered: srcId=%d, dstId=%d, isOffload=%s, mode=%d", src->getBlockId(), dst->getBlockId(),
(isOffload ? "true" : "false"), static_cast<int>(mode));
if (mode == executor::KvCacheTransferMode::DRAM)
{
TLLM_LOG_DEBUG("Using DRAM-based copy (GPU <-> CPU) for this block.");
// Iterate over all pools, partial-copy logic
for (size_t poolIdx = 0; poolIdx < pools.size(); ++poolIdx)
{
auto srcPtr = computeBlockPointer(src, pools, poolIdx);
auto dstPtr = computeBlockPointer(dst, pools, poolIdx);
// If no partial tokens or if the dataType is not supported for partial copy, copy entire block.
if (numTokensToCopy <= 0 || srcPtr->getDataType() == nvinfer1::DataType::kINT4
|| srcPtr->getDataType() == nvinfer1::DataType::kFP4)
{
// For partial copy not implemented with these data types,
// just do a full copy.
(isOffload ? mOffloadManager : mOnboardManager).copy(*srcPtr, *dstPtr);
}
else
{
int const tokensPerBlock = pools[poolIdx].tokensPerBlock;
if (numTokensToCopy >= tokensPerBlock)
{
// If requested tokens >= entire block, just do a full copy.
(isOffload ? mOffloadManager : mOnboardManager).copy(*srcPtr, *dstPtr);
}
else
{
auto stream = (isOffload ? mOffloadManager : mOnboardManager).getStream().get();
int const numLayers = pools[poolIdx].numLayers;
int const kvFactor = pools[poolIdx].kvFactor;
int const numHeads = pools[poolIdx].numKvHeads;
int const sizePerHead = pools[poolIdx].sizePerHead;
auto shape = srcPtr->getShape();
TLLM_CHECK_WITH_INFO(
shape.nbDims == 4, "Expected KVCache block to have 4 dims, got %d", shape.nbDims);
tk::kvCacheBlockPartialCopy(*dstPtr, *srcPtr, numLayers, numHeads, tokensPerBlock, sizePerHead,
numTokensToCopy, kvFactor, stream);
}
}
}
TLLM_LOG_DEBUG("copyBlock: DRAM mode complete. Returning...");
return;
}
std::vector<kvc::FileDesc> fileBlobs;
std::vector<kvc::MemoryDesc> memoryBlobs;
for (size_t poolIdx = 0; poolIdx < pools.size(); ++poolIdx)
{
auto ptr = isOffload ? computeBlockPointer(src, pools, poolIdx) : computeBlockPointer(dst, pools, poolIdx);
auto block_id = src->getBlockId();
TLLM_CHECK_WITH_INFO(
!directory.empty(), "Expected a directory path for KVCache offload, but none was provided.");
int size = std::snprintf(nullptr, 0, "%s/block_%d_pool_%zu.bin", directory.c_str(), block_id, poolIdx);
std::string filename;
filename.resize(size + 1);
std::snprintf(
filename.data(), filename.size(), "%s/block_%d_pool_%zu.bin", directory.c_str(), block_id, poolIdx);
if (mode == executor::KvCacheTransferMode::POSIX_DEBUG_FALLBACK)
{
TLLM_LOG_INFO("Forcing POSIX fallback for file: %s", filename.c_str());
if (isOffload)
{
gpuToFilePosix(ptr, filename);
}
else
{
fileToGpuPosix(ptr, filename);
}
continue;
}
else if (mode == executor::KvCacheTransferMode::GDS)
{
int openFlags = isOffload ? (O_CREAT | O_WRONLY) : O_RDONLY;
fileBlobs.emplace_back(filename, openFlags, 0664, ptr->getSizeInBytes());
memoryBlobs.emplace_back(ptr->data(), ptr->getSizeInBytes(), mDeviceId);
}
}
if (mode == executor::KvCacheTransferMode::GDS)
{
if (mLoopbackAgent == nullptr)
{
TLLM_LOG_DEBUG("KVCacheTransferManager: creating mLoopbackAgent lazily");
kvc::BaseAgentConfig config{std::string("GDSAgent"), true, true};
mLoopbackAgent = kvc::makeLoopbackAgent("nixl", &config);
}
kvc::FileDescs fileDescs(std::move(fileBlobs));
kvc::MemoryDescs memoryDescs(kvc::MemoryType::kVRAM, memoryBlobs);
mLoopbackAgent->executeLoopbackRequest(memoryDescs, fileDescs, isOffload);
}
}
//
// Note about recording events to wait for cudaMempyAsync calls between blocks:
// The memory copy involves raw memory blocks, which are pointed to by the
// memory pool block index. When recording events, you must use getMemoryPoolBlockIndex()
// as the raw memory block identifier. Using getBlockId() when recording events is wrong.
// getBlockId() returns the logical block id, which has nothing to do with the raw memory
// block pointers involved in a cudaMemcpy.
//
//
// Notes about need for synchronization:
//
// Relying on decoder syncing GPU with CPU to ensure that blocks are ready
// for offload/onboard/partial copy is dangerous. We have an asynchronous decoder
// that may not synchronize or synchronize at a later point in the execution stream.
// To avoid synchronization issues caused by changes to decoder design we rely on
// KVCacheTransferManager::syncWithBufferManager() that ensures that internal copy streams
// will wait for prefill and decode kernels that have already been scheduled.
//
// Earlier versions of this code did not account for all possible cases where a new block copy
// needed to wait for a previously scheduled copy to finish. For instance, it is possible
// that two primary blocks are offloaded to the same secondary block in a single step,
// scheduling the second offloading without waiting for the first one to finish leads to
// a corrupted block after offloading. It is possible that partial reuse will copy
// from a block that is currently being onboarded, scheduling the partial copy without
// waiting for the onboarding to finish will lead to a corrupted block. To handle all
// possible cases needing synchronization we record separate events for reads and writes
// to a block. When a new block copy is scheduled, we wait for all writes to the source
// block and all reads and writes to a destination block.
//
// As before, syncTransfers() must be called after last call to KVCacheManager::addSequence.
// Failing to do so will lead to corrupted blocks eventually.
//
void KVCacheTransferManager::onboard(BlockPtr const& offloadedBlock, BlockPtr const& block,
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
std::string const& directory)
{
// Wait for any pending writes before reading from offloadedBlock
auto offloadedBlockPendingWriteItr = mPendingWrites.find(offloadedBlock->getMemoryPoolBlockIndex());
if (offloadedBlockPendingWriteItr != mPendingWrites.end())
{
mOnboardManager.getStream().wait(offloadedBlockPendingWriteItr->second);
// Don't erase, we are not changing state of offloadedBlock
}
// Wait for any pending reads before overwriting block
auto blockPendingReadItr = mPendingReads.find(block->getMemoryPoolBlockIndex());
if (blockPendingReadItr != mPendingReads.end())
{
mOnboardManager.getStream().wait(blockPendingReadItr->second);
mPendingReads.erase(blockPendingReadItr);
}
// Wait for any pending writes before overwriting block
auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex());
if (blockPendingWriteItr != mPendingWrites.end())
{
mOnboardManager.getStream().wait(blockPendingWriteItr->second);
mPendingWrites.erase(blockPendingWriteItr);
}
copyBlock(offloadedBlock, block, pools, false, numTokensToCopy, mode, directory);
// Record new pending read from offloadedBlock
mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent();
mOnboardManager.getStream().record(mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()]);
// Record new pending write to block
mPendingWrites[block->getMemoryPoolBlockIndex()] = tr::CudaEvent();
mOnboardManager.getStream().record(mPendingWrites[block->getMemoryPoolBlockIndex()]);
}
void KVCacheTransferManager::offload(BlockPtr const& block, BlockPtr const& offloadBlock,
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
std::string const& directory)
{
// Wait for any pending writes before reading from block
auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex());
if (blockPendingWriteItr != mPendingWrites.end())
{
mOffloadManager.getStream().wait(blockPendingWriteItr->second);
// Don't erase, we are not changing state of block
}
// Wait for any pending reads before overwriting offloadBlock
auto offloadBlockPendingReadItr = mPendingReads.find(offloadBlock->getMemoryPoolBlockIndex());
if (offloadBlockPendingReadItr != mPendingReads.end())
{
mOffloadManager.getStream().wait(offloadBlockPendingReadItr->second);
mPendingReads.erase(offloadBlockPendingReadItr);
}
// Wait for any pending writes before overwriting offloadBlock
auto offloadBlockPendingWriteItr = mPendingWrites.find(offloadBlock->getMemoryPoolBlockIndex());
if (offloadBlockPendingWriteItr != mPendingWrites.end())
{
mOffloadManager.getStream().wait(offloadBlockPendingWriteItr->second);
mPendingWrites.erase(offloadBlockPendingWriteItr);
}
copyBlock(block, offloadBlock, pools, true, numTokensToCopy, mode, directory);
// Record new pending read from block
mPendingReads[block->getMemoryPoolBlockIndex()] = tr::CudaEvent();
mOffloadManager.getStream().record(mPendingReads[block->getMemoryPoolBlockIndex()]);
// Record new pending write to offloadBlock
mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent();
mOffloadManager.getStream().record(mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()]);
}
void KVCacheTransferManager::syncWithBufferManager()
{
tr::CudaEvent readyForOffloadEvent;
mBufferManager.getStream().record(readyForOffloadEvent);
mOffloadManager.getStream().wait(readyForOffloadEvent);
tr::CudaEvent readyForOnboardEvent;
mBufferManager.getStream().record(readyForOnboardEvent);
mOnboardManager.getStream().wait(readyForOnboardEvent);
// Once we synchronize, clear our list of pending thransfers.
mPendingReads.clear();
mPendingWrites.clear();
}
void KVCacheTransferManager::syncTransfers()
{
tr::CudaEvent offloadEvent;
mOffloadManager.getStream().record(offloadEvent);
mBufferManager.getStream().wait(offloadEvent);
tr::CudaEvent onboardEvent;
mOnboardManager.getStream().record(onboardEvent);
mBufferManager.getStream().wait(onboardEvent);
// Once we synchronize, clear our list of pending thransfers.
mPendingReads.clear();
mPendingWrites.clear();
}
} // namespace tensorrt_llm::batch_manager::kv_cache_manager