1
1
"""Custom activation functions."""
2
+ from typing import Optional
3
+
2
4
import torch
3
5
import torch .nn as nn
4
6
5
7
from vllm import activation_ops
8
+ from vllm .model_executor .layers .quantization import QuantizationConfig
6
9
7
10
8
11
class SiluAndMul (nn .Module ):
@@ -39,6 +42,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
42
return out
40
43
41
44
45
+ class ScaledActivation (nn .Module ):
46
+ """An activation function with post-scale parameters.
47
+
48
+ This is used for some quantization methods like AWQ.
49
+ """
50
+
51
+ def __init__ (
52
+ self ,
53
+ act_module : nn .Module ,
54
+ hidden_size : int ,
55
+ params_dtype : torch .dtype ,
56
+ ):
57
+ super ().__init__ ()
58
+ self .act = act_module
59
+ self .scales = nn .Parameter (
60
+ torch .empty (hidden_size , dtype = params_dtype , device = "cuda" ))
61
+
62
+ def forward (self , x : torch .Tensor ):
63
+ return self .act (x ) / self .scales
64
+
65
+
42
66
_ACTIVATION_REGISTRY = {
43
67
"gelu" : nn .GELU (),
44
68
"gelu_fast" : FastGELU (),
@@ -48,9 +72,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
72
}
49
73
50
74
51
- def get_act_fn (act_fn : str ) -> nn .Module :
75
+ def get_act_fn (
76
+ act_fn_name : str ,
77
+ quant_config : Optional [QuantizationConfig ] = None ,
78
+ intermediate_size : Optional [int ] = None ,
79
+ ) -> nn .Module :
52
80
"""Get an activation function by name."""
53
- act_fn = act_fn .lower ()
54
- if act_fn in _ACTIVATION_REGISTRY :
55
- return _ACTIVATION_REGISTRY [act_fn ]
56
- raise ValueError (f"Activation function { act_fn !r} is not supported." )
81
+ act_fn_name = act_fn_name .lower ()
82
+ if act_fn_name not in _ACTIVATION_REGISTRY :
83
+ raise ValueError (
84
+ f"Activation function { act_fn_name !r} is not supported." )
85
+
86
+ act_fn = _ACTIVATION_REGISTRY [act_fn_name ]
87
+ if quant_config is not None :
88
+ if act_fn_name in quant_config .get_scaled_act_names ():
89
+ if intermediate_size is None :
90
+ raise ValueError (
91
+ "intermediate_size must be specified for scaled "
92
+ "activation functions." )
93
+ return ScaledActivation (
94
+ act_fn ,
95
+ intermediate_size ,
96
+ params_dtype = torch .get_default_dtype (),
97
+ )
98
+ return act_fn
0 commit comments