Skip to content

Commit bdc4d3b

Browse files
kimishpatelZonglin Peng
authored andcommitted
[ExecuTorch][BE] Split kv cache and SDPA for better code sharing
Differential Revision: D67914054 Pull Request resolved: pytorch#7413
1 parent 87e605a commit bdc4d3b

File tree

19 files changed

+516
-387
lines changed

19 files changed

+516
-387
lines changed

.ci/scripts/test_llama.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ fi
112112

113113
if [[ "${MODE}" =~ .*quantize_kv.* ]]; then
114114
QUANTIZE_KV_CACHE=ON
115+
# quantize_kv cache transform uses custom kv cache update op
116+
CUSTOM=ON
115117
else
116118
QUANTIZE_KV_CACHE=OFF
117119
fi

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def get_custom_quant_ios_dtype(
374374
"""
375375
This function is specific for llama inputs and outputs
376376
"""
377-
if node.op == "placeholder" and "attention_sdpa_kv_cache_past_" in node.name:
377+
if node.op == "placeholder" and "attention_kv_cache_past_" in node.name:
378378
return kv_dtype
379379

380380
# Tag index put node before copy node, because copy is a skipped node in qnn

examples/models/llama/export_llama_lib.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,8 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
667667
# export_to_edge
668668
builder_exported = _prepare_for_llama_export(args).export()
669669

670+
builder_exported.run_canonical_optimizations()
671+
670672
if args.export_only:
671673
exit()
672674

examples/models/llama/llama_transformer.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -232,22 +232,16 @@ def __init__(
232232
max_seq_length: int,
233233
n_heads: int,
234234
head_dim: int,
235-
transpose_cache: bool,
236235
enable_dynamic_shape: bool,
237236
dtype=torch.float32,
238237
):
239238
super().__init__()
240239
self.max_seq_length = max_seq_length
241-
self.is_transposed = transpose_cache
242-
if transpose_cache:
243-
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
244-
else:
245-
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
240+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
246241

247242
self.max_batch_size = max_batch_size
248243
self.n_heads = n_heads
249244
self.head_dim = head_dim
250-
self.transpose_cache = transpose_cache
251245
self.enable_dynamic_shape = enable_dynamic_shape
252246
self.register_buffer(
253247
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
@@ -259,12 +253,12 @@ def __init__(
259253
def update(
260254
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
261255
) -> Tuple[torch.Tensor, torch.Tensor]:
262-
# input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
256+
# input_pos: [S], k_val: [B, H, S, D]
263257
if self.enable_dynamic_shape:
264258
start_pos = input_pos[0].item()
265259
torch._check_is_size(start_pos)
266260
torch._check(start_pos < self.max_seq_length)
267-
dim_to_slice = 2 if self.transpose_cache else 1
261+
dim_to_slice = 2
268262
seq_length = k_val.size(dim_to_slice)
269263
# Replace the entry in the cache for this token
270264
# The following lines are equivalent to:
@@ -283,28 +277,22 @@ def update(
283277
else:
284278
k_out = self.k_cache
285279
v_out = self.v_cache
286-
if self.transpose_cache:
287-
k_out[:, :, input_pos] = k_val
288-
v_out[:, :, input_pos] = v_val
289-
else:
290-
k_out[:, input_pos] = k_val
291-
v_out[:, input_pos] = v_val
280+
k_out[:, :, input_pos] = k_val
281+
v_out[:, :, input_pos] = v_val
292282

293283
return k_out, v_out
294284

295285

296286
class SDPA(nn.Module):
297287
def __init__(
298288
self,
299-
kv_cache: KVCache,
300289
dim: int,
301290
head_dim: int,
302291
n_rep: int,
303292
max_seq_len: int,
304293
enable_dynamic_shape: bool,
305294
):
306295
super().__init__()
307-
self.kv_cache = kv_cache
308296
self.dim = dim
309297
self.head_dim = head_dim
310298
self.n_rep = n_rep
@@ -314,18 +302,13 @@ def __init__(
314302
def forward(
315303
self,
316304
input_pos: torch.Tensor,
317-
q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim)
318-
k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim)
319-
v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim)
305+
q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim)
306+
k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim)
307+
v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim)
320308
bsz,
321309
seqlen,
322310
mask: torch.Tensor,
323311
) -> torch.Tensor:
324-
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
325-
k = k.transpose(1, 2)
326-
v = v.transpose(1, 2)
327-
328-
k, v = self.kv_cache.update(input_pos, k, v)
329312
if self.enable_dynamic_shape:
330313
start_pos = input_pos[-1].item()
331314
torch._check_is_size(start_pos)
@@ -336,6 +319,8 @@ def forward(
336319
else:
337320
attn_mask = mask[None, None, input_pos]
338321

322+
# TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
323+
# can natively support GQA now. But needs enable_gqa=True
339324
k = k.repeat_interleave(self.n_rep, dim=1)
340325
v = v.repeat_interleave(self.n_rep, dim=1)
341326
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
@@ -383,11 +368,9 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
383368
args.max_seq_len,
384369
self.n_kv_heads,
385370
self.head_dim,
386-
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v
387371
args.enable_dynamic_shape,
388372
)
389373
self.SDPA = SDPA(
390-
kv_cache=self.kv_cache,
391374
dim=self.n_local_heads * self.head_dim,
392375
head_dim=self.head_dim,
393376
n_rep=self.n_rep,
@@ -414,15 +397,16 @@ def forward(
414397
# RoPE relative positional embeddings
415398
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
416399

400+
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
401+
k = k.transpose(1, 2)
402+
v = v.transpose(1, 2)
403+
417404
if self.use_kv_cache:
418405
assert input_pos is not None
406+
k, v = self.kv_cache.update(input_pos, k, v)
419407
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
420408
return self.wo(output)
421409

422-
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
423-
k = k.transpose(1, 2)
424-
v = v.transpose(1, 2)
425-
426410
# grouped multiquery attention: expand out keys and values
427411
k = k.repeat_interleave(self.n_rep, dim=1)
428412
v = v.repeat_interleave(self.n_rep, dim=1)

examples/models/llama/source_transformation/attention_sink.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ def __init__(
111111
self,
112112
n_heads: int,
113113
head_dim: int,
114-
transpose_cache: bool,
115114
enable_dynamic_shape: bool,
116115
rope: RopeWithAttentionSink,
117116
window_size: int,
@@ -125,7 +124,6 @@ def __init__(
125124
max_seq_length=window_size + sink_size,
126125
n_heads=n_heads,
127126
head_dim=head_dim,
128-
transpose_cache=transpose_cache,
129127
enable_dynamic_shape=enable_dynamic_shape,
130128
dtype=dtype,
131129
)
@@ -161,28 +159,17 @@ def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int:
161159
input_pos_item + self.position_shift - self.sink_size - num_to_evict
162160
)
163161
num_empty_space = self.window_size - num_to_keep
164-
dim_to_slice = 2 if self.transpose_cache else 1
162+
dim_to_slice = 2
165163
k_to_keep = self.k_cache.narrow(
166164
dim_to_slice,
167165
self.sink_size + num_to_evict, # pyre-ignore [6]
168166
num_to_keep, # pyre-ignore [6]
169167
)
170-
if self.transpose_cache:
171-
k_to_keep = self.rope.rerotate_k(
172-
k=k_to_keep.transpose(1, 2),
173-
original_position=( # pyre-ignore [6]
174-
self.sink_size + num_to_evict
175-
),
176-
new_position=self.sink_size,
177-
).transpose(1, 2)
178-
else:
179-
k_to_keep = self.rope.rerotate_k(
180-
k=k_to_keep,
181-
original_position=( # pyre-ignore [6]
182-
self.sink_size + num_to_evict
183-
),
184-
new_position=self.sink_size,
185-
)
168+
k_to_keep = self.rope.rerotate_k(
169+
k=k_to_keep.transpose(1, 2),
170+
original_position=(self.sink_size + num_to_evict), # pyre-ignore [6]
171+
new_position=self.sink_size,
172+
).transpose(1, 2)
186173
self.k_cache = torch.cat(
187174
[
188175
self.k_cache.narrow(dim_to_slice, 0, self.sink_size),
@@ -278,7 +265,6 @@ def _replace_attention(
278265
kv_cache_with_attention_sink = KVCacheWithAttentionSink(
279266
n_heads=kv_cache.n_heads,
280267
head_dim=kv_cache.head_dim,
281-
transpose_cache=kv_cache.transpose_cache,
282268
enable_dynamic_shape=kv_cache.enable_dynamic_shape,
283269
rope=rope_with_attention_sink,
284270
max_batch_size=kv_cache.max_batch_size,
@@ -288,7 +274,6 @@ def _replace_attention(
288274
dtype=kv_cache.k_cache.dtype,
289275
)
290276
child_module.kv_cache = kv_cache_with_attention_sink
291-
child_module.SDPA.kv_cache = kv_cache_with_attention_sink
292277
child_module.forward = types.MethodType( # pyre-ignore
293278
attention_sink_forward, child_module
294279
)

0 commit comments

Comments
 (0)