1+ #################### BASE BUILD IMAGE ####################
2+ FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
3+ RUN apt-get update -y \
4+ && apt-get install -y python3-pip git
5+ # Workaround for https://github.com/openai/triton/issues/2507 and
6+ # https://github.com/pytorch/pytorch/issues/107960 -- hopefully
7+ # this won't be needed for future versions of this docker image
8+ # or future versions of triton.
9+ RUN ldconfig /usr/local/cuda-12.1/compat/
10+ WORKDIR /workspace
11+
12+ COPY model-engine/model_engine_server/inference/batch_inference/requirements-build.txt requirements-build.txt
13+ RUN --mount=type=cache,target=/root/.cache/pip \
14+ pip install -r requirements-build.txt
15+ #################### BASE BUILD IMAGE ####################
16+
17+ #################### FLASH_ATTENTION Build IMAGE ####################
18+ FROM dev as flash-attn-builder
19+ # max jobs used for build
20+ ARG max_jobs=2
21+ ENV MAX_JOBS=${max_jobs}
22+ # flash attention version
23+ ARG flash_attn_version=v2.5.6
24+ ENV FLASH_ATTN_VERSION=${flash_attn_version}
25+
26+ WORKDIR /usr/src/flash-attention-v2
27+
28+ # Download the wheel or build it if a pre-compiled release doesn't exist
29+ RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \
30+ --no-build-isolation --no-deps --no-cache-dir
31+
32+ #################### FLASH_ATTENTION Build IMAGE ####################
33+
34+ #################### Runtime IMAGE ####################
135FROM nvcr.io/nvidia/pytorch:23.09-py3
236
337RUN apt-get update && \
@@ -6,6 +40,10 @@ RUN apt-get update && \
640 rm -rf /var/lib/apt/lists/* && \
741 apt-get clean
842
43+ # Install flash attention (from pre-built wheel)
44+ RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
45+ pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
46+
947RUN pip uninstall torch -y
1048RUN pip install torch==2.1.1 --index-url https://download.pytorch.org/whl/cu121
1149
@@ -21,3 +59,5 @@ RUN pip install -r requirements.txt
2159COPY model-engine /workspace/model-engine
2260RUN pip install -e /workspace/model-engine
2361COPY model-engine/model_engine_server/inference/batch_inference/vllm_batch.py /workspace/vllm_batch.py
62+
63+ #################### Runtime IMAGE ####################
0 commit comments