99#include < algorithm>
1010#include < memory>
1111#include < numeric>
12- #include < tuple>
1312#include < unordered_map>
1413#include < vector>
1514
@@ -38,14 +37,13 @@ class StaticKVCache {
3837 * caches.
3938 */
4039 StaticKVCache (
41- size_t n_caches,
42- size_t cache_len,
40+ const std::vector<size_t >& cache_lengths,
4341 size_t head_dim,
44- size_t max_input_len = 1 ,
45- size_t n_heads_per_cache = 1 ,
42+ size_t max_input_len,
43+ size_t n_heads_per_cache,
4644 StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK)
47- : n_caches_(n_caches ),
48- cache_len_ (n_caches_, cache_len ),
45+ : n_caches_(cache_lengths.size() ),
46+ cache_lengths_ (cache_lengths ),
4947 cache_pos_(n_caches_, 0 ),
5048 max_input_len_(max_input_len),
5149 n_heads_per_cache_(n_heads_per_cache),
@@ -54,7 +52,7 @@ class StaticKVCache {
5452 input_ptrs_(n_caches_),
5553 output_ptrs_(n_caches_) {
5654 size_t total_cache_len =
57- std::accumulate (cache_len_ .begin (), cache_len_ .end (), 0 );
55+ std::accumulate (cache_lengths_ .begin (), cache_lengths_ .end (), 0 );
5856 cache_data_size_ = total_cache_len * n_heads_per_cache_ * head_dim_;
5957 update_data_size_ =
6058 n_caches_ * n_heads_per_cache_ * max_input_len_ * head_dim_;
@@ -83,12 +81,12 @@ class StaticKVCache {
8381 */
8482 void prepare (
8583 torch::executor::Method& method,
86- const std::vector<size_t >& inputIndices ,
84+ const std::vector<size_t >& input_indices ,
8785 const std::vector<size_t >& output_indices) {
88- ET_CHECK (inputIndices .size () == output_indices.size ());
86+ ET_CHECK (input_indices .size () == output_indices.size ());
8987 auto methodMeta = method.method_meta ();
9088 for (size_t i = 0 ; i < n_caches_; i++) {
91- auto inIdx = inputIndices [i];
89+ auto inIdx = input_indices [i];
9290 auto outIdx = output_indices[i];
9391 auto inMeta = methodMeta.input_tensor_meta (inIdx);
9492 auto outMeta = methodMeta.output_tensor_meta (outIdx);
@@ -126,7 +124,7 @@ class StaticKVCache {
126124 ET_CHECK_MSG (inSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
127125 ET_CHECK_MSG (outSizes[ndim - 1 ] == head_dim_, " KV head dim mismatch." );
128126 ET_CHECK_MSG (
129- inSizes[ndim - 2 ] == cache_len_ [i], " Cache length dim mismatch." );
127+ inSizes[ndim - 2 ] == cache_lengths_ [i], " Cache length dim mismatch." );
130128
131129 auto impl = ::executorch::runtime::etensor::TensorImpl (
132130 inMeta->scalar_type (),
@@ -167,7 +165,7 @@ class StaticKVCache {
167165 update_n,
168166 update_pos,
169167 input_ptrs_[i],
170- cache_len_ [i],
168+ cache_lengths_ [i],
171169 cache_pos_[i]);
172170 }
173171 }
@@ -187,7 +185,7 @@ class StaticKVCache {
187185 size_t cache_data_offset = 0 ;
188186 for (size_t i = 0 ; i < n_caches_; i++) {
189187 input_ptrs_[i] = cache_data_ + cache_data_offset;
190- cache_data_offset += cache_len_ [i] * n_heads_per_cache_ * head_dim_;
188+ cache_data_offset += cache_lengths_ [i] * n_heads_per_cache_ * head_dim_;
191189 output_ptrs_[i] =
192190 update_data_ + i * n_heads_per_cache_ * max_input_len_ * head_dim_;
193191 }
@@ -217,24 +215,25 @@ class StaticKVCache {
217215 update_head + (update_pos + update_n) * head_dim_,
218216 cache_head + cache_pos * head_dim_);
219217 }
220- cache_pos += update_n;
218+ cache_pos = (cache_pos + update_n) % cache_len ;
221219
222220 if (wrap_n > 0 ) {
221+ ET_CHECK (cache_pos == 0 );
223222 return update_one_cache (
224223 update,
225224 update_len,
226225 wrap_n,
227226 update_pos + contiguous_n,
228227 cache,
229228 cache_len,
230- 0 );
229+ cache_pos );
231230 }
232231
233232 return cache_pos;
234233 }
235234
236235 size_t n_caches_;
237- std::vector<size_t > cache_len_ ;
236+ std::vector<size_t > cache_lengths_ ;
238237 std::vector<size_t > cache_pos_;
239238 size_t max_input_len_;
240239 size_t n_heads_per_cache_;
@@ -415,11 +414,11 @@ class StaticAttentionIOManager {
415414 public:
416415 struct StaticAttentionIOConfig {
417416 size_t n_caches{};
418- size_t cache_len {};
417+ std::vector< size_t > cache_lengths {};
419418 size_t head_dim{};
420419 size_t max_input_len{};
421420 size_t n_heads_per_cache{};
422- size_t attn_mask_input_index{} ;
421+ std::unordered_map< size_t , size_t > cache_len_to_mask_idx ;
423422 size_t rope_freqs_cos_input_index{};
424423 size_t rope_freqs_sin_input_index{};
425424 std::vector<size_t > k_cache_input_indices;
@@ -433,50 +432,55 @@ class StaticAttentionIOManager {
433432
434433 StaticAttentionIOManager (StaticAttentionIOConfig config)
435434 : config_(std::move(config)),
436- kCaches_ (
437- config_.n_caches,
438- config_.cache_len,
435+ k_caches_ (
436+ config_.cache_lengths,
439437 config_.head_dim,
440438 config_.max_input_len,
441439 config_.n_heads_per_cache,
442440 config_.style),
443- vCaches_(
444- config_.n_caches,
445- config_.cache_len,
441+ v_caches_(
442+ config_.cache_lengths,
446443 config_.head_dim,
447444 config_.max_input_len,
448445 config_.n_heads_per_cache,
449446 config_.style) {
450447 ET_LOG (
451448 Info,
452- " Created StaticAttentionIOManager with"
453- " max input length = %zu cache length = %zu" ,
454- config_.max_input_len ,
455- config_.cache_len );
449+ " Created StaticAttentionIOManager with max input length = %zu" ,
450+ config_.max_input_len );
451+ for (auto cache_len : config_.cache_lengths ) {
452+ ET_LOG (Info, " Cache length = %zu" , cache_len);
453+ }
456454 }
457455
456+ using PerCacheLenMasks = std::vector<std::pair<
457+ size_t ,
458+ std::unique_ptr<StaticAttentionMask<MaskT, MaskAllocatorT>>>>;
459+
458460 /* *
459- * Create a new StaticAttentionMask that will be managed by this object .
461+ * Create a new StaticAttentionMask for each cache length used .
460462 */
461- StaticAttentionMask<MaskT, MaskAllocatorT>&
462- add_mask (size_t input_len, MaskT zero_val, MaskT mask_val) {
463- auto it = attentionMasks_.emplace (
464- std::piecewise_construct,
465- std::forward_as_tuple (input_len),
466- std::forward_as_tuple (
467- config_.cache_len ,
468- input_len,
469- config_.head_dim ,
470- zero_val,
471- mask_val,
472- config_.style ));
463+ PerCacheLenMasks& add_mask (size_t input_len, MaskT zero_val, MaskT mask_val) {
464+ PerCacheLenMasks masks;
465+ for (auto & pair : config_.cache_len_to_mask_idx ) {
466+ masks.emplace_back (
467+ pair.first ,
468+ std::make_unique<StaticAttentionMask<MaskT, MaskAllocatorT>>(
469+ pair.first ,
470+ input_len,
471+ config_.head_dim ,
472+ zero_val,
473+ mask_val,
474+ config_.style ));
475+ }
476+ auto it = attentionMasks_.emplace (input_len, std::move (masks));
473477 return it.first ->second ;
474478 }
475479
476480 /* *
477481 * Retrieve a mask suitable for given input length.
478482 */
479- StaticAttentionMask<MaskT, MaskAllocatorT> & get_mask (size_t input_len) {
483+ PerCacheLenMasks & get_mask (size_t input_len) {
480484 return attentionMasks_.at (input_len);
481485 }
482486
@@ -487,9 +491,9 @@ class StaticAttentionIOManager {
487491 torch::executor::Method& method,
488492 std::optional<const executorch::runtime::Span<size_t >> pos_offsets =
489493 std::nullopt ) {
490- kCaches_ .prepare (
494+ k_caches_ .prepare (
491495 method, config_.k_cache_input_indices , config_.k_cache_output_indices );
492- vCaches_ .prepare (
496+ v_caches_ .prepare (
493497 method, config_.v_cache_input_indices , config_.v_cache_output_indices );
494498
495499 size_t rope_dim = config_.head_dim / 2 ;
@@ -538,12 +542,14 @@ class StaticAttentionIOManager {
538542 size_t update_len,
539543 size_t cache_update_pos = 0 ) {
540544 input_pos_ += update_len;
541- kCaches_ .update (
545+ k_caches_ .update (
542546 method, k_cache_output_indices, update_len, cache_update_pos);
543- vCaches_ .update (
547+ v_caches_ .update (
544548 method, v_cache_output_indices, update_len, cache_update_pos);
545549 for (auto & it : attentionMasks_) {
546- it.second .unmask (update_len);
550+ for (auto & mask : it.second ) {
551+ mask.second ->unmask (update_len);
552+ }
547553 }
548554 }
549555
@@ -552,10 +558,12 @@ class StaticAttentionIOManager {
552558 */
553559 void reset () {
554560 input_pos_ = 0 ;
555- kCaches_ .reset ();
556- vCaches_ .reset ();
561+ k_caches_ .reset ();
562+ v_caches_ .reset ();
557563 for (auto & it : attentionMasks_) {
558- it.second .reset ();
564+ for (auto & mask : it.second ) {
565+ mask.second ->reset ();
566+ }
559567 }
560568 }
561569
@@ -570,7 +578,12 @@ class StaticAttentionIOManager {
570578 executorch::runtime::Span<TokenT> input_buffer,
571579 executorch::runtime::Method& method) {
572580 size_t input_len = input_buffer.size ();
573- get_mask (input_buffer.size ()).set_causal_mask ();
581+ auto & masks = get_mask (input_buffer.size ());
582+ for (auto & pair : masks) {
583+ auto & mask = *pair.second ;
584+ mask.set_causal_mask ();
585+ set_input (method, config_.cache_len_to_mask_idx [pair.first ], mask.get ());
586+ }
574587
575588 size_t batch_len = 0 ;
576589 for (size_t i = 0 ; i < tokens.size (); i += input_len) {
@@ -600,8 +613,12 @@ class StaticAttentionIOManager {
600613 std::function<TokenT(executorch::runtime::Method&)>& sample,
601614 std::function<bool(TokenT)>& should_stop) {
602615 set_input (method, 0 , input_buffer.data ());
603- auto & mask = get_mask (input_buffer.size ());
604- set_input (method, config_.attn_mask_input_index , mask.get ());
616+ auto & masks = get_mask (input_buffer.size ());
617+ for (auto & pair : masks) {
618+ auto & mask = *pair.second ;
619+ mask.set_causal_mask ();
620+ set_input (method, config_.cache_len_to_mask_idx [pair.first ], mask.get ());
621+ }
605622
606623 std::vector<TokenT> generated_tokens;
607624 while (true ) {
@@ -642,10 +659,18 @@ class StaticAttentionIOManager {
642659 size_t input_len = input_buffer.size ();
643660
644661 // Set up attention mask for current input length.
645- auto & mask = get_mask (input_buffer.size ());
646- set_lookahead_decoding_mask (
647- mask, input_len, ngram_size, window_size, n_verifications);
648- set_input (method, config_.attn_mask_input_index , mask.get ());
662+ auto & masks = get_mask (input_buffer.size ());
663+ for (auto & pair : masks) {
664+ auto & mask = *pair.second ;
665+ set_lookahead_decoding_mask (
666+ mask,
667+ input_len,
668+ pair.first ,
669+ ngram_size,
670+ window_size,
671+ n_verifications);
672+ set_input (method, config_.cache_len_to_mask_idx [pair.first ], mask.get ());
673+ }
649674
650675 // Position offsets relative to current position, for indexing RoPE
651676 // frequence tensors.
@@ -793,12 +818,14 @@ class StaticAttentionIOManager {
793818 const_cast <executorch::aten::TensorImpl::DimOrderType*>(
794819 inputMeta->dim_order ().data ()));
795820 executorch::runtime::etensor::Tensor t (&impl);
821+ ET_CHECK (data != nullptr );
796822 ET_CHECK (method.set_input (t, idx) == executorch::runtime::Error::Ok);
797823 }
798824
799825 void set_lookahead_decoding_mask (
800826 StaticAttentionMask<MaskT, MaskAllocatorT>& mask,
801827 size_t input_len,
828+ size_t cache_len,
802829 size_t ngram_size,
803830 size_t window_size,
804831 size_t n_verifications) {
@@ -815,8 +842,8 @@ class StaticAttentionIOManager {
815842 size_t stride_;
816843 };
817844
818- size_t stride = config_. cache_len + input_len;
819- auto input_submask = SubMask (mask.get () + config_. cache_len , stride);
845+ size_t stride = cache_len + input_len;
846+ auto input_submask = SubMask (mask.get () + cache_len, stride);
820847 input_submask.at (0 , 0 ) = mask.zero_val ();
821848
822849 // Fill entire input mask first.
@@ -895,10 +922,9 @@ class StaticAttentionIOManager {
895922
896923 StaticAttentionIOConfig config_;
897924 size_t input_pos_ = 0 ;
898- StaticKVCache<CacheT, CacheAllocatorT> kCaches_ ;
899- StaticKVCache<CacheT, CacheAllocatorT> vCaches_;
900- std::unordered_map<size_t , StaticAttentionMask<MaskT, MaskAllocatorT>>
901- attentionMasks_;
925+ StaticKVCache<CacheT, CacheAllocatorT> k_caches_;
926+ StaticKVCache<CacheT, CacheAllocatorT> v_caches_;
927+ std::unordered_map<size_t , PerCacheLenMasks> attentionMasks_;
902928 std::vector<RopeT> rope_freqs_cos_override_;
903929 std::vector<RopeT> rope_freqs_sin_override_;
904930};
0 commit comments