Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 110 additions & 3 deletions vllm/compilation/activation_quant_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
register_replacement)
from torch._ops import OpOverload

from vllm import envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
Expand Down Expand Up @@ -66,6 +67,102 @@ def register(self, pm_pass: PatternMatcherPass):
raise NotImplementedError


def is_rocm_aiter_linear_enabled() -> bool:
return current_platform.is_rocm(
) and envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR


if is_rocm_aiter_linear_enabled():
import aiter as rocm_aiter
from aiter.ops.triton.activation import act_mul_and_fp8_group_quant

from vllm.utils import direct_register_custom_op
rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8
rocm_aiter_fp8_quant_group_size = 128

def _rocm_aiter_act_mul_and_fp8_group_quant_impl(
x: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]:
return act_mul_and_fp8_group_quant(
x,
activation="silu",
group_size=rocm_aiter_fp8_quant_group_size,
dtype_quant=rocm_aiter_fp8_dtype)

def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
x: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
assert N % 2 == 0
N_half = N // 2
x_fp8 = torch.empty((M, N_half),
dtype=rocm_aiter_fp8_dtype,
device=x.device)
out_bs = torch.empty(
(M, (N_half + rocm_aiter_fp8_quant_group_size - 1) //
rocm_aiter_fp8_quant_group_size),
dtype=torch.float32,
device=x.device)
return x_fp8, out_bs

