-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Rocm][torch.compile] Adding layernorm + fp8 block quant and silu + fp8 block quant for Aiter #25693
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Rocm][torch.compile] Adding layernorm + fp8 block quant and silu + fp8 block quant for Aiter #25693
Changes from 6 commits
2f538fa
9d6507b
b901f27
1d11425
b48f84d
41e7e2f
9940a40
cd059b9
6cf02a9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
register_replacement) | ||
from torch._ops import OpOverload | ||
|
||
from vllm import envs | ||
from vllm.config import VllmConfig | ||
from vllm.logger import init_logger | ||
from vllm.model_executor.layers.quantization.utils.quant_utils import ( | ||
|
@@ -36,6 +37,102 @@ | |
kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 | ||
|
||
|
||
def is_rocm_aiter_linear_enabled() -> bool: | ||
return current_platform.is_rocm( | ||
) and envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR | ||
|
||
|
||
if is_rocm_aiter_linear_enabled(): | ||
import aiter as rocm_aiter | ||
from aiter.ops.triton.activation import act_mul_and_fp8_group_quant | ||
|
||
from vllm.utils import direct_register_custom_op | ||
rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8 | ||
rocm_aiter_fp8_quant_group_size = 128 | ||
|
||
def _rocm_aiter_act_mul_and_fp8_group_quant_impl( | ||
x: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: | ||
return act_mul_and_fp8_group_quant( | ||
x, | ||
activation="silu", | ||
group_size=rocm_aiter_fp8_quant_group_size, | ||
dtype_quant=rocm_aiter_fp8_dtype) | ||
|
||
def _rocm_aiter_act_mul_and_fp8_group_quant_fake( | ||
x: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: | ||
M, N = x.shape | ||
assert N % 2 == 0 | ||
N_half = N // 2 | ||
x_fp8 = torch.empty((M, N_half), | ||
dtype=rocm_aiter_fp8_dtype, | ||
device=x.device) | ||
out_bs = torch.empty( | ||
(M, (N_half + rocm_aiter_fp8_quant_group_size - 1) // | ||
rocm_aiter_fp8_quant_group_size), | ||
dtype=torch.float32, | ||
device=x.device) | ||
return x_fp8, out_bs | ||
|
||
direct_register_custom_op( | ||
op_name="rocm_aiter_act_mul_and_fp8_group_quant", | ||
op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl, | ||
mutates_args=[], | ||
fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake, | ||
dispatch_key=current_platform.dispatch_key, | ||
) | ||
BLOCK_LINEAR_OP = torch.ops.vllm.apply_w8a8_block_fp8_linear.default | ||
FUSED_SILU_MUL_QUANT_OP = \ | ||
torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default | ||
AITER_BLOCK_LINEAR_OP = \ | ||
torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale.default | ||
|
||
class AiterSiluMulFp8BlockQuantPattern: | ||
|
||
def __init__(self): | ||
pass | ||
|
||
def register(self, pm_pass: PatternMatcherPass): | ||
|
||
def pattern(input: torch.Tensor, result_silu_mul: torch.Tensor, | ||
linear_weight: torch.Tensor, | ||
linear_weight_scale: torch.Tensor): | ||
at1 = auto_functionalized(SILU_MUL_OP, | ||
result=result_silu_mul, | ||
input=input) | ||
at2 = BLOCK_LINEAR_OP(input=at1[1], | ||
weight=linear_weight, | ||
block_size=[128, 128], | ||
weight_scale=linear_weight_scale, | ||
input_scale=None, | ||
bias=None, | ||
cutlass_block_fp8_supported=False, | ||
use_aiter_and_is_supported=True) | ||
return at2 | ||
|
||
def replacement(input: torch.Tensor, result_silu_mul: torch.Tensor, | ||
linear_weight: torch.Tensor, | ||
linear_weight_scale: torch.Tensor): | ||
at1 = FUSED_SILU_MUL_QUANT_OP(x=input) | ||
at2 = AITER_BLOCK_LINEAR_OP(A=at1[0], | ||
B=linear_weight, | ||
As=at1[1], | ||
Bs=linear_weight_scale, | ||
block_size=[128, 128], | ||
output_dtype=input.dtype) | ||
return at2 | ||
|
||
inputs = [ | ||
empty_bf16(5, 4), # input | ||
empty_bf16(5, 4), # result_silu_mul | ||
# linear_weight | ||
torch.empty((2, 5), device="cuda", dtype=FP8_DTYPE), | ||
empty_fp32(1, 1) # linear_weight_scale | ||
] | ||
|
||
register_replacement(pattern, replacement, inputs, fwd_only, | ||
pm_pass) | ||
|
||
|
||
class ActivationQuantPattern(ABC): | ||
""" | ||
The base class for Activation+Quant fusions. | ||
|
@@ -176,6 +273,11 @@ def __init__(self, config: VllmConfig): | |
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern() | ||
pattern_silu_mul_nvfp4.register(self.patterns) | ||
|
||
if is_rocm_aiter_linear_enabled(): | ||
pattern_silu_mul_aiter_block_fp8 = AiterSiluMulFp8BlockQuantPattern( | ||
) | ||
pattern_silu_mul_aiter_block_fp8.register(self.patterns) | ||
|
||
self.dump_patterns(config, self.patterns) | ||
|
||
@VllmInductorPass.time_and_log | ||
|
@@ -186,4 +288,5 @@ def __call__(self, graph: torch.fx.Graph): | |
def uuid(self): | ||
return VllmInductorPass.hash_source(self, ActivationQuantPattern, | ||
SiluMulFp8StaticQuantPattern, | ||
SiluMulNvfp4QuantPattern) | ||
SiluMulNvfp4QuantPattern, | ||
AiterSiluMulFp8BlockQuantPattern) | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -61,9 +61,12 @@ def rocm_aiter_gemm_w8a8_blockscale_impl( | |||||||||||||||||||||||
block_size: list[int], | ||||||||||||||||||||||||
output_dtype: torch.dtype = torch.float16, | ||||||||||||||||||||||||
) -> torch.Tensor: | ||||||||||||||||||||||||
import aiter as rocm_aiter | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) | ||||||||||||||||||||||||
# MI300's fp8nuz should be enough to detect if we call ck vs triton | ||||||||||||||||||||||||
if current_platform.is_fp8_fnuz(): | ||||||||||||||||||||||||
from aiter import gemm_a8w8_blockscale | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale | ||||||||||||||||||||||||
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) | ||||||||||||||||||||||||
Comment on lines
+64
to
+69
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Importing inside a function that is on a hot path, like this custom op implementation, can introduce performance overhead. It's best practice to move imports to the module level to ensure they are only executed once. I'd recommend defining a module-level variable that holds the correct For example, you could add the following logic at the module level (e.g., near the top of the file): _gemm_a8w8_blockscale = None
if current_platform.is_rocm():
try:
# MI300's fp8nuz should be enough to detect if we call ck vs triton
if current_platform.is_fp8_fnuz():
from aiter import gemm_a8w8_blockscale
else:
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
_gemm_a8w8_blockscale = gemm_a8w8_blockscale
except ImportError:
# aiter is not installed, which is fine.
# The error will be raised when the op is actually used.
pass And then this function's body can be simplified as suggested.
Suggested change
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
def rocm_aiter_gemm_w8a8_blockscale_fake( | ||||||||||||||||||||||||
|
@@ -87,8 +90,7 @@ def rocm_aiter_gemm_w8a8_blockscale_fake( | |||||||||||||||||||||||
op_func=rocm_aiter_gemm_w8a8_blockscale_impl, | ||||||||||||||||||||||||
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR | ||||||||||||||||||||||||
and current_platform.is_fp8_fnuz()): | ||||||||||||||||||||||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR: | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
import aiter as rocm_aiter | ||||||||||||||||||||||||
from aiter import get_hip_quant | ||||||||||||||||||||||||
|
@@ -737,8 +739,7 @@ def check_aiter_fp8_linear_support() -> bool: | |||||||||||||||||||||||
"""AITER is only supported on ROCm and only for FP8_FNUZ | ||||||||||||||||||||||||
and at the moment are MI300 series""" | ||||||||||||||||||||||||
return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER | ||||||||||||||||||||||||
and envs.VLLM_ROCM_USE_AITER_LINEAR | ||||||||||||||||||||||||
and current_platform.is_fp8_fnuz()) | ||||||||||||||||||||||||
and envs.VLLM_ROCM_USE_AITER_LINEAR) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check if the latest aiter allows you to skip direct register custom ops? I remember most ops now should be able to work without calling
direct_register_custom_ops
on vLLM side as it is done in AITER repository. Moreover, removing thedirect_register_custom_ops
wrappers can reduce additional CPU overhead. Doingdirect_register_custom_ops
can be costly in terms of overhead.Please take a look at the benchmarking results in this PR ROCm#717 (the second and third case) where it shows that removing the
direct_register_custom_ops
on vLLM side improves the perf.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, thanks for the feedback. Is there a version of aiter which has aiter.ops.triton.fused_fp8_quant and also has these direct_register_custom_ops that you mentioned? I wasn't able to figure out how to call act_mul_and_fp8_group_quant without calling direct_register_custom_op first. Would be happy to investigate further if you can point me in the right direction, otherwise I think we can always come back and get rid of these direct_register_custom_op calls if needed.