Skip to content

Commit 2976dc2

Browse files
tjtanaahongxiayangkliuae
authored
[Bug] [ROCm] Fix Llama 4 Enablement Bug on ROCm: V0 ROCmFlashAttentionImpl and Triton Fused MoE bugs (#16198)
Signed-off-by: tjtanaa <[email protected]> Signed-off-by: kliuae <[email protected]> Co-authored-by: Hongxia Yang <[email protected]> Co-authored-by: kliuae <[email protected]>
1 parent 102bf96 commit 2976dc2

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,10 @@ def __init__(
471471
if blocksparse_params is not None:
472472
raise ValueError(
473473
"ROCmFlashAttention does not support blocksparse attention.")
474-
474+
if use_irope:
475+
logger.warning(
476+
"Using irope in V0 is not supported yet, it will fall back "
477+
"to global attention for long context.")
475478
if logits_soft_cap is None:
476479
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
477480
self.logits_soft_cap = 0.0

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,6 +1002,7 @@ def inplace_fused_experts_fake(
10021002
op_func=inplace_fused_experts,
10031003
mutates_args=["hidden_states"],
10041004
fake_impl=inplace_fused_experts_fake,
1005+
tags=(torch.Tag.needs_fixed_stride_order, ),
10051006
)
10061007

10071008

@@ -1060,6 +1061,7 @@ def outplace_fused_experts_fake(
10601061
op_func=outplace_fused_experts,
10611062
mutates_args=[],
10621063
fake_impl=outplace_fused_experts_fake,
1064+
tags=(torch.Tag.needs_fixed_stride_order, ),
10631065
)
10641066

10651067

vllm/utils.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from functools import cache, lru_cache, partial, wraps
4141
from types import MappingProxyType
4242
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
43-
Optional, Type, TypeVar, Union, cast, overload)
43+
Optional, Tuple, Type, TypeVar, Union, cast, overload)
4444
from uuid import uuid4
4545

4646
import cachetools
@@ -1935,12 +1935,13 @@ def __getattr__(self, key: str):
19351935

19361936

19371937
def direct_register_custom_op(
1938-
op_name: str,
1939-
op_func: Callable,
1940-
mutates_args: list[str],
1941-
fake_impl: Optional[Callable] = None,
1942-
target_lib: Optional[Library] = None,
1943-
dispatch_key: str = "CUDA",
1938+
op_name: str,
1939+
op_func: Callable,
1940+
mutates_args: list[str],
1941+
fake_impl: Optional[Callable] = None,
1942+
target_lib: Optional[Library] = None,
1943+
dispatch_key: str = "CUDA",
1944+
tags: Tuple[torch.Tag, ...] = (),
19441945
):
19451946
"""
19461947
`torch.library.custom_op` can have significant overhead because it
@@ -1979,7 +1980,7 @@ def direct_register_custom_op(
19791980
import torch._custom_op.impl
19801981
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
19811982
my_lib = target_lib or vllm_lib
1982-
my_lib.define(op_name + schema_str)
1983+
my_lib.define(op_name + schema_str, tags=tags)
19831984
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
19841985
if fake_impl is not None:
19851986
my_lib._register_fake(op_name, fake_impl)

0 commit comments

Comments
 (0)