@@ -158,16 +158,21 @@ def __init__(self, attention: Attention):
158158                attention .kv_cache [0 ].k_cache .shape 
159159            )
160160            cache_dtype  =  attention .kv_cache [0 ].k_cache .dtype 
161-             self .kv_cache  =  CustomKVCache (
162-                 max_batch_size , max_seq_length , n_heads , head_dim , cache_dtype 
163-             )
161+             # The `Attention` module being replaced can have multiple KV caches 
162+             # (denoted by `cache_lanes`).  Thus we follow the same setup format 
163+             # as in `Attention.setup_cache`. 
164+             cache_lanes  =  len (attention .kv_cache )
165+             self .kv_cache  =  nn .ModuleList ([
166+                 CustomKVCache (max_batch_size , max_seq_length , n_heads , head_dim , cache_dtype )
167+                 for  _  in  range (cache_lanes )
168+             ])
164169
165170            self .n_heads  =  attention .n_heads 
166171            self .head_dim  =  attention .head_dim 
167172            self .n_local_heads  =  attention .n_local_heads 
168173            self .dim  =  attention .dim 
169174
170-         def  forward (self , x , freqs_cis , mask , input_pos = None ):
175+         def  forward (self , x , freqs_cis , mask , input_pos = None ,  cache_lane :  int   =   0 ):
171176            bsz , seqlen , _  =  x .shape 
172177
173178            q  =  self .wq (x )
@@ -184,12 +189,13 @@ def forward(self, x, freqs_cis, mask, input_pos=None):
184189
185190            # KV cache should always be enabled 
186191            assert  self .kv_cache  is  not None 
192+             kv_cache  =  self .kv_cache [cache_lane ]
187193            output  =  torch .ops .llama .sdpa_with_kv_cache (
188194                q ,
189195                k ,
190196                v ,
191-                 self . kv_cache .k_cache ,
192-                 self . kv_cache .v_cache ,
197+                 kv_cache .k_cache ,
198+                 kv_cache .v_cache ,
193199                input_pos [- 1 ].item (),
194200                seqlen ,
195201            )
0 commit comments