diff --git a/Dockerfile.ubi b/Dockerfile.ubi index e84473d21e0e..cd17fcb9e04d 100644 --- a/Dockerfile.ubi +++ b/Dockerfile.ubi @@ -1,9 +1,12 @@ +## Global Args ################################################################# +ARG BASE_UBI_IMAGE_TAG=9.5-1739420147 +ARG PYTHON_VERSION=3.12 -ARG BASE_UBI_IMAGE_TAG -ARG PYTHON_VERSION +ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX" +ARG vllm_fa_cmake_gpu_arches='80-real;90-real' ## Base Layer ################################################################## -FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS base +FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} as base ARG PYTHON_VERSION ENV PYTHON_VERSION=${PYTHON_VERSION} RUN microdnf -y update && microdnf install -y --nodocs \ @@ -16,14 +19,13 @@ ENV LANG=C.UTF-8 \ LC_ALL=C.UTF-8 # Some utils for dev purposes - tar required for kubectl cp - RUN microdnf install -y --nodocs \ - which procps findutils tar vim git \ + which procps findutils tar vim git\ && microdnf clean all ## Python Installer ############################################################ -FROM base AS python-install +FROM base as python-install ARG PYTHON_VERSION ENV VIRTUAL_ENV=/opt/vllm @@ -37,7 +39,7 @@ RUN microdnf install -y --nodocs \ ## CUDA Base ################################################################### -FROM python-install AS cuda-base +FROM python-install as cuda-base RUN curl -Lo /etc/yum.repos.d/cuda-rhel9.repo \ https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo @@ -51,6 +53,7 @@ RUN microdnf install -y --nodocs \ ln -s ${CUDA_HOME}/lib64/stubs/libcuda.so /usr/lib64/ + ## Python cuda base ################################################################# FROM cuda-base AS python-cuda-base @@ -65,9 +68,65 @@ RUN --mount=type=cache,target=/root/.cache/uv \ -r requirements-cuda.txt +## Development ################################################################# +FROM python-cuda-base AS dev + +# install build and runtime dependencies +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,source=requirements-common.txt,target=requirements-common.txt \ + --mount=type=bind,source=requirements-cuda.txt,target=requirements-cuda.txt \ + --mount=type=bind,source=requirements-dev.txt,target=requirements-dev.txt \ + --mount=type=bind,source=requirements-lint.txt,target=requirements-lint.txt \ + --mount=type=bind,source=requirements-test.txt,target=requirements-test.txt \ + uv pip install \ + -r requirements-cuda.txt \ + -r requirements-dev.txt + +## Builder ##################################################################### +FROM dev AS build + +# install build dependencies +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,source=requirements-build.txt,target=requirements-build.txt \ + uv pip install -r requirements-build.txt + +# install compiler cache to speed up compilation leveraging local or remote caching +# git is required for the cutlass kernels +RUN rpm -ivh https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm && rpm -ql epel-release && microdnf install -y --nodocs git ccache && microdnf clean all + +COPY . . + +ARG TORCH_CUDA_ARCH_LIST +ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST +ARG vllm_fa_cmake_gpu_arches +ENV VLLM_FA_CMAKE_GPU_ARCHES=${vllm_fa_cmake_gpu_arches} + +# max jobs used by Ninja to build extensions +ARG max_jobs=2 +ENV MAX_JOBS=${max_jobs} +# number of threads used by nvcc +ARG nvcc_threads=8 +ENV NVCC_THREADS=$nvcc_threads +# make sure punica kernels are built (for LoRA) +ENV VLLM_INSTALL_PUNICA_KERNELS=1 + +# Make sure the cuda environment is in the PATH +ENV PATH=/usr/local/cuda/bin:$PATH + +ENV CCACHE_DIR=/root/.cache/ccache +RUN --mount=type=cache,target=/root/.cache/ccache \ + --mount=type=cache,target=/root/.cache/pip \ + --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,src=.git,target=/workspace/.git \ + env CFLAGS="-march=haswell" \ + CXXFLAGS="$CFLAGS $CXXFLAGS" \ + CMAKE_BUILD_TYPE=Release \ + python3 setup.py bdist_wheel --dist-dir=dist #################### libsodium Build IMAGE #################### -FROM base AS libsodium-builder +FROM base as libsodium-builder RUN microdnf install -y --nodocs gcc gzip \ && microdnf clean all @@ -98,32 +157,24 @@ ENV LD_LIBRARY_PATH="${VIRTUAL_ENV}/lib/python${PYTHON_VERSION}/site-packages/nv ENV LD_LIBRARY_PATH="${VIRTUAL_ENV}/lib/python${PYTHON_VERSION}/site-packages/nvidia/nvtx/lib:${LD_LIBRARY_PATH}" # Triton needs a CC compiler - RUN microdnf install -y --nodocs gcc \ rsync \ && microdnf clean all +# install vllm wheel first, so that torch etc will be installed +RUN --mount=type=bind,from=build,src=/workspace/dist,target=/workspace/dist \ + --mount=type=cache,target=/root/.cache/pip \ + --mount=type=cache,target=/root/.cache/uv \ + uv pip install "$(echo dist/*.whl)[audio,video,tensorizer]" --verbose # Install libsodium for Tensorizer encryption RUN --mount=type=bind,from=libsodium-builder,src=/usr/src/libsodium,target=/usr/src/libsodium \ make -C /usr/src/libsodium install -COPY LICENSE /licenses/vllm.md -COPY examples/*.jinja /app/data/template/ - -# install vllm by running the payload script and then install flashinfer - -ARG VLLM_WHEEL_VERSION -ARG VLLM_WHEEL_INDEX -ARG FLASHINFER_VERSION -RUN --mount=type=cache,target=/root/.cache/uv \ - --mount=type=bind,src=payload,target=/workspace/payload \ - --mount=type=secret,id=rhel-ai-private-index-auth/BOT_PAT \ - env BOT_PAT=$(cat /run/secrets/rhel-ai-private-index-auth/BOT_PAT) \ - VLLM_WHEEL_VERSION=${VLLM_VERSION} \ - VLLM_WHEEL_INDEX=${VLLM_WHEEL_INDEX} \ - ./payload/run.sh && \ - uv pip install "${FLASHINFER_VERSION}" +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=cache,target=/root/.cache/uv \ + uv pip install \ + "https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.0.post2/flashinfer_python-0.2.0.post2+cu124torch2.5-cp312-cp312-linux_x86_64.whl" ENV HF_HUB_OFFLINE=1 \ HOME=/home/vllm \ @@ -148,32 +199,26 @@ ENV HF_HUB_OFFLINE=1 \ RUN umask 002 && \ useradd --uid 2000 --gid 0 vllm && \ mkdir -p /home/vllm && \ + chown vllm:vllm /home/vllm && \ chmod g+rwx /home/vllm +COPY LICENSE /licenses/vllm.md +COPY examples/*.jinja /app/data/template/ + USER 2000 WORKDIR /home/vllm ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] -## TGIS Adapter layer ##################################################################### -FROM vllm-openai AS vllm-grpc-adapter +FROM vllm-openai as vllm-grpc-adapter USER root -ARG VLLM_TGIS_ADAPTER_VERSION -RUN --mount=type=cache,target=/root/.cache/uv \ - --mount=type=bind,src=payload,target=/workspace/payload \ - --mount=type=secret,id=rhel-ai-private-index-auth/BOT_PAT \ - cd /workspace && \ - ls && \ - env HOME=/root \ - BOT_PAT=$(cat /run/secrets/rhel-ai-private-index-auth/BOT_PAT) \ - VLLM_WHEEL_VERSION=${VLLM_VERSION} \ - VLLM_TGIS_ADAPTER_VERSION=${VLLM_TGIS_ADAPTER_VERSION} \ - VLLM_WHEEL_INDEX=${VLLM_WHEEL_INDEX} \ - ./payload/run.sh - +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,from=build,src=/workspace/dist,target=/workspace/dist \ + HOME=/root uv pip install "$(echo /workspace/dist/*.whl)[tensorizer]" vllm-tgis-adapter==0.6.3 ENV GRPC_PORT=8033 \ PORT=8000 \ diff --git a/argfile.konflux b/argfile.konflux deleted file mode 100644 index 3d24e5066ff7..000000000000 --- a/argfile.konflux +++ /dev/null @@ -1,7 +0,0 @@ -BASE_UBI_IMAGE_TAG=9.5-1739420147 -PYTHON_VERSION=3.11 -LIBSODIUM_VERSION=1.0.20 -VLLM_TGIS_ADAPTER_VERSION=0.6.3 -FLASHINFER_VERSION=https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post1/flashinfer_python-0.2.1.post1+cu124torch2.5-cp38-abi3-linux_x86_64.whl -VLLM_WHEEL_VERSION=0.7.2 -VLLM_WHEEL_INDEX=https://gitlab.com/api/v4/projects/66664052/packages/pypi/simple diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 88275dbdd83a..b3aa5b74e00b 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -4,6 +4,8 @@ #include +#include "core/math.hpp" + #include "cuda_compat.h" #include "dispatch_utils.h" @@ -31,6 +33,69 @@ __global__ void act_and_mul_kernel( } } +// NOTE: temporary vectorized version. + +template +__global__ void act_and_mul_kernel_vectorized( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const int d) { + const int64_t token_idx = blockIdx.x; + const int32_t blocks_per_token = gridDim.y; + + const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t); + + const int32_t tgt_elems_per_block = ceil_div(d, blocks_per_token); + const int32_t elems_per_block = + next_multiple_of(elems_per_128bit_load, tgt_elems_per_block); + const int64_t block_start = blockIdx.y * int64_t(elems_per_block); + int64_t block_end = block_start + elems_per_block; + block_end = block_end > d ? d : block_end; + + const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d; + const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d; + scalar_t* __restrict__ out_ptr = out + token_idx * d; + + // 128-bit vectorized code + const int32_t vec_loop_end = + prev_multiple_of(elems_per_128bit_load, block_end); + const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load; + const int32_t vec_start_idx = block_start / elems_per_128bit_load; + + const int4* __restrict__ x_128bit_ptr = reinterpret_cast(x_ptr); + const int4* __restrict__ y_128bit_ptr = reinterpret_cast(y_ptr); + int4* __restrict__ out_128bit_ptr = reinterpret_cast(out_ptr); + +#pragma unroll + for (int32_t vec_idx = vec_start_idx + threadIdx.x; vec_idx < vec_end_idx; + vec_idx += blockDim.x) { + const int4 x_128bit = VLLM_LDG(&x_128bit_ptr[vec_idx]); + const int4 y_128bit = VLLM_LDG(&y_128bit_ptr[vec_idx]); + using scalar_128bit_vec_t = std::array; + + scalar_128bit_vec_t out_vec; + const auto x_vec = reinterpret_cast(x_128bit); + const auto y_vec = reinterpret_cast(y_128bit); + +#pragma unroll + for (int i = 0; i < elems_per_128bit_load; i++) { + out_vec[i] = ACT_FN(x_vec[i]) * y_vec[i]; + } + + out_128bit_ptr[vec_idx] = reinterpret_cast(out_vec); + } + + // Scalar cleanup code + if (block_end > vec_loop_end) { + for (int64_t idx = vec_loop_end + threadIdx.x; idx < block_end; + idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&x_ptr[idx]); + const scalar_t y = VLLM_LDG(&y_ptr[idx]); + out_ptr[idx] = ACT_FN(x) * y; + } + } +} + template __device__ __forceinline__ T silu_kernel(const T& x) { // x * sigmoid(x) @@ -79,10 +144,26 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { input.data_ptr(), d); \ }); +// Launch activation and gating kernel. +// Vectorized Version +#define LAUNCH_ACTIVATION_GATE_KERNEL_VECTORIZED(KERNEL) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \ + dim3 block(std::min(d, 512)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "act_and_mul_kernel_vectorized", [&] { \ + vllm::act_and_mul_kernel_vectorized> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); + void silu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { - LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true); + LAUNCH_ACTIVATION_GATE_KERNEL_VECTORIZED(vllm::silu_kernel); } void mul_and_silu(torch::Tensor& out, // [..., d] diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp index ddfaca27147b..2cc05960d5cd 100644 --- a/csrc/core/math.hpp +++ b/csrc/core/math.hpp @@ -11,4 +11,16 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) { template inline constexpr std::enable_if_t, T> ceil_div(T a, T b) { return (a + b - 1) / b; -} \ No newline at end of file +} + +// Compute the next multiple of a that is greater than or equal to b +template +static inline constexpr auto next_multiple_of(A a, B b) { + return ceil_div(b, a) * a; +} + +// Compute the largest multiple of a that is less than or equal to b +template +static inline constexpr auto prev_multiple_of(A a, B b) { + return (b / a) * a; +} diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh index 32ea5db3321b..5689f1b6e64c 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh @@ -12,52 +12,248 @@ namespace vllm { using c3x::cutlass_gemm_caller; -template typename Epilogue> -struct sm90_fp8_config_default { - // M in (128, inf) - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; +#define CALL_CUTLASS_GEMM \ + cutlass_gemm_caller< \ + cutlass_3x_gemm>( \ + out, a, b, std::forward(args)...); + +struct sm90_fp8_config_M64 { + // M in [1, 64] + using ClusterShape = Shape<_1, _8, _1>; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + + template typename Epilogue, + typename... EpilogueArgs> + static void dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + + uint32_t const n = out.size(1); + + if (n < 8 * 1024) { + using TileShape = Shape<_64, _64, _128>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + CALL_CUTLASS_GEMM + + } else if (n < 16 * 1024) { + using TileShape = Shape<_64, _128, _128>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + CALL_CUTLASS_GEMM + + } else { + using TileShape = Shape<_64, _128, _128>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + CALL_CUTLASS_GEMM + } + } }; -template typename Epilogue> struct sm90_fp8_config_M128 { // M in (64, 128] - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + + template typename Epilogue, + typename... EpilogueArgs> + static void dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + + uint32_t const n = out.size(1); + + if (n <= 4 * 1024) { + using TileShape = Shape<_64, _64, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + + CALL_CUTLASS_GEMM + + } else if (n <= 8 * 1024) { + using TileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + CALL_CUTLASS_GEMM + + } else if (n <= 16 * 1024) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _8, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + + } else if (n <= 24 * 1024) { + using TileShape = Shape<_128, _64, _128>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + + } else { + using TileShape = Shape<_128, _64, _128>; + using ClusterShape = Shape<_1, _8, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + } + } }; -template typename Epilogue> -struct sm90_fp8_config_M64 { - // M in [1, 64] - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _128>; - using ClusterShape = Shape<_1, _8, _1>; +struct sm90_fp8_config_M256 { + // M in (128, 256] + + template typename Epilogue, + typename... EpilogueArgs> + static void dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + + uint32_t const n = out.size(1); + + if (n <= 4 * 1024) { + using TileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + CALL_CUTLASS_GEMM + + } else if (n <= 8 * 1024) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + + } else if (n <= 16 * 1024) { + using TileShape = Shape<_128, _256, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + CALL_CUTLASS_GEMM + + } else if (n <= 24 * 1024) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + + } else { + using TileShape = Shape<_256, _128, _64>; + using ClusterShape = Shape<_1, _8, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + CALL_CUTLASS_GEMM + } + } +}; + +struct sm90_fp8_config_M3072 { + // M in (256, 3072] + + template typename Epilogue, + typename... EpilogueArgs> + static void dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + + uint32_t const n = out.size(1); + + if (n <= 4 * 1024) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM - using Cutlass3xGemm = - cutlass_3x_gemm; + } else if (n <= 8 * 1024) { + using TileShape = Shape<_128, _256, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + CALL_CUTLASS_GEMM + + } else if (n <= 16 * 1024) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + + } else if (n <= 24 * 1024) { + using TileShape = Shape<_128, _256, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + CALL_CUTLASS_GEMM + + } else { + using TileShape = Shape<_64, _256, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + } + } +}; + +struct sm90_fp8_config_default { + // M in (3072, inf) + + template typename Epilogue, + typename... EpilogueArgs> + static void dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + + static_assert(std::is_same()); + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + } }; +#undef CALL_CUTLASS_GEMM + template typename Epilogue, typename... EpilogueArgs> @@ -69,29 +265,27 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); - using Cutlass3xGemmDefault = - typename sm90_fp8_config_default::Cutlass3xGemm; - using Cutlass3xGemmM64 = - typename sm90_fp8_config_M64::Cutlass3xGemm; - using Cutlass3xGemmM128 = - typename sm90_fp8_config_M128::Cutlass3xGemm; - uint32_t const m = a.size(0); - uint32_t const mp2 = - std::max(static_cast(64), next_pow_2(m)); // next power of 2 - if (mp2 <= 64) { + if (m <= 64) { // m in [1, 64] - return cutlass_gemm_caller( + return sm90_fp8_config_M64::dispatch( out, a, b, std::forward(args)...); - } else if (mp2 <= 128) { + } else if (m <= 128) { // m in (64, 128] - return cutlass_gemm_caller( + return sm90_fp8_config_M128::dispatch( + out, a, b, std::forward(args)...); + } else if (m <= 256) { + // m in (128, 256] + return sm90_fp8_config_M256::dispatch( + out, a, b, std::forward(args)...); + } else if (m <= 3072) { + // m in (256, 3072] + return sm90_fp8_config_M3072::dispatch( out, a, b, std::forward(args)...); } else { - // m in (128, inf) - return cutlass_gemm_caller( + // m in (3072, inf] + return sm90_fp8_config_default::dispatch( out, a, b, std::forward(args)...); } } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index ce7cf2f35282..b1bdd7dff173 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -71,12 +71,19 @@ struct enable_sm89_to_sm90 : Kernel { #endif } }; + +using GemmUniversalMode = cutlass::gemm::GemmUniversalMode; + template typename ArchGuard, typename ElementAB_, typename ElementD_, template typename Epilogue_, typename TileShape, typename WarpShape, typename InstructionShape, int32_t MainLoopStages, - typename FP8MathOperator = cutlass::arch::OpMultiplyAdd> + typename FP8MathOperator = cutlass::arch::OpMultiplyAdd, + typename ThreadblockSwizzle = + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, + GemmUniversalMode Mode_ = GemmUniversalMode::kGemm> struct cutlass_2x_gemm { + static const GemmUniversalMode Mode = Mode_; using ElementAB = ElementAB_; using ElementD = ElementD_; @@ -120,7 +127,7 @@ struct cutlass_2x_gemm { Arch, TileShape, WarpShape, InstructionShape, EVTD, - cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, + ThreadblockSwizzle, MainLoopStages, Operator, 1 /* epilogue stages */ >::GemmKernel>; diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_configs.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_configs.cuh new file mode 100644 index 000000000000..2e229787122c --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_configs.cuh @@ -0,0 +1,378 @@ +template typename Epilogue> +struct sm89_fp8_config_0 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 4; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_1 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 4; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_2 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 5; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_3 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<32, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 4; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_4 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 4; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_5 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<16, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 3; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_6 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<16, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 3; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = typename cutlass::gemm::threadblock:: + GemmSplitKHorizontalThreadblockSwizzle; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_7 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<16, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 2; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = typename cutlass::gemm::threadblock:: + GemmSplitKHorizontalThreadblockSwizzle; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_8 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<16, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 3; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_9 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 2; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = typename cutlass::gemm::threadblock:: + GemmSplitKHorizontalThreadblockSwizzle; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_10 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 4; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_11 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 3; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_12 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 4; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = typename cutlass::gemm::threadblock:: + GemmSplitKHorizontalThreadblockSwizzle; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_13 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 4; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle< + 1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_14 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<32, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 3; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_15 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 3; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_16 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 3; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemm; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_17 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 2; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = typename cutlass::gemm::threadblock:: + GemmSplitKHorizontalThreadblockSwizzle; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh index 4e82c99c3af3..a96b9e594a36 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh @@ -3,6 +3,8 @@ #include "scaled_mm_c2x.cuh" #include "cutlass/float8.h" +#include "scaled_mm_c2x_sm89_fp8_configs.cuh" + /** * This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm * shape. @@ -334,7 +336,551 @@ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out, TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); - uint32_t const m = a.size(0); + uint32_t const m = out.size(0); + uint32_t const n = out.size(1); + uint32_t const k = b.size(0); + + if (m == 1) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_1::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 16) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller< + typename sm89_fp8_config_3::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 32) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_4::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_4::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_4::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_4::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller< + typename sm89_fp8_config_3::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_3::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller< + typename sm89_fp8_config_3::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_5::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_6::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_7::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 64) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_4::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_1::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_8::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_8::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_9::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_8::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 128) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_1::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 256) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 512) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 1024) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 2048) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 4096) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else { // m512 kernels + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } + uint32_t const mp2 = std::max(static_cast(32), next_pow_2(m)); // next power of 2 diff --git a/examples/run_lm_eval.sh b/examples/run_lm_eval.sh new file mode 100644 index 000000000000..db36fc3755a7 --- /dev/null +++ b/examples/run_lm_eval.sh @@ -0,0 +1,2 @@ +MODEL=neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic +lm_eval --model vllm --model_args "pretrained=$MODEL" --tasks gsm8k --batch_size "auto" diff --git a/payload/run.sh b/payload/run.sh deleted file mode 100755 index 7d0c62fe3e16..000000000000 --- a/payload/run.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash -# required env vars: -# $BOT_PAT -# $WHEEL_RELEASE_ARTIFACTS -# optional: -# $VLLM_TGIS_ADAPTER_VERSION -# $VLLM_WHEEL_VERSION -set -ex - -cat < ${HOME}/.netrc -machine gitlab.com -login rhel-ai-wheels-prefetch-token-rhoai -password $BOT_PAT -EOF - -trap "rm ${HOME}/.netrc" EXIT - -# https://docs.astral.sh/uv/configuration/indexes/#searching-across-multiple-indexes -# This will prefer to use the custom index, and fall back to pypi if needed -export UV_EXTRA_INDEX_URL=${VLLM_WHEEL_INDEX} -export UV_INDEX_STRATEGY=unsafe-first-match - -vllm="vllm[tensorizer,audio,video]" - -if [[ -n "$VLLM_TGIS_ADAPTER_VERSION" ]]; then - vllm_tgis_adapter="vllm-tgis-adapter==${VLLM_TGIS_ADAPTER_VERSION}" -fi - -if [[ -n "$VLLM_WHEEL_VERSION" ]]; then - vllm="${vllm}==${$VLLM_WHEEL_VERSION}" -fi - -uv pip install $vllm $vllm_tgis_adapter - diff --git a/requirements-common.txt b/requirements-common.txt index b7c94cbdba8b..f13009abd15a 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -8,8 +8,7 @@ py-cpuinfo transformers >= 4.48.2 # Required for Bamba model and Transformers backend. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. -fastapi[standard] >= 0.107.0, < 0.113.0; python_version < '3.9' -fastapi[standard] >= 0.107.0, != 0.113.*, != 0.114.0; python_version >= '3.9' +fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. aiohttp openai >= 1.52.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support) pydantic >= 2.9 diff --git a/setup.py b/setup.py index d8a336c2d426..1d6d6b0b7008 100755 --- a/setup.py +++ b/setup.py @@ -15,7 +15,6 @@ from packaging.version import Version, parse from setuptools import Extension, setup from setuptools.command.build_ext import build_ext -from setuptools_scm import get_version from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME @@ -499,7 +498,8 @@ def get_gaudi_sw_version(): def get_vllm_version() -> str: - version = get_version(write_to="vllm/_version.py") + version = "0.7.3" + sep = "+" if "+" not in version else "." # dev versions might contain + if _no_device(): diff --git a/tests/accuracy/test_lm_eval_correctness.py b/tests/accuracy/test_lm_eval_correctness.py new file mode 100644 index 000000000000..7f3046efdd17 --- /dev/null +++ b/tests/accuracy/test_lm_eval_correctness.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +import itertools +import os +from pathlib import Path +from typing import TYPE_CHECKING + +import nltk +import numpy +import pandas as pd +import pytest +import yaml + +if TYPE_CHECKING: + import lm_eval as lm_eval_t + +# requires a particular lm-evaluation-harness +# pip install lm_eval==0.4.3 +lm_eval: "lm_eval_t" = pytest.importorskip("lm_eval", + reason="lm_eval required") + +MAX_MODEL_LEN = 4096 +RTOL = 0.040 +TEST_DATA_PATH = os.environ.get( + "LM_EVAL_TEST_DATA_FILE", + "../neuralmagic/lm-eval-configs/models/Meta-Llama-3-8B-Instruct.yaml") +# just show the test data file from the `neuralmagic/lm-eval-configs/models` +# directory. this could be a `model.yaml`, or a `leaderboard/model.yaml` +TEST_DATA_FILE = str(Path(TEST_DATA_PATH)).replace( + str(Path.cwd() / "../neuralmagic/lm-eval-configs/models"), "") + + +def launch_lm_eval(eval_config, tp_size): + model_args = { + "pretrained": eval_config['model_name'], + } + eval_config_model_args = eval_config.get('model_args') + if eval_config_model_args: + model_args.update(eval_config_model_args) + + model_backend = eval_config.get("backend", "vllm") + + if model_backend == "vllm": + model_args.update({ + "tensor_parallel_size": tp_size, + "distributed_executor_backend": "ray", + "max_model_len": MAX_MODEL_LEN + }) + + evaluate_args = { + "model": model_backend, + "model_args": ",".join([f"{k}={v}" for k, v in model_args.items()]), + "tasks": [task["name"] for task in eval_config["tasks"]], + "num_fewshot": eval_config["num_fewshot"], + "batch_size": "auto" + } + if "limit" in eval_config: + evaluate_args["limit"] = eval_config["limit"] + if "fewshot_as_multiturn" in eval_config: + evaluate_args["fewshot_as_multiturn"] = eval_config[ + "fewshot_as_multiturn"] + if "apply_chat_template" in eval_config: + evaluate_args["apply_chat_template"] = eval_config[ + "apply_chat_template"] + + simple_eval_args = ['{}={}'.format(k, v) for k, v in evaluate_args.items()] + print(f"lm_eval.simple_evaluate({', '.join(simple_eval_args)}") + results = lm_eval.simple_evaluate(**evaluate_args) + + return results + + +# pass the TEST_DATA_FILE in as a parameter so that the results +# are uniquely reported to TestMo +@pytest.mark.parametrize("test_data_file", [TEST_DATA_FILE]) +def test_lm_eval_correctness(num_gpus_available, test_data_file): + eval_config = yaml.safe_load( + Path(TEST_DATA_PATH).read_text(encoding="utf-8")) + eval_config_tasks = { + t['name']: { + m['name']: m['value'] + for m in t['metrics'] + } + for t in eval_config["tasks"] + } + # identify unique metrics we wish to report on. + eval_config_metrics = set( + itertools.chain.from_iterable([ + metric.keys() for metric in + [eval_config_tasks[task] for task in eval_config_tasks] + ])) + + # retrieve the ground truth values from the evaluation config + # we transpose the info into a set of records indexed by + # a "task" and "metric". The `dropna()` is necessary to remove extra + # rows where there is no ground truth value for the "task" and "metric" + ground_truth_df = pd.DataFrame.from_records( + eval_config_tasks, index=eval_config_metrics).transpose() + gt_listing_df = ground_truth_df.reset_index(names="task").melt( + id_vars="task", var_name="metric", + value_name="ground_truth").dropna().set_index(["task", "metric"]) + + # the ifeval task requires an additional set of data + if "leaderboard_ifeval" in [task["name"] for task in eval_config["tasks"]]: + nltk.download('punkt_tab') + + # Launch eval requests. + results = launch_lm_eval(eval_config, tp_size=num_gpus_available) + + # process the results into a dataframe that looks like the ground truth + # with records indexed by "task" and "metric", but with the measured value + # for each index. + results_df = pd.DataFrame.from_records( + results["results"], index=eval_config_metrics).transpose() + r_listing_df = (results_df.reset_index(names="task").melt( + id_vars="task", var_name="metric", + value_name="measured").dropna().set_index(["task", "metric"])) + + # present the results + # combine the ground truth and results into a single dataframe + # but eliminate any rows that do not have both values + # (This could happen if the eval_config includes a measure that's not + # generated, or if the LM Evaluation harness generates a measure that + # was not requested by the eval_config.) + comparing_metrics_df = pd.concat( + [gt_listing_df, r_listing_df], + axis="columns").reset_index(names=["task", "metric"]).dropna() + + # Add a column with the relative tolerance level for the task + task_rtol_map = { + t["name"]: t.get("rtol", RTOL) + for t in eval_config["tasks"] + } + comparing_metrics_df.loc[:, "rtol"] = comparing_metrics_df.apply( + lambda metric: task_rtol_map[metric.task], axis=1) + + # and determine if measured is close to ground truth + comparing_metrics_df.loc[:, "isclose"] = comparing_metrics_df.apply( + lambda metric: numpy.isclose( + metric.ground_truth, metric.measured, rtol=metric.rtol), + axis=1) + print("==== LM EVAL RESULT ====\n") + comparing_metrics_df.sort_values(by=["task", "metric"], inplace=True) + print(comparing_metrics_df.to_markdown(index=False)) + + # save the results for later summary + llm_results_md = Path("llmeval_results-" + + TEST_DATA_FILE.replace("/", "-")).with_suffix(".md") + llm_results_md.write_text( + f"## {eval_config['model_name']}\n" + f"{comparing_metrics_df.to_markdown(index=False)}\n") + + # fail if any scores fail to match ground truth. + assert comparing_metrics_df.loc[:, "isclose"].all() diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 98ea6a46133f..0c88cd7ac220 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1440,7 +1440,7 @@ class UnloadLoraAdapterRequest(BaseModel): class TranscriptionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation - #https://platform.openai.com/docs/api-reference/audio/createTranscription + # https://platform.openai.com/docs/api-reference/audio/createTranscription file: UploadFile """ diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 0f4cb253258f..073a30d25e23 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -598,15 +598,11 @@ def input_processor_for_whisper(ctx: InputContext, inputs): audio, orig_sr = multi_modal_data["audio"] processor = cached_processor_from_config(ctx.model_config) target_sr = processor.feature_extractor.sampling_rate - # NOTE: resampling is expensive, so skip it if the audio data - # sent to the Engine is already in Whisper's SAMPLE_RATE=16000. - if orig_sr != target_sr: - audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr) + audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr) multi_modal_data["audio"] = (audio, target_sr) # Pre-allocate placeholder tokens in encoder sequence num_tokens = get_max_whisper_audio_tokens(ctx) inputs["encoder"]["prompt_token_ids"] = [0] * num_tokens - return inputs @@ -627,9 +623,6 @@ def input_mapper_for_whisper( audios = [audio for audio, _ in multi_modal_data] - # 1) Pad out with empty audio to N_SAMPLES=480000 (30s * SAMPLE_RATE) - # 2) Apply log_mel_spectrogram to padded (N_MEL_FILTERS=128, N_FRAMES=3000) - # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py#L175 # noqa: E501 kwargs = processor(audios, sampling_rate=sampling_rate, return_tensors="pt") diff --git a/vllm/version.py b/vllm/version.py index 70cd0289b441..1730e3651333 100644 --- a/vllm/version.py +++ b/vllm/version.py @@ -1,13 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -try: - from ._version import __version__, __version_tuple__ -except Exception as e: - import warnings - - warnings.warn(f"Failed to read commit hash:\n{e}", - RuntimeWarning, - stacklevel=2) - - __version__ = "dev" - __version_tuple__ = (0, 0, __version__) +__version__ = "0.7.3" +__version_tuple__ = (0, 7, 3)