Skip to content

Commit 307939f

Browse files
mgoindsikka
andauthored
Use NVFP4 Marlin for CompressedTensorsW4A16Fp4 (#18000)
Signed-off-by: mgoin <[email protected]> Signed-off-by: Dipika <[email protected]> Co-authored-by: Dipika <[email protected]>
1 parent 9d7ea9d commit 307939f

File tree

1 file changed

+26
-41
lines changed

1 file changed

+26
-41
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py

Lines changed: 26 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
from typing import Callable, List, Optional
33

44
import torch
5-
import torch.nn.functional as F
65
from torch.nn.parameter import Parameter
76

87
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
98
CompressedTensorsScheme)
10-
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
11-
dequantize_to_dtype)
9+
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
10+
apply_fp4_marlin_linear, prepare_fp4_layer_for_marlin)
1211
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
1312
ModelWeightParameter,
1413
PerTensorScaleParameter)
@@ -31,6 +30,10 @@ def create_weights(self, layer: torch.nn.Module,
3130
input_size_per_partition: int,
3231
params_dtype: torch.dtype, weight_loader: Callable,
3332
**kwargs):
33+
output_size_per_partition = sum(output_partition_sizes)
34+
layer.logical_widths = output_partition_sizes
35+
layer.input_size_per_partition = input_size_per_partition
36+
layer.output_size_per_partition = output_size_per_partition
3437

3538
# Weight
3639
weight = ModelWeightParameter(data=torch.empty(
@@ -60,48 +63,30 @@ def create_weights(self, layer: torch.nn.Module,
6063

6164
layer.register_parameter("weight_scale", weight_scale)
6265

63-
def swizzle_blockscale(self, scale: torch.tensor):
64-
assert (scale.dtype == torch.float8_e4m3fn)
65-
# Pad and blockwise interleave weight_scale
66-
scale_ndim = scale.ndim
67-
if scale.ndim == 2:
68-
scale = scale.unsqueeze(0)
69-
assert scale.ndim == 3
70-
B, M, K = scale.shape
71-
round_up_multiple = lambda x, m: (x + m - 1) // m * m
72-
M_padded = round_up_multiple(M, 128)
73-
K_padded = round_up_multiple(K, 4)
74-
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
75-
padded_scale[:B, :M, :K] = scale
76-
batches, rows, cols = padded_scale.shape
77-
assert rows % 128 == 0
78-
assert cols % 4 == 0
79-
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
80-
cols // 4, 4)
81-
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
82-
swizzled_scale = swizzled_scale.contiguous().cuda()
83-
return (swizzled_scale.reshape(M, K)
84-
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
85-
8666
def process_weights_after_loading(self, layer) -> None:
87-
layer.weight_global_scale = Parameter(
88-
layer.weight_global_scale.max().to(torch.float32),
67+
# Process parameters for marlin repacking
68+
69+
# Rename weight_packed to weight that marlin expects
70+
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
71+
del layer.weight_packed
72+
# Rename weight_global_scale to weight_scale_2 that marlin expects
73+
# Note: ct stores the inverse of what is expected by the marlin kernel
74+
layer.weight_scale_2 = Parameter(
75+
1 / layer.weight_global_scale.max().to(torch.float32),
8976
requires_grad=False)
90-
# Note: a post weight loading step but not required for the emulation
91-
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
92-
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
93-
requires_grad=False)
77+
del layer.weight_global_scale
78+
79+
prepare_fp4_layer_for_marlin(layer)
9480

9581
def apply_weights(self,
9682
layer: torch.nn.Module,
9783
x: torch.Tensor,
9884
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
99-
100-
w_fp4 = layer.weight_packed.data
101-
w_global_scale = layer.weight_global_scale
102-
w_blockscale = layer.weight_scale_swizzled.data
103-
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale,
104-
x.dtype, x.device, self.group_size)
105-
out = F.linear(x, w_dq)
106-
del w_dq, w_fp4, w_global_scale, w_blockscale
107-
return out
85+
return apply_fp4_marlin_linear(input=x,
86+
weight=layer.weight,
87+
weight_scale=layer.weight_scale,
88+
weight_scale_2=layer.weight_scale_2,
89+
workspace=layer.workspace,
90+
size_n=layer.output_size_per_partition,
91+
size_k=layer.input_size_per_partition,
92+
bias=bias)

0 commit comments

Comments
 (0)