Skip to content

Commit 3a318ff

Browse files
baodiizufangzhu
authored andcommitted
add fp8_e5m2 support and fixing UT
Signed-off-by: baodii <[email protected]>
1 parent 1b96119 commit 3a318ff

File tree

4 files changed

+16
-6
lines changed

4 files changed

+16
-6
lines changed

csrc/xpu/dispatch_utils.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@
2222
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
2323

2424
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
25-
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
25+
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
26+
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__)
2627

2728
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
2829
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
30+
AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \
2931
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
3032

3133
// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.

csrc/xpu/quantization/fp8/fp8_quant.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <sycl/sycl.hpp>
66

77
#include "xpu/dispatch_utils.h"
8+
#include "xpu/ops.h"
89

910
#include "fp8_quant.h"
1011
#include "utils.h"

tests/ops/fp8_quant_op.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77
import torch.nn as nn
88
import vllm.envs as envs
99

10+
import sys
11+
import os
12+
13+
# Add parent directory to Python path
14+
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
15+
1016
import tests.register_ops as ops
1117

1218

@@ -45,7 +51,7 @@ def scaled_fp8_quant(
4551
assert (input.ndim == 2)
4652
shape: Union[tuple[int, int], torch.Size] = input.shape
4753
# out_dtype: torch.dtype = current_platform.fp8_dtype()
48-
out_dtype: torch.dtype = torch.fp8_e5m2
54+
out_dtype: torch.dtype = torch.float8_e5m2
4955
if num_token_padding:
5056
shape = (max(num_token_padding, input.shape[0]), shape[1])
5157
if output is None:

tests/test_fp8_quant.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
1212
return torch.as_tensor(x, dtype=torch.float32, device='xpu')
1313

14-
def ref_dynamic_per_tensor_fp8_quant(x, fp8_dtype):
14+
def ref_dynamic_per_tensor_fp8_quant(x, fp8_dtype=torch.float8_e5m2):
1515

1616
fp8_traits = torch.finfo(fp8_dtype)
1717
fp8_traits_max = fp8_traits.max
@@ -43,16 +43,17 @@ def seed_everything(seed):
4343
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
4444
SCALE_UBS = [True, False]
4545
SEEDS = [0]
46-
FP8_DTYPES = [torch.float8_e5m2, torch.float8_e4m3fn]
46+
FP8_DTYPES = [torch.float8_e5m2]
4747

4848

4949
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
5050
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
51+
@pytest.mark.parametrize("fp8_dtype", FP8_DTYPES)
5152
@pytest.mark.parametrize("dtype", DTYPES)
5253
@pytest.mark.parametrize("seed", SEEDS)
5354
@torch.inference_mode()
5455
def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
55-
dtype: torch.dtype,
56+
fp8_dtype: torch.dtype, dtype: torch.dtype,
5657
seed: int) -> None:
5758
seed_everything(seed)
5859

@@ -93,4 +94,4 @@ def test_fp8_quant_large(seed: int, fp8_dtype: torch.dtype) -> None:
9394
torch.testing.assert_close(ref_out, ops_out)
9495

9596
if __name__ == "__main__":
96-
test_dynamic_per_tensor_fp8_quant(1024, 1024, torch.float16, 0)
97+
test_dynamic_per_tensor_fp8_quant(1024, 1024, torch.float8_e5m2, torch.float16, 0)

0 commit comments

Comments
 (0)