-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Rocm][torch.compile] Adding layernorm + fp8 block quant and silu + fp8 block quant for Aiter #25693
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: charlifu <[email protected]>
Signed-off-by: charlifu <[email protected]>
Signed-off-by: charlifu <[email protected]>
Signed-off-by: charlifu <[email protected]>
Signed-off-by: charlifu <[email protected]>
Signed-off-by: charlifu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces new fusion passes for ROCm AITer, specifically for layernorm + fp8 block quant
and silu + fp8 block quant
. This is achieved by adding a new pattern AiterSiluMulFp8BlockQuantPattern
and registering a new custom operator. Additionally, the changes in fp8_utils.py
extend AITer support to non-MI300 series GPUs by providing a Triton-based fallback, which is a great enhancement.
My main feedback is on a performance concern in fp8_utils.py
where an import is performed inside a performance-critical function. I've suggested a refactoring to move the import to the module level to avoid repeated overhead.
# MI300's fp8nuz should be enough to detect if we call ck vs triton | ||
if current_platform.is_fp8_fnuz(): | ||
from aiter import gemm_a8w8_blockscale | ||
else: | ||
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale | ||
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Importing inside a function that is on a hot path, like this custom op implementation, can introduce performance overhead. It's best practice to move imports to the module level to ensure they are only executed once.
I'd recommend defining a module-level variable that holds the correct gemm_a8w8_blockscale
function based on the platform, and then using that variable within this function. This avoids repeated import lookups.
For example, you could add the following logic at the module level (e.g., near the top of the file):
_gemm_a8w8_blockscale = None
if current_platform.is_rocm():
try:
# MI300's fp8nuz should be enough to detect if we call ck vs triton
if current_platform.is_fp8_fnuz():
from aiter import gemm_a8w8_blockscale
else:
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
_gemm_a8w8_blockscale = gemm_a8w8_blockscale
except ImportError:
# aiter is not installed, which is fine.
# The error will be raised when the op is actually used.
pass
And then this function's body can be simplified as suggested.
# MI300's fp8nuz should be enough to detect if we call ck vs triton | |
if current_platform.is_fp8_fnuz(): | |
from aiter import gemm_a8w8_blockscale | |
else: | |
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale | |
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) | |
if _gemm_a8w8_blockscale is None: | |
raise ImportError( | |
"Aiter backend for gemm_a8w8_blockscale not available. " | |
"Please install aiter.") | |
return _gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) |
Signed-off-by: Micah Williamson <[email protected]>
I'm currently overhauling custom op matching in #24604. We also recently added a torch implementation of group quant, could you compare its performance with AITER? Also could you compare the perf of the fused AITER kernel to the fused torch.compile kernel for rmsnorm+quant. Happy to help out with instructions, but overall:
|
SiluMulFp8StaticQuantPattern, | ||
SiluMulNvfp4QuantPattern) | ||
SiluMulNvfp4QuantPattern, | ||
AiterSiluMulFp8BlockQuantPattern) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This symbol definition is conditional on is_rocm_aiter_linear_enabled():
Any run will fail here if not enabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be fixed now cd059b9
return x_fp8, out_bs | ||
|
||
direct_register_custom_op( | ||
op_name="rocm_aiter_act_mul_and_fp8_group_quant", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check if the latest aiter allows you to skip direct register custom ops? I remember most ops now should be able to work without calling direct_register_custom_ops
on vLLM side as it is done in AITER repository. Moreover, removing the direct_register_custom_ops
wrappers can reduce additional CPU overhead. Doing direct_register_custom_ops
can be costly in terms of overhead.
Please take a look at the benchmarking results in this PR ROCm#717 (the second and third case) where it shows that removing the direct_register_custom_ops
on vLLM side improves the perf.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, thanks for the feedback. Is there a version of aiter which has aiter.ops.triton.fused_fp8_quant and also has these direct_register_custom_ops that you mentioned? I wasn't able to figure out how to call act_mul_and_fp8_group_quant without calling direct_register_custom_op first. Would be happy to investigate further if you can point me in the right direction, otherwise I think we can always come back and get rid of these direct_register_custom_op calls if needed.
Signed-off-by: Micah Williamson <[email protected]>
Signed-off-by: Micah Williamson <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: charlifu <[email protected]>
Signed-off-by: charlifu <[email protected]>
This PR adds a few fusion passes for Aiter to fusion layernorm + fp8 block quant and silu + fp8 block quant.