@@ -442,8 +442,11 @@ def setup_caches(self, max_batch_size, max_seq_length):
442442 self .max_seq_length = max_seq_length
443443 self .max_batch_size = max_batch_size
444444 for b in self .layers .values ():
445- b .attention .kv_cache = KVCache (
446- max_batch_size , max_seq_length , self .config .n_local_heads , head_dim
445+ # Lower the setup_cache call to the attention module because tensor
446+ # parallelism may have been applied there and the `n_local_heads``
447+ # value being adjusted.
448+ b .attention .setup_cache (
449+ max_batch_size , max_seq_length ,
447450 )
448451
449452 freqs_cis = precompute_freqs_cis (
@@ -559,6 +562,16 @@ def __init__(self, config: TransformerArgs):
559562 self .dim = config .dim
560563 self ._register_load_state_dict_pre_hook (self .load_hook )
561564
565+ def setup_cache (self , max_batch_size , max_seq_length ):
566+ n_local_heads = self .n_local_heads
567+ # If TP is enabled, the heads would be divided and assigned to different ranks
568+ if hasattr (self , "tp_degree" ):
569+ n_local_heads = self .n_local_heads // self .tp_degree
570+
571+ self .kv_cache = KVCache (
572+ max_batch_size , max_seq_length , n_local_heads , self .head_dim
573+ )
574+
562575 def load_hook (self , state_dict , prefix , * args ):
563576 # if prefix + "wq.weight" in state_dict:
564577 # wq = state_dict.pop(prefix + "wq.weight")
@@ -599,14 +612,13 @@ def _unfuse_wqkv_state_dict(
599612
600613 def distribute (self , device_mesh : DeviceMesh ):
601614 self .device_mesh = device_mesh
615+ self .tp_degree = device_mesh .size ()
602616 parallelize_module (self .wq , device_mesh , ColwiseParallel ())
603617 parallelize_module (self .wk , device_mesh , ColwiseParallel ())
604618 parallelize_module (self .wv , device_mesh , ColwiseParallel ())
605619 parallelize_module (
606620 self .wo , device_mesh , RowwiseParallel (output_layouts = Shard (1 ))
607621 )
608- # TODO: enable kv cache in distributed case
609- self .kv_cache = None
610622
611623 def forward (
612624 self ,
0 commit comments