Skip to content

Commit 6e9ca9a

Browse files
tjohnson31415njhill
authored andcommitted
build: update Flash Attention v2 cache build and install
Makes some changes to the Flash Attention v2 build and install to simplify and to install as a complete python wheel package including metadata (not just copying a subset of the files). The metadata is needed so that transformers can detect the installation (transformers uses importlib.metadata.version() to inspect the metadata and parse a version). NB: The format of the cache image is changed with this PR so older versions will no longer work. Signed-off-by: Travis Johnson <[email protected]>
1 parent 7b24de1 commit 6e9ca9a

File tree

3 files changed

+9
-20
lines changed

3 files changed

+9
-20
lines changed

Dockerfile

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,13 @@ RUN pip install torch==$PYTORCH_VERSION+cu118 --index-url "${PYTORCH_INDEX}/cu11
205205

206206
## Build flash attention v2 ####################################################
207207
FROM python-builder as flash-att-v2-builder
208+
ARG FLASH_ATT_VERSION=v2.3.6
208209

209-
WORKDIR /usr/src
210+
WORKDIR /usr/src/flash-attention-v2
210211

211-
COPY server/Makefile-flash-att-v2 Makefile
212-
RUN MAX_JOBS=4 make build-flash-attention-v2
212+
# Download the wheel or build it if a pre-compiled release doesn't exist
213+
RUN MAX_JOBS=4 pip --verbose wheel flash-attn==${FLASH_ATT_VERSION} \
214+
--no-build-isolation --no-deps --no-cache-dir
213215

214216
## Build flash attention ######################################################
215217
FROM python-builder as flash-att-builder
@@ -253,7 +255,7 @@ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build /usr/sr
253255

254256
## Flash attention v2 cached build image #######################################
255257
FROM base as flash-att-v2-cache
256-
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build /usr/src/flash-attention-v2/build
258+
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2 /usr/src/flash-attention-v2
257259

258260

259261
## Final Inference Server image ################################################
@@ -278,8 +280,9 @@ COPY --from=flash-att-cache /usr/src/flash-attention/build/lib.linux-x86_64-cpyt
278280
COPY --from=flash-att-cache /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-* ${SITE_PACKAGES}
279281
COPY --from=flash-att-cache /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-* ${SITE_PACKAGES}
280282

281-
# Copy build artifacts from flash attention v2 builder
282-
COPY --from=flash-att-v2-cache /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-* ${SITE_PACKAGES}
283+
# Install flash attention v2 from the cache build
284+
RUN --mount=type=bind,from=flash-att-v2-cache,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
285+
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
283286

284287
# Copy build artifacts from exllama kernels builder
285288
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-* ${SITE_PACKAGES}

server/Makefile

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
include Makefile-flash-att
2-
include Makefile-flash-att-v2
32

43
gen-server:
54
# Compile protos

server/Makefile-flash-att-v2

Lines changed: 0 additions & 13 deletions
This file was deleted.

0 commit comments

Comments
 (0)