direct_register_custom_op(
op_name="rocm_aiter_act_mul_and_fp8_group_quant",
Copy link
Contributor

@tjtanaa tjtanaa Sep 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check if the latest aiter allows you to skip direct register custom ops? I remember most ops now should be able to work without calling direct_register_custom_ops on vLLM side as it is done in AITER repository. Moreover, removing the direct_register_custom_ops wrappers can reduce additional CPU overhead. Doing direct_register_custom_ops can be costly in terms of overhead.

Please take a look at the benchmarking results in this PR ROCm#717 (the second and third case) where it shows that removing the direct_register_custom_ops on vLLM side improves the perf.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, thanks for the feedback. Is there a version of aiter which has aiter.ops.triton.fused_fp8_quant and also has these direct_register_custom_ops that you mentioned? I wasn't able to figure out how to call act_mul_and_fp8_group_quant without calling direct_register_custom_op first. Would be happy to investigate further if you can point me in the right direction, otherwise I think we can always come back and get rid of these direct_register_custom_op calls if needed.

op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl,
mutates_args=[],
fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
BLOCK_LINEAR_OP = torch.ops.vllm.apply_w8a8_block_fp8_linear.default
FUSED_SILU_MUL_QUANT_OP = \
torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
AITER_BLOCK_LINEAR_OP = \
torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale.default

class AiterSiluMulFp8BlockQuantPattern(ActivationQuantPattern):

def __init__(self):
pass

def register(self, pm_pass: PatternMatcherPass):

def pattern(input: torch.Tensor, result_silu_mul: torch.Tensor,
linear_weight: torch.Tensor,
linear_weight_scale: torch.Tensor):
at1 = auto_functionalized(SILU_MUL_OP,
result=result_silu_mul,
input=input)
at2 = BLOCK_LINEAR_OP(input=at1[1],
weight=linear_weight,
block_size=[128, 128],
weight_scale=linear_weight_scale,
input_scale=None,
bias=None,
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True)
return at2

def replacement(input: torch.Tensor, result_silu_mul: torch.Tensor,
linear_weight: torch.Tensor,
linear_weight_scale: torch.Tensor):
at1 = FUSED_SILU_MUL_QUANT_OP(x=input)
at2 = AITER_BLOCK_LINEAR_OP(A=at1[0],
B=linear_weight,
As=at1[1],
Bs=linear_weight_scale,
block_size=[128, 128],
output_dtype=input.dtype)
return at2

inputs = [
empty_bf16(5, 4), # input
empty_bf16(5, 4), # result_silu_mul
# linear_weight
torch.empty((2, 5), device="cuda", dtype=FP8_DTYPE),
empty_fp32(1, 1) # linear_weight_scale
]

register_replacement(pattern, replacement, inputs, fwd_only,
pm_pass)


class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
"""
Fusion for SiluMul+Fp8StaticQuant Pattern
Expand Down Expand Up @@ -176,6 +273,11 @@ def __init__(self, config: VllmConfig):
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
pattern_silu_mul_nvfp4.register(self.patterns)

if is_rocm_aiter_linear_enabled():
pattern_silu_mul_aiter_block_fp8 = AiterSiluMulFp8BlockQuantPattern(
)
pattern_silu_mul_aiter_block_fp8.register(self.patterns)

self.dump_patterns(config, self.patterns)

@VllmInductorPass.time_and_log
Expand All @@ -184,6 +286,11 @@ def __call__(self, graph: torch.fx.Graph):
logger.debug("Replaced %s patterns", self.matched_count)

def uuid(self):
return VllmInductorPass.hash_source(self, ActivationQuantPattern,
SiluMulFp8StaticQuantPattern,
SiluMulNvfp4QuantPattern)
fusion_patterns = [
ActivationQuantPattern,
SiluMulFp8StaticQuantPattern,
SiluMulNvfp4QuantPattern,
]
if is_rocm_aiter_linear_enabled():
fusion_patterns.append(AiterSiluMulFp8BlockQuantPattern)
return VllmInductorPass.hash_source(self, *fusion_patterns)
165 changes: 160 additions & 5 deletions vllm/compilation/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload

import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
Expand Down Expand Up @@ -77,6 +78,28 @@ def __str__(self):
}


def is_rocm_aiter_enabled() -> bool:
return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER


if is_rocm_aiter_enabled():
AITER_RMS_GROUP_QUANT_OP = \
torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default
AITER_RMS_ADD_GROUP_QUANT_OP = \
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default

BLOCK_LINEAR_OP = torch.ops.vllm.apply_w8a8_block_fp8_linear.default
AITER_BLOCK_LINEAR_OP = \
torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale.default

AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default
AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default

import aiter as rocm_aiter
rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8
rocm_aiter_fp8_quant_group_size = 128


class RMSNormQuantPattern:

def __init__(self, epsilon: float, key: FusedRMSQuantKey):
Expand Down Expand Up @@ -338,6 +361,123 @@ def replacement(result: torch.Tensor, input: torch.Tensor,
)


if is_rocm_aiter_enabled():

class AiterRMSGroupQuantFP8Pattern:

def __init__(self, epsilon: float, quant_dtype: torch.dtype):
self.epsilon = epsilon
self.quant_dtype = quant_dtype

def register(self, pm_pass: PatternMatcherPass):

def pattern(
input: torch.Tensor,
weight: torch.Tensor, #result_rms: torch.Tensor,
linear_weight: torch.Tensor,
linear_weight_scale: torch.Tensor):
at1 = AITER_RMS_OP(x=input,
weight=weight,
variance_epsilon=self.epsilon)

at2 = BLOCK_LINEAR_OP(input=at1,
weight=linear_weight,
block_size=[128, 128],
weight_scale=linear_weight_scale,
input_scale=None,
bias=None,
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True)

return at2

def replacement(input: torch.Tensor, weight: torch.Tensor,
linear_weight: torch.Tensor,
linear_weight_scale: torch.Tensor):
at1 = AITER_RMS_GROUP_QUANT_OP(x=input,
residual=None,
weight=weight,
variance_epsilon=self.epsilon)

at2 = AITER_BLOCK_LINEAR_OP(A=at1[0],
B=linear_weight,
As=at1[1],
Bs=linear_weight_scale,
block_size=[128, 128],
output_dtype=input.dtype)

return at2

inputs = [
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
torch.empty((2, 5), device="cuda",
dtype=FP8_DTYPE), # linear_weight
empty_fp32(1, 1), # linear_weight_scale
]

pm.register_replacement(pattern, replacement, inputs, pm.fwd_only,
pm_pass)

class AiterFusedAddRMSGroupQuantPattern:

def __init__(self, epsilon: float, quant_dtype: torch.dtype):
self.epsilon = epsilon
self.quant_dtype = quant_dtype

def register(self, pm_pass: PatternMatcherPass):

def pattern(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, linear_weight: torch.Tensor,
linear_weight_scale: torch.Tensor):
at1 = AITER_RMS_ADD_OP(x=input,
residual=residual,
weight=weight,
variance_epsilon=self.epsilon)

at2 = BLOCK_LINEAR_OP(input=at1[0],
weight=linear_weight,
block_size=[128, 128],
weight_scale=linear_weight_scale,
input_scale=None,
bias=None,
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True)
# result, residual
return at2, at1[1]

def replacement(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, linear_weight: torch.Tensor,
linear_weight_scale: torch.Tensor):

at1 = AITER_RMS_ADD_GROUP_QUANT_OP(
x=input,
residual=residual,
weight=weight,
variance_epsilon=self.epsilon)

at2 = AITER_BLOCK_LINEAR_OP(A=at1[0],
B=linear_weight,
As=at1[1],
Bs=linear_weight_scale,
block_size=[128, 128],
output_dtype=input.dtype)
# result, residual
return at2, at1[2]

inputs = [
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
torch.empty((2, 5), device="cuda",
dtype=FP8_DTYPE), # linear_weight
empty_fp32(1, 1), # linear_weight_scale
]

pm.register_replacement(pattern, replacement, inputs, pm.fwd_only,
pm_pass)


class RMSNormQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
Expand Down Expand Up @@ -368,6 +508,14 @@ def __init__(self, config: VllmConfig):
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns)

if is_rocm_aiter_enabled():
# Fuse rms_norm + dynamic group fp8 quant
AiterRMSGroupQuantFP8Pattern(epsilon,
FP8_DTYPE).register(self.patterns)

AiterFusedAddRMSGroupQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns)

self.dump_patterns(config, self.patterns)

@VllmInductorPass.time_and_log
Expand All @@ -376,8 +524,15 @@ def __call__(self, graph: fx.Graph):
logger.debug("Replaced %s patterns", self.matched_count)

def uuid(self) -> Any:
return self.hash_source(self, RMSNormQuantPattern,
RMSNormStaticQuantPattern,
RMSNormDynamicQuantPattern,
FusedAddRMSNormStaticQuantPattern,
FusedAddRMSNormDynamicQuantPattern)
fusion_patterns = [
RMSNormQuantPattern,
RMSNormStaticQuantPattern,
RMSNormDynamicQuantPattern,
FusedAddRMSNormStaticQuantPattern,
FusedAddRMSNormDynamicQuantPattern,
]
if is_rocm_aiter_enabled():
fusion_patterns.extend([
AiterRMSGroupQuantFP8Pattern, AiterFusedAddRMSGroupQuantPattern
])
return self.hash_source(self, *fusion_patterns)
Loading