Skip to content

Commit 1f1b1bc

Browse files
authored
[V1][Quantization] Add CUDA graph compatible v1 GGUF support (#18646)
Signed-off-by: Isotr0py <[email protected]> Signed-off-by: Isotr0py <[email protected]>
1 parent 1f88dbd commit 1f1b1bc

File tree

5 files changed

+188
-59
lines changed

5 files changed

+188
-59
lines changed

tests/kernels/quantization/test_gguf.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from huggingface_hub import snapshot_download
99

1010
import vllm._custom_ops as ops
11-
from vllm.model_executor.layers.activation import SiluAndMul
1211
from vllm.model_executor.layers.fused_moe import fused_experts
1312
from vllm.model_executor.layers.quantization.gguf import _fused_moe_gguf
1413
from vllm.platforms import current_platform
@@ -176,12 +175,11 @@ def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype,
176175

177176
w2_dequant = torch.tensor(dequantize(w2.data, quant_type),
178177
device="cuda").to(dtype)
179-
act = SiluAndMul()
180178

181179
output = _fused_moe_gguf(x, torch.tensor(w13.data, device="cuda"),
182180
torch.tensor(w2.data,
183181
device="cuda"), topk_weights,
184-
topk_ids, quant_type, quant_type, act)
182+
topk_ids, quant_type, quant_type, "silu")
185183

186184
ref_output = fused_experts(x, w13_dequant, w2_dequant, topk_weights,
187185
topk_ids).reshape(output.shape)

tests/models/quantization/test_gguf.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,12 @@ def gguf_model(self):
7878
)
7979

8080
MODELS = [
81-
LLAMA_CONFIG, QWEN2_CONFIG, PHI3_CONFIG, GPT2_CONFIG, STABLELM_CONFIG,
82-
DOLPHIN_CONFIG
81+
LLAMA_CONFIG,
82+
QWEN2_CONFIG,
83+
PHI3_CONFIG,
84+
GPT2_CONFIG,
85+
# STABLELM_CONFIG, # enable this when v1 support head_size=80
86+
DOLPHIN_CONFIG,
8387
# STARCODER_CONFIG, # broken
8488
]
8589

vllm/engine/arg_utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,14 +1291,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
12911291
recommend_to_remove=False)
12921292
return False
12931293

