Skip to content

Commit 8c678d4

Browse files
committed
add per-token quanzation
Signed-off-by: baodii <[email protected]>
1 parent 41ea86e commit 8c678d4

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from setuptools.command.build_ext import build_ext
1414
from torch.utils.cpp_extension import SYCL_HOME
1515

16-
print("************************************* Using SYCL_HOME:", SYCL_HOME)
1716

1817
def load_module_from_path(module_name, path):
1918
spec = importlib.util.spec_from_file_location(module_name, path)

tests/test_fp8_quant.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,42 @@ def ref_dynamic_per_tensor_fp8_quant(x, fp8_dtype=torch.float8_e5m2):
3030
fp8_traits_min, fp8_traits_max).to(fp8_dtype)
3131
return ref_out, ref_scale.view((1, ))
3232

33+
def ref_dynamic_per_token_quant(x: torch.tensor,
34+
quant_dtype: torch.dtype,
35+
scale_ub: Optional[torch.tensor] = None) \
36+
-> tuple[torch.tensor, torch.tensor]:
37+
38+
assert quant_dtype in [torch.float8_e5m2, torch.float8_e4m3fn]
39+
# if scale_ub is not None:
40+
# assert quant_dtype == FP8_DTYPE
41+
42+
qtype_traits = torch.finfo(quant_dtype)
43+
qtype_traits_max = qtype_traits.max
44+
qtype_traits_min = qtype_traits.min
45+
qtype_max = as_float32_tensor(qtype_traits_max)
46+
s_1 = as_float32_tensor(1.0)
47+
s_512 = as_float32_tensor(512.0)
48+
49+
# For fp8, in order to match the cuda kernel output, we have to do exactly
50+
# the same operations as in the corresponding fp8 kernel to prevent
51+
# rounding errors.
52+
53+
# Compute scales
54+
x_token_max, _ = x.abs().max(dim=-1)
55+
x_token_max = as_float32_tensor(x_token_max)
56+
if scale_ub is not None:
57+
x_token_max = x_token_max.clamp(max=scale_ub)
58+
scales = (x_token_max / qtype_max)[:, None]
59+
60+
# Quant
61+
min_scaling_factor = s_1 / (qtype_max * s_512)
62+
scales = scales.clamp(min=min_scaling_factor)
63+
torch_out = as_float32_tensor(x) / scales
64+
torch_out = torch_out.clamp(qtype_traits_min,
65+
qtype_traits_max).to(quant_dtype)
66+
67+
return torch_out, scales
68+
3369
def seed_everything(seed):
3470
if seed is not None:
3571
random.seed(seed)
@@ -68,6 +104,34 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
68104
ops_out.to(dtype=torch.float32))
69105

70106

107+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
108+
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
109+
@pytest.mark.parametrize("dtype", DTYPES)
110+
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
111+
@pytest.mark.parametrize("seed", SEEDS)
112+
@pytest.mark.parametrize("fp8_dtype", FP8_DTYPES)
113+
@torch.inference_mode()
114+
def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
115+
dtype: torch.dtype, scale_ub: bool,
116+
seed: int, fp8_dtype: torch.dtype) -> None:
117+
seed_everything(seed)
118+
119+
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
120+
device="xpu") + 1e-6 # avoid nans
121+
122+
scale_ub = torch.mean(x).to(dtype=torch.float32, device='xpu') \
123+
if scale_ub else None
124+
ref_out, ref_scales = ref_dynamic_per_token_quant(x, fp8_dtype, scale_ub)
125+
126+
ops_out, ops_scales = scaled_fp8_quant(x,
127+
scale_ub=scale_ub,
128+
use_per_token_if_dynamic=True)
129+
130+
torch.testing.assert_close(ref_scales, ops_scales)
131+
torch.testing.assert_close(ref_out.to(dtype=torch.float32),
132+
ops_out.to(dtype=torch.float32))
133+
134+
71135
# Regression test for a case with large activations where an int32 index cannot
72136
# represent the number of elements.
73137
@torch.inference_mode()

0 commit comments

Comments
 (0)