Skip to content

Commit 08b2735

Browse files
sxufacebook-github-bot
authored andcommitted
Static attention: support local-global attention (pytorch#13043)
Summary: Pull Request resolved: pytorch#13043 Runtime: support different cache lengths for different layer. Python: add sliding window cache update which was already in the runtime. Reviewed By: billmguo Differential Revision: D79267644
1 parent cf2f170 commit 08b2735

File tree

3 files changed

+236
-116
lines changed

3 files changed

+236
-116
lines changed

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 90 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include <algorithm>
1010
#include <memory>
1111
#include <numeric>
12-
#include <tuple>
1312
#include <unordered_map>
1413
#include <vector>
1514

@@ -38,14 +37,13 @@ class StaticKVCache {
3837
* caches.
3938
*/
4039
StaticKVCache(
41-
size_t n_caches,
42-
size_t cache_len,
40+
const std::vector<size_t>& cache_lengths,
4341
size_t head_dim,
44-
size_t max_input_len = 1,
45-
size_t n_heads_per_cache = 1,
42+
size_t max_input_len,
43+
size_t n_heads_per_cache,
4644
StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK)
47-
: n_caches_(n_caches),
48-
cache_len_(n_caches_, cache_len),
45+
: n_caches_(cache_lengths.size()),
46+
cache_lengths_(cache_lengths),
4947
cache_pos_(n_caches_, 0),
5048
max_input_len_(max_input_len),
5149
n_heads_per_cache_(n_heads_per_cache),
@@ -54,7 +52,7 @@ class StaticKVCache {
5452
input_ptrs_(n_caches_),
5553
output_ptrs_(n_caches_) {
5654
size_t total_cache_len =
57-
std::accumulate(cache_len_.begin(), cache_len_.end(), 0);
55+
std::accumulate(cache_lengths_.begin(), cache_lengths_.end(), 0);
5856
cache_data_size_ = total_cache_len * n_heads_per_cache_ * head_dim_;
5957
update_data_size_ =
6058
n_caches_ * n_heads_per_cache_ * max_input_len_ * head_dim_;
@@ -83,12 +81,12 @@ class StaticKVCache {
8381
*/
8482
void prepare(
8583
torch::executor::Method& method,
86-
const std::vector<size_t>& inputIndices,
84+
const std::vector<size_t>& input_indices,
8785
const std::vector<size_t>& output_indices) {
88-
ET_CHECK(inputIndices.size() == output_indices.size());
86+
ET_CHECK(input_indices.size() == output_indices.size());
8987
auto methodMeta = method.method_meta();
9088
for (size_t i = 0; i < n_caches_; i++) {
91-
auto inIdx = inputIndices[i];
89+
auto inIdx = input_indices[i];
9290
auto outIdx = output_indices[i];
9391
auto inMeta = methodMeta.input_tensor_meta(inIdx);
9492
auto outMeta = methodMeta.output_tensor_meta(outIdx);
@@ -126,7 +124,7 @@ class StaticKVCache {
126124
ET_CHECK_MSG(inSizes[ndim - 1] == head_dim_, "KV head dim mismatch.");
127125
ET_CHECK_MSG(outSizes[ndim - 1] == head_dim_, "KV head dim mismatch.");
128126
ET_CHECK_MSG(
129-
inSizes[ndim - 2] == cache_len_[i], "Cache length dim mismatch.");
127+
inSizes[ndim - 2] == cache_lengths_[i], "Cache length dim mismatch.");
130128

131129
auto impl = ::executorch::runtime::etensor::TensorImpl(
132130
inMeta->scalar_type(),
@@ -167,7 +165,7 @@ class StaticKVCache {
167165
update_n,
168166
update_pos,
169167
input_ptrs_[i],
170-
cache_len_[i],
168+
cache_lengths_[i],
171169
cache_pos_[i]);
172170
}
173171
}
@@ -187,7 +185,7 @@ class StaticKVCache {
187185
size_t cache_data_offset = 0;
188186
for (size_t i = 0; i < n_caches_; i++) {
189187
input_ptrs_[i] = cache_data_ + cache_data_offset;
190-
cache_data_offset += cache_len_[i] * n_heads_per_cache_ * head_dim_;
188+
cache_data_offset += cache_lengths_[i] * n_heads_per_cache_ * head_dim_;
191189
output_ptrs_[i] =
192190
update_data_ + i * n_heads_per_cache_ * max_input_len_ * head_dim_;
193191
}
@@ -217,24 +215,25 @@ class StaticKVCache {
217215
update_head + (update_pos + update_n) * head_dim_,
218216
cache_head + cache_pos * head_dim_);
219217
}
220-
cache_pos += update_n;
218+
cache_pos = (cache_pos + update_n) % cache_len;
221219

222220
if (wrap_n > 0) {
221+
ET_CHECK(cache_pos == 0);
223222
return update_one_cache(
224223
update,
225224
update_len,
226225
wrap_n,
227226
update_pos + contiguous_n,
228227
cache,
229228
cache_len,
230-
0);
229+
cache_pos);
231230
}
232231

233232
return cache_pos;
234233
}
235234

236235
size_t n_caches_;
237-
std::vector<size_t> cache_len_;
236+
std::vector<size_t> cache_lengths_;
238237
std::vector<size_t> cache_pos_;
239238
size_t max_input_len_;
240239
size_t n_heads_per_cache_;
@@ -415,11 +414,11 @@ class StaticAttentionIOManager {
415414
public:
416415
struct StaticAttentionIOConfig {
417416
size_t n_caches{};
418-
size_t cache_len{};
417+
std::vector<size_t> cache_lengths{};
419418
size_t head_dim{};
420419
size_t max_input_len{};
421420
size_t n_heads_per_cache{};
422-
size_t attn_mask_input_index{};
421+
std::unordered_map<size_t, size_t> cache_len_to_mask_idx;
423422
size_t rope_freqs_cos_input_index{};
424423
size_t rope_freqs_sin_input_index{};
425424
std::vector<size_t> k_cache_input_indices;
@@ -433,50 +432,55 @@ class StaticAttentionIOManager {
433432

434433
StaticAttentionIOManager(StaticAttentionIOConfig config)
435434
: config_(std::move(config)),
436-
kCaches_(
437-
config_.n_caches,
438-
config_.cache_len,
435+
k_caches_(
436+
config_.cache_lengths,
439437
config_.head_dim,
440438
config_.max_input_len,
441439
config_.n_heads_per_cache,
442440
config_.style),
443-
vCaches_(
444-
config_.n_caches,
445-
config_.cache_len,
441+
v_caches_(
442+
config_.cache_lengths,
446443
config_.head_dim,
447444
config_.max_input_len,
448445
config_.n_heads_per_cache,
449446
config_.style) {
450447
ET_LOG(
451448
Info,
452-
"Created StaticAttentionIOManager with"
453-
" max input length = %zu cache length = %zu",
454-
config_.max_input_len,
455-
config_.cache_len);
449+
"Created StaticAttentionIOManager with max input length = %zu",
450+
config_.max_input_len);
451+
for (auto cache_len : config_.cache_lengths) {
452+
ET_LOG(Info, "Cache length = %zu", cache_len);
453+
}
456454
}
457455

456+
using PerCacheLenMasks = std::vector<std::pair<
457+
size_t,
458+
std::unique_ptr<StaticAttentionMask<MaskT, MaskAllocatorT>>>>;
459+
458460
/**
459-
* Create a new StaticAttentionMask that will be managed by this object.
461+
* Create a new StaticAttentionMask for each cache length used.
460462
*/
461-
StaticAttentionMask<MaskT, MaskAllocatorT>&
462-
add_mask(size_t input_len, MaskT zero_val, MaskT mask_val) {
463-
auto it = attentionMasks_.emplace(
464-
std::piecewise_construct,
465-
std::forward_as_tuple(input_len),
466-
std::forward_as_tuple(
467-
config_.cache_len,
468-
input_len,
469-
config_.head_dim,
470-
zero_val,
471-
mask_val,
472-
config_.style));
463+
PerCacheLenMasks& add_mask(size_t input_len, MaskT zero_val, MaskT mask_val) {
464+
PerCacheLenMasks masks;
465+
for (auto& pair : config_.cache_len_to_mask_idx) {
466+
masks.emplace_back(
467+
pair.first,
468+
std::make_unique<StaticAttentionMask<MaskT, MaskAllocatorT>>(
469+
pair.first,
470+
input_len,
471+
config_.head_dim,
472+
zero_val,
473+
mask_val,
474+
config_.style));
475+
}
476+
auto it = attentionMasks_.emplace(input_len, std::move(masks));
473477
return it.first->second;
474478
}
475479

476480
/**
477481
* Retrieve a mask suitable for given input length.
478482
*/
479-
StaticAttentionMask<MaskT, MaskAllocatorT>& get_mask(size_t input_len) {
483+
PerCacheLenMasks& get_mask(size_t input_len) {
480484
return attentionMasks_.at(input_len);
481485
}
482486

@@ -487,9 +491,9 @@ class StaticAttentionIOManager {
487491
torch::executor::Method& method,
488492
std::optional<const executorch::runtime::Span<size_t>> pos_offsets =
489493
std::nullopt) {
490-
kCaches_.prepare(
494+
k_caches_.prepare(
491495
method, config_.k_cache_input_indices, config_.k_cache_output_indices);
492-
vCaches_.prepare(
496+
v_caches_.prepare(
493497
method, config_.v_cache_input_indices, config_.v_cache_output_indices);
494498

495499
size_t rope_dim = config_.head_dim / 2;
@@ -538,12 +542,14 @@ class StaticAttentionIOManager {
538542
size_t update_len,
539543
size_t cache_update_pos = 0) {
540544
input_pos_ += update_len;
541-
kCaches_.update(
545+
k_caches_.update(
542546
method, k_cache_output_indices, update_len, cache_update_pos);
543-
vCaches_.update(
547+
v_caches_.update(
544548
method, v_cache_output_indices, update_len, cache_update_pos);
545549
for (auto& it : attentionMasks_) {
546-
it.second.unmask(update_len);
550+
for (auto& mask : it.second) {
551+
mask.second->unmask(update_len);
552+
}
547553
}
548554
}
549555

@@ -552,10 +558,12 @@ class StaticAttentionIOManager {
552558
*/
553559
void reset() {
554560
input_pos_ = 0;
555-
kCaches_.reset();
556-
vCaches_.reset();
561+
k_caches_.reset();
562+
v_caches_.reset();
557563
for (auto& it : attentionMasks_) {
558-
it.second.reset();
564+
for (auto& mask : it.second) {
565+
mask.second->reset();
566+
}
559567
}
560568
}
561569

@@ -570,7 +578,12 @@ class StaticAttentionIOManager {
570578
executorch::runtime::Span<TokenT> input_buffer,
571579
executorch::runtime::Method& method) {
572580
size_t input_len = input_buffer.size();
573-
get_mask(input_buffer.size()).set_causal_mask();
581+
auto& masks = get_mask(input_buffer.size());
582+
for (auto& pair : masks) {
583+
auto& mask = *pair.second;
584+
mask.set_causal_mask();
585+
set_input(method, config_.cache_len_to_mask_idx[pair.first], mask.get());
586+
}
574587

575588
size_t batch_len = 0;
576589
for (size_t i = 0; i < tokens.size(); i += input_len) {
@@ -600,8 +613,12 @@ class StaticAttentionIOManager {
600613
std::function<TokenT(executorch::runtime::Method&)>& sample,
601614
std::function<bool(TokenT)>& should_stop) {
602615
set_input(method, 0, input_buffer.data());
603-
auto& mask = get_mask(input_buffer.size());
604-
set_input(method, config_.attn_mask_input_index, mask.get());
616+
auto& masks = get_mask(input_buffer.size());
617+
for (auto& pair : masks) {
618+
auto& mask = *pair.second;
619+
mask.set_causal_mask();
620+
set_input(method, config_.cache_len_to_mask_idx[pair.first], mask.get());
621+
}
605622

606623
std::vector<TokenT> generated_tokens;
607624
while (true) {
@@ -642,10 +659,18 @@ class StaticAttentionIOManager {
642659
size_t input_len = input_buffer.size();
643660

644661
// Set up attention mask for current input length.
645-
auto& mask = get_mask(input_buffer.size());
646-
set_lookahead_decoding_mask(
647-
mask, input_len, ngram_size, window_size, n_verifications);
648-
set_input(method, config_.attn_mask_input_index, mask.get());
662+
auto& masks = get_mask(input_buffer.size());
663+
for (auto& pair : masks) {
664+
auto& mask = *pair.second;
665+
set_lookahead_decoding_mask(
666+
mask,
667+
input_len,
668+
pair.first,
669+
ngram_size,
670+
window_size,
671+
n_verifications);
672+
set_input(method, config_.cache_len_to_mask_idx[pair.first], mask.get());
673+
}
649674

650675
// Position offsets relative to current position, for indexing RoPE
651676
// frequence tensors.
@@ -793,12 +818,14 @@ class StaticAttentionIOManager {
793818
const_cast<executorch::aten::TensorImpl::DimOrderType*>(
794819
inputMeta->dim_order().data()));
795820
executorch::runtime::etensor::Tensor t(&impl);
821+
ET_CHECK(data != nullptr);
796822
ET_CHECK(method.set_input(t, idx) == executorch::runtime::Error::Ok);
797823
}
798824

799825
void set_lookahead_decoding_mask(
800826
StaticAttentionMask<MaskT, MaskAllocatorT>& mask,
801827
size_t input_len,
828+
size_t cache_len,
802829
size_t ngram_size,
803830
size_t window_size,
804831
size_t n_verifications) {
@@ -815,8 +842,8 @@ class StaticAttentionIOManager {
815842
size_t stride_;
816843
};
817844

818-
size_t stride = config_.cache_len + input_len;
819-
auto input_submask = SubMask(mask.get() + config_.cache_len, stride);
845+
size_t stride = cache_len + input_len;
846+
auto input_submask = SubMask(mask.get() + cache_len, stride);
820847
input_submask.at(0, 0) = mask.zero_val();
821848

822849
// Fill entire input mask first.
@@ -895,10 +922,9 @@ class StaticAttentionIOManager {
895922

896923
StaticAttentionIOConfig config_;
897924
size_t input_pos_ = 0;
898-
StaticKVCache<CacheT, CacheAllocatorT> kCaches_;
899-
StaticKVCache<CacheT, CacheAllocatorT> vCaches_;
900-
std::unordered_map<size_t, StaticAttentionMask<MaskT, MaskAllocatorT>>
901-
attentionMasks_;
925+
StaticKVCache<CacheT, CacheAllocatorT> k_caches_;
926+
StaticKVCache<CacheT, CacheAllocatorT> v_caches_;
927+
std::unordered_map<size_t, PerCacheLenMasks> attentionMasks_;
902928
std::vector<RopeT> rope_freqs_cos_override_;
903929
std::vector<RopeT> rope_freqs_sin_override_;
904930
};

0 commit comments

Comments
 (0)