1294-
# Some quantization is not compatible with torch.compile.
1295-
V1_UNSUPPORTED_QUANT = ["gguf"]
1296-
if model_config.quantization in V1_UNSUPPORTED_QUANT:
1297-
_raise_or_fallback(
1298-
feature_name=f"--quantization {model_config.quantization}",
1299-
recommend_to_remove=False)
1300-
return False
1301-
13021294
# No Embedding Models so far.
13031295
if model_config.task not in ["generate"]:
13041296
_raise_or_fallback(feature_name=f"--task {model_config.task}",

vllm/model_executor/layers/linear.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -587,8 +587,6 @@ def weight_loader(self,
587587
param.shard_id.append(loaded_shard_id)
588588
param.shard_id_map[loaded_shard_id] = len(param.data_container)
589589
param.data_container.append(loaded_weight)
590-
if len(param.data_container) == 2:
591-
self.qweight = param.materialize_nested()
592590
return
593591

594592
param_data = param.data
@@ -982,8 +980,6 @@ def weight_loader(self,
982980
param.shard_id.append(loaded_shard_id)
983981
param.shard_id_map[loaded_shard_id] = len(param.data_container)
984982
param.data_container.append(loaded_weight)
985-
if len(param.data_container) == 3:
986-
self.qweight = param.materialize_nested()
987983
return
988984

989985
param_data = param.data

vllm/model_executor/layers/quantization/gguf.py

Lines changed: 181 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from vllm import _custom_ops as ops
1111
from vllm.logger import init_logger
12-
from vllm.model_executor.layers.activation import SiluAndMul
1312
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
1413
FusedMoEMethodBase)
1514
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
@@ -19,6 +18,7 @@
1918
from vllm.model_executor.layers.vocab_parallel_embedding import (
2019
VocabParallelEmbedding)
2120
from vllm.model_executor.utils import set_weight_attrs
21+
from vllm.utils import direct_register_custom_op
2222

2323
logger = init_logger(__name__)
2424

@@ -96,8 +96,8 @@ def get_quant_method(self, layer: torch.nn.Module,
9696
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
9797

9898

99-
def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
100-
qweight_type: int) -> torch.Tensor:
99+
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
100+
qweight_type: int) -> torch.Tensor:
101101
# HACK: when doing chunked prefill we don't generate output tokens
102102
# so input to logits generator is empty which causes invalid parameter
103103
if x.shape[0] == 0:
@@ -130,6 +130,30 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
130130
return y
131131

132132

133+
def _fused_mul_mat_gguf_fake(
134+
x: torch.Tensor,
135+
qweight: torch.Tensor,
136+
qweight_type: int,
137+
) -> torch.Tensor:
138+
return torch.empty(x.shape[0],
139+
qweight.shape[0],
140+
dtype=x.dtype,
141+
device=x.device)
142+
143+
144+
try:
145+
direct_register_custom_op(
146+
op_name="_fused_mul_mat_gguf",
147+
op_func=_fused_mul_mat_gguf,
148+
mutates_args=[],
149+
fake_impl=_fused_mul_mat_gguf_fake,
150+
)
151+
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
152+
153+
except AttributeError as error:
154+
raise error
155+
156+
133157
def _fused_moe_gguf(
134158
x: torch.Tensor,
135159
w1: torch.Tensor,
@@ -138,8 +162,21 @@ def _fused_moe_gguf(
138162
topk_ids: torch.Tensor,
139163
qweight_type: int,
140164
qweight_type2: int,
141-
act,
165+
activation: str,
142166
) -> torch.Tensor:
167+
168+
def act(x: torch.Tensor):
169+
d = x.shape[-1] // 2
170+
output_shape = (x.shape[:-1] + (d, ))
171+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
172+
if activation == "silu":
173+
torch.ops._C.silu_and_mul(out, x)
174+
elif activation == "gelu":
175+
torch.ops._C.gelu_and_mul(out, x)
176+
else:
177+
raise ValueError(f"Unsupported activation: {activation}")
178+
return out
179+
143180
# lazy import to avoid triggering triton import in CPU backend
144181
from vllm.model_executor.layers.fused_moe.fused_moe import (
145182
moe_align_block_size)
@@ -189,12 +226,12 @@ def _fused_moe_gguf(
189226
for ww, ii in zip(w, idx):
190227
expert_up = w1[ii]
191228

192-
out = _fuse_mul_mat(inp, expert_up, qweight_type)
229+
out = fused_mul_mat_gguf(inp, expert_up, qweight_type)
193230
out = act(out)
194231

195232
expert_down = w2[ii]
196-
current_state = _fuse_mul_mat(out, expert_down,
197-
qweight_type2).mul_(ww)
233+
current_state = fused_mul_mat_gguf(out, expert_down,
234+
qweight_type2).mul_(ww)
198235
if current_hidden_state is None:
199236
current_hidden_state = current_state
200237
else:
@@ -203,6 +240,78 @@ def _fused_moe_gguf(
203240
return out_hidden_states
204241

205242

243+
def _fused_moe_gguf_fake(
244+
x: torch.Tensor,
245+
w1: torch.Tensor,
246+
w2: torch.Tensor,
247+
topk_weights: torch.Tensor,
248+
topk_ids: torch.Tensor,
249+
qweight_type: int,
250+
qweight_type2: int,
251+
activation: str,
252+
) -> torch.Tensor:
253+
return torch.empty_like(x)
254+
255+
256+
try:
257+
direct_register_custom_op(
258+
op_name="_fused_moe_gguf",
259+
op_func=_fused_moe_gguf,
260+
mutates_args=[],
261+
fake_impl=_fused_moe_gguf_fake,
262+
)
263+
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
264+
265+
except AttributeError as error:
266+
raise error
267+
268+
269+
def _apply_gguf_embedding(
270+
x: torch.Tensor,
271+
qweight: torch.Tensor,
272+
qweight_type: int,
273+
hidden_size: int,
274+
dtype: Optional[torch.dtype] = None,
275+
) -> torch.Tensor:
276+
if qweight_type in UNQUANTIZED_TYPES:
277+
return torch.embedding(qweight, x)
278+
elif qweight_type in DEQUANT_TYPES:
279+
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
280+
x_flat = x.flatten()
281+
assert (hidden_size == qweight.shape[1] // type_size * block_size)
282+
quant = torch.index_select(qweight, dim=0, index=x_flat)
283+
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
284+
x_flat.shape[0], dtype)
285+
return dequant.view(*x.shape, hidden_size)
286+
else:
287+
qweight_type = WeightType(qweight_type)
288+
raise NotImplementedError(
289+
f"Unsupported GGUF quantization type: {qweight_type}")
290+
291+
292+
def _apply_gguf_embedding_fake(
293+
x: torch.Tensor,
294+
qweight: torch.Tensor,
295+
qweight_type: int,
296+
hidden_size: int,
297+
dtype: Optional[torch.dtype] = None,
298+
) -> torch.Tensor:
299+
return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device)
300+
301+
302+
try:
303+
direct_register_custom_op(
304+
op_name="_apply_gguf_embedding",
305+
op_func=_apply_gguf_embedding,
306+
mutates_args=[],
307+
fake_impl=_apply_gguf_embedding_fake,
308+
)
309+
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding
310+
311+
except AttributeError as error:
312+
raise error
313+
314+
206315
class GGUFLinearMethod(LinearMethodBase):
207316
"""Linear method for GGUF.
208317
@@ -249,26 +358,76 @@ def create_weights(self, layer: torch.nn.Module,
249358
set_weight_attrs(qweight_type, extra_weight_attrs)
250359
layer.register_parameter("qweight_type", qweight_type)
251360

361+
def process_weights_after_loading(self, layer: torch.nn.Module):
362+
qweight_type = layer.qweight_type.weight_type
363+
if not (qweight_type in UNQUANTIZED_TYPES
364+
or qweight_type in DEQUANT_TYPES):
365+
qweight_type = WeightType(qweight_type)
366+
raise ValueError(
367+
f"Unsupported GGUF quantization type {qweight_type} in "
368+
f"layer {layer}.")
369+
# For MergedColumnParallelLinear and QKVParallelLinear, we need to
370+
# materialize the padded weight parameter for CUDA Graph compatibility.
371+
self._create_padded_weight_param(layer)
372+
373+
def _create_padded_weight_param(self, layer: torch.nn.Module):
374+
"""Create padded weight parameter for GGUF MergedLinear layer."""
375+
qweight = layer.qweight
376+
shard_id_map = qweight.shard_id_map
377+
shard_id = qweight.shard_id
378+
if len(data_container := qweight.data_container) > 1:
379+
dtype = {data.dtype for data in data_container}
380+
assert len(dtype) == 1, ValueError(
381+
f"Data container has mixed dtypes: {dtype}")
382+
dtype = next(iter(dtype))
383+
# concat dim0 and pad dim1
384+
padded_side = max(x.size(1) for x in data_container)
385+
concat_side = sum(x.size(0) for x in data_container)
386+
# Pad the quantized weights to dense tensor, and create a map
387+
# with the location of each shard in the padded tensor.
388+
padded_data = torch.zeros((concat_side, padded_side),
389+
dtype=dtype,
390+
device=qweight.device)
391+
# (dim0_start, dim0_end, dim1_size)
392+
shard_offset_map = dict[str, tuple[int, int, int]]()
393+
for idx in shard_id:
394+
id_in_container = shard_id_map[idx]
395+
start = sum(
396+
x.size(0) for x in data_container[:id_in_container])
397+
end = start + data_container[id_in_container].size(0)
398+
size = data_container[id_in_container].size(1)
399+
padded_data[start:end, :size] = data_container[id_in_container]
400+
shard_offset_map[idx] = (start, end, size)
401+
qweight.data_container.clear()
402+
padded_param = Parameter(padded_data, requires_grad=False)
403+
set_weight_attrs(padded_param, vars(qweight))
404+
set_weight_attrs(padded_param,
405+
{"shard_offset_map": shard_offset_map})
406+
layer.register_parameter("qweight", padded_param)
407+
252408
def apply(self,
253409
layer: torch.nn.Module,
254410
x: torch.Tensor,
255411
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
256-
shard_id = getattr(layer.qweight, "shard_id", None)
412+
shard_id = layer.qweight.shard_id
257413

258414
if shard_id:
259415
# dequantize shard weights respectively
260416
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
261-
qweight = layer.qweight.unbind(0)
417+
qweight = layer.qweight
262418
result = []
263419
for idx in shard_id:
264-
q_idx = layer.qweight.shard_id_map[idx]
420+
start, end, offset = layer.qweight.shard_offset_map[idx]
265421
qweight_type = layer.qweight_type.shard_weight_type[idx]
266-
result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type))
422+
result.append(
423+
fused_mul_mat_gguf(
424+
x, qweight[start:end, :offset].contiguous(),
425+
qweight_type))
267426
out = torch.cat(result, axis=1)
268427
else:
269428
qweight = layer.qweight
270429
qweight_type = layer.qweight_type.weight_type
271-
out = _fuse_mul_mat(x, qweight, qweight_type)
430+
out = fused_mul_mat_gguf(x, qweight, qweight_type)
272431
if bias is not None:
273432
out.add_(bias)
274433
return out
@@ -338,7 +497,6 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
338497

339498
set_weight_attrs(w2_qweight_type, extra_weight_attrs)
340499
layer.register_parameter("w2_qweight_type", w2_qweight_type)
341-
self.act = SiluAndMul()
342500

343501
def apply(
344502
self,
@@ -375,10 +533,10 @@ def apply(
375533
custom_routing_function=custom_routing_function,
376534
scoring_func=scoring_func,
377535
e_score_correction_bias=e_score_correction_bias)
378-
return _fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
379-
topk_weights, topk_ids,
380-
layer.w13_qweight_type.weight_type,
381-
layer.w2_qweight_type.weight_type, self.act)
536+
return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
537+
topk_weights, topk_ids,
538+
layer.w13_qweight_type.weight_type,
539+
layer.w2_qweight_type.weight_type, activation)
382540

383541

384542
class GGUFEmbeddingMethod(GGUFLinearMethod):
@@ -392,34 +550,15 @@ def embedding(self, layer: torch.nn.Module,
392550
x: torch.Tensor) -> torch.Tensor:
393551
qweight = layer.qweight
394552
qweight_type = layer.qweight_type.weight_type
553+
hidden_size = qweight.tensor_shape[1]
395554

396-
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
397-
hidden_size = qweight.shape[1] // type_size * block_size
398-
if qweight_type < 2:
399-
return torch.embedding(qweight, x)
400-
x_flat = x.flatten()
401-
quant = torch.index_select(qweight, dim=0, index=x_flat)
402-
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
403-
x_flat.shape[0], self.params_dtype)
404-
return dequant.view(*x.shape, hidden_size)
555+
return apply_gguf_embedding(x,
556+
qweight,
557+
qweight_type,
558+
hidden_size,
559+
dtype=self.params_dtype)
405560

406561

407562
class GGUFUninitializedParameter(UninitializedParameter):
408563
cls_to_become = Parameter
409564
data_container: list[torch.Tensor]
410-
411-
def materialize_nested(self) -> Parameter:
412-
dtype = {data.dtype for data in self.data_container}
413-
assert len(dtype) == 1, ValueError(
414-
f"Data container has mixed dtypes: {dtype}")
415-
dtype = next(iter(dtype))
416-
nested_data = torch.nested.nested_tensor(self.data_container,
417-
device=self.device,
418-
dtype=dtype)
419-
self.data_container.clear()
420-
param = torch.Tensor._make_subclass(self.cls_to_become,
421-
nested_data,
422-
require_grad=False)
423-
for k, v in self.__dict__.items():
424-
setattr(param, k, v)
425-
return param

0 commit comments

Comments
 (0)