Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 6d2ef4a

Browse files
authored
[Distributed] Fix cache lane (#1194)
* Fix cache_lane in forward * Add cache lane to CustomKVCache
1 parent c40c6bb commit 6d2ef4a

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

torchchat/export.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,21 @@ def __init__(self, attention: Attention):
158158
attention.kv_cache[0].k_cache.shape
159159
)
160160
cache_dtype = attention.kv_cache[0].k_cache.dtype
161-
self.kv_cache = CustomKVCache(
162-
max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype
163-
)
161+
# The `Attention` module being replaced can have multiple KV caches
162+
# (denoted by `cache_lanes`). Thus we follow the same setup format
163+
# as in `Attention.setup_cache`.
164+
cache_lanes = len(attention.kv_cache)
165+
self.kv_cache = nn.ModuleList([
166+
CustomKVCache(max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype)
167+
for _ in range(cache_lanes)
168+
])
164169

165170
self.n_heads = attention.n_heads
166171
self.head_dim = attention.head_dim
167172
self.n_local_heads = attention.n_local_heads
168173
self.dim = attention.dim
169174

170-
def forward(self, x, freqs_cis, mask, input_pos=None):
175+
def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0):
171176
bsz, seqlen, _ = x.shape
172177

173178
q = self.wq(x)
@@ -184,12 +189,13 @@ def forward(self, x, freqs_cis, mask, input_pos=None):
184189

185190
# KV cache should always be enabled
186191
assert self.kv_cache is not None
192+
kv_cache = self.kv_cache[cache_lane]
187193
output = torch.ops.llama.sdpa_with_kv_cache(
188194
q,
189195
k,
190196
v,
191-
self.kv_cache.k_cache,
192-
self.kv_cache.v_cache,
197+
kv_cache.k_cache,
198+
kv_cache.v_cache,
193199
input_pos[-1].item(),
194200
seqlen,
195201
)

torchchat/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ def distribute(self, device_mesh: DeviceMesh):
653653
ColwiseParallel(output_layouts=Replicate()),
654654
)
655655

656-
def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 1) -> Tensor:
656+
def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 0) -> Tensor:
657657
assert self.freqs_cis is not None, "Caches must be initialized first"
658658
mask = self.causal_mask[None, None, input_pos]
659659
freqs_cis = self.freqs_cis[input_pos]
@@ -686,7 +686,9 @@ def distribute(self, device_mesh: DeviceMesh):
686686
def forward(
687687
self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor, cache_lane: int = 0
688688
) -> Tensor:
689-
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
689+
h = x + self.attention(
690+
self.attention_norm(x), freqs_cis, mask, input_pos, cache_lane=cache_lane
691+
)
690692
out = h + self.feed_forward(self.ffn_norm(h))
691693
return out
692694

0 commit comments

Comments
 (0)