@@ -100,11 +100,6 @@ RUN dnf install -y --setopt=install_weak_deps=False \
100100 cmake \
101101 git && dnf clean all && rm -rf /var/cache/dnf/*
102102
103- # Install ninja as root (critical for flash-attention, reduces build from hours to minutes)
104- # ninja-build package not available in base repos, so install via pip
105- RUN pip install --no-cache-dir ninja && \
106- ln -sf /usr/local/bin/ninja /usr/bin/ninja
107-
108103# Bundle RDMA runtime libs to a staging dir
109104RUN mkdir -p /opt/rdma-runtime \
110105 && cp -a /usr/lib64/libibverbs* /opt/rdma-runtime/ || true \
@@ -145,25 +140,29 @@ ENV UV_NO_CACHE=
145140RUN pip install --retries 5 --timeout 300 --no-cache-dir \
146141 "git+https://github.com/opendatahub-io/kubeflow-sdk@main"
147142
148- # Install Flash Attention from original Dao-AILab repo
149- # --no-build-isolation: Use already-installed torch instead of isolated env
143+ # Install Flash Attention from ROCm fork with Triton AMD backend
144+ # This is faster to build and optimized for AMD GPUs
150145USER 0
151146
152147# Set build parallelism environment variables
153148# MAX_JOBS: Controls PyTorch extension build parallelism
154149# CMAKE_BUILD_PARALLEL_LEVEL: Controls CMake parallelism
155- # NINJA_FLAGS: Controls ninja build parallelism
156150# GPU_ARCHS: Target GPU architectures (gfx942=MI300, gfx90a=MI200/MI250)
157151ENV GPU_ARCHS="gfx90a;gfx942" \
158152 MAX_JOBS=12 \
159- CMAKE_BUILD_PARALLEL_LEVEL=12 \
160- NINJA_FLAGS=-j12
153+ CMAKE_BUILD_PARALLEL_LEVEL=12
154+
155+ # Install Triton and ninja (required for ROCm flash-attention build)
156+ RUN /opt/app-root/bin/pip install --no-cache-dir triton==3.2.0 ninja
157+
158+ # Enable Triton AMD backend for flash-attention
159+ ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
161160
162161RUN cd /tmp \
163- && git clone --depth 1 --branch v2.8.2 https://github.com/Dao-AILab /flash-attention.git \
162+ && git clone https://github.com/ROCm /flash-attention.git \
164163 && cd flash-attention \
165- && git submodule update --init \
166- && pip install --no-build-isolation --no-cache-dir --no-deps . \
164+ && git checkout main_perf \
165+ && /opt/app-root/bin/python setup.py install \
167166 && cd / && rm -rf /tmp/flash-attention
168167
169168
0 commit comments