Skip to content

Commit 726efc6

Browse files
authored
[Quantization][V1] BitsAndBytes support V1 (#15611)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent bd45912 commit 726efc6

File tree

7 files changed

+52
-24
lines changed

7 files changed

+52
-24
lines changed

tests/models/encoder_decoder/vision_language/test_mllama.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,6 @@ def test_bnb_regression(
425425
max_model_len=4096,
426426
max_num_seqs=2,
427427
quantization="bitsandbytes",
428-
load_format="bitsandbytes",
429428
)
430429
sampling_params = SamplingParams(
431430
temperature=0,

tests/models/test_transformers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def test_distributed(
7272
"meta-llama/Llama-3.2-1B-Instruct",
7373
{
7474
"quantization": "bitsandbytes",
75-
"load_format": "bitsandbytes",
7675
},
7776
),
7877
])

tests/quantization/test_bitsandbytes.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,6 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None:
101101
"--enable-prefix-caching",
102102
"--quantization",
103103
"bitsandbytes",
104-
"--load-format",
105-
"bitsandbytes",
106104
"--gpu-memory-utilization",
107105
"0.7",
108106
]
@@ -137,7 +135,6 @@ def validate_generated_texts(hf_runner,
137135
# when using distributed inference
138136
with vllm_runner(model_name,
139137
quantization='bitsandbytes',
140-
load_format='bitsandbytes',
141138
tensor_parallel_size=vllm_tp_size,
142139
enforce_eager=False) as llm:
143140
vllm_outputs = llm.generate_greedy(prompts, 8)

vllm/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -682,8 +682,9 @@ def _verify_cuda_graph(self) -> None:
682682

683683
def _verify_bnb_config(self) -> None:
684684
"""
685-
The current version of bitsandbytes (0.44.0) with 8-bit models does not
685+
The current version of bitsandbytes (0.45.3) with 8-bit models does not
686686
yet support CUDA graph.
687+
# TODO Remove this when bitsandbytes supports.
687688
"""
688689
is_bitsandbytes = self.quantization == "bitsandbytes"
689690
has_quantization_config = (getattr(self.hf_config,
@@ -698,8 +699,9 @@ def _verify_bnb_config(self) -> None:
698699
not self.enforce_eager,
699700
]):
700701
logger.warning(
701-
"CUDA graph is not supported on BitAndBytes 8bit yet, "
702+
"CUDA graph is not supported on BitsAndBytes 8bit yet, "
702703
"fallback to the eager mode.")
704+
703705
self.enforce_eager = True
704706

705707
def _verify_with_expert_parallelism(self) -> None:

vllm/engine/arg_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1616,7 +1616,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
16161616
return False
16171617

16181618
# Some quantization is not compatible with torch.compile.
1619-
V1_UNSUPPORTED_QUANT = ["bitsandbytes", "gguf"]
1619+
V1_UNSUPPORTED_QUANT = ["gguf"]
16201620
if model_config.quantization in V1_UNSUPPORTED_QUANT:
16211621
_raise_or_fallback(
16221622
feature_name=f"--quantization {model_config.quantization}",

vllm/model_executor/layers/quantization/bitsandbytes.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
set_weight_attrs)
1010
from vllm.model_executor.layers.quantization.base_config import (
1111
QuantizationConfig)
12+
from vllm.utils import direct_register_custom_op
1213

1314

1415
class BitsAndBytesConfig(QuantizationConfig):
@@ -321,9 +322,6 @@ def _apply_4bit_weight(
321322
x: torch.Tensor,
322323
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
323324

324-
# only load the bitsandbytes module when needed
325-
from bitsandbytes import matmul_4bit
326-
327325
original_type = x.dtype
328326
original_shape = x.shape
329327
reshape_after_matmul = False
@@ -343,19 +341,7 @@ def _apply_4bit_weight(
343341
out_dim_1,
344342
dtype=torch.bfloat16,
345343
device=x.device)
346-
347-
current_index = 0
348-
for i in range(len(quant_states)):
349-
output_size = quant_states[i].shape[0]
350-
# It is more efficient to use out kwarg like
351-
# matmul_4bit(..., out = ...). Infeasible now due to the bug
352-
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
353-
# Need to change after the bug is fixed.
354-
out[:, current_index:current_index + output_size] = matmul_4bit(
355-
bf_x, qweight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
356-
357-
current_index += output_size
358-
344+
apply_bnb_4bit(bf_x, qweight, offsets, out)
359345
out = out.to(original_type)
360346

361347
if reshape_after_matmul:
@@ -365,3 +351,46 @@ def _apply_4bit_weight(
365351
out += bias
366352

367353
return out
354+
355+
356+
def _apply_bnb_4bit(
357+
x: torch.Tensor,
358+
weight: torch.Tensor,
359+
offsets: torch.Tensor,
360+
out: torch.Tensor,
361+
) -> None:
362+
# only load the bitsandbytes module when needed
363+
from bitsandbytes import matmul_4bit
364+
quant_states = weight.bnb_quant_state
365+
current_index = 0
366+
for i in range(len(quant_states)):
367+
output_size = quant_states[i].shape[0]
368+
# It is more efficient to use out kwarg like
369+
# matmul_4bit(..., out = ...). Infeasible now due to the bug
370+
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
371+
# Need to change after the bug is fixed.
372+
out[:, current_index:current_index + output_size] = matmul_4bit(
373+
x, weight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
374+
current_index += output_size
375+
376+
377+
def _apply_bnb_4bit_fake(
378+
x: torch.Tensor,
379+
weight: torch.Tensor,
380+
offsets: torch.Tensor,
381+
out: torch.Tensor,
382+
) -> None:
383+
return
384+
385+
386+
try:
387+
direct_register_custom_op(
388+
op_name="apply_bnb_4bit",
389+
op_func=_apply_bnb_4bit,
390+
mutates_args=["out"],
391+
fake_impl=_apply_bnb_4bit_fake,
392+
)
393+
apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit
394+
395+
except AttributeError as error:
396+
raise error

vllm/model_executor/model_loader/loader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,6 +1259,8 @@ def _load_weights(self, model_config: ModelConfig,
12591259
pack_ratio)
12601260

12611261
offsets = np.concatenate(([0], np.cumsum(num_elements)))
1262+
# Make torch infer_schema happy
1263+
offsets = torch.tensor(offsets).cpu()
12621264
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
12631265

12641266
if load_8bit:

0 commit comments

Comments
 (0)