Skip to content

Commit 95e5b08

Browse files
authored
[AMD][Hardware][Misc][Bugfix] xformer cleanup and light navi logic and CI fixes and refactoring (#4129)
1 parent a37d815 commit 95e5b08

File tree

6 files changed

+19
-217
lines changed

6 files changed

+19
-217
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@ steps:
1515
commands:
1616
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
1717
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
18-
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_basic_correctness.py
1918
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
2019
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
21-
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_chunked_prefill.py
2220

2321
- label: Core Test
2422
command: pytest -v -s core

Dockerfile.rocm

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ RUN echo "Base image is $BASE_IMAGE"
1414
ARG FA_GFX_ARCHS="gfx90a;gfx942"
1515
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
1616

17-
ARG FA_BRANCH="3d2b6f5"
17+
ARG FA_BRANCH="ae7928c"
1818
RUN echo "FA_BRANCH is $FA_BRANCH"
1919

2020
# whether to build flash-attention
@@ -92,13 +92,10 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \
9292
COPY ./ /app/vllm
9393

9494
RUN python3 -m pip install --upgrade pip numba
95-
RUN python3 -m pip install xformers==0.0.23 --no-deps
9695

9796
RUN cd /app \
9897
&& cd vllm \
9998
&& pip install -U -r requirements-rocm.txt \
100-
&& if [ "$BUILD_FA" = "1" ]; then \
101-
bash patch_xformers.rocm.sh; fi \
10299
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \
103100
&& python3 setup.py install \
104101
&& cd ..

patch_xformers.rocm.sh

Lines changed: 0 additions & 33 deletions
This file was deleted.

rocm_patch/commonpy_xformers-0.0.23.rocm.patch

Lines changed: 0 additions & 13 deletions
This file was deleted.

rocm_patch/flashpy_xformers-0.0.23.rocm.patch

Lines changed: 0 additions & 152 deletions
This file was deleted.

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -154,25 +154,30 @@ def __init__(
154154
f"Head size {head_size} is not supported by PagedAttention. "
155155
f"Supported head sizes are: {suppored_head_sizes}.")
156156

157-
self.use_naive_attn = torch.cuda.get_device_capability()[0] != 9
157+
self.use_naive_attn = False
158158
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
159159
self.use_triton_flash_attn = (os.environ.get(
160160
"VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1"))
161-
if self.use_naive_attn:
162-
# AMD Radeon 7900 series (gfx1100) currently does not support
163-
# xFormers nor FlashAttention. As a temporary workaround, we use
164-
# naive PyTorch implementation of attention.
165-
self.attn_fuc = _naive_attention
166-
logger.debug("Using naive attention in ROCmBackend")
167-
elif self.use_triton_flash_attn:
161+
if self.use_triton_flash_attn:
168162
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
169163
triton_attention)
170164
self.attn_func = triton_attention
171165
logger.debug("Using Triton FA in ROCmBackend")
172166
else:
173-
from flash_attn import flash_attn_varlen_func # noqa: F401
174-
self.attn_func = flash_attn_varlen_func
175-
logger.debug("Using CK FA in ROCmBackend")
167+
# if not using triton, navi3x not use flash-attn either
168+
if torch.cuda.get_device_capability()[0] == 11:
169+
self.use_naive_attn = True
170+
else:
171+
try:
172+
from flash_attn import flash_attn_varlen_func # noqa: F401
173+
self.attn_func = flash_attn_varlen_func
174+
logger.debug("Using CK FA in ROCmBackend")
175+
except ModuleNotFoundError:
176+
self.use_naive_attn = True
177+
178+
if self.use_naive_attn:
179+
self.attn_func = _naive_attention
180+
logger.debug("Using naive attention in ROCmBackend")
176181

177182
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
178183
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
@@ -247,13 +252,13 @@ def forward(
247252
# triton attention
248253
# When block_tables are not filled, it means q and k are the
249254
# prompt, and they have the same length.
250-
if self.use_naive_attn or self.use_triton_flash_attn:
255+
if self.use_triton_flash_attn or self.use_naive_attn:
251256
if self.num_kv_heads != self.num_heads:
252257
# Interleave for MQA workaround.
253258
key = self.repeat_kv(key, self.num_queries_per_kv)
254259
value = self.repeat_kv(value, self.num_queries_per_kv)
255260
if self.use_naive_attn:
256-
out = self.attn_fuc(
261+
out = self.attn_func(
257262
query,
258263
key,
259264
value,

0 commit comments

Comments
 (0)