Skip to content

Commit de60a3f

Browse files
authored
Added DeciLM-7b and DeciLM-7b-instruct (#2062)
1 parent 21d5daa commit de60a3f

File tree

5 files changed

+129
-0
lines changed

5 files changed

+129
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
5454
- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
5555
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
5656
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
57+
- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
5758
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
5859
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
5960
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)

docs/source/models/supported_models.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ Alongside each architecture, we include some popular models that use it.
2323
* - :code:`ChatGLMModel`
2424
- ChatGLM
2525
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
26+
* - :code:`DeciLMForCausalLM`
27+
- DeciLM
28+
- :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc.
2629
* - :code:`BloomForCausalLM`
2730
- BLOOM, BLOOMZ, BLOOMChat
2831
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.

tests/models/test_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"facebook/opt-125m",
99
"meta-llama/Llama-2-7b-hf",
1010
"mistralai/Mistral-7B-v0.1",
11+
"Deci/DeciLM-7b",
1112
"tiiuae/falcon-7b",
1213
"gpt2",
1314
"bigcode/tiny_starcoder_py",

vllm/model_executor/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
1818
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
1919
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
20+
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
2021
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
2122
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
2223
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# coding=utf-8
2+
# Adapted from
3+
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
4+
# Copyright 2023 DeciAI Research Team. All rights reserved.
5+
# Copyright 2023 The vLLM team.
6+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
7+
#
8+
# This code is based on MistralAI GPT-NeoX library and the GPT-NeoX
9+
# and OPT implementations in this library. It has been modified from its
10+
# original forms to accommodate minor architectural differences compared
11+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
12+
#
13+
# Licensed under the Apache License, Version 2.0 (the "License");
14+
# you may not use this file except in compliance with the License.
15+
# You may obtain a copy of the License at
16+
#
17+
# http://www.apache.org/licenses/LICENSE-2.0
18+
#
19+
# Unless required by applicable law or agreed to in writing, software
20+
# distributed under the License is distributed on an "AS IS" BASIS,
21+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22+
# See the License for the specific language governing permissions and
23+
# limitations under the License.
24+
"""Inference-only DeciLM model compatible with HuggingFace weights."""
25+
26+
from typing import Optional
27+
28+
import torch
29+
from transformers import PretrainedConfig
30+
31+
from vllm.model_executor.layers.linear import LinearMethodBase
32+
from vllm.model_executor.models.llama import LlamaForCausalLM
33+
from vllm.model_executor.weight_utils import (default_weight_loader,
34+
hf_model_weights_iterator)
35+
36+
37+
class DeciLMForCausalLM(LlamaForCausalLM):
38+
"""
39+
Implementation for https://huggingface.co/Deci/DeciLM-7b-instruct.
40+
Based on the llama executor.
41+
42+
The main difference is that DeciLM uses Variable Grouped Query Attention.
43+
The constant number of GQA heads in the decoder is overriden with a value
44+
per layer.
45+
46+
Usually, in the HuggingFace implementation, instead of
47+
"config.num_key_value_heads", we use
48+
"config.num_key_value_heads_per_layer[i]" which varies.
49+
50+
Currently, PagedAttention does not work well with variable GQA, so we
51+
normalize the weights upon loading, and use uniform GQA with the max value
52+
instead.
53+
"""
54+
55+
def __init__(
56+
self,
57+
config: Optional[PretrainedConfig] = None,
58+
linear_method: Optional[LinearMethodBase] = None,
59+
) -> None:
60+
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
61+
delattr(config, "num_key_value_heads_per_layer")
62+
super().__init__(config=config, linear_method=linear_method)
63+
64+
def load_weights(self,
65+
model_name_or_path: str,
66+
cache_dir: Optional[str] = None,
67+
load_format: str = "auto",
68+
revision: Optional[str] = None):
69+
stacked_params_mapping = [
70+
# (param_name, shard_name, shard_id)
71+
("qkv_proj", "q_proj", "q"),
72+
("qkv_proj", "k_proj", "k"),
73+
("qkv_proj", "v_proj", "v"),
74+
("gate_up_proj", "gate_proj", 0),
75+
("gate_up_proj", "up_proj", 1),
76+
]
77+
params_dict = dict(self.named_parameters())
78+
for name, loaded_weight in hf_model_weights_iterator(
79+
model_name_or_path, cache_dir, load_format, revision):
80+
if "rotary_emb.inv_freq" in name:
81+
continue
82+
83+
if "k_proj" in name or "v_proj" in name:
84+
loaded_weight = self._degroup_weight(loaded_weight)
85+
86+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
87+
if weight_name not in name:
88+
continue
89+
name = name.replace(weight_name, param_name)
90+
# Skip loading extra bias for GPTQ models.
91+
if name.endswith(".bias") and name not in params_dict:
92+
continue
93+
param = params_dict[name]
94+
weight_loader = param.weight_loader
95+
weight_loader(param, loaded_weight, shard_id)
96+
break
97+
else:
98+
# Skip loading extra bias for GPTQ models.
99+
if name.endswith(".bias") and name not in params_dict:
100+
continue
101+
param = params_dict[name]
102+
weight_loader = getattr(param, "weight_loader",
103+
default_weight_loader)
104+
weight_loader(param, loaded_weight)
105+
106+
def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor:
107+
hidden_size = self.config.hidden_size
108+
head_size = self.config.hidden_size // self.config.num_attention_heads
109+
target_num_kv_heads = self.config.num_key_value_heads
110+
num_kv_heads = loaded_weight.shape[0] // head_size
111+
n_repeats = target_num_kv_heads / num_kv_heads
112+
assert n_repeats == int(n_repeats)
113+
114+
n_repeats = int(n_repeats)
115+
loaded_weight = loaded_weight.view(num_kv_heads, head_size,
116+
hidden_size)
117+
loaded_weight = torch.repeat_interleave(loaded_weight,
118+
repeats=n_repeats,
119+
dim=0)
120+
loaded_weight = loaded_weight.reshape(target_num_kv_heads * head_size,
121+
hidden_size)
122+
123+
return loaded_weight

0 commit comments

Comments
 (0)