Skip to content

Commit f375ec8

Browse files
[ROCm] Upgrade xformers version for ROCm & update doc (#2079)
Co-authored-by: miloice <[email protected]>
1 parent 518369d commit f375ec8

File tree

6 files changed

+84
-56
lines changed

6 files changed

+84
-56
lines changed

Dockerfile.rocm

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@ RUN mkdir libs \
4747
COPY ./ /app/vllm
4848

4949
RUN python3 -m pip install --upgrade pip
50-
RUN pip install xformers==0.0.22.post7 --no-deps
50+
RUN pip install xformers==0.0.23 --no-deps
5151

5252
RUN cd /app \
5353
&& cd vllm \
5454
&& pip install -U -r requirements-rocm.txt \
55-
&& bash patch_xformers-0.0.22.post7.rocm.sh \
55+
&& bash patch_xformers-0.0.23.rocm.sh \
5656
&& python3 setup.py install \
5757
&& cd ..
5858

docs/source/getting_started/amd-installation.rst

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Installation with ROCm
44
======================
55

6-
vLLM 0.2.x onwards supports model inferencing and serving on AMD GPUs with ROCm.
6+
vLLM 0.2.4 onwards supports model inferencing and serving on AMD GPUs with ROCm.
77
At the moment AWQ quantization is not supported in ROCm, but SqueezeLLM quantization has been ported.
88
Data types currently supported in ROCm are FP16 and BF16.
99

@@ -29,7 +29,7 @@ Installation options:
2929

3030
.. code-block:: console
3131
32-
$ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.3
32+
$ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.4
3333
$ docker run -it \
3434
--network=host \
3535
--group-add=video \
@@ -70,12 +70,12 @@ You can build and install vLLM from source:
7070
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
7171
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
7272

73-
2. Setup `xformers==0.0.22.post7` without dependencies, and apply patches to adapt for ROCm flash attention
73+
2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention
7474

7575
.. code-block:: console
7676
77-
$ pip install xformers==0.0.22.post7 --no-deps
78-
$ bash patch_xformers-0.0.22.post7.rocm.sh
77+
$ pip install xformers==0.0.23 --no-deps
78+
$ bash patch_xformers.rocm.sh
7979
8080
3. Build vLLM.
8181

@@ -127,12 +127,12 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from
127127
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
128128
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
129129

130-
2. Setup `xformers==0.0.22.post7` without dependencies, and apply patches to adapt for ROCm flash attention
130+
2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention
131131

132132
.. code-block:: console
133133
134-
$ pip install xformers==0.0.22.post7 --no-deps
135-
$ bash patch_xformers-0.0.22.post7.rocm.sh
134+
$ pip install xformers==0.0.23 --no-deps
135+
$ bash patch_xformers.rocm.sh
136136
137137
3. Build vLLM.
138138

patch_xformers-0.0.22.post7.rocm.sh renamed to patch_xformers.rocm.sh

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,32 @@
11
#!/bin/bash
2+
set -e
3+
4+
XFORMERS_VERSION="0.0.23"
5+
6+
export XFORMERS_INSTALLED_VERSION=$(python -c 'import xformers; print(xformers.__version__)')
7+
8+
if [ "$XFORMERS_INSTALLED_VERSION" != "$XFORMERS_VERSION" ]; then
9+
echo "ERROR: xformers version must be ${XFORMERS_VERSION}. ${XFORMERS_INSTALLED_VERSION} is installed"
10+
exit 1
11+
fi
12+
213
export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)')
314
export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)')
415

5-
echo $XFORMERS_FMHA_FLASH_PATH
6-
echo $XFORMERS_FMHA_COMMON_PATH
16+
echo "XFORMERS_FMHA_FLASH_PATH = ${XFORMERS_FMHA_FLASH_PATH}"
17+
echo "XFORMERS_FMHA_COMMON_PATH = ${XFORMERS_FMHA_COMMON_PATH}"
718

8-
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch"; then
19+
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
920
echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}"
10-
patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch"
21+
patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"
1122
echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}"
1223
else
1324
echo "${XFORMERS_FMHA_FLASH_PATH} was patched before"
1425
fi
1526

16-
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch"; then
27+
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
1728
echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}"
18-
patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch"
29+
patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"
1930
echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}"
2031
else
2132
echo "${XFORMERS_FMHA_COMMON_PATH} was patched before"

requirements-rocm.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ pyarrow # Required for Ray data.
88
sentencepiece # Required for LLaMA tokenizer.
99
numpy
1010
tokenizers>=0.15.0
11-
huggingface_hub<0.18,>=0.16.4
1211
transformers >= 4.36.0 # Required for Mixtral.
1312
fastapi
1413
uvicorn[standard]
File renamed without changes.

rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch renamed to rocm_patch/flashpy_xformers-0.0.23.rocm.patch

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/flash.py 2023-11-29 03:17:03.930103539 +0000
2-
+++ flash.py 2023-11-28 16:14:25.206128903 +0000
3-
@@ -31,39 +31,39 @@
1+
--- flash_ori.py 2023-12-13 05:43:31.530752623 +0000
2+
+++ flash_patch.py 2023-12-13 06:00:45.962403104 +0000
3+
@@ -36,44 +36,44 @@
44

