Skip to content

Commit b514d3c

Browse files
authored
Revert MptConfig to MPTConfig (#1668)
1 parent 7076fa1 commit b514d3c

File tree

6 files changed

+260
-26
lines changed

6 files changed

+260
-26
lines changed

vllm/model_executor/model_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
3030
"MistralForCausalLM": MistralForCausalLM,
3131
# transformers's mpt class has lower case
32-
"MptForCausalLM": MptForCausalLM,
33-
"MPTForCausalLM": MptForCausalLM,
32+
"MptForCausalLM": MPTForCausalLM,
33+
"MPTForCausalLM": MPTForCausalLM,
3434
"OPTForCausalLM": OPTForCausalLM,
3535
"QWenLMHeadModel": QWenLMHeadModel,
3636
"RWForCausalLM": FalconForCausalLM,

vllm/model_executor/models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from vllm.model_executor.models.internlm import InternLMForCausalLM
1111
from vllm.model_executor.models.llama import LlamaForCausalLM
1212
from vllm.model_executor.models.mistral import MistralForCausalLM
13-
from vllm.model_executor.models.mpt import MptForCausalLM
13+
from vllm.model_executor.models.mpt import MPTForCausalLM
1414
from vllm.model_executor.models.opt import OPTForCausalLM
1515
from vllm.model_executor.models.qwen import QWenLMHeadModel
1616
from vllm.model_executor.models.chatglm import ChatGLMForCausalLM
@@ -29,7 +29,7 @@
2929
"GPTNeoXForCausalLM",
3030
"InternLMForCausalLM",
3131
"LlamaForCausalLM",
32-
"MptForCausalLM",
32+
"MPTForCausalLM",
3333
"OPTForCausalLM",
3434
"QWenLMHeadModel",
3535
"MistralForCausalLM",

vllm/model_executor/models/mpt.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import torch
77
import torch.nn as nn
8-
from transformers import MptConfig
98

109
from vllm.model_executor.input_metadata import InputMetadata
1110
from vllm.model_executor.layers.activation import get_act_fn
@@ -22,6 +21,7 @@
2221
from vllm.model_executor.weight_utils import (default_weight_loader,
2322
hf_model_weights_iterator)
2423
from vllm.sequence import SamplerOutput
24+
from vllm.transformers_utils.configs.mpt import MPTConfig
2525

2626
KVCache = Tuple[torch.Tensor, torch.Tensor]
2727

@@ -39,21 +39,21 @@ def _get_alibi_slopes(
3939
return slopes
4040

4141

42-
class MptAttention(nn.Module):
42+
class MPTAttention(nn.Module):
4343

4444
def __init__(
4545
self,
46-
config: MptConfig,
46+
config: MPTConfig,
4747
linear_method: Optional[LinearMethodBase] = None,
4848
):
4949
super().__init__()
5050
self.d_model = config.d_model
5151
self.total_num_heads = config.n_heads
52-
self.clip_qkv = config.attn_config.clip_qkv
53-
self.qk_ln = config.attn_config.qk_ln
54-
self.alibi_bias_max = config.attn_config.alibi_bias_max
55-
assert not config.attn_config.prefix_lm
56-
assert config.attn_config.alibi
52+
self.clip_qkv = config.attn_config["clip_qkv"]
53+
self.qk_ln = config.attn_config["qk_ln"]
54+
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
55+
assert not config.attn_config["prefix_lm"]
56+
assert config.attn_config["alibi"]
5757

5858
# pylint: disable=invalid-name
5959
self.Wqkv = QKVParallelLinear(
@@ -113,11 +113,11 @@ def forward(
113113
return output
114114

115115

116-
class MptMLP(nn.Module):
116+
class MPTMLP(nn.Module):
117117

118118
def __init__(
119119
self,
120-
config: MptConfig,
120+
config: MPTConfig,
121121
linear_method: Optional[LinearMethodBase] = None,
122122
):
123123
super().__init__()
@@ -145,19 +145,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
145145
return x
146146

147147

148-
class MptBlock(nn.Module):
148+
class MPTBlock(nn.Module):
149149

150150
def __init__(
151151
self,
152-
config: MptConfig,
152+
config: MPTConfig,
153153
linear_method: Optional[LinearMethodBase] = None,
154154
):
155155
super().__init__()
156156
hidden_size = config.d_model
157157
self.norm_1 = nn.LayerNorm(hidden_size)
158-
self.attn = MptAttention(config, linear_method)
158+
self.attn = MPTAttention(config, linear_method)
159159
self.norm_2 = nn.LayerNorm(hidden_size)
160-
self.ffn = MptMLP(config, linear_method)
160+
self.ffn = MPTMLP(config, linear_method)
161161

162162
def forward(
163163
self,
@@ -182,11 +182,11 @@ def forward(
182182
return hidden_states
183183

184184

185-
class MptModel(nn.Module):
185+
class MPTModel(nn.Module):
186186

187187
def __init__(
188188
self,
189-
config: MptConfig,
189+
config: MPTConfig,
190190
linear_method: Optional[LinearMethodBase] = None,
191191
):
192192
super().__init__()
@@ -198,7 +198,7 @@ def __init__(
198198
config.d_model,
199199
)
200200
self.blocks = nn.ModuleList(
201-
[MptBlock(config, linear_method) for _ in range(config.n_layers)])
201+
[MPTBlock(config, linear_method) for _ in range(config.n_layers)])
202202
self.norm_f = nn.LayerNorm(config.d_model)
203203
if config.no_bias:
204204
for module in self.modules():
@@ -233,19 +233,19 @@ def forward(
233233
return hidden_states
234234

235235

236-
class MptForCausalLM(nn.Module):
236+
class MPTForCausalLM(nn.Module):
237237

238238
def __init__(
239239
self,
240-
config: MptConfig,
240+
config: MPTConfig,
241241
linear_method: Optional[LinearMethodBase] = None,
242242
):
243243
super().__init__()
244244
self.config = config
245245
assert config.tie_word_embeddings
246246
self.linear_method = linear_method
247247

248-
self.transformer = MptModel(config, linear_method)
248+
self.transformer = MPTModel(config, linear_method)
249249
self.lm_head_weight = self.transformer.wte.weight
250250
self.sampler = Sampler(config.vocab_size)
251251

vllm/transformers_utils/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from typing import Optional
22

3-
from transformers import AutoConfig, MptConfig, PretrainedConfig
3+
from transformers import AutoConfig, PretrainedConfig
44

55
from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
66

77
_CONFIG_REGISTRY = {
88
"aquila": AquilaConfig,
99
"baichuan": BaiChuanConfig,
1010
"chatglm": ChatGLMConfig,
11-
"mpt": MptConfig,
11+
"mpt": MPTConfig,
1212
"qwen": QWenConfig,
1313
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
1414
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)

vllm/transformers_utils/configs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from vllm.transformers_utils.configs.aquila import AquilaConfig
22
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
33
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
4+
from vllm.transformers_utils.configs.mpt import MPTConfig
45
from vllm.transformers_utils.configs.qwen import QWenConfig
56
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
67
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
@@ -12,6 +13,7 @@
1213
"AquilaConfig",
1314
"BaiChuanConfig",
1415
"ChatGLMConfig",
16+
"MPTConfig",
1517
"QWenConfig",
1618
"RWConfig",
1719
"YiConfig",

0 commit comments

Comments
 (0)