diff --git a/torchchat/export.py b/torchchat/export.py index 21e7fcaa8..3de4233d7 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -155,16 +155,21 @@ def __init__(self, attention: Attention): attention.kv_cache[0].k_cache.shape ) cache_dtype = attention.kv_cache[0].k_cache.dtype - self.kv_cache = CustomKVCache( - max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype - ) + # The `Attention` module being replaced can have multiple KV caches + # (denoted by `cache_lanes`). Thus we follow the same setup format + # as in `Attention.setup_cache`. + cache_lanes = len(attention.kv_cache) + self.kv_cache = nn.ModuleList([ + CustomKVCache(max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype) + for _ in range(cache_lanes) + ]) self.n_heads = attention.n_heads self.head_dim = attention.head_dim self.n_local_heads = attention.n_local_heads self.dim = attention.dim - def forward(self, x, freqs_cis, mask, input_pos=None): + def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0): bsz, seqlen, _ = x.shape q = self.wq(x) @@ -181,12 +186,13 @@ def forward(self, x, freqs_cis, mask, input_pos=None): # KV cache should always be enabled assert self.kv_cache is not None + kv_cache = self.kv_cache[cache_lane] output = torch.ops.llama.sdpa_with_kv_cache( q, k, v, - self.kv_cache.k_cache, - self.kv_cache.v_cache, + kv_cache.k_cache, + kv_cache.v_cache, input_pos[-1].item(), seqlen, ) diff --git a/torchchat/model.py b/torchchat/model.py index 3300ebee9..00f783234 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -653,7 +653,7 @@ def distribute(self, device_mesh: DeviceMesh): ColwiseParallel(output_layouts=Replicate()), ) - def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 1) -> Tensor: + def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 0) -> Tensor: assert self.freqs_cis is not None, "Caches must be initialized first" mask = self.causal_mask[None, None, input_pos] freqs_cis = self.freqs_cis[input_pos] @@ -686,7 +686,9 @@ def distribute(self, device_mesh: DeviceMesh): def forward( self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor, cache_lane: int = 0 ) -> Tensor: - h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + h = x + self.attention( + self.attention_norm(x), freqs_cis, mask, input_pos, cache_lane=cache_lane + ) out = h + self.feed_forward(self.ffn_norm(h)) return out