55
FLASH_VERSION = "0.0.0"
66
try:
@@ -15,9 +15,12 @@
1515
- from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
1616
-
1717
- FLASH_VERSION = flash_attn.__version__
18-
- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
19-
- if flash_ver_parsed < (2, 3):
20-
- raise ImportError("Requires 2.3 for sliding window support")
18+
- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
19+
- if (
20+
- flash_ver_parsed != (2, 3, 6)
21+
- and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
22+
- ):
23+
- raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
2124
+ #try:
2225
+ # from ... import _C_flashattention # type: ignore[attr-defined]
2326
+ # from ..._cpp_lib import _build_metadata
@@ -29,88 +32,103 @@
2932
+ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
3033
+
3134
+ FLASH_VERSION = flash_attn.__version__
32-
+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
33-
+ # if flash_ver_parsed < (2, 3):
34-
+ # raise ImportError("Requires 2.3 for sliding window support")
35+
+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
36+
+ # if (
37+
+ # flash_ver_parsed != (2, 3, 6)
38+
+ # and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
39+
+ # ):
40+
+ # raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
3541

3642
# create library so that flash-attn goes through the PyTorch Dispatcher
3743
- _flash_lib = torch.library.Library("xformers_flash", "DEF")
38-
+ #_flash_lib = torch.library.Library("xformers_flash", "DEF")
39-
44+
-
4045
- _flash_lib.define(
4146
- "flash_fwd(Tensor query, Tensor key, Tensor value, "
42-
- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
47+
- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
4348
- "int max_seqlen_q, int max_seqlen_k, "
4449
- "float p, float softmax_scale, "
45-
- "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
50+
- "bool is_causal, int window_left, "
51+
- "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
4652
- )
47-
-
53+
+ #_flash_lib = torch.library.Library("xformers_flash", "DEF")
54+
4855
- _flash_lib.define(
4956
- "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
5057
- "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
5158
- "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
5259
- "int max_seqlen_q, int max_seqlen_k, "
53-
- "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
60+
- "float p, float softmax_scale, bool is_causal, "
61+
- "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
5462
- )
5563
+ #_flash_lib.define(
5664
+ # "flash_fwd(Tensor query, Tensor key, Tensor value, "
57-
+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
65+
+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
5866
+ # "int max_seqlen_q, int max_seqlen_k, "
5967
+ # "float p, float softmax_scale, "
60-
+ # "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
68+
+ # "bool is_causal, int window_left, "
69+
+ # "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
6170
+ #)
6271
+
6372
+ #_flash_lib.define(
6473
+ # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
6574
+ # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
6675
+ # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
6776
+ # "int max_seqlen_q, int max_seqlen_k, "
68-
+ # "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
77+
+ # "float p, float softmax_scale, bool is_causal, "
78+
+ # "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
6979
+ #)
7080

7181
def _flash_fwd(
7282
query,
73-
@@ -98,8 +98,8 @@
83+
@@ -111,8 +111,8 @@
7484
p,
7585
softmax_scale,
7686
is_causal,
77-
- window_size - 1, # window_size_left
78-
- -1, # window_size_right
79-
+ # window_size - 1, # window_size_left
80-
+ # -1, # window_size_right
87+
- window_left, # window_size_left
88+
- window_right, # window_size_right
89+
+ # window_left, # window_size_left
90+
+ # window_right, # window_size_right
8191
return_softmax,
8292
None, # rng
8393
)
84-
@@ -127,8 +127,8 @@
94+
@@ -134,15 +134,15 @@
95+
out,
96+
cu_seq_lens_q,
97+
cu_seq_lens_k,
98+
- seqused_k,
99+
+ # seqused_k,
100+
max_seq_len_q,
101+
max_seq_len_k,
102+
p,
85103
softmax_scale,
86104
False,
87105
is_causal,
88-
- window_size - 1, # window_size_left
89-
- -1, # window_size_right
90-
+ # window_size - 1, # window_size_left
91-
+ # -1, # window_size_right
106+
- window_left,
107+
- window_right,
108+
+ # window_left,
109+
+ # window_right,
92110
return_softmax,
93111
None,
94112
)
95-
@@ -169,8 +169,8 @@
113+
@@ -184,8 +184,8 @@
96114
p,
97115
softmax_scale,
98116
is_causal,
99-
- window_size - 1, # window_size_left
100-
- -1, # window_size_right
101-
+ # window_size - 1, # window_size_left
102-
+ # -1, # window_size_right
117+
- window_left,
118+
- window_right,
119+
+ # window_left,
120+
+ # window_right,
103121
None,
104122
rng_state,
105123
)
106-
@@ -193,15 +193,15 @@
124+
@@ -208,15 +208,15 @@
107125
softmax_scale,
108126
False, # zero_tensors
109127
is_causal,
110-
- window_size - 1, # window_size_left
111-
- -1, # window_size_right
112-
+ # window_size - 1, # window_size_left
113-
+ # -1, # window_size_right
128+
- window_left,
129+
- window_right,
130+
+ # window_left,
131+
+ # window_right,
114132
None,
115133
rng_state,
116134
)
@@ -123,7 +141,7 @@
123141
except ImportError:
124142
pass
125143

126-
@@ -348,7 +348,7 @@
144+
@@ -400,7 +400,7 @@
127145
implementation.
128146
"""
129147

0 commit comments

Comments
 (0)