Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 104 additions & 1 deletion vllm/compilation/activation_quant_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
register_replacement)
from torch._ops import OpOverload

from vllm import envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
Expand Down Expand Up @@ -36,6 +37,102 @@
kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501


def is_rocm_aiter_linear_enabled() -> bool:
return current_platform.is_rocm(
) and envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR


if is_rocm_aiter_linear_enabled():
import aiter as rocm_aiter
from aiter.ops.triton.activation import act_mul_and_fp8_group_quant

from vllm.utils import direct_register_custom_op
rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8
rocm_aiter_fp8_quant_group_size = 128

def _rocm_aiter_act_mul_and_fp8_group_quant_impl(
x: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]:
return act_mul_and_fp8_group_quant(
x,
activation="silu",
group_size=rocm_aiter_fp8_quant_group_size,
dtype_quant=rocm_aiter_fp8_dtype)

def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
x: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
assert N % 2 == 0
N_half = N // 2
x_fp8 = torch.empty((M, N_half),
dtype=rocm_aiter_fp8_dtype,
device=x.device)
out_bs = torch.empty(
(M, (N_half + rocm_aiter_fp8_quant_group_size - 1) //
rocm_aiter_fp8_quant_group_size),
dtype=torch.float32,
device=x.device)
return x_fp8, out_bs

direct_register_custom_op(
op_name="rocm_aiter_act_mul_and_fp8_group_quant",
Copy link
Contributor

@tjtanaa tjtanaa Sep 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check if the latest aiter allows you to skip direct register custom ops? I remember most ops now should be able to work without calling direct_register_custom_ops on vLLM side as it is done in AITER repository. Moreover, removing the direct_register_custom_ops wrappers can reduce additional CPU overhead. Doing direct_register_custom_ops can be costly in terms of overhead.

Please take a look at the benchmarking results in this PR ROCm#717 (the second and third case) where it shows that removing the direct_register_custom_ops on vLLM side improves the perf.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, thanks for the feedback. Is there a version of aiter which has aiter.ops.triton.fused_fp8_quant and also has these direct_register_custom_ops that you mentioned? I wasn't able to figure out how to call act_mul_and_fp8_group_quant without calling direct_register_custom_op first. Would be happy to investigate further if you can point me in the right direction, otherwise I think we can always come back and get rid of these direct_register_custom_op calls if needed.

op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl,
mutates_args=[],
fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
BLOCK_LINEAR_OP = torch.ops.vllm.apply_w8a8_block_fp8_linear.default
FUSED_SILU_MUL_QUANT_OP = \
torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
AITER_BLOCK_LINEAR_OP = \
torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale.default

class AiterSiluMulFp8BlockQuantPattern:

def __init__(self):
pass

def register(self, pm_pass: PatternMatcherPass):

def pattern(input: torch.Tensor, result_silu_mul: torch.Tensor,
linear_weight: torch.Tensor,
linear_weight_scale: torch.Tensor):
at1 = auto_functionalized(SILU_MUL_OP,
result=result_silu_mul,
input=input)
at2 = BLOCK_LINEAR_OP(input=at1[1],
weight=linear_weight,
block_size=[128, 128],
weight_scale=linear_weight_scale,
input_scale=None,
bias=None,
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True)
return at2

def replacement(input: torch.Tensor, result_silu_mul: torch.Tensor,
linear_weight: torch.Tensor,
linear_weight_scale: torch.Tensor):
at1 = FUSED_SILU_MUL_QUANT_OP(x=input)
at2 = AITER_BLOCK_LINEAR_OP(A=at1[0],
B=linear_weight,
As=at1[1],
Bs=linear_weight_scale,
block_size=[128, 128],
output_dtype=input.dtype)
return at2

inputs = [
empty_bf16(5, 4), # input
empty_bf16(5, 4), # result_silu_mul
# linear_weight
torch.empty((2, 5), device="cuda", dtype=FP8_DTYPE),
empty_fp32(1, 1) # linear_weight_scale
]

register_replacement(pattern, replacement, inputs, fwd_only,
pm_pass)


class ActivationQuantPattern(ABC):
"""
The base class for Activation+Quant fusions.
Expand Down Expand Up @@ -176,6 +273,11 @@ def __init__(self, config: VllmConfig):
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
pattern_silu_mul_nvfp4.register(self.patterns)

if is_rocm_aiter_linear_enabled():
pattern_silu_mul_aiter_block_fp8 = AiterSiluMulFp8BlockQuantPattern(
)
pattern_silu_mul_aiter_block_fp8.register(self.patterns)

self.dump_patterns(config, self.patterns)

@VllmInductorPass.time_and_log
Expand All @@ -186,4 +288,5 @@ def __call__(self, graph: torch.fx.Graph):
def uuid(self):
return VllmInductorPass.hash_source(self, ActivationQuantPattern,
SiluMulFp8StaticQuantPattern,
SiluMulNvfp4QuantPattern)
SiluMulNvfp4QuantPattern,
AiterSiluMulFp8BlockQuantPattern)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This symbol definition is conditional on is_rocm_aiter_linear_enabled():
Any run will fail here if not enabled.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed now cd059b9

15 changes: 8 additions & 7 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,12 @@ def rocm_aiter_gemm_w8a8_blockscale_impl(
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
import aiter as rocm_aiter

return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
# MI300's fp8nuz should be enough to detect if we call ck vs triton
if current_platform.is_fp8_fnuz():
from aiter import gemm_a8w8_blockscale
else:
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
Comment on lines +64 to +69
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Importing inside a function that is on a hot path, like this custom op implementation, can introduce performance overhead. It's best practice to move imports to the module level to ensure they are only executed once.

I'd recommend defining a module-level variable that holds the correct gemm_a8w8_blockscale function based on the platform, and then using that variable within this function. This avoids repeated import lookups.

For example, you could add the following logic at the module level (e.g., near the top of the file):

_gemm_a8w8_blockscale = None
if current_platform.is_rocm():
    try:
        # MI300's fp8nuz should be enough to detect if we call ck vs triton
        if current_platform.is_fp8_fnuz():
            from aiter import gemm_a8w8_blockscale
        else:
            from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
        _gemm_a8w8_blockscale = gemm_a8w8_blockscale
    except ImportError:
        # aiter is not installed, which is fine.
        # The error will be raised when the op is actually used.
        pass

And then this function's body can be simplified as suggested.

Suggested change
# MI300's fp8nuz should be enough to detect if we call ck vs triton
if current_platform.is_fp8_fnuz():
from aiter import gemm_a8w8_blockscale
else:
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
if _gemm_a8w8_blockscale is None:
raise ImportError(
"Aiter backend for gemm_a8w8_blockscale not available. "
"Please install aiter.")
return _gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)



def rocm_aiter_gemm_w8a8_blockscale_fake(
Expand All @@ -87,8 +90,7 @@ def rocm_aiter_gemm_w8a8_blockscale_fake(
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
)
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz()):
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR:

import aiter as rocm_aiter
from aiter import get_hip_quant
Expand Down Expand Up @@ -737,8 +739,7 @@ def check_aiter_fp8_linear_support() -> bool:
"""AITER is only supported on ROCm and only for FP8_FNUZ
and at the moment are MI300 series"""
return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz())
and envs.VLLM_ROCM_USE_AITER_LINEAR)


def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
Expand Down