Skip to content

Commit 4d0ec37

Browse files
[Quantization][FP8] Adding support for fp8 gemm layer input in fp8 (#14578)
Signed-off-by: Gregory Shtrasberg <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
1 parent e7f720e commit 4d0ec37

File tree

5 files changed

+41
-9
lines changed

5 files changed

+41
-9
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
2323

2424
def __init__(self, strategy: str, is_static_input_scheme: bool):
2525
self.strategy = strategy
26+
self.out_dtype = torch.get_default_dtype()
2627
self.is_static_input_scheme = is_static_input_scheme
2728
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
2829

@@ -143,5 +144,6 @@ def apply_weights(self,
143144
return self.fp8_linear.apply(input=x,
144145
weight=layer.weight,
145146
weight_scale=layer.weight_scale,
147+
out_dtype=self.out_dtype,
146148
input_scale=layer.input_scale,
147149
bias=bias)

vllm/model_executor/layers/quantization/fbgemm_fp8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
7373
def __init__(self, quant_config: FBGEMMFp8Config):
7474
self.quant_config = quant_config
7575
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
76+
self.out_dtype = torch.get_default_dtype()
7677

7778
def create_weights(
7879
self,
@@ -161,6 +162,7 @@ def apply(self,
161162
return self.fp8_linear.apply(input=x,
162163
weight=layer.weight,
163164
weight_scale=layer.weight_scale,
165+
out_dtype=self.out_dtype,
164166
input_scale=None,
165167
input_scale_ub=layer.input_scale_ub,
166168
bias=bias)

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,21 @@ def get_quant_method(self, layer: torch.nn.Module,
116116
return Fp8KVCacheMethod(self)
117117
return None
118118

119+
def get_cache_scale(self, name: str) -> Optional[str]:
120+
"""
121+
Check whether the param name matches the format for k/v cache scales
122+
in compressed-tensors. If this is the case, return its equivalent
123+
param name expected by vLLM
124+
125+
:param name: param name
126+
:return: matching param name for KV cache scale in vLLM
127+
"""
128+
if name.endswith(".output_scale") and ".k_proj" in name:
129+
return name.replace(".k_proj.output_scale", ".attn.k_scale")
130+
if name.endswith(".output_scale") and ".v_proj" in name:
131+
return name.replace(".v_proj.output_scale", ".attn.v_scale")
132+
return None
133+
119134

120135
class Fp8LinearMethod(LinearMethodBase):
121136
"""Linear method for FP8.
@@ -138,6 +153,7 @@ class Fp8LinearMethod(LinearMethodBase):
138153
def __init__(self, quant_config: Fp8Config):
139154
self.quant_config = quant_config
140155
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
156+
self.out_dtype = torch.get_default_dtype()
141157

142158
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
143159
# kernel for fast weight-only FP8 quantization
@@ -386,6 +402,7 @@ def apply(self,
386402
return self.fp8_linear.apply(input=x,
387403
weight=layer.weight,
388404
weight_scale=layer.weight_scale,
405+
out_dtype=self.out_dtype,
389406
input_scale=layer.input_scale,
390407
bias=bias)
391408

vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
2222
self.qscheme = qscheme
2323
self.is_static_input_scheme = is_static_input_scheme
2424
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
25+
self.out_dtype = torch.get_default_dtype()
2526

2627
@classmethod
2728
def get_min_capability(cls) -> int:
@@ -134,5 +135,6 @@ def apply_weights(self,
134135
return self.fp8_linear.apply(input=x,
135136
weight=layer.weight,
136137
weight_scale=layer.weight_scale,
138+
out_dtype=self.out_dtype,
137139
input_scale=layer.input_scale,
138140
bias=bias)

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def apply(
163163
input: torch.Tensor,
164164
weight: torch.Tensor,
165165
weight_scale: torch.Tensor,
166+
out_dtype: Optional[torch.dtype] = None,
166167
input_scale: Optional[torch.Tensor] = None,
167168
input_scale_ub: Optional[torch.Tensor] = None,
168169
bias: Optional[torch.Tensor] = None,
@@ -182,8 +183,13 @@ def apply(
182183
if use_per_token_if_dynamic is None:
183184
use_per_token_if_dynamic = self.use_per_token_if_dynamic
184185

186+
if out_dtype is None:
187+
out_dtype = input.dtype
188+
185189
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
186190
if self.cutlass_fp8_supported:
191+
assert input.dtype != current_platform.fp8_dtype(
192+
), "FP8 input to cutlass is not currently implemented"
187193
qinput, x_scale = ops.scaled_fp8_quant(
188194
input_2d,
189195
input_scale,
@@ -193,7 +199,7 @@ def apply(
193199
# Fused GEMM_DQ
194200
output = ops.cutlass_scaled_mm(qinput,
195201
weight,
196-
out_dtype=input.dtype,
202+
out_dtype=out_dtype,
197203
scale_a=x_scale,
198204
scale_b=weight_scale,
199205
bias=bias)
@@ -202,12 +208,15 @@ def apply(
202208
# torch.scaled_mm supports per tensor weights + activations only
203209
# so fallback to naive if per channel or per token
204210
else:
205-
# Maybe apply padding to output, see comment in __init__
206-
qinput, x_scale = ops.scaled_fp8_quant(
207-
input_2d,
208-
input_scale,
209-
num_token_padding=self.output_padding,
210-
use_per_token_if_dynamic=use_per_token_if_dynamic)
211+
if input.dtype != current_platform.fp8_dtype():
212+
# Maybe apply padding to output, see comment in __init__
213+
qinput, x_scale = ops.scaled_fp8_quant(
214+
input_2d,
215+
input_scale,
216+
num_token_padding=self.output_padding,
217+
use_per_token_if_dynamic=use_per_token_if_dynamic)
218+
else:
219+
qinput, x_scale = input_2d, input_scale
211220

212221
per_tensor_weights = (weight_scale.numel() == 1)
213222
per_tensor_activations = (x_scale.numel() == 1)
@@ -216,7 +225,7 @@ def apply(
216225
# Fused GEMM_DQ
217226
output = torch._scaled_mm(qinput,
218227
weight,
219-
out_dtype=input.dtype,
228+
out_dtype=out_dtype,
220229
scale_a=x_scale,
221230
scale_b=weight_scale,
222231
bias=bias)
@@ -240,7 +249,7 @@ def apply(
240249
# Fused GEMM_DQ Rowwise GEMM
241250
output = torch._scaled_mm(qinput,
242251
weight,
243-
out_dtype=input.dtype,
252+
out_dtype=out_dtype,
244253
scale_a=x_scale,
245254
scale_b=weight_scale.t(),
246255
bias=bias)

0 commit comments

Comments
 (0)