Skip to content

Commit 0c9d6f4

Browse files
move to rocm flash attn fork
1 parent 7e6c3e2 commit 0c9d6f4

File tree

1 file changed

+12
-13
lines changed
  • images/universal/training/rocm64-torch290-py312

1 file changed

+12
-13
lines changed

images/universal/training/rocm64-torch290-py312/Dockerfile

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,6 @@ RUN dnf install -y --setopt=install_weak_deps=False \
100100
cmake \
101101
git && dnf clean all && rm -rf /var/cache/dnf/*
102102

103-
# Install ninja as root (critical for flash-attention, reduces build from hours to minutes)
104-
# ninja-build package not available in base repos, so install via pip
105-
RUN pip install --no-cache-dir ninja && \
106-
ln -sf /usr/local/bin/ninja /usr/bin/ninja
107-
108103
# Bundle RDMA runtime libs to a staging dir
109104
RUN mkdir -p /opt/rdma-runtime \
110105
&& cp -a /usr/lib64/libibverbs* /opt/rdma-runtime/ || true \
@@ -145,25 +140,29 @@ ENV UV_NO_CACHE=
145140
RUN pip install --retries 5 --timeout 300 --no-cache-dir \
146141
"git+https://github.com/opendatahub-io/kubeflow-sdk@main"
147142

148-
# Install Flash Attention from original Dao-AILab repo
149-
# --no-build-isolation: Use already-installed torch instead of isolated env
143+
# Install Flash Attention from ROCm fork with Triton AMD backend
144+
# This is faster to build and optimized for AMD GPUs
150145
USER 0
151146

152147
# Set build parallelism environment variables
153148
# MAX_JOBS: Controls PyTorch extension build parallelism
154149
# CMAKE_BUILD_PARALLEL_LEVEL: Controls CMake parallelism
155-
# NINJA_FLAGS: Controls ninja build parallelism
156150
# GPU_ARCHS: Target GPU architectures (gfx942=MI300, gfx90a=MI200/MI250)
157151
ENV GPU_ARCHS="gfx90a;gfx942" \
158152
MAX_JOBS=12 \
159-
CMAKE_BUILD_PARALLEL_LEVEL=12 \
160-
NINJA_FLAGS=-j12
153+
CMAKE_BUILD_PARALLEL_LEVEL=12
154+
155+
# Install Triton and ninja (required for ROCm flash-attention build)
156+
RUN /opt/app-root/bin/pip install --no-cache-dir triton==3.2.0 ninja
157+
158+
# Enable Triton AMD backend for flash-attention
159+
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
161160

162161
RUN cd /tmp \
163-
&& git clone --depth 1 --branch v2.8.2 https://github.com/Dao-AILab/flash-attention.git \
162+
&& git clone https://github.com/ROCm/flash-attention.git \
164163
&& cd flash-attention \
165-
&& git submodule update --init \
166-
&& pip install --no-build-isolation --no-cache-dir --no-deps . \
164+
&& git checkout main_perf \
165+
&& /opt/app-root/bin/python setup.py install \
167166
&& cd / && rm -rf /tmp/flash-attention
168167

169168

0 commit comments

Comments
 (0)