From a27a49b1bc051744b2e2f03e5b2801da9566240e Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 10 Feb 2025 18:09:06 +0000 Subject: [PATCH 01/10] added version str --- setup.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/setup.py b/setup.py index a4043c43a7d5..a35ededc3161 100755 --- a/setup.py +++ b/setup.py @@ -479,9 +479,7 @@ def get_gaudi_sw_version(): def get_vllm_version() -> str: - version = get_version( - write_to="vllm/_version.py", # TODO: move this to pyproject.toml - ) + version = "0.7.2.0" sep = "+" if "+" not in version else "." # dev versions might contain + From e6132b82f84bf7489fa60ca8a53bed4cf34447f3 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Mon, 10 Feb 2025 17:08:08 -0500 Subject: [PATCH 02/10] [Kernel] Vectorized SiluAndMul (#182) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SUMMARY: * add vectorized `SiluAndMul` kernel. ```bash MODEL=neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic lm_eval --model vllm --model_args "pretrained=$MODEL" --tasks gsm8k --batch_size "auto" >> |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| >> |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| >> |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.7650|± |0.0117| >> | | |strict-match | 5|exact_match|↑ |0.7377|± |0.0121| ``` - this pr ```bash python3 benchmark_throughput.py --model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic --dataset ShareGPT_V3_unfiltered_cleaned_split.json >> Throughput: 53.50 requests/s, 22122.56 total tokens/s, 10610.50 output tokens/s ``` **BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE** ---
PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

Adding or changing kernels

Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

--- csrc/activation_kernels.cu | 83 +++++++++++++++++++++++++++++++++++++- csrc/core/math.hpp | 14 ++++++- examples/run_lm_eval.sh | 2 + setup.py | 1 - 4 files changed, 97 insertions(+), 3 deletions(-) create mode 100644 examples/run_lm_eval.sh diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 88275dbdd83a..b3aa5b74e00b 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -4,6 +4,8 @@ #include +#include "core/math.hpp" + #include "cuda_compat.h" #include "dispatch_utils.h" @@ -31,6 +33,69 @@ __global__ void act_and_mul_kernel( } } +// NOTE: temporary vectorized version. + +template +__global__ void act_and_mul_kernel_vectorized( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const int d) { + const int64_t token_idx = blockIdx.x; + const int32_t blocks_per_token = gridDim.y; + + const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t); + + const int32_t tgt_elems_per_block = ceil_div(d, blocks_per_token); + const int32_t elems_per_block = + next_multiple_of(elems_per_128bit_load, tgt_elems_per_block); + const int64_t block_start = blockIdx.y * int64_t(elems_per_block); + int64_t block_end = block_start + elems_per_block; + block_end = block_end > d ? d : block_end; + + const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d; + const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d; + scalar_t* __restrict__ out_ptr = out + token_idx * d; + + // 128-bit vectorized code + const int32_t vec_loop_end = + prev_multiple_of(elems_per_128bit_load, block_end); + const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load; + const int32_t vec_start_idx = block_start / elems_per_128bit_load; + + const int4* __restrict__ x_128bit_ptr = reinterpret_cast(x_ptr); + const int4* __restrict__ y_128bit_ptr = reinterpret_cast(y_ptr); + int4* __restrict__ out_128bit_ptr = reinterpret_cast(out_ptr); + +#pragma unroll + for (int32_t vec_idx = vec_start_idx + threadIdx.x; vec_idx < vec_end_idx; + vec_idx += blockDim.x) { + const int4 x_128bit = VLLM_LDG(&x_128bit_ptr[vec_idx]); + const int4 y_128bit = VLLM_LDG(&y_128bit_ptr[vec_idx]); + using scalar_128bit_vec_t = std::array; + + scalar_128bit_vec_t out_vec; + const auto x_vec = reinterpret_cast(x_128bit); + const auto y_vec = reinterpret_cast(y_128bit); + +#pragma unroll + for (int i = 0; i < elems_per_128bit_load; i++) { + out_vec[i] = ACT_FN(x_vec[i]) * y_vec[i]; + } + + out_128bit_ptr[vec_idx] = reinterpret_cast(out_vec); + } + + // Scalar cleanup code + if (block_end > vec_loop_end) { + for (int64_t idx = vec_loop_end + threadIdx.x; idx < block_end; + idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&x_ptr[idx]); + const scalar_t y = VLLM_LDG(&y_ptr[idx]); + out_ptr[idx] = ACT_FN(x) * y; + } + } +} + template __device__ __forceinline__ T silu_kernel(const T& x) { // x * sigmoid(x) @@ -79,10 +144,26 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { input.data_ptr(), d); \ }); +// Launch activation and gating kernel. +// Vectorized Version +#define LAUNCH_ACTIVATION_GATE_KERNEL_VECTORIZED(KERNEL) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \ + dim3 block(std::min(d, 512)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "act_and_mul_kernel_vectorized", [&] { \ + vllm::act_and_mul_kernel_vectorized> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); + void silu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { - LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true); + LAUNCH_ACTIVATION_GATE_KERNEL_VECTORIZED(vllm::silu_kernel); } void mul_and_silu(torch::Tensor& out, // [..., d] diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp index ddfaca27147b..2cc05960d5cd 100644 --- a/csrc/core/math.hpp +++ b/csrc/core/math.hpp @@ -11,4 +11,16 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) { template inline constexpr std::enable_if_t, T> ceil_div(T a, T b) { return (a + b - 1) / b; -} \ No newline at end of file +} + +// Compute the next multiple of a that is greater than or equal to b +template +static inline constexpr auto next_multiple_of(A a, B b) { + return ceil_div(b, a) * a; +} + +// Compute the largest multiple of a that is less than or equal to b +template +static inline constexpr auto prev_multiple_of(A a, B b) { + return (b / a) * a; +} diff --git a/examples/run_lm_eval.sh b/examples/run_lm_eval.sh new file mode 100644 index 000000000000..db36fc3755a7 --- /dev/null +++ b/examples/run_lm_eval.sh @@ -0,0 +1,2 @@ +MODEL=neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic +lm_eval --model vllm --model_args "pretrained=$MODEL" --tasks gsm8k --batch_size "auto" diff --git a/setup.py b/setup.py index a35ededc3161..380f37b0e100 100755 --- a/setup.py +++ b/setup.py @@ -15,7 +15,6 @@ from packaging.version import Version, parse from setuptools import Extension, find_packages, setup from setuptools.command.build_ext import build_ext -from setuptools_scm import get_version from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME From 8c9e2020dddf84b11d46391fdc8e99fcc024076d Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Mon, 10 Feb 2025 17:35:30 -0500 Subject: [PATCH 03/10] Fp8 Kernels c2x (#183) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SUMMARY: * Fp8 kernels with tuned configs ```bash MODEL=neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic lm_eval --model vllm --model_args "pretrained=$MODEL" --tasks gsm8k --batch_size "auto" >> vllm (pretrained=neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: auto >> |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| >> |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| >> |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.7680|± |0.0116| >> | | |strict-match | 5|exact_match|↑ |0.7407|± |0.0121| ``` - this branch ```bash MODEL=neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic VLLM_USE_V1=1 python3 benchmark_throughput.py --model $MODEL --dataset ShareGPT_V3_unfiltered_cleaned_split.json >> Throughput: 21.53 requests/s, 8902.71 total tokens/s, 4269.95 output tokens/s ``` - upstream wheel: ```bash MODEL=neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic VLLM_USE_V1=1 python3 benchmark_throughput.py --model $MODEL --dataset ShareGPT_V3_unfiltered_cleaned_split.json >> Throughput: 13.95 requests/s, 5770.85 total tokens/s, 2767.83 output tokens/s ``` **BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE** ---
PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Adding or changing kernels

Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.

  • Make sure custom ops are registered following PyTorch guidelines: Custom C++ and CUDA Operators and The Custom Operators Manual
  • Custom operations that return Tensors require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.
  • Use torch.libary.opcheck() to test the function registration and meta-function for any registered ops. See tests/kernels for examples.
  • When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.
  • If a new custom type is needed, see the following document: Custom Class Support in PT2.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

--------- Signed-off-by: rshaw@neuralmagic.com --- .../cutlass_w8a8/scaled_mm_c2x.cuh | 11 +- .../scaled_mm_c2x_sm89_fp8_configs.cuh | 378 ++++++++++++ .../scaled_mm_c2x_sm89_fp8_dispatch.cuh | 548 +++++++++++++++++- 3 files changed, 934 insertions(+), 3 deletions(-) create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_configs.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index f2fae4b66d65..fd0bb6ff2947 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -71,12 +71,19 @@ struct enable_sm89_to_sm90 : Kernel { #endif } }; + +using GemmUniversalMode = cutlass::gemm::GemmUniversalMode; + template typename ArchGuard, typename ElementAB_, typename ElementD_, template typename Epilogue_, typename TileShape, typename WarpShape, typename InstructionShape, int32_t MainLoopStages, - typename FP8MathOperator = cutlass::arch::OpMultiplyAdd> + typename FP8MathOperator = cutlass::arch::OpMultiplyAdd, + typename ThreadblockSwizzle = + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, + GemmUniversalMode Mode_ = GemmUniversalMode::kGemm> struct cutlass_2x_gemm { + static const GemmUniversalMode Mode = Mode_; using ElementAB = ElementAB_; using ElementD = ElementD_; @@ -115,7 +122,7 @@ struct cutlass_2x_gemm { Arch, TileShape, WarpShape, InstructionShape, EVTD, - cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, + ThreadblockSwizzle, MainLoopStages, Operator, 1 /* epilogue stages */ >::GemmKernel>; diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_configs.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_configs.cuh new file mode 100644 index 000000000000..2e229787122c --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_configs.cuh @@ -0,0 +1,378 @@ +template typename Epilogue> +struct sm89_fp8_config_0 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 4; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_1 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 4; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_2 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 5; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_3 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<32, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 4; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_4 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 4; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_5 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<16, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 3; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_6 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<16, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 3; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = typename cutlass::gemm::threadblock:: + GemmSplitKHorizontalThreadblockSwizzle; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_7 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<16, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 2; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = typename cutlass::gemm::threadblock:: + GemmSplitKHorizontalThreadblockSwizzle; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_8 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<16, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 3; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_9 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 2; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = typename cutlass::gemm::threadblock:: + GemmSplitKHorizontalThreadblockSwizzle; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_10 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 4; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_11 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 3; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_12 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 4; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = typename cutlass::gemm::threadblock:: + GemmSplitKHorizontalThreadblockSwizzle; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_13 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 4; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle< + 1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_14 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<32, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 3; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_15 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 3; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_16 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 3; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemm; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_17 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 2; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = typename cutlass::gemm::threadblock:: + GemmSplitKHorizontalThreadblockSwizzle; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh index 4e82c99c3af3..a96b9e594a36 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh @@ -3,6 +3,8 @@ #include "scaled_mm_c2x.cuh" #include "cutlass/float8.h" +#include "scaled_mm_c2x_sm89_fp8_configs.cuh" + /** * This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm * shape. @@ -334,7 +336,551 @@ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out, TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); - uint32_t const m = a.size(0); + uint32_t const m = out.size(0); + uint32_t const n = out.size(1); + uint32_t const k = b.size(0); + + if (m == 1) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_1::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 16) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller< + typename sm89_fp8_config_3::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 32) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_4::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_4::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_4::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_4::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller< + typename sm89_fp8_config_3::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_3::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller< + typename sm89_fp8_config_3::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_5::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_6::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_7::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 64) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_4::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_1::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_8::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_8::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_9::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_8::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 128) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_1::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 256) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 512) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 1024) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 2048) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 4096) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else { // m512 kernels + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } + uint32_t const mp2 = std::max(static_cast(32), next_pow_2(m)); // next power of 2 From db753c0db9a3398dd915b5f7a61f4bf5816c59d1 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Mon, 10 Feb 2025 21:33:35 -0500 Subject: [PATCH 04/10] [Kernel] Cutlass c3x Performance (#186) SUMMARY: * add cutlass c3x configs --------- Signed-off-by: rshaw@neuralmagic.com --- .../c3x/scaled_mm_sm90_fp8_dispatch.cuh | 300 ++++++++++++++---- 1 file changed, 247 insertions(+), 53 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh index 32ea5db3321b..5689f1b6e64c 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh @@ -12,52 +12,248 @@ namespace vllm { using c3x::cutlass_gemm_caller; -template typename Epilogue> -struct sm90_fp8_config_default { - // M in (128, inf) - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; +#define CALL_CUTLASS_GEMM \ + cutlass_gemm_caller< \ + cutlass_3x_gemm>( \ + out, a, b, std::forward(args)...); + +struct sm90_fp8_config_M64 { + // M in [1, 64] + using ClusterShape = Shape<_1, _8, _1>; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + + template typename Epilogue, + typename... EpilogueArgs> + static void dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + + uint32_t const n = out.size(1); + + if (n < 8 * 1024) { + using TileShape = Shape<_64, _64, _128>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + CALL_CUTLASS_GEMM + + } else if (n < 16 * 1024) { + using TileShape = Shape<_64, _128, _128>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + CALL_CUTLASS_GEMM + + } else { + using TileShape = Shape<_64, _128, _128>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + CALL_CUTLASS_GEMM + } + } }; -template typename Epilogue> struct sm90_fp8_config_M128 { // M in (64, 128] - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + + template typename Epilogue, + typename... EpilogueArgs> + static void dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + + uint32_t const n = out.size(1); + + if (n <= 4 * 1024) { + using TileShape = Shape<_64, _64, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + + CALL_CUTLASS_GEMM + + } else if (n <= 8 * 1024) { + using TileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + CALL_CUTLASS_GEMM + + } else if (n <= 16 * 1024) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _8, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + + } else if (n <= 24 * 1024) { + using TileShape = Shape<_128, _64, _128>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + + } else { + using TileShape = Shape<_128, _64, _128>; + using ClusterShape = Shape<_1, _8, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + } + } }; -template typename Epilogue> -struct sm90_fp8_config_M64 { - // M in [1, 64] - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _128>; - using ClusterShape = Shape<_1, _8, _1>; +struct sm90_fp8_config_M256 { + // M in (128, 256] + + template typename Epilogue, + typename... EpilogueArgs> + static void dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + + uint32_t const n = out.size(1); + + if (n <= 4 * 1024) { + using TileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + CALL_CUTLASS_GEMM + + } else if (n <= 8 * 1024) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + + } else if (n <= 16 * 1024) { + using TileShape = Shape<_128, _256, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + CALL_CUTLASS_GEMM + + } else if (n <= 24 * 1024) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + + } else { + using TileShape = Shape<_256, _128, _64>; + using ClusterShape = Shape<_1, _8, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + CALL_CUTLASS_GEMM + } + } +}; + +struct sm90_fp8_config_M3072 { + // M in (256, 3072] + + template typename Epilogue, + typename... EpilogueArgs> + static void dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + + uint32_t const n = out.size(1); + + if (n <= 4 * 1024) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM - using Cutlass3xGemm = - cutlass_3x_gemm; + } else if (n <= 8 * 1024) { + using TileShape = Shape<_128, _256, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + CALL_CUTLASS_GEMM + + } else if (n <= 16 * 1024) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + + } else if (n <= 24 * 1024) { + using TileShape = Shape<_128, _256, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + CALL_CUTLASS_GEMM + + } else { + using TileShape = Shape<_64, _256, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + } + } +}; + +struct sm90_fp8_config_default { + // M in (3072, inf) + + template typename Epilogue, + typename... EpilogueArgs> + static void dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + + static_assert(std::is_same()); + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + } }; +#undef CALL_CUTLASS_GEMM + template typename Epilogue, typename... EpilogueArgs> @@ -69,29 +265,27 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); - using Cutlass3xGemmDefault = - typename sm90_fp8_config_default::Cutlass3xGemm; - using Cutlass3xGemmM64 = - typename sm90_fp8_config_M64::Cutlass3xGemm; - using Cutlass3xGemmM128 = - typename sm90_fp8_config_M128::Cutlass3xGemm; - uint32_t const m = a.size(0); - uint32_t const mp2 = - std::max(static_cast(64), next_pow_2(m)); // next power of 2 - if (mp2 <= 64) { + if (m <= 64) { // m in [1, 64] - return cutlass_gemm_caller( + return sm90_fp8_config_M64::dispatch( out, a, b, std::forward(args)...); - } else if (mp2 <= 128) { + } else if (m <= 128) { // m in (64, 128] - return cutlass_gemm_caller( + return sm90_fp8_config_M128::dispatch( + out, a, b, std::forward(args)...); + } else if (m <= 256) { + // m in (128, 256] + return sm90_fp8_config_M256::dispatch( + out, a, b, std::forward(args)...); + } else if (m <= 3072) { + // m in (256, 3072] + return sm90_fp8_config_M3072::dispatch( out, a, b, std::forward(args)...); } else { - // m in (128, inf) - return cutlass_gemm_caller( + // m in (3072, inf] + return sm90_fp8_config_default::dispatch( out, a, b, std::forward(args)...); } } From 4e91c9d511eedafb5ef8f4d9084c3d43806c72a9 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 11 Feb 2025 04:06:19 +0000 Subject: [PATCH 05/10] added version --- vllm/version.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/version.py b/vllm/version.py index 70cd0289b441..10f87d243497 100644 --- a/vllm/version.py +++ b/vllm/version.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 try: - from ._version import __version__, __version_tuple__ + __version__ = "0.7.2.0" + __version_tuple__ = (0, 7, 2, 0) + except Exception as e: import warnings From 153c77bec4160480fb8c7db2786a395fe4df0a2c Mon Sep 17 00:00:00 2001 From: andy-neuma Date: Thu, 13 Feb 2025 09:24:39 -0500 Subject: [PATCH 06/10] add missing lm-eval file --- tests/accuracy/test_lm_eval_correctness.py | 150 +++++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 tests/accuracy/test_lm_eval_correctness.py diff --git a/tests/accuracy/test_lm_eval_correctness.py b/tests/accuracy/test_lm_eval_correctness.py new file mode 100644 index 000000000000..c3c5f8a04d8d --- /dev/null +++ b/tests/accuracy/test_lm_eval_correctness.py @@ -0,0 +1,150 @@ +import itertools +import os +from pathlib import Path +from typing import TYPE_CHECKING + +import nltk +import numpy +import pandas as pd +import pytest +import yaml + +if TYPE_CHECKING: + import lm_eval as lm_eval_t + +# requires a particular lm-evaluation-harness +# pip install lm_eval==0.4.3 +lm_eval: "lm_eval_t" = pytest.importorskip("lm_eval", + reason="lm_eval required") + +MAX_MODEL_LEN = 4096 +RTOL = 0.040 +TEST_DATA_PATH = os.environ.get( + "LM_EVAL_TEST_DATA_FILE", + "../neuralmagic/lm-eval-configs/models/Meta-Llama-3-8B-Instruct.yaml") +# just show the test data file from the `neuralmagic/lm-eval-configs/models` +# directory. this could be a `model.yaml`, or a `leaderboard/model.yaml` +TEST_DATA_FILE = str(Path(TEST_DATA_PATH)).replace( + str(Path.cwd() / "../neuralmagic/lm-eval-configs/models"), "") + + +def launch_lm_eval(eval_config, tp_size): + model_args = { + "pretrained": eval_config['model_name'], + } + eval_config_model_args = eval_config.get('model_args') + if eval_config_model_args: + model_args.update(eval_config_model_args) + + model_backend = eval_config.get("backend", "vllm") + + if model_backend == "vllm": + model_args.update({ + "tensor_parallel_size": tp_size, + "distributed_executor_backend": "ray", + "max_model_len": MAX_MODEL_LEN + }) + + evaluate_args = { + "model": model_backend, + "model_args": ",".join([f"{k}={v}" for k, v in model_args.items()]), + "tasks": [task["name"] for task in eval_config["tasks"]], + "num_fewshot": eval_config["num_fewshot"], + "batch_size": "auto" + } + if "limit" in eval_config: + evaluate_args["limit"] = eval_config["limit"] + if "fewshot_as_multiturn" in eval_config: + evaluate_args["fewshot_as_multiturn"] = eval_config[ + "fewshot_as_multiturn"] + if "apply_chat_template" in eval_config: + evaluate_args["apply_chat_template"] = eval_config[ + "apply_chat_template"] + + simple_eval_args = ['{}={}'.format(k, v) for k, v in evaluate_args.items()] + print(f"lm_eval.simple_evaluate({', '.join(simple_eval_args)}") + results = lm_eval.simple_evaluate(**evaluate_args) + + return results + + +# pass the TEST_DATA_FILE in as a parameter so that the results +# are uniquely reported to TestMo +@pytest.mark.parametrize("test_data_file", [TEST_DATA_FILE]) +def test_lm_eval_correctness(num_gpus_available, test_data_file): + eval_config = yaml.safe_load( + Path(TEST_DATA_PATH).read_text(encoding="utf-8")) + eval_config_tasks = { + t['name']: {m['name']: m['value'] + for m in t['metrics']} + for t in eval_config["tasks"] + } + # identify unique metrics we wish to report on. + eval_config_metrics = set( + itertools.chain.from_iterable([ + metric.keys() for metric in + [eval_config_tasks[task] for task in eval_config_tasks] + ])) + + # retrieve the ground truth values from the evaluation config + # we transpose the info into a set of records indexed by + # a "task" and "metric". The `dropna()` is necessary to remove extra + # rows where there is no ground truth value for the "task" and "metric" + ground_truth_df = pd.DataFrame.from_records( + eval_config_tasks, index=eval_config_metrics).transpose() + gt_listing_df = ground_truth_df.reset_index(names="task").melt( + id_vars="task", var_name="metric", + value_name="ground_truth").dropna().set_index(["task", "metric"]) + + # the ifeval task requires an additional set of data + if "leaderboard_ifeval" in [task["name"] for task in eval_config["tasks"]]: + nltk.download('punkt_tab') + + # Launch eval requests. + results = launch_lm_eval(eval_config, tp_size=num_gpus_available) + + # process the results into a dataframe that looks like the ground truth + # with records indexed by "task" and "metric", but with the measured value + # for each index. + results_df = pd.DataFrame.from_records( + results["results"], index=eval_config_metrics).transpose() + r_listing_df = (results_df.reset_index(names="task").melt( + id_vars="task", var_name="metric", + value_name="measured").dropna().set_index(["task", "metric"])) + + # present the results + # combine the ground truth and results into a single dataframe + # but eliminate any rows that do not have both values + # (This could happen if the eval_config includes a measure that's not + # generated, or if the LM Evaluation harness generates a measure that + # was not requested by the eval_config.) + comparing_metrics_df = pd.concat( + [gt_listing_df, r_listing_df], + axis="columns").reset_index(names=["task", "metric"]).dropna() + + # Add a column with the relative tolerance level for the task + task_rtol_map = { + t["name"]: t.get("rtol", RTOL) + for t in eval_config["tasks"] + } + comparing_metrics_df.loc[:, "rtol"] = comparing_metrics_df.apply( + lambda metric: task_rtol_map[metric.task], axis=1) + + # and determine if measured is close to ground truth + comparing_metrics_df.loc[:, "isclose"] = comparing_metrics_df.apply( + lambda metric: numpy.isclose( + metric.ground_truth, metric.measured, rtol=metric.rtol), + axis=1) + print("==== LM EVAL RESULT ====\n") + comparing_metrics_df.sort_values(by=["task", "metric"], inplace=True) + print(comparing_metrics_df.to_markdown(index=False)) + + # save the results for later summary + llm_results_md = Path("llmeval_results-" + + TEST_DATA_FILE.replace("/", "-")).with_suffix(".md") + llm_results_md.write_text( + f"## {eval_config['model_name']}\n" + f"{comparing_metrics_df.to_markdown(index=False)}\n") + + # fail if any scores fail to match ground truth. + assert comparing_metrics_df.loc[:, "isclose"].all() From 7c056e72cc915f3d0878f3ffd8b6d52ce8761b51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Thu, 13 Feb 2025 16:23:45 +0100 Subject: [PATCH 07/10] [Frontend] Add `/v1/audio/transcriptions` OpenAI API endpoint (#12909) --- .buildkite/test-pipeline.yaml | 10 +- .../serving/openai_compatible_server.md | 13 + .../openai_transcription_client.py | 23 ++ requirements-common.txt | 7 +- requirements-test.in | 1 + requirements-test.txt | 5 + .../openai/correctness/__init__.py | 0 .../test_lmeval.py} | 2 +- .../test_transcription_api_correctness.py | 166 ++++++++++ .../openai/test_transcription_validation.py | 122 +++++++ tests/test_config.py | 1 + vllm/assets/audio.py | 5 + vllm/config.py | 11 +- vllm/entrypoints/openai/api_server.py | 43 ++- vllm/entrypoints/openai/protocol.py | 163 +++++++++- vllm/entrypoints/openai/serving_engine.py | 6 +- .../openai/serving_transcription.py | 305 ++++++++++++++++++ vllm/model_executor/models/interfaces.py | 27 ++ vllm/model_executor/models/registry.py | 12 +- vllm/model_executor/models/whisper.py | 5 +- 20 files changed, 909 insertions(+), 18 deletions(-) create mode 100644 examples/online_serving/openai_transcription_client.py create mode 100644 tests/entrypoints/openai/correctness/__init__.py rename tests/entrypoints/openai/{test_accuracy.py => correctness/test_lmeval.py} (98%) create mode 100644 tests/entrypoints/openai/correctness/test_transcription_api_correctness.py create mode 100644 tests/entrypoints/openai/test_transcription_validation.py create mode 100644 vllm/entrypoints/openai/serving_transcription.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 7ef40564c5bd..ea71e8ac0d2f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -113,7 +113,7 @@ steps: - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process - pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process - - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/correctness/ - pytest -v -s entrypoints/test_chat_utils.py - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests @@ -328,6 +328,14 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - bash ./run-tests.sh -c configs/models-small.txt -t 1 +- label: OpenAI API correctness + source_file_dependencies: + - csrc/ + - vllm/entrypoints/openai/ + - vllm/model_executor/models/whisper.py + commands: # LMEval+Transcription WER check + - pytest -s entrypoints/openai/correctness/ + - label: Encoder Decoder tests # 5min source_file_dependencies: - vllm/ diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 82ef54c16daf..64439475fdb5 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -41,6 +41,8 @@ We currently support the following OpenAI APIs: - *Note: `parallel_tool_calls` and `user` parameters are ignored.* - [Embeddings API](#embeddings-api) (`/v1/embeddings`) - Only applicable to [embedding models](../models/pooling_models.md) (`--task embed`). +- [Transcriptions API](#transcriptions-api) (`/v1/audio/transcriptions`) + - Only applicable to Automatic Speech Recognition (ASR) models (OpenAI Whisper) (`--task generate`). In addition, we have the following custom APIs: @@ -296,6 +298,17 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s :end-before: end-chat-embedding-extra-params ::: +(transcriptions-api)= + +### Transcriptions API + +Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription); +you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it. + + + +Code example: + (tokenizer-api)= ### Tokenizer API diff --git a/examples/online_serving/openai_transcription_client.py b/examples/online_serving/openai_transcription_client.py new file mode 100644 index 000000000000..bd3c02a8a95e --- /dev/null +++ b/examples/online_serving/openai_transcription_client.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +from openai import OpenAI + +from vllm.assets.audio import AudioAsset + +mary_had_lamb = AudioAsset('mary_had_lamb').get_local_path() +winning_call = AudioAsset('winning_call').get_local_path() + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) +with open(str(mary_had_lamb), "rb") as f: + transcription = client.audio.transcriptions.create( + file=f, + model="openai/whisper-large-v3", + language="en", + response_format="text", + temperature=0.0) + print("transcription result:", transcription) diff --git a/requirements-common.txt b/requirements-common.txt index cfa02025629f..0b7253cc121d 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -8,12 +8,11 @@ py-cpuinfo transformers >= 4.48.2 # Required for Bamba model and Transformers backend. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. -fastapi >= 0.107.0, < 0.113.0; python_version < '3.9' -fastapi >= 0.107.0, != 0.113.*, != 0.114.0; python_version >= '3.9' +fastapi[standard] >= 0.107.0, < 0.113.0; python_version < '3.9' +fastapi[standard] >= 0.107.0, != 0.113.*, != 0.114.0; python_version >= '3.9' aiohttp openai >= 1.52.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support) -uvicorn[standard] -pydantic >= 2.9 # Required for fastapi >= 0.113.0 +pydantic >= 2.9 prometheus_client >= 0.18.0 pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 diff --git a/requirements-test.in b/requirements-test.in index 229d743ec802..ecf874ecc50f 100644 --- a/requirements-test.in +++ b/requirements-test.in @@ -19,6 +19,7 @@ pqdm ray[adag]==2.40.0 sentence-transformers # required for embedding tests soundfile # required for audio tests +jiwer # required for audio tests timm # required for internvl test torch==2.5.1 torchaudio==2.5.1 diff --git a/requirements-test.txt b/requirements-test.txt index e032aac710dd..648a2626c857 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -66,6 +66,7 @@ charset-normalizer==3.4.0 click==8.1.7 # via # black + # jiwer # nltk # ray colorama==0.4.6 @@ -187,6 +188,8 @@ jinja2==3.1.4 # via # datamodel-code-generator # torch +jiwer==3.0.5 + # via -r requirements-test.in jmespath==1.0.1 # via # boto3 @@ -470,6 +473,8 @@ pyyaml==6.0.2 # timm # transformers # vocos +rapidfuzz==3.12.1 + # via jiwer ray[adag]==2.40.0 # via -r requirements-test.in redis==5.2.0 diff --git a/tests/entrypoints/openai/correctness/__init__.py b/tests/entrypoints/openai/correctness/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/correctness/test_lmeval.py similarity index 98% rename from tests/entrypoints/openai/test_accuracy.py rename to tests/entrypoints/openai/correctness/test_lmeval.py index df25780cd0f4..ebb2ea4d9d14 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/correctness/test_lmeval.py @@ -13,7 +13,7 @@ from vllm.platforms import current_platform -from ...utils import RemoteOpenAIServer +from ....utils import RemoteOpenAIServer MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct" NUM_CONCURRENT = 500 diff --git a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py new file mode 100644 index 000000000000..19d4735b9dde --- /dev/null +++ b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py @@ -0,0 +1,166 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Evaluate Transcription API correctness by computing Word Error Rate (WER) +on a given ASR dataset. When provided, it will also compare the WER against +a baseline. +This simulates real work usage of the API and makes sure that the frontend and +AsyncLLMEngine are working correctly. +""" +import asyncio +import io +import time +from statistics import mean, median +from typing import List + +import librosa +import pytest +import soundfile +import torch +from datasets import load_dataset +from evaluate import load +from transformers import AutoTokenizer + +from ....utils import RemoteOpenAIServer + + +def to_bytes(y, sr): + buffer = io.BytesIO() + soundfile.write(buffer, y, sr, format="WAV") + buffer.seek(0) + return buffer + + +async def transcribe_audio(client, tokenizer, y, sr): + # Send loaded audio directly instead of loading from disk, + # dont account for that time though + with to_bytes(y, sr) as f: + start_time = time.perf_counter() + transcription = await client.audio.transcriptions.create( + file=f, + model=tokenizer.name_or_path, + language="en", + temperature=0.0, + ) + end_time = time.perf_counter() + # NOTE there's no streaming in transcriptions, can't measure ttft + latency = end_time - start_time + num_output_tokens = len( + tokenizer(transcription.text, add_special_tokens=False).input_ids) + return latency, num_output_tokens, transcription.text + + +async def bound_transcribe(model_name, sem, client, audio, reference): + tokenizer = AutoTokenizer.from_pretrained(model_name) + # Use semaphore to limit concurrent requests. + async with sem: + result = await transcribe_audio(client, tokenizer, *audio) + # Normalize *english* output/reference for evaluation. + out = tokenizer.normalize(result[2]) + ref = tokenizer.normalize(reference) + return result[:2] + (out, ref) + + +async def process_dataset(model, client, data, concurrent_request): + sem = asyncio.Semaphore(concurrent_request) + + # Warmup call as the first `librosa.load` server-side is quite slow. + audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"] + _ = await bound_transcribe(model, sem, client, (audio, sr), "") + + tasks: List[asyncio.Task] = [] + for sample in data: + audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"] + task = asyncio.create_task( + bound_transcribe(model, sem, client, (audio, sr), sample["text"])) + tasks.append(task) + return await asyncio.gather(*tasks) + + +def print_performance_metrics(results, total_time): + latencies = [res[0] for res in results] + total_tokens = sum([res[1] for res in results]) + + total = len(results) + print(f"Total Requests: {total}") + print(f"Successful Requests: {len(latencies)}") + print(f"Average Latency: {mean(latencies):.4f} seconds") + print(f"Median Latency: {median(latencies):.4f} seconds") + perc = sorted(latencies)[int(len(latencies) * 0.95) - 1] + print(f"95th Percentile Latency: {perc:.4f} seconds") + # Throughput + req_throughput = len(latencies) / total_time + print(f"Estimated req_Throughput: {req_throughput:.2f} requests/s") + throughput = total_tokens / total_time + print(f"Estimated Throughput: {throughput:.2f} tok/s") + + +def add_duration(sample): + y, sr = sample['audio']["array"], sample['audio']["sampling_rate"] + sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000 + return sample + + +def load_hf_dataset(dataset_repo: str, split='validation', **hf_kwargs): + ## Load and filter the dataset + dataset = load_dataset(dataset_repo, split=split, **hf_kwargs) + if 'duration_ms' not in dataset[0]: + # compute duration to filter + dataset = dataset.map(add_duration) + + # Whisper max supported duration + dataset = dataset.filter(lambda example: example['duration_ms'] < 30000) + return dataset + + +def run_evaluation(model: str, + client, + dataset, + max_concurrent_reqs: int, + n_examples: int = -1, + print_metrics: bool = True): + if n_examples > 0: + dataset = dataset.select(range(n_examples)) + start = time.perf_counter() + results = asyncio.run( + process_dataset(model, client, dataset, max_concurrent_reqs)) + end = time.perf_counter() + total_time = end - start + print(f"Total Test Time: {total_time:.4f} seconds") + if print_metrics: + print_performance_metrics(results, total_time) + # Compute WER + predictions = [res[2] for res in results] + references = [res[3] for res in results] + wer = load("wer") + wer_score = 100 * wer.compute(references=references, + predictions=predictions) + print("WER:", wer_score) + return wer_score + + +# alternatives "openai/whisper-large-v2", "openai/whisper-large-v3-turbo".. +@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"]) +# Original dataset is 20GB+ in size, hence we use a pre-filtered slice. +@pytest.mark.parametrize( + "dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"]) +# NOTE: Expected WER measured with equivalent hf.transformers args: +# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered. +@pytest.mark.parametrize("expected_wer", [12.744980]) +def test_wer_correctness(model_name, + dataset_repo, + expected_wer, + n_examples=-1, + max_concurrent_request=None): + with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server: + dataset = load_hf_dataset(dataset_repo) + + if not max_concurrent_request: + # No max concurrency + max_concurrent_request = n_examples if n_examples > 0\ + else len(dataset) + + client = remote_server.get_async_client() + wer = run_evaluation(model_name, client, dataset, + max_concurrent_request, n_examples) + if expected_wer: + torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2) diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py new file mode 100644 index 000000000000..5d4a5de4badd --- /dev/null +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 + +# imports for guided decoding tests +import io +import json + +import librosa +import numpy as np +import openai +import pytest +import soundfile as sf + +from vllm.assets.audio import AudioAsset + +from ...utils import RemoteOpenAIServer + + +@pytest.fixture +def mary_had_lamb(): + path = AudioAsset('mary_had_lamb').get_local_path() + with open(str(path), "rb") as f: + yield f + + +@pytest.fixture +def winning_call(): + path = AudioAsset('winning_call').get_local_path() + with open(str(path), "rb") as f: + yield f + + +@pytest.mark.asyncio +async def test_basic_audio(mary_had_lamb): + model_name = "openai/whisper-large-v3-turbo" + server_args = ["--enforce-eager"] + # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb. + prompt = "THE FIRST WORDS I SPOKE" + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + transcription = await client.audio.transcriptions.create( + model=model_name, + file=mary_had_lamb, + language="en", + response_format="text", + temperature=0.0) + out = json.loads(transcription)['text'] + assert "Mary had a little lamb," in out + # This should "force" whisper to continue prompt in all caps + transcription_wprompt = await client.audio.transcriptions.create( + model=model_name, + file=mary_had_lamb, + language="en", + response_format="text", + prompt=prompt, + temperature=0.0) + out_capital = json.loads(transcription_wprompt)['text'] + assert prompt not in out_capital + + +@pytest.mark.asyncio +async def test_bad_requests(mary_had_lamb): + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + + # invalid language + with pytest.raises(openai.BadRequestError): + await client.audio.transcriptions.create(model=model_name, + file=mary_had_lamb, + language="hh", + temperature=0.0) + + # Expect audio too long: repeat the timeseries + mary_had_lamb.seek(0) + audio, sr = librosa.load(mary_had_lamb) + repeated_audio = np.tile(audio, 10) + # Repeated audio to buffer + buffer = io.BytesIO() + sf.write(buffer, repeated_audio, sr, format='WAV') + buffer.seek(0) + with pytest.raises(openai.BadRequestError): + await client.audio.transcriptions.create(model=model_name, + file=buffer, + language="en", + temperature=0.0) + + +@pytest.mark.asyncio +async def test_non_asr_model(winning_call): + # text to text model + model_name = "JackFram/llama-68m" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + res = await client.audio.transcriptions.create(model=model_name, + file=winning_call, + language="en", + temperature=0.0) + assert res.code == 400 and not res.text + assert res.message == "The model does not support Transcriptions API" + + +@pytest.mark.asyncio +async def test_completion_endpoints(): + # text to text model + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + res = await client.chat.completions.create( + model=model_name, + messages=[{ + "role": "system", + "content": "You are a helpful assistant." + }]) + assert res.code == 400 + assert res.message == "The model does not support Chat Completions API" + + res = await client.completions.create(model=model_name, prompt="Hello") + assert res.code == 400 + assert res.message == "The model does not support Completions API" diff --git a/tests/test_config.py b/tests/test_config.py index 2dfae218b47d..3fb83b4c0328 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -17,6 +17,7 @@ ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"), ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "score"), ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"), + ("openai/whisper-small", "transcription", "transcription"), ], ) def test_auto_task(model_id, expected_runner_type, expected_task): diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py index d9e51082e6ca..0203dc092a71 100644 --- a/vllm/assets/audio.py +++ b/vllm/assets/audio.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass +from pathlib import Path from typing import Literal from urllib.parse import urljoin @@ -28,6 +29,10 @@ def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]: s3_prefix=ASSET_DIR) return librosa.load(audio_path, sr=None) + def get_local_path(self) -> Path: + return get_vllm_public_assets(filename=f"{self.name}.ogg", + s3_prefix=ASSET_DIR) + @property def url(self) -> str: return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg") diff --git a/vllm/config.py b/vllm/config.py index 9ba497576124..561569162ed8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -54,17 +54,18 @@ _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", - "score", "reward"] + "score", "reward", "transcription"] _ResolvedTask = Literal["generate", "embed", "classify", "score", "reward", - "draft"] + "draft", "transcription"] -RunnerType = Literal["generate", "pooling", "draft"] +RunnerType = Literal["generate", "pooling", "draft", "transcription"] _RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = { "generate": ["generate"], "pooling": ["embed", "classify", "score", "reward"], "draft": ["draft"], + "transcription": ["transcription"], } _TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = { @@ -483,6 +484,8 @@ def _get_preferred_task( return "embed" if ModelRegistry.is_cross_encoder_model(architectures): return "score" + if ModelRegistry.is_transcription_model(architectures): + return "transcription" suffix_to_preferred_task: List[Tuple[str, _ResolvedTask]] = [ # Other models follow this pattern @@ -515,6 +518,8 @@ def _resolve_task( runner_support: Dict[RunnerType, bool] = { # NOTE: Listed from highest to lowest priority, # in case the model supports multiple of them + "transcription": + ModelRegistry.is_transcription_model(architectures), "generate": ModelRegistry.is_text_generation_model(architectures), "pooling": ModelRegistry.is_pooling_model(architectures), } diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b8f54d6c7804..781cff350529 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -17,10 +17,10 @@ from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union +from typing import Annotated, AsyncIterator, Dict, Optional, Set, Tuple, Union import uvloop -from fastapi import APIRouter, FastAPI, HTTPException, Request +from fastapi import APIRouter, FastAPI, Form, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -62,6 +62,8 @@ ScoreRequest, ScoreResponse, TokenizeRequest, TokenizeResponse, + TranscriptionRequest, + TranscriptionResponse, UnloadLoraAdapterRequest) from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager # yapf: enable @@ -76,6 +78,8 @@ from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) +from vllm.entrypoints.openai.serving_transcription import ( + OpenAIServingTranscription) from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.utils import with_cancellation from vllm.logger import init_logger @@ -319,6 +323,10 @@ def tokenization(request: Request) -> OpenAIServingTokenization: return request.app.state.openai_serving_tokenization +def transcription(request: Request) -> OpenAIServingTranscription: + return request.app.state.openai_serving_transcription + + def engine_client(request: Request) -> EngineClient: return request.app.state.engine_client @@ -511,6 +519,31 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): return await create_score(request, raw_request) +@router.post("/v1/audio/transcriptions") +@with_cancellation +async def create_transcriptions(request: Annotated[TranscriptionRequest, + Form()], + raw_request: Request): + + handler = transcription(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Transcriptions API") + + audio_data = await request.file.read() + generator = await handler.create_transcription(audio_data, request, + raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + + elif isinstance(generator, TranscriptionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + @router.post("/rerank") @with_cancellation async def do_rerank(request: RerankRequest, raw_request: Request): @@ -821,6 +854,12 @@ async def init_app_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, ) + state.openai_serving_transcription = OpenAIServingTranscription( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + ) if model_config.runner_type == "transcription" else None state.task = model_config.task diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 83b841826231..2bcfdc235776 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -8,9 +8,10 @@ from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union import torch +from fastapi import UploadFile from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, ValidationInfo, field_validator, model_validator) -from typing_extensions import Annotated +from typing_extensions import Annotated, TypeAlias from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.logger import init_logger @@ -1426,3 +1427,163 @@ class LoadLoraAdapterRequest(BaseModel): class UnloadLoraAdapterRequest(BaseModel): lora_name: str lora_int_id: Optional[int] = Field(default=None) + + +## Protocols for Audio +AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", + "vtt"] + + +class TranscriptionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + #https://platform.openai.com/docs/api-reference/audio/createTranscription + + file: UploadFile + """ + The audio file object (not file name) to transcribe, in one of these + formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. + """ + + model: str + """ID of the model to use. + """ + + language: Optional[str] = None + """The language of the input audio. + + Supplying the input language in + [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format + will improve accuracy and latency. + """ + + prompt: str = Field(default="") + """An optional text to guide the model's style or continue a previous audio + segment. + + The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting) + should match the audio language. + """ + + response_format: AudioResponseFormat = Field(default="json") + """ + The format of the output, in one of these options: `json`, `text`, `srt`, + `verbose_json`, or `vtt`. + """ + + ## TODO (varun) : Support if set to 0, certain thresholds are met !! + temperature: float = Field(default=0.0) + """The sampling temperature, between 0 and 1. + + Higher values like 0.8 will make the output more random, while lower values + like 0.2 will make it more focused / deterministic. If set to 0, the model + will use [log probability](https://en.wikipedia.org/wiki/Log_probability) + to automatically increase the temperature until certain thresholds are hit. + """ + + timestamp_granularities: List[Literal["word", "segment"]] = Field( + alias="timestamp_granularities[]", default=[]) + """The timestamp granularities to populate for this transcription. + + `response_format` must be set `verbose_json` to use timestamp granularities. + Either or both of these options are supported: `word`, or `segment`. Note: + There is no additional latency for segment timestamps, but generating word + timestamps incurs additional latency. + """ + + # Default sampling parameters for transcription requests. + _DEFAULT_SAMPLING_PARAMS: dict = { + "temperature": 0, + } + + def to_sampling_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None) -> SamplingParams: + # TODO(#9845): remove max_tokens when field is removed from OpenAI API + max_tokens = default_max_tokens + + if default_sampling_params is None: + default_sampling_params = {} + # Default parameters + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + + return SamplingParams.from_optional(temperature=temperature, + max_tokens=max_tokens) + + +# Transcription response objects +class TranscriptionResponse(OpenAIBaseModel): + text: str + """The transcribed text.""" + + +class TranscriptionWord(OpenAIBaseModel): + end: float + """End time of the word in seconds.""" + + start: float + """Start time of the word in seconds.""" + + word: str + """The text content of the word.""" + + +class TranscriptionSegment(OpenAIBaseModel): + id: int + """Unique identifier of the segment.""" + + avg_logprob: float + """Average logprob of the segment. + + If the value is lower than -1, consider the logprobs failed. + """ + + compression_ratio: float + """Compression ratio of the segment. + + If the value is greater than 2.4, consider the compression failed. + """ + + end: float + """End time of the segment in seconds.""" + + no_speech_prob: float + """Probability of no speech in the segment. + + If the value is higher than 1.0 and the `avg_logprob` is below -1, consider + this segment silent. + """ + + seek: int + """Seek offset of the segment.""" + + start: float + """Start time of the segment in seconds.""" + + temperature: float + """Temperature parameter used for generating the segment.""" + + text: str + """Text content of the segment.""" + + tokens: List[int] + """Array of token IDs for the text content.""" + + +class TranscriptionResponseVerbose(OpenAIBaseModel): + duration: str + """The duration of the input audio.""" + + language: str + """The language of the input audio.""" + + text: str + """The transcribed text.""" + + segments: Optional[List[TranscriptionSegment]] = None + """Segments of the transcribed text and their corresponding details.""" + + words: Optional[List[TranscriptionWord]] = None + """Extracted words and their corresponding timestamps.""" diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8d39fdcb7483..d4ce4e91251a 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -31,7 +31,8 @@ ErrorResponse, RerankRequest, ScoreRequest, TokenizeChatRequest, - TokenizeCompletionRequest) + TokenizeCompletionRequest, + TranscriptionRequest) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser # yapf: enable @@ -57,7 +58,8 @@ ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, TokenizeChatRequest] -AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest] +AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, + TranscriptionRequest] class TextTokensPrompt(TypedDict): diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py new file mode 100644 index 000000000000..da4930e0e2d8 --- /dev/null +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -0,0 +1,305 @@ +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import io +from typing import AsyncGenerator, Optional, Union, cast + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (ErrorResponse, + RequestResponseMetadata, + TranscriptionRequest, + TranscriptionResponse, + TranscriptionResponseVerbose) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import PromptType +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.utils import PlaceholderModule + +try: + import librosa +except ImportError: + librosa = PlaceholderModule("librosa") # type: ignore[assignment] + +logger = init_logger(__name__) + +# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages#supported-languages +# TODO these configs should live somewhere with the model so we can support +# additional ones + +ISO639_1_SUPPORTED_LANGS = { + "af": "Afrikaans", + "ar": "Arabic", + "hy": "Armenian", + "az": "Azerbaijani", + "be": "Belarusian", + "bs": "Bosnian", + "bg": "Bulgarian", + "ca": "Catalan", + "zh": "Chinese", + "hr": "Croatian", + "cs": "Czech", + "da": "Danish", + "nl": "Dutch", + "en": "English", + "et": "Estonian", + "fi": "Finnish", + "fr": "French", + "gl": "Galician", + "de": "German", + "el": "Greek", + "he": "Hebrew", + "hi": "Hindi", + "hu": "Hungarian", + "is": "Icelandic", + "id": "Indonesian", + "it": "Italian", + "ja": "Japanese", + "kn": "Kannada", + "kk": "Kazakh", + "ko": "Korean", + "lv": "Latvian", + "lt": "Lithuanian", + "mk": "Macedonian", + "ms": "Malay", + "mr": "Marathi", + "mi": "Maori", + "ne": "Nepali", + "no": "Norwegian", + "fa": "Persian", + "pl": "Polish", + "pt": "Portuguese", + "ro": "Romanian", + "ru": "Russian", + "sr": "Serbian", + "sk": "Slovak", + "sl": "Slovenian", + "es": "Spanish", + "sw": "Swahili", + "sv": "Swedish", + "tl": "Tagalog", + "ta": "Tamil", + "th": "Thai", + "tr": "Turkish", + "uk": "Ukrainian", + "ur": "Urdu", + "vi": "Vietnamese", + "cy": "Welsh" +} +ISO639_1_OTHER_LANGS = { + "lo": "Lao", + "jw": "Javanese", + "tk": "Turkmen", + "yi": "Yiddish", + "so": "Somali", + "bn": "Bengali", + "nn": "Norwegian Nynorsk", + "si": "Sinhala", + "yo": "Yoruba", + "sa": "Sanskrit", + "mi": "Māori", + "fo": "Faroese", # codespell:ignore + "mt": "Maltese", + "tg": "Tajik", + "mg": "Malagasy", + "haw": "Hawaiian", + "km": "Khmer", + "br": "Breton", + "ps": "Pashto", + "ln": "Lingala", + "la": "Latin", + "ml": "Malayalam", + "sq": "Albanian", + "su": "Sundanese", + "eu": "Basque", + "ka": "Georgian", + "uz": "Uzbek", + "sn": "Shona", + "ht": "Haitian", + "as": "Assamese", + "mn": "Mongolian", + "te": "Telugu", + "pa": "Panjabi", + "tt": "Tatar", + "gu": "Gujarati", + "oc": "Occitan", + "ha": "Hausa", + "ba": "Bashkir", + "my": "Burmese", + "sd": "Sindhi", + "am": "Amharic", + "lb": "Luxembourgish", + "bo": "Tibetan" +} + +# As per https://platform.openai.com/docs/guides/speech-to-text#overview. +# TODO configurable +MAX_AUDIO_CLIP_FILESIZE_MB = 25 +# TODO get from processor.feature_extractor.chunk_length +MAX_AUDIO_CLIP_DURATION_S = 30 + + +class OpenAIServingTranscription(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + ): + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids) + + diff_sampling_param = self.model_config.get_diff_sampling_param() + if diff_sampling_param: + logger.info( + "Overwriting default completion sampling param with: %s", + diff_sampling_param) + + async def _preprocess_transcription( + self, + request: TranscriptionRequest, + audio_data: bytes, + ) -> PromptType: + # Validate request + # TODO language should be optional and can be guessed. + # For now we default to en. See + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520 + lang_token = f"<|{request.language}|>" if request.language else "<|en|>" + if request.language: + if request.language in ISO639_1_SUPPORTED_LANGS: + pass + elif request.language in ISO639_1_OTHER_LANGS: + logger.warning( + "The selected language %s has limited accuracy with" + " reported WER>=0.5. Results may be less accurate " + "for this choice.", request.language) + else: + raise ValueError( + f"Unsupported language: {request.language}." + "Language should be one of:" + + f" {list(ISO639_1_SUPPORTED_LANGS.values())}" + + f"or {list(ISO639_1_OTHER_LANGS.values())}") + + if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB: + raise ValueError("Maximum file size exceeded.") + + with io.BytesIO(audio_data) as bytes_: + y, sr = librosa.load(bytes_) + if librosa.get_duration(y=y, sr=sr) > MAX_AUDIO_CLIP_DURATION_S: + raise ValueError( + f"Maximum clip duration ({MAX_AUDIO_CLIP_DURATION_S}s) " + "exceeded.") + + prompt = { + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": (y, sr), + }, + }, + "decoder_prompt": + f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}" + } + return cast(PromptType, prompt) + + # TODO (varun) : Make verbose response work ! + async def create_transcription( + self, audio_data: bytes, request: TranscriptionRequest, + raw_request: Request + ) -> Union[TranscriptionResponse, TranscriptionResponseVerbose, + ErrorResponse]: + """Transcription API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/audio/createTranscription + for the API specification. This API mimics the OpenAI transcription API. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + if request.response_format not in ['text', 'json']: + return self.create_error_response( + "Currently only support response_format `text` or `json`") + + # TODO cmpl->transcription? + request_id = f"cmpl-{self._base_request_id(raw_request)}" + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + if lora_request: + return self.create_error_response( + "Currently do not support LoRA for Transcription.") + if prompt_adapter_request: + return self.create_error_response( + "Currently do not support PromptAdapter for Transcription." + ) + + prompt = await self._preprocess_transcription( + request=request, + audio_data=audio_data, + ) + + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + result_generator: Optional[AsyncGenerator[RequestOutput, None]] = None + try: + # TODO(rob): subtract len of tokenized prompt. + default_max_tokens = self.model_config.max_model_len + default_params = self.model_config.get_diff_sampling_param() + sampling_params = request.to_sampling_params( + default_max_tokens, default_params) + + self._log_inputs( + request_id, + prompt['decoder_prompt'], # type: ignore + params=sampling_params, + lora_request=None, + prompt_adapter_request=None) + + result_generator = self.engine_client.generate( + prompt, + sampling_params, + request_id, + ) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + # TODO(rob): figure out a way to pipe streaming in. + # Non-streaming response. + try: + async for op in result_generator: + result = op + return TranscriptionResponse(text=result.outputs[0].text) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 0fc5c4db179c..a0a1b69ad502 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -441,3 +441,30 @@ def supports_cross_encoding( model: Union[Type[object], object], ) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: return is_pooling_model(model) and _supports_cross_encoding(model) + + +@runtime_checkable +class SupportsTranscription(Protocol): + """The interface required for all models that support transcription.""" + + supports_transcription: ClassVar[Literal[True]] = True + + +@overload +def supports_transcription( + model: Type[object]) -> TypeIs[Type[SupportsTranscription]]: + ... + + +@overload +def supports_transcription(model: object) -> TypeIs[SupportsTranscription]: + ... + + +def supports_transcription( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsTranscription]], TypeIs[SupportsTranscription]]: + if isinstance(model, type): + return isinstance(model, SupportsTranscription) + + return isinstance(model, SupportsTranscription) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 3b2a7069efc9..c3f1dd36fc1b 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -22,7 +22,7 @@ from .interfaces import (has_inner_state, is_attention_free, is_hybrid, supports_cross_encoding, supports_multimodal, - supports_pp) + supports_pp, supports_transcription) from .interfaces_base import is_text_generation_model logger = init_logger(__name__) @@ -212,6 +212,7 @@ class _ModelInfo: has_inner_state: bool is_attention_free: bool is_hybrid: bool + supports_transcription: bool @staticmethod def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": @@ -225,7 +226,7 @@ def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), is_hybrid=is_hybrid(model), - ) + supports_transcription=supports_transcription(model)) class _BaseRegisteredModel(ABC): @@ -473,6 +474,13 @@ def is_hybrid_model( model_cls, _ = self.inspect_model_cls(architectures) return model_cls.is_hybrid + def is_transcription_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_transcription + ModelRegistry = _ModelRegistry({ model_arch: diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 0a3011d36101..0b506072094e 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -31,7 +31,7 @@ from vllm.sequence import SequenceData from vllm.transformers_utils.processor import cached_get_processor -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsTranscription from .utils import AutoWeightsLoader, WeightsMapper, make_layers logger = init_logger(__name__) @@ -637,7 +637,8 @@ def input_mapper_for_whisper( @MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper) @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( "audio", get_max_whisper_audio_tokens) -class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal): +class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, + SupportsMultiModal): packed_modules_mapping = { "self_attn.qkv_proj": [ "self_attn.q_proj", From 902459cf79dcc4c0e855b2dfca25364ecf467f88 Mon Sep 17 00:00:00 2001 From: Wen Sun <35923278+HermitSun@users.noreply.github.com> Date: Tue, 25 Feb 2025 22:03:33 +0800 Subject: [PATCH 08/10] Fix `/v1/audio/transcriptions ` Bad Request Error (#13811) --- requirements-common.txt | 3 +-- vllm/entrypoints/openai/protocol.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/requirements-common.txt b/requirements-common.txt index 0b7253cc121d..2c1ac6595ef8 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -8,8 +8,7 @@ py-cpuinfo transformers >= 4.48.2 # Required for Bamba model and Transformers backend. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. -fastapi[standard] >= 0.107.0, < 0.113.0; python_version < '3.9' -fastapi[standard] >= 0.107.0, != 0.113.*, != 0.114.0; python_version >= '3.9' +fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. aiohttp openai >= 1.52.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support) pydantic >= 2.9 diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 2bcfdc235776..4bd52df2a821 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1436,7 +1436,7 @@ class UnloadLoraAdapterRequest(BaseModel): class TranscriptionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation - #https://platform.openai.com/docs/api-reference/audio/createTranscription + # https://platform.openai.com/docs/api-reference/audio/createTranscription file: UploadFile """ From d53845c919b4e7e4e9999bad3efabd265ced21ff Mon Sep 17 00:00:00 2001 From: andy-neuma Date: Tue, 25 Mar 2025 10:27:08 -0400 Subject: [PATCH 09/10] update version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 380f37b0e100..9fa0c1225e17 100755 --- a/setup.py +++ b/setup.py @@ -478,7 +478,7 @@ def get_gaudi_sw_version(): def get_vllm_version() -> str: - version = "0.7.2.0" + version = "0.7.2.1" sep = "+" if "+" not in version else "." # dev versions might contain + From bab9c19d31d4bba8db0ac2edc73800462ecbbddf Mon Sep 17 00:00:00 2001 From: andy-neuma Date: Tue, 25 Mar 2025 17:36:42 -0400 Subject: [PATCH 10/10] missed a spot --- vllm/version.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/version.py b/vllm/version.py index 10f87d243497..233e9fe5f7b4 100644 --- a/vllm/version.py +++ b/vllm/version.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 try: - __version__ = "0.7.2.0" - __version_tuple__ = (0, 7, 2, 0) + __version__ = "0.7.2.1" + __version_tuple__ = (0, 7, 2, 1) except Exception as e: import warnings