Skip to content

Commit 7076fa1

Browse files
authored
TP/quantization/weight loading refactor part 2 - Refactor quantized linear logic and extend quantization support to all models (#1622)
Refactor the tensor parallelism, quantization, and weight-loading codes. Summary of the new features enabled by this PR: - **All models** are able to be quantized with AWQ and SqueezeLLM, and [soon GPTQ](#1580). - Model loading code became much simpler. - Support model parallelism for all MQA/GQA models when the number of key/value heads is smaller than the tensor parallel size.
1 parent 660a7fc commit 7076fa1

36 files changed

+2159
-2508
lines changed

vllm/config.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@ def get_head_size(self) -> int:
140140
# FIXME(woosuk): This may not be true for all models.
141141
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
142142

143-
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
144-
"""Returns the number of KV heads per GPU worker."""
143+
def get_total_num_kv_heads(self) -> int:
144+
"""Returns the total number of KV heads."""
145145
# For GPTBigCode & Falcon:
146146
# NOTE: for falcon, when new_decoder_architecture is True, the
147147
# multi_query flag is ignored and we use n_head_kv for the number of
@@ -155,23 +155,34 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
155155
# Multi-query attention, only one KV head.
156156
# Currently, tensor parallelism is not supported in this case.
157157
return 1
158-
# For Falcon:
159-
if getattr(self.hf_config, "n_head_kv", None) is not None:
160-
return (self.hf_config.n_head_kv //
161-
parallel_config.tensor_parallel_size)
162-
if getattr(self.hf_config, "num_kv_heads", None) is not None:
163-
return (self.hf_config.num_kv_heads //
164-
parallel_config.tensor_parallel_size)
165-
# For LLaMA-2:
166-
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
167-
return (self.hf_config.num_key_value_heads //
168-
parallel_config.tensor_parallel_size)
169-
# For ChatGLM-2:
170-
if getattr(self.hf_config, "multi_query_group_num", None) is not None:
171-
return (self.hf_config.multi_query_group_num //
172-
parallel_config.tensor_parallel_size)
173-
total_num_attention_heads = self.hf_config.num_attention_heads
174-
return total_num_attention_heads // parallel_config.tensor_parallel_size
158+
159+
attributes = [
160+
# For Falcon:
161+
"n_head_kv",
162+
"num_kv_heads",
163+
# For LLaMA-2:
164+
"num_key_value_heads",
165+
# For ChatGLM:
166+
"multi_query_group_num",
167+
]
168+
for attr in attributes:
169+
num_kv_heads = getattr(self.hf_config, attr, None)
170+
if num_kv_heads is not None:
171+
return num_kv_heads
172+
173+
# For non-grouped-query attention models, the number of KV heads is
174+
# equal to the number of attention heads.
175+
return self.hf_config.num_attention_heads
176+
177+
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
178+
"""Returns the number of KV heads per GPU."""
179+
total_num_kv_heads = self.get_total_num_kv_heads()
180+
# If tensor parallelism is used, we divide the number of KV heads by
181+
# the tensor parallel size. We will replicate the KV heads in the
182+
# case where the number of KV heads is smaller than the tensor
183+
# parallel size so each GPU has at least one KV head.
184+
return max(1,
185+
total_num_kv_heads // parallel_config.tensor_parallel_size)
175186

176187
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
177188
total_num_hidden_layers = self.hf_config.num_hidden_layers

vllm/engine/async_llm_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,10 @@ def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
142142

143143
self._request_streams[request_id].finish()
144144

145-
def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]:
145+
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
146146
"""Get the new requests and finished requests to be
147147
sent to the engine."""
148-
new_requests: List[dict] = []
148+
new_requests: List[Dict] = []
149149
finished_requests: Set[str] = set()
150150

151151
while not self._finished_requests.empty():

0 commit comments

Comments
 (0)