Skip to content

Commit 521b35f

Browse files
authored
Support Microsoft Phi 1.5 (#1664)
1 parent cb08cd0 commit 521b35f

File tree

6 files changed

+320
-0
lines changed

6 files changed

+320
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
5959
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
6060
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
6161
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
62+
- Phi-1.5 (`microsoft/phi-1_5`, etc.)
6263
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
6364

6465
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ pandas # Required for Ray data.
55
pyarrow # Required for Ray data.
66
sentencepiece # Required for LLaMA tokenizer.
77
numpy
8+
einops # Required for phi-1_5
89
torch >= 2.1.0
910
transformers >= 4.34.0 # Required for Mistral.
1011
xformers >= 0.0.22.post7 # Required for CUDA 12.1.

tests/models/test_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"EleutherAI/pythia-70m",
1616
"bigscience/bloom-560m",
1717
"mosaicml/mpt-7b",
18+
"microsoft/phi-1_5",
1819
]
1920

2021

vllm/model_executor/model_loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"MptForCausalLM": MPTForCausalLM,
3333
"MPTForCausalLM": MPTForCausalLM,
3434
"OPTForCausalLM": OPTForCausalLM,
35+
"PhiForCausalLM": PhiForCausalLM,
3536
"QWenLMHeadModel": QWenLMHeadModel,
3637
"RWForCausalLM": FalconForCausalLM,
3738
"YiForCausalLM": YiForCausalLM,

