9
9
10
10
from vllm import _custom_ops as ops
11
11
from vllm .logger import init_logger
12
- from vllm .model_executor .layers .activation import SiluAndMul
13
12
from vllm .model_executor .layers .fused_moe .layer import (FusedMoE ,
14
13
FusedMoEMethodBase )
15
14
from vllm .model_executor .layers .linear import LinearBase , LinearMethodBase
19
18
from vllm .model_executor .layers .vocab_parallel_embedding import (
20
19
VocabParallelEmbedding )
21
20
from vllm .model_executor .utils import set_weight_attrs
21
+ from vllm .utils import direct_register_custom_op
22
22
23
23
logger = init_logger (__name__ )
24
24
@@ -96,8 +96,8 @@ def get_quant_method(self, layer: torch.nn.Module,
96
96
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
97
97
98
98
99
- def _fuse_mul_mat (x : torch .Tensor , qweight : torch .Tensor ,
100
- qweight_type : int ) -> torch .Tensor :
99
+ def _fused_mul_mat_gguf (x : torch .Tensor , qweight : torch .Tensor ,
100
+ qweight_type : int ) -> torch .Tensor :
101
101
# HACK: when doing chunked prefill we don't generate output tokens
102
102
# so input to logits generator is empty which causes invalid parameter
103
103
if x .shape [0 ] == 0 :
@@ -130,6 +130,30 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
130
130
return y
131
131
132
132
133
+ def _fused_mul_mat_gguf_fake (
134
+ x : torch .Tensor ,
135
+ qweight : torch .Tensor ,
136
+ qweight_type : int ,
137
+ ) -> torch .Tensor :
138
+ return torch .empty (x .shape [0 ],
139
+ qweight .shape [0 ],
140
+ dtype = x .dtype ,
141
+ device = x .device )
142
+
143
+
144
+ try :
145
+ direct_register_custom_op (
146
+ op_name = "_fused_mul_mat_gguf" ,
147
+ op_func = _fused_mul_mat_gguf ,
148
+ mutates_args = [],
149
+ fake_impl = _fused_mul_mat_gguf_fake ,
150
+ )
151
+ fused_mul_mat_gguf = torch .ops .vllm ._fused_mul_mat_gguf
152
+
153
+ except AttributeError as error :
154
+ raise error
155
+
156
+
133
157
def _fused_moe_gguf (
134
158
x : torch .Tensor ,
135
159
w1 : torch .Tensor ,
@@ -138,8 +162,21 @@ def _fused_moe_gguf(
138
162
topk_ids : torch .Tensor ,
139
163
qweight_type : int ,
140
164
qweight_type2 : int ,
141
- act ,
165
+ activation : str ,
142
166
) -> torch .Tensor :
167
+
168
+ def act (x : torch .Tensor ):
169
+ d = x .shape [- 1 ] // 2
170
+ output_shape = (x .shape [:- 1 ] + (d , ))
171
+ out = torch .empty (output_shape , dtype = x .dtype , device = x .device )
172
+ if activation == "silu" :
173
+ torch .ops ._C .silu_and_mul (out , x )
174
+ elif activation == "gelu" :
175
+ torch .ops ._C .gelu_and_mul (out , x )
176
+ else :
177
+ raise ValueError (f"Unsupported activation: { activation } " )
178
+ return out
179
+
143
180
# lazy import to avoid triggering triton import in CPU backend
144
181
from vllm .model_executor .layers .fused_moe .fused_moe import (
145
182
moe_align_block_size )
@@ -189,12 +226,12 @@ def _fused_moe_gguf(
189
226
for ww , ii in zip (w , idx ):
190
227
expert_up = w1 [ii ]
191
228
192
- out = _fuse_mul_mat (inp , expert_up , qweight_type )
229
+ out = fused_mul_mat_gguf (inp , expert_up , qweight_type )
193
230
out = act (out )
194
231
195
232
expert_down = w2 [ii ]
196
- current_state = _fuse_mul_mat (out , expert_down ,
197
- qweight_type2 ).mul_ (ww )
233
+ current_state = fused_mul_mat_gguf (out , expert_down ,
234
+ qweight_type2 ).mul_ (ww )
198
235
if current_hidden_state is None :
199
236
current_hidden_state = current_state
200
237
else :
@@ -203,6 +240,78 @@ def _fused_moe_gguf(
203
240
return out_hidden_states
204
241
205
242
243
+ def _fused_moe_gguf_fake (
244
+ x : torch .Tensor ,
245
+ w1 : torch .Tensor ,
246
+ w2 : torch .Tensor ,
247
+ topk_weights : torch .Tensor ,
248
+ topk_ids : torch .Tensor ,
249
+ qweight_type : int ,
250
+ qweight_type2 : int ,
251
+ activation : str ,
252
+ ) -> torch .Tensor :
253
+ return torch .empty_like (x )
254
+
255
+
256
+ try :
257
+ direct_register_custom_op (
258
+ op_name = "_fused_moe_gguf" ,
259
+ op_func = _fused_moe_gguf ,
260
+ mutates_args = [],
261
+ fake_impl = _fused_moe_gguf_fake ,
262
+ )
263
+ fused_moe_gguf = torch .ops .vllm ._fused_moe_gguf
264
+
265
+ except AttributeError as error :
266
+ raise error
267
+
268
+
269
+ def _apply_gguf_embedding (
270
+ x : torch .Tensor ,
271
+ qweight : torch .Tensor ,
272
+ qweight_type : int ,
273
+ hidden_size : int ,
274
+ dtype : Optional [torch .dtype ] = None ,
275
+ ) -> torch .Tensor :
276
+ if qweight_type in UNQUANTIZED_TYPES :
277
+ return torch .embedding (qweight , x )
278
+ elif qweight_type in DEQUANT_TYPES :
279
+ block_size , type_size = gguf .GGML_QUANT_SIZES [qweight_type ]
280
+ x_flat = x .flatten ()
281
+ assert (hidden_size == qweight .shape [1 ] // type_size * block_size )
282
+ quant = torch .index_select (qweight , dim = 0 , index = x_flat )
283
+ dequant = ops .ggml_dequantize (quant , qweight_type , hidden_size ,
284
+ x_flat .shape [0 ], dtype )
285
+ return dequant .view (* x .shape , hidden_size )
286
+ else :
287
+ qweight_type = WeightType (qweight_type )
288
+ raise NotImplementedError (
289
+ f"Unsupported GGUF quantization type: { qweight_type } " )
290
+
291
+
292
+ def _apply_gguf_embedding_fake (
293
+ x : torch .Tensor ,
294
+ qweight : torch .Tensor ,
295
+ qweight_type : int ,
296
+ hidden_size : int ,
297
+ dtype : Optional [torch .dtype ] = None ,
298
+ ) -> torch .Tensor :
299
+ return torch .empty (x .shape [0 ], hidden_size , dtype = dtype , device = x .device )
300
+
301
+
302
+ try :
303
+ direct_register_custom_op (
304
+ op_name = "_apply_gguf_embedding" ,
305
+ op_func = _apply_gguf_embedding ,
306
+ mutates_args = [],
307
+ fake_impl = _apply_gguf_embedding_fake ,
308
+ )
309
+ apply_gguf_embedding = torch .ops .vllm ._apply_gguf_embedding
310
+
311
+ except AttributeError as error :
312
+ raise error
313
+
314
+
206
315
class GGUFLinearMethod (LinearMethodBase ):
207
316
"""Linear method for GGUF.
208
317
@@ -249,26 +358,76 @@ def create_weights(self, layer: torch.nn.Module,
249
358
set_weight_attrs (qweight_type , extra_weight_attrs )
250
359
layer .register_parameter ("qweight_type" , qweight_type )
251
360
361
+ def process_weights_after_loading (self , layer : torch .nn .Module ):
362
+ qweight_type = layer .qweight_type .weight_type
363
+ if not (qweight_type in UNQUANTIZED_TYPES
364
+ or qweight_type in DEQUANT_TYPES ):
365
+ qweight_type = WeightType (qweight_type )
366
+ raise ValueError (
367
+ f"Unsupported GGUF quantization type { qweight_type } in "
368
+ f"layer { layer } ." )
369
+ # For MergedColumnParallelLinear and QKVParallelLinear, we need to
370
+ # materialize the padded weight parameter for CUDA Graph compatibility.
371
+ self ._create_padded_weight_param (layer )
372
+
373
+ def _create_padded_weight_param (self , layer : torch .nn .Module ):
374
+ """Create padded weight parameter for GGUF MergedLinear layer."""
375
+ qweight = layer .qweight
376
+ shard_id_map = qweight .shard_id_map
377
+ shard_id = qweight .shard_id
378
+ if len (data_container := qweight .data_container ) > 1 :
379
+ dtype = {data .dtype for data in data_container }
380
+ assert len (dtype ) == 1 , ValueError (
381
+ f"Data container has mixed dtypes: { dtype } " )
382
+ dtype = next (iter (dtype ))
383
+ # concat dim0 and pad dim1
384
+ padded_side = max (x .size (1 ) for x in data_container )
385
+ concat_side = sum (x .size (0 ) for x in data_container )
386
+ # Pad the quantized weights to dense tensor, and create a map
387
+ # with the location of each shard in the padded tensor.
388
+ padded_data = torch .zeros ((concat_side , padded_side ),
389
+ dtype = dtype ,
390
+ device = qweight .device )
391
+ # (dim0_start, dim0_end, dim1_size)
392
+ shard_offset_map = dict [str , tuple [int , int , int ]]()
393
+ for idx in shard_id :
394
+ id_in_container = shard_id_map [idx ]
395
+ start = sum (
396
+ x .size (0 ) for x in data_container [:id_in_container ])
397
+ end = start + data_container [id_in_container ].size (0 )
398
+ size = data_container [id_in_container ].size (1 )
399
+ padded_data [start :end , :size ] = data_container [id_in_container ]
400
+ shard_offset_map [idx ] = (start , end , size )
401
+ qweight .data_container .clear ()
402
+ padded_param = Parameter (padded_data , requires_grad = False )
403
+ set_weight_attrs (padded_param , vars (qweight ))
404
+ set_weight_attrs (padded_param ,
405
+ {"shard_offset_map" : shard_offset_map })
406
+ layer .register_parameter ("qweight" , padded_param )
407
+
252
408
def apply (self ,
253
409
layer : torch .nn .Module ,
254
410
x : torch .Tensor ,
255
411
bias : Optional [torch .Tensor ] = None ) -> torch .Tensor :
256
- shard_id = getattr ( layer .qweight , " shard_id" , None )
412
+ shard_id = layer .qweight . shard_id
257
413
258
414
if shard_id :
259
415
# dequantize shard weights respectively
260
416
shard_id = ["q" , "k" , "v" ] if "q" in shard_id else shard_id
261
- qweight = layer .qweight . unbind ( 0 )
417
+ qweight = layer .qweight
262
418
result = []
263
419
for idx in shard_id :
264
- q_idx = layer .qweight .shard_id_map [idx ]
420
+ start , end , offset = layer .qweight .shard_offset_map [idx ]
265
421
qweight_type = layer .qweight_type .shard_weight_type [idx ]
266
- result .append (_fuse_mul_mat (x , qweight [q_idx ], qweight_type ))
422
+ result .append (
423
+ fused_mul_mat_gguf (
424
+ x , qweight [start :end , :offset ].contiguous (),
425
+ qweight_type ))
267
426
out = torch .cat (result , axis = 1 )
268
427
else :
269
428
qweight = layer .qweight
270
429
qweight_type = layer .qweight_type .weight_type
271
- out = _fuse_mul_mat (x , qweight , qweight_type )
430
+ out = fused_mul_mat_gguf (x , qweight , qweight_type )
272
431
if bias is not None :
273
432
out .add_ (bias )
274
433
return out
@@ -338,7 +497,6 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
338
497
339
498
set_weight_attrs (w2_qweight_type , extra_weight_attrs )
340
499
layer .register_parameter ("w2_qweight_type" , w2_qweight_type )
341
- self .act = SiluAndMul ()
342
500
343
501
def apply (
344
502
self ,
@@ -375,10 +533,10 @@ def apply(
375
533
custom_routing_function = custom_routing_function ,
376
534
scoring_func = scoring_func ,
377
535
e_score_correction_bias = e_score_correction_bias )
378
- return _fused_moe_gguf (x , layer .w13_qweight , layer .w2_qweight ,
379
- topk_weights , topk_ids ,
380
- layer .w13_qweight_type .weight_type ,
381
- layer .w2_qweight_type .weight_type , self . act )
536
+ return fused_moe_gguf (x , layer .w13_qweight , layer .w2_qweight ,
537
+ topk_weights , topk_ids ,
538
+ layer .w13_qweight_type .weight_type ,
539
+ layer .w2_qweight_type .weight_type , activation )
382
540
383
541
384
542
class GGUFEmbeddingMethod (GGUFLinearMethod ):
@@ -392,34 +550,15 @@ def embedding(self, layer: torch.nn.Module,
392
550
x : torch .Tensor ) -> torch .Tensor :
393
551
qweight = layer .qweight
394
552
qweight_type = layer .qweight_type .weight_type
553
+ hidden_size = qweight .tensor_shape [1 ]
395
554
396
- block_size , type_size = gguf .GGML_QUANT_SIZES [qweight_type ]
397
- hidden_size = qweight .shape [1 ] // type_size * block_size
398
- if qweight_type < 2 :
399
- return torch .embedding (qweight , x )
400
- x_flat = x .flatten ()
401
- quant = torch .index_select (qweight , dim = 0 , index = x_flat )
402
- dequant = ops .ggml_dequantize (quant , qweight_type , hidden_size ,
403
- x_flat .shape [0 ], self .params_dtype )
404
- return dequant .view (* x .shape , hidden_size )
555
+ return apply_gguf_embedding (x ,
556
+ qweight ,
557
+ qweight_type ,
558
+ hidden_size ,
559
+ dtype = self .params_dtype )
405
560
406
561
407
562
class GGUFUninitializedParameter (UninitializedParameter ):
408
563
cls_to_become = Parameter
409
564
data_container : list [torch .Tensor ]
410
-
411
- def materialize_nested (self ) -> Parameter :
412
- dtype = {data .dtype for data in self .data_container }
413
- assert len (dtype ) == 1 , ValueError (
414
- f"Data container has mixed dtypes: { dtype } " )
415
- dtype = next (iter (dtype ))
416
- nested_data = torch .nested .nested_tensor (self .data_container ,
417
- device = self .device ,
418
- dtype = dtype )
419
- self .data_container .clear ()
420
- param = torch .Tensor ._make_subclass (self .cls_to_become ,
421
- nested_data ,
422
- require_grad = False )
423
- for k , v in self .__dict__ .items ():
424
- setattr (param , k , v )
425
- return param
0 commit comments