@@ -232,22 +232,16 @@ def __init__(
232232 max_seq_length : int ,
233233 n_heads : int ,
234234 head_dim : int ,
235- transpose_cache : bool ,
236235 enable_dynamic_shape : bool ,
237236 dtype = torch .float32 ,
238237 ):
239238 super ().__init__ ()
240239 self .max_seq_length = max_seq_length
241- self .is_transposed = transpose_cache
242- if transpose_cache :
243- cache_shape = (max_batch_size , n_heads , max_seq_length , head_dim )
244- else :
245- cache_shape = (max_batch_size , max_seq_length , n_heads , head_dim )
240+ cache_shape = (max_batch_size , n_heads , max_seq_length , head_dim )
246241
247242 self .max_batch_size = max_batch_size
248243 self .n_heads = n_heads
249244 self .head_dim = head_dim
250- self .transpose_cache = transpose_cache
251245 self .enable_dynamic_shape = enable_dynamic_shape
252246 self .register_buffer (
253247 "k_cache" , torch .zeros (cache_shape , dtype = dtype , device = "cpu" )
@@ -259,12 +253,12 @@ def __init__(
259253 def update (
260254 self , input_pos : torch .Tensor , k_val : torch .Tensor , v_val : torch .Tensor
261255 ) -> Tuple [torch .Tensor , torch .Tensor ]:
262- # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
256+ # input_pos: [S], k_val: [B, H, S, D]
263257 if self .enable_dynamic_shape :
264258 start_pos = input_pos [0 ].item ()
265259 torch ._check_is_size (start_pos )
266260 torch ._check (start_pos < self .max_seq_length )
267- dim_to_slice = 2 if self . transpose_cache else 1
261+ dim_to_slice = 2
268262 seq_length = k_val .size (dim_to_slice )
269263 # Replace the entry in the cache for this token
270264 # The following lines are equivalent to:
@@ -283,28 +277,22 @@ def update(
283277 else :
284278 k_out = self .k_cache
285279 v_out = self .v_cache
286- if self .transpose_cache :
287- k_out [:, :, input_pos ] = k_val
288- v_out [:, :, input_pos ] = v_val
289- else :
290- k_out [:, input_pos ] = k_val
291- v_out [:, input_pos ] = v_val
280+ k_out [:, :, input_pos ] = k_val
281+ v_out [:, :, input_pos ] = v_val
292282
293283 return k_out , v_out
294284
295285
296286class SDPA (nn .Module ):
297287 def __init__ (
298288 self ,
299- kv_cache : KVCache ,
300289 dim : int ,
301290 head_dim : int ,
302291 n_rep : int ,
303292 max_seq_len : int ,
304293 enable_dynamic_shape : bool ,
305294 ):
306295 super ().__init__ ()
307- self .kv_cache = kv_cache
308296 self .dim = dim
309297 self .head_dim = head_dim
310298 self .n_rep = n_rep
@@ -314,18 +302,13 @@ def __init__(
314302 def forward (
315303 self ,
316304 input_pos : torch .Tensor ,
317- q : torch .Tensor , # Already have rotary embeddings. (bs, seqlen, n_local_heads , head_dim)
318- k : torch .Tensor , # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads , head_dim)
319- v : torch .Tensor , # (bs, seqlen, n_local_kv_heads , head_dim)
305+ q : torch .Tensor , # Already have rotary embeddings. (bs, n_local_heads, seqlen , head_dim)
306+ k : torch .Tensor , # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen , head_dim)
307+ v : torch .Tensor , # (bs, n_local_kv_heads, seqlen , head_dim)
320308 bsz ,
321309 seqlen ,
322310 mask : torch .Tensor ,
323311 ) -> torch .Tensor :
324- q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
325- k = k .transpose (1 , 2 )
326- v = v .transpose (1 , 2 )
327-
328- k , v = self .kv_cache .update (input_pos , k , v )
329312 if self .enable_dynamic_shape :
330313 start_pos = input_pos [- 1 ].item ()
331314 torch ._check_is_size (start_pos )
@@ -336,6 +319,8 @@ def forward(
336319 else :
337320 attn_mask = mask [None , None , input_pos ]
338321
322+ # TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
323+ # can natively support GQA now. But needs enable_gqa=True
339324 k = k .repeat_interleave (self .n_rep , dim = 1 )
340325 v = v .repeat_interleave (self .n_rep , dim = 1 )
341326 y = F .scaled_dot_product_attention (q , k , v , attn_mask = attn_mask , dropout_p = 0.0 )
@@ -383,11 +368,9 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
383368 args .max_seq_len ,
384369 self .n_kv_heads ,
385370 self .head_dim ,
386- not args .use_sdpa_with_kv_cache_op , # if we are using the custom op don't transpose the cache. Expect untransposed q k v
387371 args .enable_dynamic_shape ,
388372 )
389373 self .SDPA = SDPA (
390- kv_cache = self .kv_cache ,
391374 dim = self .n_local_heads * self .head_dim ,
392375 head_dim = self .head_dim ,
393376 n_rep = self .n_rep ,
@@ -414,15 +397,16 @@ def forward(
414397 # RoPE relative positional embeddings
415398 q , k = self .rope .forward (q , k , freqs_cos , freqs_sin )
416399
400+ q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
401+ k = k .transpose (1 , 2 )
402+ v = v .transpose (1 , 2 )
403+
417404 if self .use_kv_cache :
418405 assert input_pos is not None
406+ k , v = self .kv_cache .update (input_pos , k , v )
419407 output = self .SDPA (input_pos , q , k , v , bsz , seqlen , self .mask )
420408 return self .wo (output )
421409
422- q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
423- k = k .transpose (1 , 2 )
424- v = v .transpose (1 , 2 )
425-
426410 # grouped multiquery attention: expand out keys and values
427411 k = k .repeat_interleave (self .n_rep , dim = 1 )
428412 v = v .repeat_interleave (self .n_rep , dim = 1 )
0 commit comments