Skip to content

Commit 404422f

Browse files
authored
[Model] Add support for MPT (#334)
1 parent 7717d08 commit 404422f

File tree

11 files changed

+388
-4
lines changed

11 files changed

+388
-4
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ vLLM seamlessly supports many Huggingface models, including the following archit
4646
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
4747
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
4848
- LLaMA (`lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
49+
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
4950
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
5051

5152
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):

csrc/attention/attention_kernels.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,9 @@ void single_query_cached_kv_attention_launcher(
395395
case 96:
396396
LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS);
397397
break;
398+
case 112:
399+
LAUNCH_ATTENTION_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS);
400+
break;
398401
case 128:
399402
LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
400403
break;

docs/source/models/supported_models.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ Alongside each architecture, we include some popular models that use it.
2929
* - :code:`LlamaForCausalLM`
3030
- LLaMA, Vicuna, Alpaca, Koala, Guanaco
3131
- :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, :code:`JosephusCheung/Guanaco`, etc.
32+
* - :code: `MPTForCausalLM`
33+
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
34+
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
3235
* - :code:`OPTForCausalLM`
3336
- OPT, OPT-IML
3437
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.

vllm/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from typing import Optional
22

33
import torch
4-
from transformers import AutoConfig, PretrainedConfig
4+
from transformers import PretrainedConfig
55

66
from vllm.logger import init_logger
7+
from vllm.transformers_utils.config import get_config
78
from vllm.utils import get_cpu_memory
89

910
logger = init_logger(__name__)
@@ -49,7 +50,7 @@ def __init__(
4950
self.use_dummy_weights = use_dummy_weights
5051
self.seed = seed
5152

52-
self.hf_config: PretrainedConfig = AutoConfig.from_pretrained(model)
53+
self.hf_config = get_config(model)
5354
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
5455
self._verify_tokenizer_mode()
5556

vllm/model_executor/layers/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm import pos_encoding_ops
1313
from vllm.model_executor.input_metadata import InputMetadata
1414

15-
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
15+
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128]
1616

1717

1818
class PagedAttention(nn.Module):

vllm/model_executor/model_loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
1717
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
1818
"LlamaForCausalLM": LlamaForCausalLM,
19-
"LLaMAForCausalLM": LlamaForCausalLM,
19+
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
20+
"MPTForCausalLM": MPTForCausalLM,
2021
"OPTForCausalLM": OPTForCausalLM,
2122
}
2223

vllm/model_executor/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
44
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
55
from vllm.model_executor.models.llama import LlamaForCausalLM
6+
from vllm.model_executor.models.mpt import MPTForCausalLM
67
from vllm.model_executor.models.opt import OPTForCausalLM
78

89
__all__ = [
@@ -11,5 +12,6 @@
1112
"GPTBigCodeForCausalLM",
1213
"GPTNeoXForCausalLM",
1314
"LlamaForCausalLM",
15+
"MPTForCausalLM",
1416
"OPTForCausalLM",
1517
]

vllm/model_executor/models/mpt.py

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
2+
import math
3+
from typing import Dict, List, Optional, Tuple
4+
5+
import torch
6+
import torch.nn as nn
7+
8+
from vllm.model_executor.input_metadata import InputMetadata
9+
from vllm.model_executor.layers.activation import get_act_fn
10+
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
11+
from vllm.model_executor.layers.sampler import Sampler
12+
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
13+
load_tensor_parallel_weights)
14+
from vllm.model_executor.parallel_utils.parallel_state import (
15+
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
16+
from vllm.model_executor.parallel_utils.tensor_parallel import (
17+
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
18+
from vllm.sequence import SequenceOutputs
19+
from vllm.transformers_utils.configs.mpt import MPTConfig
20+
21+
KVCache = Tuple[torch.Tensor, torch.Tensor]
22+
23+
24+
def _get_alibi_slopes(
25+
total_num_heads: int,
26+
alibi_bias_max: int,
27+
) -> torch.Tensor:
28+
next_power_of_2 = 2**math.ceil(math.log2(total_num_heads))
29+
m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
30+
m = m.mul(alibi_bias_max / next_power_of_2)
31+
slopes = 1.0 / torch.pow(2, m)
32+
if next_power_of_2 != total_num_heads:
33+
slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
34+
return slopes
35+
36+
37+
class MPTAttention(nn.Module):
38+
39+
def __init__(self, config: MPTConfig):
40+
super().__init__()
41+
self.d_model = config.d_model
42+
self.total_num_heads = config.n_heads
43+
self.clip_qkv = config.attn_config["clip_qkv"]
44+
self.qk_ln = config.attn_config["qk_ln"]
45+
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
46+
assert not config.attn_config["prefix_lm"]
47+
assert config.attn_config["alibi"]
48+
49+
self.qkv_proj = ColumnParallelLinear(
50+
self.d_model,
51+
3 * self.d_model,
52+
bias=not config.no_bias,
53+
gather_output=False,
54+
perform_initialization=False,
55+
)
56+
if self.qk_ln:
57+
self.q_ln = nn.LayerNorm(self.d_model)
58+
self.k_ln = nn.LayerNorm(self.d_model)
59+
self.out_proj = RowParallelLinear(
60+
self.d_model,
61+
self.d_model,
62+
bias=not config.no_bias,
63+
input_is_parallel=True,
64+
perform_initialization=False,
65+
)
66+
67+
tp_world_size = get_tensor_model_parallel_world_size()
68+
assert self.total_num_heads % tp_world_size == 0
69+
self.num_heads = self.total_num_heads // tp_world_size
70+
71+
# Create the alibi slopes and slice them.
72+
tp_rank = get_tensor_model_parallel_rank()
73+
head_start = tp_rank * self.num_heads
74+
head_end = (tp_rank + 1) * self.num_heads
75+
alibi_slopes = _get_alibi_slopes(self.total_num_heads,
76+
self.alibi_bias_max)
77+
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
78+
79+
self.head_dim = self.d_model // self.total_num_heads
80+
scaling = self.head_dim**-0.5
81+
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
82+
scaling, alibi_slopes)
83+
84+
def forward(
85+
self,
86+
position_ids: torch.Tensor,
87+
hidden_states: torch.Tensor,
88+
kv_cache: KVCache,
89+
input_metadata: InputMetadata,
90+
cache_event: Optional[torch.cuda.Event],
91+
) -> torch.Tensor:
92+
del position_ids # unused.
93+
qkv, _ = self.qkv_proj(hidden_states)
94+
if self.clip_qkv is not None:
95+
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
96+
q, k, v = qkv.chunk(chunks=3, dim=-1)
97+
if self.qk_ln:
98+
q = self.q_ln(q)
99+
k = self.k_ln(k)
100+
k_cache, v_cache = kv_cache
101+
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
102+
cache_event)
103+
output, _ = self.out_proj(attn_output)
104+
return output
105+
106+
107+
class MPTMLP(nn.Module):
108+
109+
def __init__(self, config: MPTConfig):
110+
super().__init__()
111+
hidden_size = config.d_model
112+
expansion_ratio = config.expansion_ratio
113+
intermediate_size = expansion_ratio * hidden_size
114+
self.up_proj = ColumnParallelLinear(hidden_size,
115+
intermediate_size,
116+
bias=not config.no_bias,
117+
gather_output=False,
118+
perform_initialization=False)
119+
self.act = get_act_fn("gelu")
120+
self.down_proj = RowParallelLinear(intermediate_size,
121+
hidden_size,
122+
bias=not config.no_bias,
123+
input_is_parallel=True,
124+
perform_initialization=False)
125+
126+
def forward(self, x: torch.Tensor) -> torch.Tensor:
127+
x, _ = self.up_proj(x)
128+
x = self.act(x)
129+
x, _ = self.down_proj(x)
130+
return x
131+
132+
133+
class MPTBlock(nn.Module):
134+
135+
def __init__(self, config: MPTConfig):
136+
super().__init__()
137+
hidden_size = config.d_model
138+
self.norm_1 = nn.LayerNorm(hidden_size)
139+
self.attn = MPTAttention(config)
140+
self.norm_2 = nn.LayerNorm(hidden_size)
141+
self.ffn = MPTMLP(config)
142+
143+
def forward(
144+
self,
145+
position_ids: torch.Tensor,
146+
hidden_states: torch.Tensor,
147+
kv_cache: KVCache,
148+
input_metadata: InputMetadata,
149+
cache_event: Optional[torch.cuda.Event],
150+
) -> torch.Tensor:
151+
x = self.norm_1(hidden_states)
152+
x = self.attn(
153+
position_ids=position_ids,
154+
hidden_states=x,
155+
kv_cache=kv_cache,
156+
input_metadata=input_metadata,
157+
cache_event=cache_event,
158+
)
159+
hidden_states = hidden_states + x
160+
x = self.norm_2(hidden_states)
161+
x = self.ffn(x)
162+
hidden_states = hidden_states + x
163+
return hidden_states
164+
165+
166+
class MPTModel(nn.Module):
167+
168+
def __init__(self, config: MPTConfig):
169+
super().__init__()
170+
assert config.embedding_fraction == 1.0
171+
assert config.norm_type == "low_precision_layernorm"
172+
173+
self.wte = VocabParallelEmbedding(config.vocab_size,
174+
config.d_model,
175+
perform_initialization=False)
176+
self.blocks = nn.ModuleList(
177+
[MPTBlock(config) for _ in range(config.n_layers)])
178+
self.norm_f = nn.LayerNorm(config.d_model)
179+
if config.no_bias:
180+
for module in self.modules():
181+
if hasattr(module, "bias"):
182+
if isinstance(module.bias, nn.Parameter):
183+
# Remove the bias term in Linear and LayerNorm.
184+
module.register_parameter("bias", None)
185+
186+
def forward(
187+
self,
188+
input_ids: torch.Tensor,
189+
position_ids: torch.Tensor,
190+
kv_caches: List[KVCache],
191+
input_metadata: InputMetadata,
192+
cache_events: Optional[List[torch.cuda.Event]],
193+
) -> torch.Tensor:
194+
hidden_states = self.wte(input_ids)
195+
for i in range(len(self.blocks)):
196+
if cache_events is None:
197+
cache_event = None
198+
else:
199+
cache_event = cache_events[i]
200+
block = self.blocks[i]
201+
hidden_states = block(
202+
position_ids,
203+
hidden_states,
204+
kv_caches[i],
205+
input_metadata,
206+
cache_event,
207+
)
208+
hidden_states = self.norm_f(hidden_states)
209+
return hidden_states
210+
211+
212+
class MPTForCausalLM(nn.Module):
213+
214+
def __init__(self, config: MPTConfig):
215+
super().__init__()
216+
self.config = config
217+
assert config.tie_word_embeddings
218+
219+
self.transformer = MPTModel(config)
220+
# TODO(zhuohan): create a new weight after implementing pipeline
221+
# parallelism
222+
self.lm_head_weight = self.transformer.wte.weight
223+
self.sampler = Sampler(config.vocab_size)
224+
225+
def forward(
226+
self,
227+
input_ids: torch.Tensor,
228+
positions: torch.Tensor,
229+
kv_caches: List[KVCache],
230+
input_metadata: InputMetadata,
231+
cache_events: Optional[List[torch.cuda.Event]],
232+
) -> Dict[int, SequenceOutputs]:
233+
hidden_states = self.transformer(input_ids, positions, kv_caches,
234+
input_metadata, cache_events)
235+
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
236+
input_metadata)
237+
return next_tokens
238+
239+
_column_parallel_weights = ["wte.weight", "up_proj.weight", "up_proj.bias"]
240+
_row_parallel_weights = ["out_proj.weight", "down_proj.weight"]
241+
242+
def load_weights(self,
243+
model_name_or_path: str,
244+
cache_dir: Optional[str] = None,
245+
use_np_cache: bool = False):
246+
tp_world_size = get_tensor_model_parallel_world_size()
247+
tp_rank = get_tensor_model_parallel_rank()
248+
state_dict = self.state_dict()
249+
for name, loaded_weight in hf_model_weights_iterator(
250+
model_name_or_path, cache_dir, use_np_cache):
251+
if "Wqkv" in name:
252+
# NOTE(woosuk): MPT's fused QKV has the shape of
253+
# [3 * num_heads * head_size, hidden_size].
254+
# When tensor model parallelism is used, we need to shard
255+
# the weight along the hidden dimension.
256+
total_num_heads = self.config.num_attention_heads
257+
hidden_size = self.config.hidden_size
258+
head_size = hidden_size // total_num_heads
259+
num_heads = total_num_heads // tp_world_size
260+
head_start = tp_rank * num_heads
261+
head_end = (tp_rank + 1) * num_heads
262+
263+
if name.endswith(".weight"):
264+
loaded_weight = loaded_weight.view(3, total_num_heads,
265+
head_size, hidden_size)
266+
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
267+
loaded_weight = loaded_weight.reshape(-1, hidden_size)
268+
elif name.endswith(".bias"):
269+
loaded_weight = loaded_weight.view(3, total_num_heads,
270+
head_size)
271+
loaded_weight = loaded_weight[:, head_start:head_end, :]
272+
loaded_weight = loaded_weight.reshape(-1)
273+
else:
274+
raise ValueError(f"Unexpected parameter name {name}")
275+
name = name.replace("Wqkv", "qkv_proj")
276+
param = state_dict[name]
277+
load_tensor_parallel_weights(param, loaded_weight, name,
278+
self._column_parallel_weights,
279+
self._row_parallel_weights, tp_rank)

vllm/transformers_utils/config.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from transformers import AutoConfig, PretrainedConfig
2+
3+
from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
4+
5+
_CONFIG_REGISTRY = {
6+
"mpt": MPTConfig,
7+
}
8+
9+
10+
def get_config(model: str) -> PretrainedConfig:
11+
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
12+
if config.model_type in _CONFIG_REGISTRY:
13+
config_class = _CONFIG_REGISTRY[config.model_type]
14+
config = config_class.from_pretrained(model)
15+
return config
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from vllm.transformers_utils.configs.mpt import MPTConfig
2+
3+
__all__ = [
4+
"MPTConfig",
5+
]

0 commit comments

Comments
 (0)