|
9 | 9 | from vllm.model_executor.layers.quantization.base_config import ( |
10 | 10 | QuantizationConfig) |
11 | 11 | from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead |
12 | | -from vllm.model_executor.utils import set_weight_attrs |
| 12 | +from vllm.model_executor.parameter import (BasevLLMParameter, |
| 13 | + ChannelQuantScaleParameter, |
| 14 | + GroupQuantScaleParameter, |
| 15 | + PackedvLLMParameter) |
13 | 16 |
|
14 | 17 | logger = init_logger(__name__) |
15 | 18 |
|
@@ -132,6 +135,7 @@ def create_weights( |
132 | 135 | **extra_weight_attrs, |
133 | 136 | ): |
134 | 137 | del output_size # Unused. |
| 138 | + weight_loader = extra_weight_attrs["weight_loader"] |
135 | 139 |
|
136 | 140 | if params_dtype != torch.float16: |
137 | 141 | raise ValueError( |
@@ -170,64 +174,64 @@ def create_weights( |
170 | 174 | "Each permutation group must reside on the same gpu") |
171 | 175 |
|
172 | 176 | # Quantized 4Bit weights packed into Int32. |
173 | | - qweight = Parameter( |
174 | | - torch.empty( |
| 177 | + qweight = PackedvLLMParameter( |
| 178 | + data=torch.empty( |
175 | 179 | input_size_per_partition // self.quant_config.tile_size, |
176 | 180 | output_size_per_partition * self.quant_config.tile_size // |
177 | 181 | self.quant_config.pack_factor, |
178 | 182 | device="cuda", |
179 | 183 | dtype=torch.int32, |
180 | 184 | ), |
181 | | - requires_grad=False, |
182 | | - ) |
183 | | - set_weight_attrs( |
184 | | - qweight, |
185 | | - { |
186 | | - "input_dim": 0, |
187 | | - "output_dim": 1, |
188 | | - "packed_dim": 1, |
189 | | - "pack_factor": self.quant_config.pack_factor, |
190 | | - "marlin_tile_size": self.quant_config.tile_size, |
191 | | - }, |
192 | | - ) |
| 185 | + input_dim=0, |
| 186 | + output_dim=1, |
| 187 | + packed_dim=1, |
| 188 | + packed_factor=self.quant_config.pack_factor, |
| 189 | + marlin_tile_size=self.quant_config.tile_size, |
| 190 | + weight_loader=weight_loader) |
193 | 191 |
|
194 | 192 | # Determine if channelwise or not |
195 | 193 | input_groups = (1 if self.quant_config.group_size == -1 else |
196 | 194 | input_size_per_partition // |
197 | 195 | self.quant_config.group_size) |
198 | 196 |
|
199 | | - scales = Parameter( |
| 197 | + weight_scale_args = { |
| 198 | + "data": |
200 | 199 | torch.empty( |
201 | 200 | input_groups, |
202 | 201 | output_size_per_partition, |
203 | 202 | device="cuda", |
204 | 203 | dtype=params_dtype, |
205 | 204 | ), |
206 | | - requires_grad=False, |
207 | | - ) |
208 | | - set_weight_attrs( |
209 | | - scales, |
210 | | - { |
211 | | - "input_dim": None if input_groups == 1 else 0, |
212 | | - "output_dim": 1, |
213 | | - }, |
214 | | - ) |
| 205 | + "weight_loader": |
| 206 | + weight_loader |
| 207 | + } |
| 208 | + if input_groups == 1: |
| 209 | + scales = ChannelQuantScaleParameter(output_dim=1, |
| 210 | + **weight_scale_args) |
| 211 | + else: |
| 212 | + scales = GroupQuantScaleParameter(output_dim=1, |
| 213 | + input_dim=0, |
| 214 | + **weight_scale_args) |
215 | 215 |
|
216 | 216 | # Allocate workspace (Used for internal locking mechanism) |
217 | 217 | max_workspace_size = ( |
218 | 218 | output_size_per_partition // |
219 | 219 | self.quant_config.min_n_threads) * self.quant_config.max_parallel |
220 | | - workspace = Parameter(torch.zeros(max_workspace_size, |
221 | | - device="cuda", |
222 | | - dtype=torch.int), |
223 | | - requires_grad=False) |
| 220 | + |
| 221 | + workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, |
| 222 | + device="cuda", |
| 223 | + dtype=torch.int), |
| 224 | + weight_loader=weight_loader) |
224 | 225 |
|
225 | 226 | layer.register_parameter("B", qweight) |
226 | | - set_weight_attrs(qweight, extra_weight_attrs) |
227 | 227 | layer.register_parameter("s", scales) |
228 | | - set_weight_attrs(scales, extra_weight_attrs) |
229 | 228 | layer.register_parameter("workspace", workspace) |
230 | | - set_weight_attrs(workspace, extra_weight_attrs) |
| 229 | + |
| 230 | + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: |
| 231 | + # required by torch.compile |
| 232 | + layer.B = Parameter(layer.B.data, requires_grad=False) |
| 233 | + layer.s = Parameter(layer.s.data, requires_grad=False) |
| 234 | + layer.workspace = Parameter(layer.workspace.data, requires_grad=False) |
231 | 235 |
|
232 | 236 | def apply( |
233 | 237 | self, |
|
0 commit comments