Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 07c21c1

Browse files
committed
[Distributed] Enable KV cache
1 parent 03c9819 commit 07c21c1

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

dist_run.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
"llama2-7b-chat": ("meta-llama/Llama-2-7b-chat-hf", torch.float16),
5656
"llama3": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16),
5757
}
58-
CACHE_PRECISION = torch.bfloat16
5958

6059

6160
def _init_distributed():
@@ -245,8 +244,8 @@ def main(args):
245244

246245
tokenizer = _build_chat_tokenizer(model_name)
247246

248-
set_precision(CACHE_PRECISION)
249-
logger.info(f"Using cache precision {CACHE_PRECISION}")
247+
set_precision(model_dtype)
248+
logger.info(f"Using cache precision {model_dtype}")
250249

251250
hf_config = get_hf_config_file(distribution)
252251
if hf_config is None:
@@ -285,8 +284,6 @@ def main(args):
285284
with device:
286285
model = Transformer(config)
287286

288-
model.setup_caches(1, 4096)
289-
290287
# Distribute model on TP mesh
291288
model.distribute(tp_mesh)
292289
if rank == 0:
@@ -300,6 +297,12 @@ def main(args):
300297
dim = 4096 # embedding dimension
301298
assert seqlen % sp_degree == 0
302299

300+
# Setup KV caches (after model distribution)
301+
# TODO: the setting below only works for 1 micro-batch case. To support
302+
# multiple micro-batches, we need the KV cache in the model to be aware of
303+
# the number of micro-batches and the current micro-batch index.
304+
model.setup_caches(mb_size, seqlen)
305+
303306
mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device)
304307
activation = torch.rand(
305308
mb_size, seqlen // sp_degree, dim, device=device, dtype=model_dtype

torchchat/model.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -437,13 +437,15 @@ def setup_caches(self, max_batch_size, max_seq_length):
437437
and self.max_batch_size >= max_batch_size
438438
):
439439
return
440-
head_dim = self.config.dim // self.config.n_heads
441440
max_seq_length = find_multiple(max_seq_length, 8)
442441
self.max_seq_length = max_seq_length
443442
self.max_batch_size = max_batch_size
444443
for b in self.layers.values():
445-
b.attention.kv_cache = KVCache(
446-
max_batch_size, max_seq_length, self.config.n_local_heads, head_dim
444+
# Lower the setup_cache call to the attention module because tensor
445+
# parallelism may have been applied there and the `n_local_heads``
446+
# value being adjusted.
447+
b.attention.setup_cache(
448+
max_batch_size, max_seq_length,
447449
)
448450

449451
freqs_cis = precompute_freqs_cis(
@@ -559,6 +561,16 @@ def __init__(self, config: TransformerArgs):
559561
self.dim = config.dim
560562
self._register_load_state_dict_pre_hook(self.load_hook)
561563

564+
def setup_cache(self, max_batch_size, max_seq_length):
565+
n_local_heads = self.n_local_heads
566+
# If TP is enabled, the heads would be divided and assigned to different ranks
567+
if hasattr(self, "tp_degree"):
568+
n_local_heads = self.n_local_heads // self.tp_degree
569+
570+
self.kv_cache = KVCache(
571+
max_batch_size, max_seq_length, n_local_heads, self.head_dim
572+
)
573+
562574
def load_hook(self, state_dict, prefix, *args):
563575
# if prefix + "wq.weight" in state_dict:
564576
# wq = state_dict.pop(prefix + "wq.weight")
@@ -599,14 +611,13 @@ def _unfuse_wqkv_state_dict(
599611

600612
def distribute(self, device_mesh: DeviceMesh):
601613
self.device_mesh = device_mesh
614+
self.tp_degree = device_mesh.size()
602615
parallelize_module(self.wq, device_mesh, ColwiseParallel())
603616
parallelize_module(self.wk, device_mesh, ColwiseParallel())
604617
parallelize_module(self.wv, device_mesh, ColwiseParallel())
605618
parallelize_module(
606619
self.wo, device_mesh, RowwiseParallel(output_layouts=Shard(1))
607620
)
608-
# TODO: enable kv cache in distributed case
609-
self.kv_cache = None
610621

611622
def forward(
612623
self,

0 commit comments

Comments
 (0)