@@ -437,13 +437,15 @@ def setup_caches(self, max_batch_size, max_seq_length):
437437 and self .max_batch_size >= max_batch_size
438438 ):
439439 return
440- head_dim = self .config .dim // self .config .n_heads
441440 max_seq_length = find_multiple (max_seq_length , 8 )
442441 self .max_seq_length = max_seq_length
443442 self .max_batch_size = max_batch_size
444443 for b in self .layers .values ():
445- b .attention .kv_cache = KVCache (
446- max_batch_size , max_seq_length , self .config .n_local_heads , head_dim
444+ # Lower the setup_cache call to the attention module because tensor
445+ # parallelism may have been applied there and the `n_local_heads``
446+ # value being adjusted.
447+ b .attention .setup_cache (
448+ max_batch_size , max_seq_length ,
447449 )
448450
449451 freqs_cis = precompute_freqs_cis (
@@ -559,6 +561,16 @@ def __init__(self, config: TransformerArgs):
559561 self .dim = config .dim
560562 self ._register_load_state_dict_pre_hook (self .load_hook )
561563
564+ def setup_cache (self , max_batch_size , max_seq_length ):
565+ n_local_heads = self .n_local_heads
566+ # If TP is enabled, the heads would be divided and assigned to different ranks
567+ if hasattr (self , "tp_degree" ):
568+ n_local_heads = self .n_local_heads // self .tp_degree
569+
570+ self .kv_cache = KVCache (
571+ max_batch_size , max_seq_length , n_local_heads , self .head_dim
572+ )
573+
562574 def load_hook (self , state_dict , prefix , * args ):
563575 # if prefix + "wq.weight" in state_dict:
564576 # wq = state_dict.pop(prefix + "wq.weight")
@@ -599,14 +611,13 @@ def _unfuse_wqkv_state_dict(
599611
600612 def distribute (self , device_mesh : DeviceMesh ):
601613 self .device_mesh = device_mesh
614+ self .tp_degree = device_mesh .size ()
602615 parallelize_module (self .wq , device_mesh , ColwiseParallel ())
603616 parallelize_module (self .wk , device_mesh , ColwiseParallel ())
604617 parallelize_module (self .wv , device_mesh , ColwiseParallel ())
605618 parallelize_module (
606619 self .wo , device_mesh , RowwiseParallel (output_layouts = Shard (1 ))
607620 )
608- # TODO: enable kv cache in distributed case
609- self .kv_cache = None
610621
611622 def forward (
612623 self ,
0 commit comments