forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkvCacheManager.h
More file actions
1992 lines (1622 loc) · 82.4 KB
/
kvCacheManager.h
File metadata and controls
1992 lines (1622 loc) · 82.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/batch_manager/kvCacheConnector.h"
#include "tensorrt_llm/batch_manager/kvCacheEventManager.h"
#include "tensorrt_llm/batch_manager/kvCacheType.h"
#include "tensorrt_llm/batch_manager/llmRequest.h" // TODO forward declare
#include "tensorrt_llm/common/optionalRef.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/executor/transferAgent.h"
#include "tensorrt_llm/kernels/kvCacheIndex.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/worldConfig.h"
#include <NvInferRuntime.h>
#include <array>
#include <cstdint>
#include <limits>
#include <list>
#include <memory>
#include <optional>
#include <set>
#include <unordered_map>
#include <utility>
#include <vector>
namespace kvc = tensorrt_llm::executor::kv_cache;
namespace tensorrt_llm::batch_manager::eviction_policy
{
class BaseEvictionPolicy;
} // namespace tensorrt_llm::batch_manager::eviction_policy
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
static constexpr SizeType32 kPrimaryLevel = 0;
static constexpr SizeType32 kSecondaryLevel = 1;
// Extra block buffer allocated for SWA to be able to always keep "window size"
// tokens held in the blocks.
static constexpr SizeType32 kSWAExtraBlock = 1;
class KVCacheBlock;
class BlockManager;
class KVCacheManager;
class KVCacheTransferManager;
using SizeType32 = tensorrt_llm::runtime::SizeType32;
using TokenIdType = tensorrt_llm::runtime::TokenIdType;
using VecTokens = std::vector<TokenIdType>;
using BeamTokens = std::vector<VecTokens>;
using BlockPtr = std::shared_ptr<KVCacheBlock>;
using FreeBlocksQueue = std::list<BlockPtr>;
using UniqueToken = tensorrt_llm::runtime::UniqueToken;
using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens;
using LoraTaskIdType = tensorrt_llm::runtime::LoraTaskIdType;
using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;
using CacheSaltIDType = tensorrt_llm::runtime::CacheSaltIDType;
// Type alias for multimodal hash key (hash array + start offset)
using MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>;
template <typename T>
using OptionalRef = tensorrt_llm::common::OptionalRef<T>;
//! \brief Split vector into list of blocks of given size.
//! \param vec vector to split
//! \param usableSize part of the vector that is processed
//! \param elementsPerBlock desired size of blocks
//! \param allowPartial whether to append a block smaller than `elementsPerBlock` at the end
//! \return list of blocks
template <typename T>
std::list<std::vector<T>> chopVectorIntoBlocks(
std::vector<T> const& vec, SizeType32 usableSize, SizeType32 elementsPerBlock, bool allowPartial)
{
TLLM_CHECK_WITH_INFO(
usableSize <= static_cast<SizeType32>(vec.size()), "usableSize=%d > %ld=vec.size()", usableSize, vec.size());
std::list<std::vector<T>> blockedVectors;
auto const vecEnd = vec.begin() + usableSize;
for (auto begin = vec.begin(); begin < vecEnd; begin += elementsPerBlock)
{
auto blockSize = std::min(elementsPerBlock, static_cast<SizeType32>(std::distance(begin, vecEnd)));
auto end = begin + blockSize;
if (blockSize == elementsPerBlock || allowPartial)
{
blockedVectors.emplace_back(begin, end);
}
}
return blockedVectors;
}
struct TempAttentionWindowInputs
{
bool pagedContextFMHA;
SizeType32 maxInputLen;
SizeType32 maxNumTokens;
};
struct WindowSizeMetadata
{
SizeType32 allottedPrimaryBlocks; // Number of primary blocks allotted to the windowSize
SizeType32 allottedSecondaryBlocks; // Number of secondary blocks allotted to the windowSize
SizeType32 absolutePoolsOffset; // cumulative number of pools up to manager
SizeType32 numPools; // number of managed pools
SizeType32 maxTokenNum; // Maximum token length per sequence (TODO: account for streamLLM)
SizeType32 maxBlocksPerSeq; // Maximum number of blocks per sequence
SizeType32 maxNumBlocks; // Number of primary+secondary blocks allotted to the windowSize
SizeType32 temporaryAttentionWindow; // Temporary kv cache length per sequence.
// Only needed when chunked context + sliding window attention are used
// together. And it should only be considered when allocating blocks.
SizeType32 windowSize;
bool isSWA;
std::string toString()
{
return tensorrt_llm::common::fmtstr(
"WindowSizeMetadata{ .allottedPrimaryBlocks=%d, .allottedSecondaryBlocks=%d, .absolutePoolsOffset=%d, "
".numPools=%d, .maxTokenNum=%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d, "
".windowSize=%d, .isSWA=%d }",
allottedPrimaryBlocks, allottedSecondaryBlocks, absolutePoolsOffset, numPools, maxTokenNum, maxBlocksPerSeq,
maxNumBlocks, temporaryAttentionWindow, windowSize, isSWA);
}
};
std::vector<MmKey> generateBlockHashExtraKeys(
tensorrt_llm::batch_manager::LlmRequest const& llmRequest, SizeType32 startTokenIdx, SizeType32 endTokenIdx);
struct BlockKey
{
bool usesExtraIds = false;
std::optional<LoraTaskIdType> loraTaskId = std::nullopt;
VecUniqueTokens uniqueTokens;
// Extra keys for multimodal data (similar to VLLM's approach)
// Each extra key is a pair of (mm_hash, start_offset_in_block)
std::vector<MmKey> extraKeys;
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt;
BlockKey() = default;
explicit BlockKey(VecTokens const& tokens, std::optional<LoraTaskIdType> loraTaskId = std::nullopt)
: loraTaskId{loraTaskId}
{
uniqueTokens.reserve(tokens.size());
for (auto const& token : tokens)
{
uniqueTokens.push_back(UniqueToken{token, 0});
}
}
explicit BlockKey(bool usesExtraIds, std::optional<LoraTaskIdType> loraTaskId, VecUniqueTokens uniqueTokens,
std::vector<MmKey> extraKeys = {}, std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: usesExtraIds{usesExtraIds}
, loraTaskId{loraTaskId}
, uniqueTokens{std::move(uniqueTokens)}
, extraKeys{std::move(extraKeys)}
, cacheSaltID{cacheSaltID}
{
}
bool operator==(BlockKey const& other) const noexcept;
int partialMatch(BlockKey const& other) const noexcept
{
SizeType32 numMatched{0};
if (loraTaskId == other.loraTaskId && extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID)
{
auto [matchEnd, otherMatchEnd] = std::mismatch(
uniqueTokens.begin(), uniqueTokens.end(), other.uniqueTokens.begin(), other.uniqueTokens.end());
numMatched = std::distance(uniqueTokens.begin(), matchEnd);
}
return numMatched;
}
};
std::vector<BlockKey> buildBlockKeys(std::list<VecUniqueTokens>& blockedUniqueTokens, LlmRequest const& llmRequest);
// Implement hash functor for BlockKey.
// This allows us to use unordered_map with BlockKey as key.
// Based on https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector/72073933#72073933
struct BlockKeyHasher
{
[[nodiscard]] static size_t hash(BlockKey const& blockKey, std::size_t parentHash = 0) noexcept;
std::size_t operator()(BlockKey const& blockKey, std::size_t parentHash = 0) const noexcept
{
return hash(blockKey, parentHash);
}
};
using NextBlockMap = std::unordered_map<BlockKey, BlockPtr, BlockKeyHasher>;
struct KvCacheStats
{
// Number of maximum available blocks in the primary memory pool. This is determined and set by available primary
// memory. See calculateMaxNumBlocks for details.
SizeType32 maxNumBlocks;
// Number of free blocks in the primary memory pool.
SizeType32 freeNumBlocks;
// Number of used blocks in the primary memory pool. usedNumBlocks = maxNumBlocks - freeNumBlocks.
SizeType32 usedNumBlocks;
SizeType32 toksPerBlock;
// Total number of blocks allocated by all requests.
SizeType32 allocTotalBlocks;
// Number of new blocks that were allocated.
SizeType32 allocNewBlocks;
// Number of blocks that were matched and reused.
SizeType32 reusedBlocks;
// Number of blocks that were not matched and not reused.
SizeType32 missedBlocks;
// Measuring the KV Cache reuse rate. cacheHitRate = reusedBlocks / (reusedBlocks + missedBlocks).
float cacheHitRate;
// Number of free blocks for every configured attention-window size.
std::map<SizeType32, SizeType32> numFreeBlocksPerWindowSize;
// GPU bytes allocated for KV-cache
std::size_t allocatedBytes{};
};
// Basic building block of a paged KV cache - a single
// cache block. This class just holds metadata, no pointers
// since it is reused across all layers.
class KVCacheBlock
{
public:
using IdType = std::int32_t;
static constexpr IdType kCachedBlocksRootId = -1;
explicit KVCacheBlock(IdType blockId, kernels::KVCacheIndex blockIdx);
void startScheduling();
[[nodiscard]] IdType getBlockId() const;
[[nodiscard]] NextBlockMap getNextBlocks() const;
[[nodiscard]] kernels::KVCacheIndex::UnderlyingType getMemoryPoolBlockIndex() const;
[[nodiscard]] bool isPrimary() const;
void swapMemoryPoolBlockOffset(std::shared_ptr<KVCacheBlock> otherBlock);
void incRefCount();
void decRefCount();
void decSchedulingRefCount();
[[nodiscard]] bool hasRefs() const;
[[nodiscard]] bool hasSchedulingRefs() const;
void setBlockKey(BlockKey const& blockKey, bool isFull);
BlockKey getBlockKey();
[[nodiscard]] VecUniqueTokens const& getUniqueTokens() const;
BlockPtr const& getPrevBlock() const;
void setPrevBlock(BlockPtr prevBlock);
BlockPtr const& getPrevBlockInSeq() const;
void setPrevBlockInSeq(BlockPtr prevBlock);
void addNextBlock(BlockKey const& blockKey, BlockPtr block);
void removeNextBlock(BlockKey const& blockKey);
//! \brief Find block matching blockKey. If allowPartial is true, the returned block may match only a prefix of
//! blockKey.
//! @return tuple of [partialMatch, numMatched, block], partialMatch is true if not all the tokens of the block were
//! matched.
[[nodiscard]] std::tuple<bool, SizeType32, BlockPtr> findMatchingBlock(
BlockKey const& blockKey, bool enablePartialReuse, bool copyOnPartialReuse) const;
//! \brief Free block from previous block if present.
void freeLeafBlock();
[[nodiscard]] bool isFull() const;
[[nodiscard]] bool isShared() const;
[[nodiscard]] bool isLeaf() const;
void setPriority(executor::RetentionPriority priority);
[[nodiscard]] executor::RetentionPriority getPriority() const;
void setDurationMs(std::optional<std::chrono::milliseconds> durationMs);
[[nodiscard]] std::optional<std::chrono::milliseconds> getDurationMs() const;
void setExpirationTime(std::optional<std::chrono::steady_clock::time_point::duration> expirationTime);
[[nodiscard]] std::optional<std::chrono::steady_clock::time_point::duration> getExpirationTime() const;
void setHash(size_t hash);
// set hash automatically from block key and previous block in sequence
void setHash();
size_t getHash() const;
private:
// Linear ID of block independent of pool
IdType mBlockId;
// Index of block in memory pool backing this block
// Choice of pool is encoded into the type
kernels::KVCacheIndex mMemoryPoolBlockIndex;
// Number of references to the block
SizeType32 mRefCount;
// Number of references to the block
SizeType32 mSchedulingRefCount;
// Key of this block in mNextBlocks map in block pointed to by mPrevBlock
BlockKey mBlockKey;
// Previous block in reuse tree, or nullptr if not reusing
BlockPtr mPrevBlock;
// Previous block in sequence, == nullptr for first block, == mPrevBlock if reusing and not first
BlockPtr mPrevBlockInSeq;
// Next block(s) in sequence(s)
NextBlockMap mNextBlocks;
// Iterator pointing to this block in mFreeBlocks.
std::optional<FreeBlocksQueue::iterator> mFreeBlockIterator;
// Flag indicating if block is full
bool mIsFull;
// Priority of the block
executor::RetentionPriority mPriority;
// Duration that the block's priority level applies for
std::optional<std::chrono::milliseconds> mDurationMs;
// Expiration time of the block
std::optional<std::chrono::steady_clock::time_point::duration> mExpirationTime;
// Hash for the event manager
size_t mHash;
};
class GenerationRequest
{
public:
using SizeType32 = tensorrt_llm::runtime::SizeType32;
explicit GenerationRequest(LlmRequest::RequestIdType requestId, SizeType32 numTokens, SizeType32 beamWidth,
std::map<SizeType32, WindowSizeMetadata> const& windowSizeToMetadata,
executor::KvCacheRetentionConfig kvCacheRetentionConfig = executor::KvCacheRetentionConfig())
: mRequestId(requestId)
, mNumTokens(numTokens)
, mBeamWidth(beamWidth)
, mKvCacheRetentionConfig(std::move(kvCacheRetentionConfig))
, mNumFrontBlocksRemoved(0)
{
auto const numWindowSizes = windowSizeToMetadata.size();
mCacheBlockIds.reserve(numWindowSizes);
mCacheBlockIndices.reserve(numWindowSizes);
for (auto const [windowSize, metadata] : windowSizeToMetadata)
{
mCacheBlockIds[windowSize] = std::vector<std::vector<KVCacheBlock::IdType>>(beamWidth);
auto const numPools = metadata.numPools;
auto const maxBlocks = metadata.maxBlocksPerSeq;
mCacheBlockIndices[windowSize]
= runtime::BufferManager::cpu(runtime::ITensor::makeShape({numPools, beamWidth, 2, maxBlocks}),
runtime::TRTDataType<tensorrt_llm::kernels::KVCacheIndex>::value);
auto cacheBlockIdsRange
= runtime::BufferRange<tensorrt_llm::kernels::KVCacheIndex>(*mCacheBlockIndices.at(windowSize));
std::fill(cacheBlockIdsRange.begin(), cacheBlockIdsRange.end(),
tensorrt_llm::kernels::KVCacheIndex{
std::numeric_limits<tensorrt_llm::kernels::KVCacheIndex::UnderlyingType>::max()});
}
}
void addNewTokens(SizeType32 n)
{
mNumTokens += n;
}
void removeTokens(SizeType32 n)
{
TLLM_CHECK(n <= mNumTokens);
TLLM_CHECK(mNumTokens - n >= 0);
mNumTokens -= n;
}
[[nodiscard]] LlmRequest::RequestIdType getRequestId() const
{
return mRequestId;
}
[[nodiscard]] SizeType32 getNumTokens() const
{
return mNumTokens;
}
[[nodiscard]] SizeType32 getNumFrontBlocksRemoved() const
{
return mNumFrontBlocksRemoved;
}
[[nodiscard]] SizeType32 getBeamWidth() const
{
return mBeamWidth;
}
[[nodiscard]] std::vector<std::vector<SizeType32>> const& getCacheBlockIds(SizeType32 windowSize) const
{
return mCacheBlockIds.at(windowSize);
}
[[nodiscard]] runtime::ITensor& getCacheBlockIndices(SizeType32 windowSize)
{
return *(mCacheBlockIndices.at(windowSize));
}
[[nodiscard]] runtime::ITensor const& getCacheBlockIndices(SizeType32 windowSize) const
{
return *(mCacheBlockIndices.at(windowSize));
}
void addCacheBlock(SizeType32 windowSize, SizeType32 beamIdx, KVCacheBlock::IdType blockId)
{
mCacheBlockIds.at(windowSize).at(beamIdx).push_back(blockId);
}
void changeCacheBlock(
SizeType32 windowSize, SizeType32 beamIdx, SizeType32 pagedBlockIdx, KVCacheBlock::IdType blockId)
{
mCacheBlockIds.at(windowSize).at(beamIdx).at(pagedBlockIdx) = blockId;
}
void clearCacheBlocks(SizeType32 windowSize)
{
for (auto& beamBlockIds : mCacheBlockIds.at(windowSize))
{
beamBlockIds.clear();
}
mNumFrontBlocksRemoved = 0;
}
void removeFrontBlock(SizeType32 windowSize)
{
++mNumFrontBlocksRemoved;
}
void removeLastBlock(SizeType32 windowSize)
{
for (auto& beamBlockIds : mCacheBlockIds.at(windowSize))
{
beamBlockIds.pop_back();
}
}
[[nodiscard]] executor::RetentionPriority getDecodeRetentionPriority() const
{
return mKvCacheRetentionConfig.getDecodeRetentionPriority();
}
[[nodiscard]] std::optional<std::chrono::milliseconds> getDecodeDurationMs() const
{
return mKvCacheRetentionConfig.getDecodeDurationMs();
}
[[nodiscard]] executor::KvCacheTransferMode getTransferMode() const
{
return mKvCacheRetentionConfig.getTransferMode();
}
[[nodiscard]] std::string const& getDirectory() const
{
return mKvCacheRetentionConfig.getDirectory();
}
private:
// Request id of the sequence
LlmRequest::RequestIdType mRequestId;
// Current number of generated tokens
SizeType32 mNumTokens;
// Number of beams
SizeType32 mBeamWidth;
// List of block ids allocated per each window size, for each beam of the sequence
std::unordered_map<SizeType32, std::vector<std::vector<KVCacheBlock::IdType>>> mCacheBlockIds;
// Tensor of block indices allocated per each window size, for each beam of the sequence
std::unordered_map<SizeType32, runtime::ITensor::SharedPtr> mCacheBlockIndices;
// The retention priority to assign to decode blocks
executor::KvCacheRetentionConfig mKvCacheRetentionConfig;
// Number of front blocks removed from the sequence
SizeType32 mNumFrontBlocksRemoved;
// Set of used blocks by the sequence
std::set<KVCacheBlock::IdType> mUsedBlocks;
};
// attach metadata to a pool pointer
class KVCacheBlockPool
{
public:
SizeType32 numLayers;
SizeType32 kvFactor;
SizeType32 numKvHeads;
SizeType32 sizePerHead;
SizeType32 tokensPerBlock;
SizeType32 blockSize;
// Memory pools. Primary is fast memory, secondary is slower memory used for offloading.
runtime::ITensor::SharedPtr primaryPtr;
runtime::ITensor::SharedPtr secondaryPtr;
// FP4 KV caches have extra pools that contain second level scales for dequantization.
bool containsBlockScales;
bool containsIndexerKCache;
KVCacheBlockPool(SizeType32 numLayers, SizeType32 kvFactor, SizeType32 numKvHeads, SizeType32 sizePerHead,
SizeType32 tokensPerBlock, runtime::ITensor::SharedPtr primaryPtr = nullptr,
runtime::ITensor::SharedPtr secondaryPtr = nullptr, bool containsBlockScales = false,
bool containsIndexerKCache = false)
: numLayers(numLayers)
, kvFactor(kvFactor)
, numKvHeads(numKvHeads)
, sizePerHead(sizePerHead)
, tokensPerBlock(tokensPerBlock)
, blockSize(numKvHeads * sizePerHead * tokensPerBlock)
, primaryPtr(std::move(primaryPtr))
, secondaryPtr(std::move(secondaryPtr))
, containsBlockScales(containsBlockScales)
, containsIndexerKCache(containsIndexerKCache)
{
}
};
// The WindowBlockManager manages the metadata of KVCacheBlocks.
// It manages multiple arrays of cache blocks called pools.
// Layers with the same number of kv heads are grouped under the same pool.
// Each pool has shape [max_blocks, num_layers, 2, num_kv_heads, tokens_pre_block, head_size], where num_layers refers
// to the number of layers with the same num_kv_heads that share that pool.
// The metadata of KVCacheBlocks is shared between layers, so each block spans all of the managed pool - an allocated
// block matches some chunk of memory in each pool. The shape of the chunk in every pool is [2, num_kv_heads,
// tokens_per_block, head_size]. The size per block and number of blocks are pre-determined and set in the constructor.
// WindowBlockManager maintains a list of free blocks at any time.
//
// FP4 KV caches allocate additional pools for block scale factors. These pools have the same
// shape as the regular KV pools, except that the the last dim is head_size / N where N is determined
// by the precise FP4 format being used (16 for NVFP4). There is one block scale pool per normal pool.
//
// BlockManager maintains a list of free blocks at any time.
// Alloc pops off the block at the front, and Free pushes it back to the vector.
// WindowBlockManager maintains a vector of lists of request ids to allocated blocks
// per sequence. This can be used to Free all blocks belonging to a sequence.
class WindowBlockManager
{
public:
using SizeType32 = tensorrt_llm::runtime::SizeType32;
using CacheType = tensorrt_llm::batch_manager::kv_cache_manager::CacheType;
using BaseEvictionPolicy = tensorrt_llm::batch_manager::eviction_policy::BaseEvictionPolicy;
using BlockMap = std::unordered_multimap<size_t, BlockPtr>;
using BlockMapIterRange = std::pair<BlockMap::const_iterator, BlockMap::const_iterator>;
explicit WindowBlockManager(nvinfer1::DataType dtype, SizeType32 windowSize,
std::vector<SizeType32> const& managedLayers, std::vector<SizeType32> const& numKvHeadsPerLayer,
SizeType32 sizePerHead, SizeType32 tokensPerBlock, bool isSWA, SizeType32 blocksInPrimaryPool,
SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager,
std::shared_ptr<kvc::BaseLoopbackAgent> loopbackAgent = nullptr, bool enableIndexerKCache = false,
SizeType32 indexerKCacheQuantBlockSize = 128, SizeType32 indexerKCacheIndexHeadDim = 0);
~WindowBlockManager();
void allocatePools(bool useUvm);
void releasePools();
void createIndexerKCachePools();
void startScheduling();
//! \brief Assign blocks for new sequence. Try to reuse blocks.
void addSequence(
GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest);
//! \brief Assign blocks for new sequence. Does not try to reuse blocks.
void addSequence(GenerationRequest& sequence, SizeType32 numContextBlocks, bool isShareLastContextBlock);
//! \brief Allocate new block for each beam of the sequence.
//! \details Might free cached blocks if no free blocks are available.
void allocateBlock(GenerationRequest& sequence, bool shareAmongBeams);
void replaceSharedBlock(GenerationRequest& sequence, SizeType32 blockIdx);
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false);
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
//! \brief Pin blocks associated with a sequence to prevent eviction.
void pinBlocks(GenerationRequest& sequence);
//! \brief Release blocks of the sequence.
//! \details When llmRequest is provided and reuse is enabled, blocks will be stored.
std::optional<KVCacheBlock::IdType> releaseBlocks(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
//! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);
//! \brief Update cache offsets for last block
void updateLastCacheBlockOffsets(GenerationRequest& seq);
//! \brief Release last block in the sequence
void releaseLastBlock(GenerationRequest& sequence);
//! \brief Detach front block from the sequence
void detachFrontBlock(GenerationRequest& sequence);
//! \brief Add/detach block(s) to/from the sequence if needed
//! \details When we need a new block, we add it. For sliding window
//! attention (SWA), when a block goes out-of-window (OOW), we detach it
//! If this called in the first step of the generation phase, we may detach
//! more than a single block since there may be more than one context block
//! that goes OOW.
void adjustBlocksIfNeeded(GenerationRequest& sequence);
[[nodiscard]] SizeType32 getWindowSize() const noexcept
{
return mWindowSize;
}
[[nodiscard]] std::string const& getLogPrefix() const noexcept
{
return mLogPrefix;
}
[[nodiscard]] SizeType32 getNumFreeBlocks() const noexcept;
[[nodiscard]] SizeType32 getNumAllocTotalBlocks() const
{
return mAllocTotalBlocks;
}
[[nodiscard]] SizeType32 getNumAllocNewBlocks() const
{
return mAllocNewBlocks;
}
[[nodiscard]] SizeType32 getNumReusedBlocks() const noexcept
{
return mReusedBlocks;
}
[[nodiscard]] SizeType32 getNumAllocatedBlocks() const noexcept
{
return getMaxNumBlocks() - getNumFreeBlocks();
}
[[nodiscard]] SizeType32 getNumMissedBlocks() const noexcept
{
return mMissedBlocks;
}
[[nodiscard]] bool hasFreeBlocks(SizeType32 numRequired = 1) const noexcept
{
return getNumFreeBlocks() >= numRequired;
}
[[nodiscard]] bool schedulingHasFreeBlocks(SizeType32 numRequired) const noexcept
{
return mSchedulingNumFreeBlocks >= numRequired;
}
[[nodiscard]] SizeType32 getMaxNumBlocks() const noexcept
{
return static_cast<SizeType32>(mAllBlocksById.size());
}
[[nodiscard]] BlockPtr const& getBlockById(KVCacheBlock::IdType blockId) const
{
return mAllBlocksById.at(blockId);
}
[[nodiscard]] SizeType32 getTokensPerBlock() const noexcept
{
return mTokensPerBlock;
}
//! \brief Get size of one K/V cache block in one layer for the specified pool.
//! @details Volume of [numKvHeads, tokensPerBlock, sizePerHead] in the specified pool.
[[nodiscard]] SizeType32 getBlockSize(SizeType32 poolIdx) const
{
return mPools.at(poolIdx).blockSize;
}
[[nodiscard]] SizeType32 getNumEltsPerContainer() const
{
#ifdef ENABLE_FP4
return mDataType == nvinfer1::DataType::kFP4 ? 2 : 1;
#else
return 1;
#endif
}
[[nodiscard]] SizeType32 getNumPools(
bool includeBlockScalePools = true, bool includeIndexerKCachePools = true) const noexcept
{
if (includeBlockScalePools && includeIndexerKCachePools)
{
return mPools.size();
}
SizeType32 count = 0;
for (auto const& pool : mPools)
{
if (includeBlockScalePools && pool.containsBlockScales)
{
count++;
}
else if (includeIndexerKCachePools && pool.containsIndexerKCache)
{
count++;
}
if (!pool.containsBlockScales && !pool.containsIndexerKCache)
{
count++;
}
}
return count;
}
[[nodiscard]] KVCacheBlockPool const& getPool(SizeType32 poolIdx) const
{
return mPools.at(poolIdx);
}
[[nodiscard]] bool containsBlockScales(SizeType32 poolIdx) const
{
return mPools.at(poolIdx).containsBlockScales;
}
[[nodiscard]] SizeType32 getNumPrimaryBlocks() const
{
return mNumPrimaryBlocks;
}
[[nodiscard]] SizeType32 getNumSecondaryBlocks() const
{
return mNumSecondaryBlocks;
}
[[nodiscard]] SizeType32 getLayerPoolIdx(SizeType32 layerIdx) const
{
return mLayerToPoolIndex.at(layerIdx);
}
//! \brief Maps a global layer index to its layer index within its pool.
//! \details If we only have one pool, then getPoolLayerIdx(i) == i. Otherwise,
//! \details gives the layer index into the getLayerPoolIdx(i).
[[nodiscard]] SizeType32 getPoolLayerIdx(SizeType32 layerIdx) const
{
return mLayerToIndexWithinPool.at(layerIdx);
}
void setOffsets(kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 beamIdx,
SizeType32 blockIdx, KVCacheBlock::IdType blockId) const;
//! \brief Bring offloaded block from secondary to primary memory.
//! \details Does nothing if block is already in primary memory.
void onboardBlock(GenerationRequest& sequence, BlockPtr const& offloadBlock,
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");
//! \brief Bring block from primary to secondary memory.
//! \details Does nothing if block is already in secondary memory.
void offloadBlock(BlockPtr const& block, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM,
std::string const& directory = "");
//! \brief Find first new block that must be allocated for context phase and return it's concatenated token vectors.
//! \details Only full blocks are considered.
[[nodiscard]] std::optional<BlockKey> findNewContextBlock(
VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const;
[[nodiscard]] runtime::BufferManager const& getBufferManager() const
{
return mBufferManager;
}
//! \brief Sync internal streams used by transfer manager with buffer manager stream
void syncTransferManagerWithBufferManager();
//! \brief Perform per-request bookkeeping
void refreshBlocks();
[[nodiscard]] static bool blockInRadixTree(BlockPtr const& block);
//! \brief Store blocks in cached blocks.
//! \param blockKeys Key of each block.
//! \param blockIds Id of each block.
//! \param pinBlocks If true, increment ref count for blocks while storing (pin on store).
//! \return Pair of (num blocks stored for reuse, id of the last block stored if any).
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks(
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
bool pinBlocks = false);
[[nodiscard]] bool verifyQueueIntegrity();
// Only needed when sliding window attention + paged context fmha are used together.
// In that case, a temporary kv cache buffer with maximum chunk size (maxNumTokens) is needed.
// TODO: There are several things that can be improved later.
// 1. a dynamic temporary kv cache allocation based on real chunk size might be needed.
// 2. reuse the same temporary kv cache buffer among all layers in the same pool.
[[nodiscard]] SizeType32 calculateTemporaryAttentionWindow(
std::optional<TempAttentionWindowInputs> const& inputs) const
{
if (inputs && inputs->pagedContextFMHA && (inputs->maxInputLen > mWindowSize))
{
auto window = std::min(inputs->maxNumTokens, inputs->maxInputLen - mWindowSize);
window = std::max(window, 0); // clamp negative values to 0
return window;
}
return 0;
}
//! \brief Return whether this window is SWA.
[[nodiscard]] bool isSWA() const
{
return mIsSWA;
}
[[nodiscard]] std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey);
//! \brief Unpin blocks by starting from a block id and walking prev pointers.
void unpinBlocksById(KVCacheBlock::IdType blockId);
void initializeSequenceStorageValidity(LlmRequest::RequestIdType requestId)
{
mIsValidStoreForReuseSequence[requestId] = true;
}
void releaseSequenceStorageValidity(LlmRequest::RequestIdType requestId)
{
mIsValidStoreForReuseSequence.erase(requestId);
}
//! \brief Return whether this sequence is valid for store for reuse
[[nodiscard]] bool isSequenceValidForStoreForReuse(LlmRequest::RequestIdType requestId) const
{
TLLM_CHECK_WITH_INFO(mIsValidStoreForReuseSequence.count(requestId) > 0, "Sequence should be bookkeeped");
return mIsValidStoreForReuseSequence.at(requestId);
}
void resetReuseState()
{
std::lock_guard<std::mutex> lock(mCachedBlocksRootMutex);
mCachedBlocksRoot
= std::make_shared<KVCacheBlock>(KVCacheBlock::kCachedBlocksRootId, tensorrt_llm::kernels::KVCacheIndex{0});
}
private:
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
//! \brief Add single block to all beams of sequence.
void addBlockToAllBeams(BlockPtr& block, GenerationRequest& sequence);
//! \brief Try to load blocks from cache. Allocate new blocks if necessary.
//! \param blockKeys Key of each block.
//! \param sequence Sequence to which blocks are assigned.
//! \return Number of matched tokens from loaded blocks.
SizeType32 loadOrAllocateBlocks(std::vector<BlockKey> const& blockKeys, SizeType32 numContextBlocks,
GenerationRequest& sequence, std::vector<executor::RetentionPriorityAndDuration> const& perBlockRetentions,
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");
//! \brief Free block and all it's descendants. This makes block a claimed leaf block.
void freeChildren(BlockPtr const& block);
//! \brief Find block least likely to be reused, free it if necessary and return.
//! \param sequence Sequence which the free block is allocated for
[[nodiscard]] BlockPtr getFreeBlock(GenerationRequest& sequence,
executor::RetentionPriority = executor::KvCacheRetentionConfig::kDefaultRetentionPriority,
std::optional<std::chrono::milliseconds> durationMs = std::nullopt,
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");
//! \brief Calls KVCacheBlock::freeLeafBlock to remove block from search tree.
void freeLeafBlock(BlockPtr const& block);
//! \brief For FP4 quantization. Creates pool objects for FP4 block scalars.
void createBlockScalePools(SizeType32 blockSize);
private:
nvinfer1::DataType mDataType;
SizeType32 mWindowSize;
// Number of blocks in pools
SizeType32 mNumPrimaryBlocks;
SizeType32 mNumSecondaryBlocks;
// List of allocated blocks for each sequences
std::unordered_map<LlmRequest::RequestIdType, std::vector<BlockPtr>> mAllocatedBlocksPerSeq;
// Pool per unique numKvHeads in the model
std::vector<KVCacheBlockPool> mPools;
// Matching layers to their respective pools: {<layer #0>: <pool idx 2>, }, etc.
std::unordered_map<SizeType32, SizeType32> mLayerToPoolIndex;
// Matching layers to their index *within* their respective pools: {..., <layer 3>: <idx 2 within pool> }. See
// getPoolLayerIdx
std::unordered_map<SizeType32, SizeType32> mLayerToIndexWithinPool;
// Whether offloaded blocks should be onboarded before reuse.
bool mOnboardBlocks;
// Buffer manager
runtime::BufferManager mBufferManager;
// Used to keep track of number of free blocks during scheduling
SizeType32 mSchedulingNumFreeBlocks;
// Number of tokens per one block
SizeType32 mTokensPerBlock;
// Whether this window is sliding window attention/full attention
bool mIsSWA;
// List of all blocks by idx
std::vector<BlockPtr> mAllBlocksById;
// Dummy block acting as root for BlockToken searches
BlockPtr mCachedBlocksRoot;
// KV cache type (self or cross)
CacheType mCacheType;
// Eviction Policy
std::shared_ptr<BaseEvictionPolicy> mEvictionPolicy;
// Event manager
std::shared_ptr<KVCacheEventManager> mEventManager;
// Pointer to parent loopback agent
std::shared_ptr<kvc::BaseLoopbackAgent> mLoopbackAgent;
// Transfer manager
std::shared_ptr<KVCacheTransferManager> mTransferManager;
// Statistics for block allocations/reuse
// Total number of blocks allocated by all requests
SizeType32 mAllocTotalBlocks;
// Number of new blocks that were allocated
SizeType32 mAllocNewBlocks;
// Number of blocks that were reused
SizeType32 mReusedBlocks;
// Number of unique blocks that were reused
SizeType32 mReusedUniqueBlocks;
// Number of blocks that were not reused
SizeType32 mMissedBlocks;
// Only be 1 or 2. If 2: general KV stored. If 1: K == V for any token, so only K is stored to optimize the
// max_num_tokens(For DeepSeek). Controlled by mCacheType
SizeType32 mKVFactor;
std::set<KVCacheBlock::IdType> reusedBlockIds;
std::string const mLogPrefix;
// Number of reused tokens
double mReusedTokens;
// Total number of input tokens
double mTotalInputTokens;
// Whether blocks that are partially matched should be reused.
bool mEnablePartialReuse;
// Whether partially matched blocks that are already in use should be copied and reused.
bool mCopyOnPartialReuse;
// The kv cache connector manager
std::shared_ptr<kv_connector::KvCacheConnectorManager> mKvCacheConnectorManager;
// Mutex for the cached blocks root
std::mutex mCachedBlocksRootMutex;
// Record which sequence is using the block
std::map<KVCacheBlock::IdType, LlmRequest::RequestIdType> mBlockToSequence;
// Record whether a sequence has all blocks held valid.
// The boolean value is set to true upon first encounter of a new sequence.
// It may be invalidated to false when other sequence acquires a block that
// is used by another sequence.
std::map<LlmRequest::RequestIdType, bool> mIsValidStoreForReuseSequence;
// Whether to enable indexer K cache
bool mEnableIndexerKCache;