Skip to content

Commit 8d17774

Browse files
authored
Add AWQ support for all models (#1714)
1 parent e946260 commit 8d17774

File tree

13 files changed

+90
-17
lines changed

13 files changed

+90
-17
lines changed

vllm/model_executor/layers/activation.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
"""Custom activation functions."""
2+
from typing import Optional
3+
24
import torch
35
import torch.nn as nn
46

57
from vllm import activation_ops
8+
from vllm.model_executor.layers.quantization import QuantizationConfig
69

710

811
class SiluAndMul(nn.Module):
@@ -39,6 +42,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
3942
return out
4043

4144

45+
class ScaledActivation(nn.Module):
46+
"""An activation function with post-scale parameters.
47+
48+
This is used for some quantization methods like AWQ.
49+
"""
50+
51+
def __init__(
52+
self,
53+
act_module: nn.Module,
54+
hidden_size: int,
55+
params_dtype: torch.dtype,
56+
):
57+
super().__init__()
58+
self.act = act_module
59+
self.scales = nn.Parameter(
60+
torch.empty(hidden_size, dtype=params_dtype, device="cuda"))
61+
62+
def forward(self, x: torch.Tensor):
63+
return self.act(x) / self.scales
64+
65+
4266
_ACTIVATION_REGISTRY = {
4367
"gelu": nn.GELU(),
4468
"gelu_fast": FastGELU(),
@@ -48,9 +72,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4872
}
4973

5074

51-
def get_act_fn(act_fn: str) -> nn.Module:
75+
def get_act_fn(
76+
act_fn_name: str,
77+
quant_config: Optional[QuantizationConfig] = None,
78+
intermediate_size: Optional[int] = None,
79+
) -> nn.Module:
5280
"""Get an activation function by name."""
53-
act_fn = act_fn.lower()
54-
if act_fn in _ACTIVATION_REGISTRY:
55-
return _ACTIVATION_REGISTRY[act_fn]
56-
raise ValueError(f"Activation function {act_fn!r} is not supported.")
81+
act_fn_name = act_fn_name.lower()
82+
if act_fn_name not in _ACTIVATION_REGISTRY:
83+
raise ValueError(
84+
f"Activation function {act_fn_name!r} is not supported.")
85+
86+
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
87+
if quant_config is not None:
88+
if act_fn_name in quant_config.get_scaled_act_names():
89+
if intermediate_size is None:
90+
raise ValueError(
91+
"intermediate_size must be specified for scaled "
92+
"activation functions.")
93+
return ScaledActivation(
94+
act_fn,
95+
intermediate_size,
96+
params_dtype=torch.get_default_dtype(),
97+
)
98+
return act_fn

vllm/model_executor/layers/quantization/awq.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
6363
def get_linear_method(self) -> "AWQLinearMethod":
6464
return AWQLinearMethod(self)
6565

66+
def get_scaled_act_names(self) -> List[str]:
67+
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
68+
6669

6770
class AWQLinearMethod(LinearMethodBase):
6871
"""Linear method for AWQ.

vllm/model_executor/layers/quantization/base_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,11 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
5454
def get_linear_method(self) -> LinearMethodBase:
5555
"""Get the linear method to use for the quantized linear layer."""
5656
raise NotImplementedError
57+
58+
@abstractmethod
59+
def get_scaled_act_names(self) -> List[str]:
60+
"""Returns the activation function names that should be post-scaled.
61+
62+
For now, this is only used by AWQ.
63+
"""
64+
raise NotImplementedError

vllm/model_executor/layers/quantization/squeezellm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig":
5252
def get_linear_method(self) -> "SqueezeLLMLinearMethod":
5353
return SqueezeLLMLinearMethod(self)
5454

55+
def get_scaled_act_names(self) -> List[str]:
56+
return []
57+
5558

5659
class SqueezeLLMLinearMethod(LinearMethodBase):
5760
"""Linear method for SqueezeLLM.

vllm/model_executor/models/bloom.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ def __init__(
145145
4 * hidden_size,
146146
linear_method=linear_method,
147147
)
148-
self.act = get_act_fn("gelu")
148+
quant_config = getattr(linear_method, "quant_config", None)
149+
self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
149150
self.dense_4h_to_h = RowParallelLinear(
150151
4 * hidden_size,
151152
hidden_size,
@@ -154,7 +155,7 @@ def __init__(
154155

155156
def forward(self, x: torch.Tensor) -> torch.Tensor:
156157
x, _ = self.dense_h_to_4h(x)
157-
x = self.act(x)
158+
x = self.gelu_impl(x)
158159
x, _ = self.dense_4h_to_h(x)
159160
return x
160161

vllm/model_executor/models/falcon.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from transformers import FalconConfig as HF_FalconConfig
2828

2929
from vllm.model_executor.input_metadata import InputMetadata
30+
from vllm.model_executor.layers.activation import get_act_fn
3031
from vllm.model_executor.layers.attention import (PagedAttention,
3132
PagedAttentionWithALiBi,
3233
PagedAttentionWithRoPE)
@@ -131,6 +132,7 @@ def __init__(
131132
self.hidden_size,
132133
bias=config.bias,
133134
skip_bias_add=True,
135+
linear_method=linear_method,
134136
reduce_results=self.reduce_row_parallel_results)
135137

136138
self.use_rotary = config.rotary
@@ -206,7 +208,8 @@ def __init__(
206208
bias=config.bias,
207209
skip_bias_add=True,
208210
linear_method=linear_method)
209-
self.act = nn.GELU()
211+
quant_config = getattr(linear_method, "quant_config", None)
212+
self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
210213
self.reduce_row_parallel_results = not (config.new_decoder_architecture
211214
or config.parallel_attn)
212215
self.dense_4h_to_h = RowParallelLinear(

vllm/model_executor/models/gpt2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,9 @@ def __init__(
118118
bias=True,
119119
linear_method=linear_method,
120120
)
121-
self.act = get_act_fn(config.activation_function)
121+
quant_config = getattr(linear_method, "quant_config", None)
122+
self.act = get_act_fn(config.activation_function, quant_config,
123+
intermediate_size)
122124

123125
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
124126
hidden_states, _ = self.c_fc(hidden_states)

vllm/model_executor/models/gpt_bigcode.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,9 @@ def __init__(
137137
bias=True,
138138
linear_method=linear_method,
139139
)
140-
self.act = get_act_fn(config.activation_function)
140+
quant_config = getattr(linear_method, "quant_config", None)
141+
self.act = get_act_fn(config.activation_function, quant_config,
142+
intermediate_size)
141143

142144
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
143145
hidden_states, _ = self.c_fc(hidden_states)

vllm/model_executor/models/gpt_j.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ def __init__(
128128
hidden_size,
129129
linear_method=linear_method,
130130
)
131-
self.act = get_act_fn(config.activation_function)
131+
quant_config = getattr(linear_method, "quant_config", None)
132+
self.act = get_act_fn(config.activation_function, quant_config,
133+
intermediate_size)
132134

133135
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
134136
hidden_states, _ = self.fc_in(hidden_states)

vllm/model_executor/models/gpt_neox.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def __init__(
124124
config.hidden_size,
125125
linear_method=linear_method,
126126
)
127-
self.act = get_act_fn(config.hidden_act)
127+
quant_config = getattr(linear_method, "quant_config", None)
128+
self.act = get_act_fn(config.hidden_act, quant_config,
129+
config.intermediate_size)
128130

129131
def forward(self, hidden_states):
130132
hidden_states, _ = self.dense_h_to_4h(hidden_states)

0 commit comments

Comments
 (0)