@@ -88,7 +88,8 @@ ENV NV_CUDA_CUDART_DEV_VERSION=12.1.55-1 \
8888 NV_NVML_DEV_VERSION=12.1.55-1 \
8989 NV_LIBCUBLAS_DEV_VERSION=12.1.0.26-1 \
9090 NV_LIBNPP_DEV_VERSION=12.0.2.50-1 \
91- NV_LIBNCCL_DEV_PACKAGE_VERSION=2.18.3-1+cuda12.1
91+ NV_LIBNCCL_DEV_PACKAGE_VERSION=2.18.3-1+cuda12.1 \
92+ NV_CUDNN9_CUDA_VERSION=9.6.0.74-1
9293
9394RUN dnf config-manager \
9495 --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \
@@ -103,6 +104,15 @@ RUN dnf config-manager \
103104 libnccl-devel-${NV_LIBNCCL_DEV_PACKAGE_VERSION} \
104105 && dnf clean all
105106
107+ # opening connection for too long in one go was resulting in timeouts
108+ RUN dnf config-manager \
109+ --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo \
110+ && dnf clean packages \
111+ && dnf install -y \
112+ libcusparselt0 libcusparselt-devel \
113+ cudnn9-cuda-12-6-${NV_CUDNN9_CUDA_VERSION} \
114+ && dnf clean all
115+
106116ENV LIBRARY_PATH="$CUDA_HOME/lib64/stubs"
107117
108118FROM cuda-devel AS python-installations
@@ -138,7 +148,8 @@ RUN if [[ -z "${WHEEL_VERSION}" ]]; \
138148RUN --mount=type=cache,target=/home/${USER}/.cache/pip,uid=${USER_UID} \
139149 python -m pip install --user wheel && \
140150 python -m pip install --user "$(head bdist_name)" && \
141- python -m pip install --user "$(head bdist_name)[flash-attn]"
151+ python -m pip install --user "$(head bdist_name)[flash-attn]" && \
152+ python -m pip install --user "$(head bdist_name)[mamba]"
142153
143154# fms_acceleration_peft = PEFT-training, e.g., 4bit QLoRA
144155# fms_acceleration_foak = Fused LoRA and triton kernels
0 commit comments