Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 3166c62

Browse files
committed
Add cache lane to CustomKVCache
1 parent b0593b6 commit 3166c62

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

torchchat/export.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,16 +155,21 @@ def __init__(self, attention: Attention):
155155
attention.kv_cache[0].k_cache.shape
156156
)
157157
cache_dtype = attention.kv_cache[0].k_cache.dtype
158-
self.kv_cache = CustomKVCache(
159-
max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype
160-
)
158+
# The `Attention` module being replaced can have multiple KV caches
159+
# (denoted by `cache_lanes`). Thus we follow the same setup format
160+
# as in `Attention.setup_cache`.
161+
cache_lanes = len(attention.kv_cache)
162+
self.kv_cache = nn.ModuleList([
163+
CustomKVCache(max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype)
164+
for _ in range(cache_lanes)
165+
])
161166

162167
self.n_heads = attention.n_heads
163168
self.head_dim = attention.head_dim
164169
self.n_local_heads = attention.n_local_heads
165170
self.dim = attention.dim
166171

167-
def forward(self, x, freqs_cis, mask, input_pos=None):
172+
def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0):
168173
bsz, seqlen, _ = x.shape
169174

170175
q = self.wq(x)
@@ -181,12 +186,13 @@ def forward(self, x, freqs_cis, mask, input_pos=None):
181186

182187
# KV cache should always be enabled
183188
assert self.kv_cache is not None
189+
kv_cache = self.kv_cache[cache_lane]
184190
output = torch.ops.llama.sdpa_with_kv_cache(
185191
q,
186192
k,
187193
v,
188-
self.kv_cache.k_cache,
189-
self.kv_cache.v_cache,
194+
kv_cache.k_cache,
195+
kv_cache.v_cache,
190196
input_pos[-1].item(),
191197
seqlen,
192198
)

0 commit comments

Comments
 (0)