vllm/model_executor/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.model_executor.models.mistral import MistralForCausalLM
1313
from vllm.model_executor.models.mpt import MPTForCausalLM
1414
from vllm.model_executor.models.opt import OPTForCausalLM
15+
from vllm.model_executor.models.phi_1_5 import PhiForCausalLM
1516
from vllm.model_executor.models.qwen import QWenLMHeadModel
1617
from vllm.model_executor.models.chatglm import ChatGLMForCausalLM
1718
from vllm.model_executor.models.yi import YiForCausalLM
@@ -31,6 +32,7 @@
3132
"LlamaForCausalLM",
3233
"MPTForCausalLM",
3334
"OPTForCausalLM",
35+
"PhiForCausalLM",
3436
"QWenLMHeadModel",
3537
"MistralForCausalLM",
3638
"YiForCausalLM",
Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
# coding=utf-8
2+
# Adapted from
3+
# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
4+
# Copyright 2023 The vLLM team.
5+
# Copyright (c) Microsoft Corporation.
6+
# Licensed under the MIT license.
7+
#
8+
# BSD 3-Clause License
9+
#
10+
# Copyright (c) 2022, Tri Dao, [email protected].
11+
# All rights reserved.
12+
#
13+
# Redistribution and use in source and binary forms, with or without
14+
# modification, are permitted provided that the following conditions are met:
15+
#
16+
# * Redistributions of source code must retain the above copyright notice, this
17+
# list of conditions and the following disclaimer.
18+
#
19+
# * Redistributions in binary form must reproduce the above copyright notice,
20+
# this list of conditions and the following disclaimer in the documentation
21+
# and/or other materials provided with the distribution.
22+
#
23+
# * Neither the name of the copyright holder nor the names of its
24+
# contributors may be used to endorse or promote products derived from
25+
# this software without specific prior written permission.
26+
#
27+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
28+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
29+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
30+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
31+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
32+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
33+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
34+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
35+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
36+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
37+
"""Inference-only Phi-1.5 model compatible with HuggingFace weights.
38+
39+
The input of the model is flattened to a 1D tensor of tokens. The model uses
40+
InputMetadata to extract the original 2D shape of the input.
41+
"""
42+
from typing import List, Optional, Tuple
43+
44+
import torch
45+
from torch import nn
46+
from transformers import PretrainedConfig
47+
48+
from vllm.model_executor.input_metadata import InputMetadata
49+
from vllm.model_executor.layers.activation import get_act_fn
50+
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
51+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
52+
LinearMethodBase,
53+
QKVParallelLinear,
54+
RowParallelLinear)
55+
from vllm.model_executor.layers.sampler import Sampler
56+
from vllm.model_executor.layers.vocab_parallel_embedding import (
57+
VocabParallelEmbedding, ParallelLMHead)
58+
from vllm.model_executor.parallel_utils.parallel_state import (
59+
get_tensor_model_parallel_world_size)
60+
from vllm.model_executor.weight_utils import (default_weight_loader,
61+
hf_model_weights_iterator)
62+
from vllm.sequence import SamplerOutput
63+
64+
KVCache = Tuple[torch.Tensor, torch.Tensor]
65+
66+
67+
class PhiEmbedding(nn.Module):
68+
69+
def __init__(self, config: PretrainedConfig):
70+
super().__init__()
71+
72+
self.wte = VocabParallelEmbedding(
73+
config.vocab_size,
74+
config.hidden_size,
75+
)
76+
77+
def forward(self, input_ids: torch.LongTensor):
78+
return self.wte(input_ids)
79+
80+
81+
class PhiAttention(nn.Module):
82+
83+
def __init__(self,
84+
config: PretrainedConfig,
85+
linear_method: Optional[LinearMethodBase] = None):
86+
super().__init__()
87+
self.total_num_heads = config.num_attention_heads
88+
self.hidden_size = config.hidden_size
89+
self.head_size = self.hidden_size // self.total_num_heads
90+
91+
tensor_model_parallel_world_size = (
92+
get_tensor_model_parallel_world_size())
93+
assert self.total_num_heads % tensor_model_parallel_world_size == 0
94+
self.num_heads = (self.total_num_heads //
95+
tensor_model_parallel_world_size)
96+
97+
# pylint: disable=C0103
98+
self.Wqkv = QKVParallelLinear(
99+
self.hidden_size,
100+
self.head_size,
101+
self.total_num_heads,
102+
linear_method=linear_method,
103+
)
104+
self.qkv_proj = QKVParallelLinear(
105+
config.hidden_size,
106+
self.head_size,
107+
self.total_num_heads,
108+
bias=False,
109+
linear_method=linear_method,
110+
)
111+
self.out_proj = RowParallelLinear(
112+
self.hidden_size,
113+
self.hidden_size,
114+
linear_method=linear_method,
115+
)
116+
117+
scaling = self.head_size**-0.5
118+
rotary_dim = config.rotary_dim
119+
assert rotary_dim % 2 == 0
120+
121+
# pylint: disable=C0301
122+
# Refer to:
123+
# https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
124+
rope_theta = 10000
125+
max_position_embeddings = getattr(config, "n_positions", 2048)
126+
self.attn = PagedAttentionWithRoPE(
127+
self.num_heads,
128+
self.head_size,
129+
scaling,
130+
rotary_dim,
131+
base=rope_theta,
132+
max_position=max_position_embeddings)
133+
134+
def forward(
135+
self,
136+
position_ids: torch.Tensor,
137+
hidden_states: torch.Tensor,
138+
kv_cache: KVCache,
139+
input_metadata: InputMetadata,
140+
cache_event: Optional[torch.cuda.Event],
141+
) -> torch.Tensor:
142+
qkv, _ = self.Wqkv(hidden_states)
143+
q, k, v = qkv.chunk(chunks=3, dim=-1)
144+
k_cache, v_cache = kv_cache
145+
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
146+
input_metadata, cache_event)
147+
output, _ = self.out_proj(attn_output)
148+
return output
149+
150+
151+
class PhiMLP(nn.Module):
152+
153+
def __init__(self,
154+
config: PretrainedConfig,
155+
linear_method: Optional[LinearMethodBase] = None):
156+
super().__init__()
157+
158+
n_inner = getattr(config, "n_inner", None)
159+
n_inner = n_inner if n_inner is not None else 4 * config.hidden_size
160+
161+
self.fc1 = ColumnParallelLinear(
162+
config.hidden_size,
163+
n_inner,
164+
linear_method=linear_method,
165+
)
166+
self.fc2 = RowParallelLinear(
167+
n_inner,
168+
config.hidden_size,
169+
linear_method=linear_method,
170+
)
171+
self.act = get_act_fn(config.activation_function)
172+
173+
def forward(self, hidden_states):
174+
hidden_states, _ = self.fc1(hidden_states)
175+
hidden_states = self.act(hidden_states)
176+
hidden_states, _ = self.fc2(hidden_states)
177+
return hidden_states
178+
179+
180+
class PhiLayer(nn.Module):
181+
182+
def __init__(self,
183+
config: PretrainedConfig,
184+
linear_method: Optional[LinearMethodBase] = None):
185+
super().__init__()
186+
self.ln = nn.LayerNorm(config.hidden_size,
187+
eps=config.layer_norm_epsilon)
188+
self.mixer = PhiAttention(config, linear_method)
189+
self.mlp = PhiMLP(config, linear_method)
190+
191+
def forward(
192+
self,
193+
position_ids: torch.Tensor,
194+
hidden_states: torch.Tensor,
195+
kv_cache: KVCache,
196+
input_metadata: InputMetadata,
197+
cache_event: Optional[torch.cuda.Event],
198+
) -> torch.Tensor:
199+
residual = hidden_states
200+
hidden_states = self.ln(hidden_states)
201+
attn_outputs = self.mixer(
202+
position_ids=position_ids,
203+
hidden_states=hidden_states,
204+
kv_cache=kv_cache,
205+
input_metadata=input_metadata,
206+
cache_event=cache_event,
207+
)
208+
feed_forward_hidden_states = self.mlp(hidden_states)
209+
hidden_states = attn_outputs + feed_forward_hidden_states + residual
210+
return hidden_states
211+
212+
213+
class PhiCausalLMHead(nn.Module):
214+
215+
def __init__(self, config: PretrainedConfig):
216+
super().__init__()
217+
self.ln = nn.LayerNorm(config.hidden_size,
218+
eps=config.layer_norm_epsilon)
219+
self.linear = ParallelLMHead(config.vocab_size,
220+
config.hidden_size,
221+
bias=True)
222+
self.sampler = Sampler(config.vocab_size)
223+
224+
def forward(
225+
self,
226+
hidden_states: torch.Tensor,
227+
input_metadata: InputMetadata,
228+
):
229+
hidden_states = self.ln(hidden_states)
230+
next_tokens = self.sampler(self.linear.weight, hidden_states,
231+
input_metadata, self.linear.bias)
232+
return next_tokens
233+
234+
235+
class PhiModel(nn.Module):
236+
237+
def __init__(self,
238+
config: PretrainedConfig,
239+
linear_method: Optional[LinearMethodBase] = None):
240+
super().__init__()
241+
self.config = config
242+
self.linear_method = linear_method
243+
self.embd = PhiEmbedding(config)
244+
self.h = nn.ModuleList([
245+
PhiLayer(config, linear_method)
246+
for _ in range(config.num_hidden_layers)
247+
])
248+
249+
def forward(
250+
self,
251+
input_ids: torch.Tensor,
252+
positions: torch.Tensor,
253+
kv_caches: List[KVCache],
254+
input_metadata: InputMetadata,
255+
cache_events: Optional[List[torch.cuda.Event]],
256+
) -> SamplerOutput:
257+
hidden_states = self.embd(input_ids)
258+
for i in range(self.config.num_hidden_layers):
259+
if cache_events is None:
260+
cache_event = None
261+
else:
262+
cache_event = cache_events[i]
263+
layer = self.h[i]
264+
hidden_states = layer(
265+
positions,
266+
hidden_states,
267+
kv_caches[i],
268+
input_metadata,
269+
cache_event,
270+
)
271+
return hidden_states
272+
273+
274+
class PhiForCausalLM(nn.Module):
275+
276+
def __init__(self,
277+
config: PretrainedConfig,
278+
linear_method: Optional[LinearMethodBase] = None):
279+
super().__init__()
280+
self.config = config
281+
self.linear_method = linear_method
282+
283+
self.transformer = PhiModel(config, linear_method)
284+
self.lm_head = PhiCausalLMHead(config)
285+
286+
def forward(
287+
self,
288+
input_ids: torch.Tensor,
289+
positions: torch.Tensor,
290+
kv_caches: List[KVCache],
291+
input_metadata: InputMetadata,
292+
cache_events: Optional[List[torch.cuda.Event]],
293+
) -> SamplerOutput:
294+
hidden_states = self.transformer(input_ids, positions, kv_caches,
295+
input_metadata, cache_events)
296+
lm_logits = self.lm_head(hidden_states, input_metadata)
297+
return lm_logits
298+
299+
def load_weights(self,
300+
model_name_or_path: str,
301+
cache_dir: Optional[str] = None,
302+
load_format: str = "auto",
303+
revision: Optional[str] = None):
304+
params_dict = dict(self.named_parameters())
305+
for name, loaded_weight in hf_model_weights_iterator(
306+
model_name_or_path, cache_dir, load_format, revision):
307+
if "rotary_emb.inv_freq" in name:
308+
continue
309+
310+
# pylint: disable=E1136
311+
param = params_dict[name]
312+
weight_loader = getattr(param, "weight_loader",
313+
default_weight_loader)
314+
weight_loader(param, loaded_weight)

0 commit comments

Comments
 (0)