9
9
set_weight_attrs )
10
10
from vllm .model_executor .layers .quantization .base_config import (
11
11
QuantizationConfig )
12
+ from vllm .utils import direct_register_custom_op
12
13
13
14
14
15
class BitsAndBytesConfig (QuantizationConfig ):
@@ -321,9 +322,6 @@ def _apply_4bit_weight(
321
322
x : torch .Tensor ,
322
323
bias : Optional [torch .Tensor ] = None ) -> torch .Tensor :
323
324
324
- # only load the bitsandbytes module when needed
325
- from bitsandbytes import matmul_4bit
326
-
327
325
original_type = x .dtype
328
326
original_shape = x .shape
329
327
reshape_after_matmul = False
@@ -343,19 +341,7 @@ def _apply_4bit_weight(
343
341
out_dim_1 ,
344
342
dtype = torch .bfloat16 ,
345
343
device = x .device )
346
-
347
- current_index = 0
348
- for i in range (len (quant_states )):
349
- output_size = quant_states [i ].shape [0 ]
350
- # It is more efficient to use out kwarg like
351
- # matmul_4bit(..., out = ...). Infeasible now due to the bug
352
- # https://github.com/TimDettmers/bitsandbytes/issues/1235.
353
- # Need to change after the bug is fixed.
354
- out [:, current_index :current_index + output_size ] = matmul_4bit (
355
- bf_x , qweight [offsets [i ]:offsets [i + 1 ]].t (), quant_states [i ])
356
-
357
- current_index += output_size
358
-
344
+ apply_bnb_4bit (bf_x , qweight , offsets , out )
359
345
out = out .to (original_type )
360
346
361
347
if reshape_after_matmul :
@@ -365,3 +351,46 @@ def _apply_4bit_weight(
365
351
out += bias
366
352
367
353
return out
354
+
355
+
356
+ def _apply_bnb_4bit (
357
+ x : torch .Tensor ,
358
+ weight : torch .Tensor ,
359
+ offsets : torch .Tensor ,
360
+ out : torch .Tensor ,
361
+ ) -> None :
362
+ # only load the bitsandbytes module when needed
363
+ from bitsandbytes import matmul_4bit
364
+ quant_states = weight .bnb_quant_state
365
+ current_index = 0
366
+ for i in range (len (quant_states )):
367
+ output_size = quant_states [i ].shape [0 ]
368
+ # It is more efficient to use out kwarg like
369
+ # matmul_4bit(..., out = ...). Infeasible now due to the bug
370
+ # https://github.com/TimDettmers/bitsandbytes/issues/1235.
371
+ # Need to change after the bug is fixed.
372
+ out [:, current_index :current_index + output_size ] = matmul_4bit (
373
+ x , weight [offsets [i ]:offsets [i + 1 ]].t (), quant_states [i ])
374
+ current_index += output_size
375
+
376
+
377
+ def _apply_bnb_4bit_fake (
378
+ x : torch .Tensor ,
379
+ weight : torch .Tensor ,
380
+ offsets : torch .Tensor ,
381
+ out : torch .Tensor ,
382
+ ) -> None :
383
+ return
384
+
385
+
386
+ try :
387
+ direct_register_custom_op (
388
+ op_name = "apply_bnb_4bit" ,
389
+ op_func = _apply_bnb_4bit ,
390
+ mutates_args = ["out" ],
391
+ fake_impl = _apply_bnb_4bit_fake ,
392
+ )
393
+ apply_bnb_4bit = torch .ops .vllm .apply_bnb_4bit
394
+
395
+ except AttributeError as error :
396
+ raise error
0 commit comments