Skip to content

Commit 6428f1d

Browse files
megha95WoosukKwon
andauthored
Support MPT with GQA (#1938)
Co-authored-by: Woosuk Kwon <[email protected]>
1 parent 7e1b21d commit 6428f1d

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

vllm/model_executor/layers/attention.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def forward(
138138
input_metadata.attn_bias = attn_bias
139139
else:
140140
input_metadata.attn_bias = _make_alibi_bias(
141-
self.alibi_slopes, batch_size, seq_len, query.dtype)
141+
self.alibi_slopes, self.num_kv_heads, batch_size,
142+
seq_len, query.dtype)
142143

143144
# TODO(woosuk): Too many view operations. Let's try to reduce them
144145
# in the future for code readability.
@@ -180,31 +181,34 @@ def forward(
180181

181182
def _make_alibi_bias(
182183
alibi_slopes: torch.Tensor,
184+
num_kv_heads: int,
183185
batch_size: int,
184186
seq_len: int,
185187
dtype: torch.dtype,
186188
) -> LowerTriangularMaskWithTensorBias:
187-
bias = torch.arange(seq_len, dtype=dtype)
189+
bias = torch.arange(seq_len, dtype=dtype, device="cuda")
188190
# NOTE(zhuohan): HF uses
189191
# `bias = bias[None, :].repeat(prompt_len, 1)`
190192
# here. We find that both biases give the same results, but
191193
# the bias below more accurately follows the original ALiBi
192194
# paper.
193195
bias = bias[None, :] - bias[:, None]
194-
bias = bias.to(alibi_slopes.device)
195196

196197
# When using custom attention bias, xformers requires the bias to
197198
# be sliced from a tensor whose length is a multiple of 8.
198199
padded_len = (seq_len + 7) // 8 * 8
200+
num_heads = alibi_slopes.shape[0]
199201
bias = torch.empty(
200202
batch_size,
201-
alibi_slopes.shape[0],
203+
num_heads,
202204
seq_len,
203205
padded_len,
204206
device=alibi_slopes.device,
205207
dtype=dtype,
206208
)[:, :, :, :seq_len].copy_(bias)
207209
bias.mul_(alibi_slopes[:, None, None])
210+
if num_heads != num_kv_heads:
211+
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
208212
attn_bias = LowerTriangularMaskWithTensorBias(bias)
209213
return attn_bias
210214

vllm/model_executor/models/mpt.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,14 @@ def __init__(
5050
super().__init__()
5151
self.d_model = config.d_model
5252
self.total_num_heads = config.n_heads
53+
self.head_dim = self.d_model // self.total_num_heads
5354
self.clip_qkv = config.attn_config["clip_qkv"]
5455
self.qk_ln = config.attn_config["qk_ln"]
5556
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
57+
if "kv_n_heads" in config.attn_config:
58+
self.total_num_kv_heads = config.attn_config['kv_n_heads']
59+
else:
60+
self.total_num_kv_heads = self.total_num_heads
5661
assert not config.attn_config["prefix_lm"]
5762
assert config.attn_config["alibi"]
5863

@@ -61,6 +66,7 @@ def __init__(
6166
self.d_model,
6267
self.d_model // self.total_num_heads,
6368
self.total_num_heads,
69+
self.total_num_kv_heads,
6470
bias=not config.no_bias,
6571
linear_method=linear_method,
6672
)
@@ -78,6 +84,17 @@ def __init__(
7884
assert self.total_num_heads % tp_world_size == 0
7985
self.num_heads = self.total_num_heads // tp_world_size
8086

87+
if self.total_num_kv_heads >= tp_world_size:
88+
# Number of KV heads is greater than TP size, so we partition
89+
# the KV heads across multiple tensor parallel GPUs.
90+
assert self.total_num_kv_heads % tp_world_size == 0
91+
else:
92+
# Number of KV heads is less than TP size, so we replicate
93+
# the KV heads across multiple tensor parallel GPUs.
94+
assert tp_world_size % self.total_num_kv_heads == 0
95+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
96+
self.q_size = self.num_heads * self.head_dim
97+
self.kv_size = self.num_kv_heads * self.head_dim
8198
# Create the alibi slopes and slice them.
8299
tp_rank = get_tensor_model_parallel_rank()
83100
head_start = tp_rank * self.num_heads
@@ -91,7 +108,8 @@ def __init__(
91108
self.attn = PagedAttention(self.num_heads,
92109
self.head_dim,
93110
scaling,
94-
alibi_slopes=alibi_slopes)
111+
alibi_slopes=alibi_slopes,
112+
num_kv_heads=self.num_kv_heads)
95113

96114
def forward(
97115
self,
@@ -105,7 +123,7 @@ def forward(
105123
qkv, _ = self.Wqkv(hidden_states)
106124
if self.clip_qkv is not None:
107125
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
108-
q, k, v = qkv.chunk(chunks=3, dim=-1)
126+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
109127
if self.qk_ln:
110128
q = self.q_ln(q)
111129
k = self.k_ln(k)

0 commit comments

Comments
 (0)