@@ -155,16 +155,21 @@ def __init__(self, attention: Attention):
155155 attention .kv_cache [0 ].k_cache .shape
156156 )
157157 cache_dtype = attention .kv_cache [0 ].k_cache .dtype
158- self .kv_cache = CustomKVCache (
159- max_batch_size , max_seq_length , n_heads , head_dim , cache_dtype
160- )
158+ # The `Attention` module being replaced can have multiple KV caches
159+ # (denoted by `cache_lanes`). Thus we follow the same setup format
160+ # as in `Attention.setup_cache`.
161+ cache_lanes = len (attention .kv_cache )
162+ self .kv_cache = nn .ModuleList ([
163+ CustomKVCache (max_batch_size , max_seq_length , n_heads , head_dim , cache_dtype )
164+ for _ in range (cache_lanes )
165+ ])
161166
162167 self .n_heads = attention .n_heads
163168 self .head_dim = attention .head_dim
164169 self .n_local_heads = attention .n_local_heads
165170 self .dim = attention .dim
166171
167- def forward (self , x , freqs_cis , mask , input_pos = None ):
172+ def forward (self , x , freqs_cis , mask , input_pos = None , cache_lane : int = 0 ):
168173 bsz , seqlen , _ = x .shape
169174
170175 q = self .wq (x )
@@ -181,12 +186,13 @@ def forward(self, x, freqs_cis, mask, input_pos=None):
181186
182187 # KV cache should always be enabled
183188 assert self .kv_cache is not None
189+ kv_cache = self .kv_cache [cache_lane ]
184190 output = torch .ops .llama .sdpa_with_kv_cache (
185191 q ,
186192 k ,
187193 v ,
188- self . kv_cache .k_cache ,
189- self . kv_cache .v_cache ,
194+ kv_cache .k_cache ,
195+ kv_cache .v_cache ,
190196 input_pos [- 1 ].item (),
191197 seqlen ,
192198 )
0 commit comments