2
2
from typing import Callable , List , Optional
3
3
4
4
import torch
5
- import torch .nn .functional as F
6
5
from torch .nn .parameter import Parameter
7
6
8
7
from vllm .model_executor .layers .quantization .compressed_tensors .schemes import (
9
8
CompressedTensorsScheme )
10
- from vllm .model_executor .layers .quantization .utils .nvfp4_emulation_utils import ( # noqa: E501
11
- dequantize_to_dtype )
9
+ from vllm .model_executor .layers .quantization .utils .marlin_utils_fp4 import (
10
+ apply_fp4_marlin_linear , prepare_fp4_layer_for_marlin )
12
11
from vllm .model_executor .parameter import (GroupQuantScaleParameter ,
13
12
ModelWeightParameter ,
14
13
PerTensorScaleParameter )
@@ -31,6 +30,10 @@ def create_weights(self, layer: torch.nn.Module,
31
30
input_size_per_partition : int ,
32
31
params_dtype : torch .dtype , weight_loader : Callable ,
33
32
** kwargs ):
33
+ output_size_per_partition = sum (output_partition_sizes )
34
+ layer .logical_widths = output_partition_sizes
35
+ layer .input_size_per_partition = input_size_per_partition
36
+ layer .output_size_per_partition = output_size_per_partition
34
37
35
38
# Weight
36
39
weight = ModelWeightParameter (data = torch .empty (
@@ -60,48 +63,30 @@ def create_weights(self, layer: torch.nn.Module,
60
63
61
64
layer .register_parameter ("weight_scale" , weight_scale )
62
65
63
- def swizzle_blockscale (self , scale : torch .tensor ):
64
- assert (scale .dtype == torch .float8_e4m3fn )
65
- # Pad and blockwise interleave weight_scale
66
- scale_ndim = scale .ndim
67
- if scale .ndim == 2 :
68
- scale = scale .unsqueeze (0 )
69
- assert scale .ndim == 3
70
- B , M , K = scale .shape
71
- round_up_multiple = lambda x , m : (x + m - 1 ) // m * m
72
- M_padded = round_up_multiple (M , 128 )
73
- K_padded = round_up_multiple (K , 4 )
74
- padded_scale = torch .zeros ((B , M_padded , K_padded ), dtype = scale .dtype )
75
- padded_scale [:B , :M , :K ] = scale
76
- batches , rows , cols = padded_scale .shape
77
- assert rows % 128 == 0
78
- assert cols % 4 == 0
79
- padded_scale = padded_scale .reshape (batches , rows // 128 , 4 , 32 ,
80
- cols // 4 , 4 )
81
- swizzled_scale = padded_scale .permute ((0 , 1 , 4 , 3 , 2 , 5 ))
82
- swizzled_scale = swizzled_scale .contiguous ().cuda ()
83
- return (swizzled_scale .reshape (M , K )
84
- if scale_ndim == 2 else swizzled_scale .reshape (B , M , K ))
85
-
86
66
def process_weights_after_loading (self , layer ) -> None :
87
- layer .weight_global_scale = Parameter (
88
- layer .weight_global_scale .max ().to (torch .float32 ),
67
+ # Process parameters for marlin repacking
68
+
69
+ # Rename weight_packed to weight that marlin expects
70
+ layer .weight = Parameter (layer .weight_packed .data , requires_grad = False )
71
+ del layer .weight_packed
72
+ # Rename weight_global_scale to weight_scale_2 that marlin expects
73
+ # Note: ct stores the inverse of what is expected by the marlin kernel
74
+ layer .weight_scale_2 = Parameter (
75
+ 1 / layer .weight_global_scale .max ().to (torch .float32 ),
89
76
requires_grad = False )
90
- # Note: a post weight loading step but not required for the emulation
91
- swizzled_weight_scale = self .swizzle_blockscale (layer .weight_scale )
92
- layer .weight_scale_swizzled = Parameter (swizzled_weight_scale ,
93
- requires_grad = False )
77
+ del layer .weight_global_scale
78
+
79
+ prepare_fp4_layer_for_marlin (layer )
94
80
95
81
def apply_weights (self ,
96
82
layer : torch .nn .Module ,
97
83
x : torch .Tensor ,
98
84
bias : Optional [torch .Tensor ] = None ) -> torch .Tensor :
99
-
100
- w_fp4 = layer .weight_packed .data
101
- w_global_scale = layer .weight_global_scale
102
- w_blockscale = layer .weight_scale_swizzled .data
103
- w_dq = dequantize_to_dtype (w_fp4 , w_blockscale , w_global_scale ,
104
- x .dtype , x .device , self .group_size )
105
- out = F .linear (x , w_dq )
106
- del w_dq , w_fp4 , w_global_scale , w_blockscale
107
- return out
85
+ return apply_fp4_marlin_linear (input = x ,
86
+ weight = layer .weight ,
87
+ weight_scale = layer .weight_scale ,
88
+ weight_scale_2 = layer .weight_scale_2 ,
89
+ workspace = layer .workspace ,
90
+ size_n = layer .output_size_per_partition ,
91
+ size_k = layer .input_size_per_partition ,
92
+ bias = bias )
0 commit comments