@@ -50,9 +50,14 @@ def __init__(
50
50
super ().__init__ ()
51
51
self .d_model = config .d_model
52
52
self .total_num_heads = config .n_heads
53
+ self .head_dim = self .d_model // self .total_num_heads
53
54
self .clip_qkv = config .attn_config ["clip_qkv" ]
54
55
self .qk_ln = config .attn_config ["qk_ln" ]
55
56
self .alibi_bias_max = config .attn_config ["alibi_bias_max" ]
57
+ if "kv_n_heads" in config .attn_config :
58
+ self .total_num_kv_heads = config .attn_config ['kv_n_heads' ]
59
+ else :
60
+ self .total_num_kv_heads = self .total_num_heads
56
61
assert not config .attn_config ["prefix_lm" ]
57
62
assert config .attn_config ["alibi" ]
58
63
@@ -61,6 +66,7 @@ def __init__(
61
66
self .d_model ,
62
67
self .d_model // self .total_num_heads ,
63
68
self .total_num_heads ,
69
+ self .total_num_kv_heads ,
64
70
bias = not config .no_bias ,
65
71
linear_method = linear_method ,
66
72
)
@@ -78,6 +84,17 @@ def __init__(
78
84
assert self .total_num_heads % tp_world_size == 0
79
85
self .num_heads = self .total_num_heads // tp_world_size
80
86
87
+ if self .total_num_kv_heads >= tp_world_size :
88
+ # Number of KV heads is greater than TP size, so we partition
89
+ # the KV heads across multiple tensor parallel GPUs.
90
+ assert self .total_num_kv_heads % tp_world_size == 0
91
+ else :
92
+ # Number of KV heads is less than TP size, so we replicate
93
+ # the KV heads across multiple tensor parallel GPUs.
94
+ assert tp_world_size % self .total_num_kv_heads == 0
95
+ self .num_kv_heads = max (1 , self .total_num_kv_heads // tp_world_size )
96
+ self .q_size = self .num_heads * self .head_dim
97
+ self .kv_size = self .num_kv_heads * self .head_dim
81
98
# Create the alibi slopes and slice them.
82
99
tp_rank = get_tensor_model_parallel_rank ()
83
100
head_start = tp_rank * self .num_heads
@@ -91,7 +108,8 @@ def __init__(
91
108
self .attn = PagedAttention (self .num_heads ,
92
109
self .head_dim ,
93
110
scaling ,
94
- alibi_slopes = alibi_slopes )
111
+ alibi_slopes = alibi_slopes ,
112
+ num_kv_heads = self .num_kv_heads )
95
113
96
114
def forward (
97
115
self ,
@@ -105,7 +123,7 @@ def forward(
105
123
qkv , _ = self .Wqkv (hidden_states )
106
124
if self .clip_qkv is not None :
107
125
qkv .clamp_ (min = - self .clip_qkv , max = self .clip_qkv )
108
- q , k , v = qkv .chunk ( chunks = 3 , dim = - 1 )
126
+ q , k , v = qkv .split ([ self . q_size , self . kv_size , self . kv_size ] , dim = - 1 )
109
127
if self .qk_ln :
110
128
q = self .q_ln (q )
111
129
k = self .k_ln (k )
0 commit comments