diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 7ef40564c5bd..ea71e8ac0d2f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -113,7 +113,7 @@ steps: - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process - pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process - - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/correctness/ - pytest -v -s entrypoints/test_chat_utils.py - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests @@ -328,6 +328,14 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - bash ./run-tests.sh -c configs/models-small.txt -t 1 +- label: OpenAI API correctness + source_file_dependencies: + - csrc/ + - vllm/entrypoints/openai/ + - vllm/model_executor/models/whisper.py + commands: # LMEval+Transcription WER check + - pytest -s entrypoints/openai/correctness/ + - label: Encoder Decoder tests # 5min source_file_dependencies: - vllm/ diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 88275dbdd83a..b3aa5b74e00b 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -4,6 +4,8 @@ #include +#include "core/math.hpp" + #include "cuda_compat.h" #include "dispatch_utils.h" @@ -31,6 +33,69 @@ __global__ void act_and_mul_kernel( } } +// NOTE: temporary vectorized version. + +template +__global__ void act_and_mul_kernel_vectorized( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const int d) { + const int64_t token_idx = blockIdx.x; + const int32_t blocks_per_token = gridDim.y; + + const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t); + + const int32_t tgt_elems_per_block = ceil_div(d, blocks_per_token); + const int32_t elems_per_block = + next_multiple_of(elems_per_128bit_load, tgt_elems_per_block); + const int64_t block_start = blockIdx.y * int64_t(elems_per_block); + int64_t block_end = block_start + elems_per_block; + block_end = block_end > d ? d : block_end; + + const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d; + const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d; + scalar_t* __restrict__ out_ptr = out + token_idx * d; + + // 128-bit vectorized code + const int32_t vec_loop_end = + prev_multiple_of(elems_per_128bit_load, block_end); + const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load; + const int32_t vec_start_idx = block_start / elems_per_128bit_load; + + const int4* __restrict__ x_128bit_ptr = reinterpret_cast(x_ptr); + const int4* __restrict__ y_128bit_ptr = reinterpret_cast(y_ptr); + int4* __restrict__ out_128bit_ptr = reinterpret_cast(out_ptr); + +#pragma unroll + for (int32_t vec_idx = vec_start_idx + threadIdx.x; vec_idx < vec_end_idx; + vec_idx += blockDim.x) { + const int4 x_128bit = VLLM_LDG(&x_128bit_ptr[vec_idx]); + const int4 y_128bit = VLLM_LDG(&y_128bit_ptr[vec_idx]); + using scalar_128bit_vec_t = std::array; + + scalar_128bit_vec_t out_vec; + const auto x_vec = reinterpret_cast(x_128bit); + const auto y_vec = reinterpret_cast(y_128bit); + +#pragma unroll + for (int i = 0; i < elems_per_128bit_load; i++) { + out_vec[i] = ACT_FN(x_vec[i]) * y_vec[i]; + } + + out_128bit_ptr[vec_idx] = reinterpret_cast(out_vec); + } + + // Scalar cleanup code + if (block_end > vec_loop_end) { + for (int64_t idx = vec_loop_end + threadIdx.x; idx < block_end; + idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&x_ptr[idx]); + const scalar_t y = VLLM_LDG(&y_ptr[idx]); + out_ptr[idx] = ACT_FN(x) * y; + } + } +} + template __device__ __forceinline__ T silu_kernel(const T& x) { // x * sigmoid(x) @@ -79,10 +144,26 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { input.data_ptr(), d); \ }); +// Launch activation and gating kernel. +// Vectorized Version +#define LAUNCH_ACTIVATION_GATE_KERNEL_VECTORIZED(KERNEL) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \ + dim3 block(std::min(d, 512)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "act_and_mul_kernel_vectorized", [&] { \ + vllm::act_and_mul_kernel_vectorized> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); + void silu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { - LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true); + LAUNCH_ACTIVATION_GATE_KERNEL_VECTORIZED(vllm::silu_kernel); } void mul_and_silu(torch::Tensor& out, // [..., d] diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp index ddfaca27147b..2cc05960d5cd 100644 --- a/csrc/core/math.hpp +++ b/csrc/core/math.hpp @@ -11,4 +11,16 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) { template inline constexpr std::enable_if_t, T> ceil_div(T a, T b) { return (a + b - 1) / b; -} \ No newline at end of file +} + +// Compute the next multiple of a that is greater than or equal to b +template +static inline constexpr auto next_multiple_of(A a, B b) { + return ceil_div(b, a) * a; +} + +// Compute the largest multiple of a that is less than or equal to b +template +static inline constexpr auto prev_multiple_of(A a, B b) { + return (b / a) * a; +} diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh index 32ea5db3321b..5689f1b6e64c 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh @@ -12,52 +12,248 @@ namespace vllm { using c3x::cutlass_gemm_caller; -template typename Epilogue> -struct sm90_fp8_config_default { - // M in (128, inf) - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; +#define CALL_CUTLASS_GEMM \ + cutlass_gemm_caller< \ + cutlass_3x_gemm>( \ + out, a, b, std::forward(args)...); + +struct sm90_fp8_config_M64 { + // M in [1, 64] + using ClusterShape = Shape<_1, _8, _1>; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + + template typename Epilogue, + typename... EpilogueArgs> + static void dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + + uint32_t const n = out.size(1); + + if (n < 8 * 1024) { + using TileShape = Shape<_64, _64, _128>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + CALL_CUTLASS_GEMM + + } else if (n < 16 * 1024) { + using TileShape = Shape<_64, _128, _128>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + CALL_CUTLASS_GEMM + + } else { + using TileShape = Shape<_64, _128, _128>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + CALL_CUTLASS_GEMM + } + } }; -template typename Epilogue> struct sm90_fp8_config_M128 { // M in (64, 128] - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + + template typename Epilogue, + typename... EpilogueArgs> + static void dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + + uint32_t const n = out.size(1); + + if (n <= 4 * 1024) { + using TileShape = Shape<_64, _64, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + + CALL_CUTLASS_GEMM + + } else if (n <= 8 * 1024) { + using TileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + CALL_CUTLASS_GEMM + + } else if (n <= 16 * 1024) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _8, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + + } else if (n <= 24 * 1024) { + using TileShape = Shape<_128, _64, _128>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + + } else { + using TileShape = Shape<_128, _64, _128>; + using ClusterShape = Shape<_1, _8, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + } + } }; -template typename Epilogue> -struct sm90_fp8_config_M64 { - // M in [1, 64] - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _128>; - using ClusterShape = Shape<_1, _8, _1>; +struct sm90_fp8_config_M256 { + // M in (128, 256] + + template typename Epilogue, + typename... EpilogueArgs> + static void dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + + uint32_t const n = out.size(1); + + if (n <= 4 * 1024) { + using TileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + CALL_CUTLASS_GEMM + + } else if (n <= 8 * 1024) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + + } else if (n <= 16 * 1024) { + using TileShape = Shape<_128, _256, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + CALL_CUTLASS_GEMM + + } else if (n <= 24 * 1024) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + + } else { + using TileShape = Shape<_256, _128, _64>; + using ClusterShape = Shape<_1, _8, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + CALL_CUTLASS_GEMM + } + } +}; + +struct sm90_fp8_config_M3072 { + // M in (256, 3072] + + template typename Epilogue, + typename... EpilogueArgs> + static void dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + + uint32_t const n = out.size(1); + + if (n <= 4 * 1024) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM - using Cutlass3xGemm = - cutlass_3x_gemm; + } else if (n <= 8 * 1024) { + using TileShape = Shape<_128, _256, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + CALL_CUTLASS_GEMM + + } else if (n <= 16 * 1024) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + + } else if (n <= 24 * 1024) { + using TileShape = Shape<_128, _256, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + CALL_CUTLASS_GEMM + + } else { + using TileShape = Shape<_64, _256, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + } + } +}; + +struct sm90_fp8_config_default { + // M in (3072, inf) + + template typename Epilogue, + typename... EpilogueArgs> + static void dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + + static_assert(std::is_same()); + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + CALL_CUTLASS_GEMM + } }; +#undef CALL_CUTLASS_GEMM + template typename Epilogue, typename... EpilogueArgs> @@ -69,29 +265,27 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); - using Cutlass3xGemmDefault = - typename sm90_fp8_config_default::Cutlass3xGemm; - using Cutlass3xGemmM64 = - typename sm90_fp8_config_M64::Cutlass3xGemm; - using Cutlass3xGemmM128 = - typename sm90_fp8_config_M128::Cutlass3xGemm; - uint32_t const m = a.size(0); - uint32_t const mp2 = - std::max(static_cast(64), next_pow_2(m)); // next power of 2 - if (mp2 <= 64) { + if (m <= 64) { // m in [1, 64] - return cutlass_gemm_caller( + return sm90_fp8_config_M64::dispatch( out, a, b, std::forward(args)...); - } else if (mp2 <= 128) { + } else if (m <= 128) { // m in (64, 128] - return cutlass_gemm_caller( + return sm90_fp8_config_M128::dispatch( + out, a, b, std::forward(args)...); + } else if (m <= 256) { + // m in (128, 256] + return sm90_fp8_config_M256::dispatch( + out, a, b, std::forward(args)...); + } else if (m <= 3072) { + // m in (256, 3072] + return sm90_fp8_config_M3072::dispatch( out, a, b, std::forward(args)...); } else { - // m in (128, inf) - return cutlass_gemm_caller( + // m in (3072, inf] + return sm90_fp8_config_default::dispatch( out, a, b, std::forward(args)...); } } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index f2fae4b66d65..fd0bb6ff2947 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -71,12 +71,19 @@ struct enable_sm89_to_sm90 : Kernel { #endif } }; + +using GemmUniversalMode = cutlass::gemm::GemmUniversalMode; + template typename ArchGuard, typename ElementAB_, typename ElementD_, template typename Epilogue_, typename TileShape, typename WarpShape, typename InstructionShape, int32_t MainLoopStages, - typename FP8MathOperator = cutlass::arch::OpMultiplyAdd> + typename FP8MathOperator = cutlass::arch::OpMultiplyAdd, + typename ThreadblockSwizzle = + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, + GemmUniversalMode Mode_ = GemmUniversalMode::kGemm> struct cutlass_2x_gemm { + static const GemmUniversalMode Mode = Mode_; using ElementAB = ElementAB_; using ElementD = ElementD_; @@ -115,7 +122,7 @@ struct cutlass_2x_gemm { Arch, TileShape, WarpShape, InstructionShape, EVTD, - cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, + ThreadblockSwizzle, MainLoopStages, Operator, 1 /* epilogue stages */ >::GemmKernel>; diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_configs.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_configs.cuh new file mode 100644 index 000000000000..2e229787122c --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_configs.cuh @@ -0,0 +1,378 @@ +template typename Epilogue> +struct sm89_fp8_config_0 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 4; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_1 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 4; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_2 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 5; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_3 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<32, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 4; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_4 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 4; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_5 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<16, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 3; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_6 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<16, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 3; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = typename cutlass::gemm::threadblock:: + GemmSplitKHorizontalThreadblockSwizzle; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_7 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<16, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 2; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = typename cutlass::gemm::threadblock:: + GemmSplitKHorizontalThreadblockSwizzle; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_8 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<16, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 3; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_9 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 2; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = typename cutlass::gemm::threadblock:: + GemmSplitKHorizontalThreadblockSwizzle; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_10 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 4; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_11 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 3; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_12 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 4; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = typename cutlass::gemm::threadblock:: + GemmSplitKHorizontalThreadblockSwizzle; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_13 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 4; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle< + 1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_14 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<32, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 3; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_15 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 3; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_16 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 3; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemm; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm89_fp8_config_17 { + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + static constexpr int32_t MainLoopStages = 2; + using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum; + using ThreadBlockSwizzle = typename cutlass::gemm::threadblock:: + GemmSplitKHorizontalThreadblockSwizzle; + static constexpr cutlass::gemm::GemmUniversalMode GemmMode = + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + + using Cutlass2xGemm = + vllm::cutlass_2x_gemm; +}; \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh index 4e82c99c3af3..a96b9e594a36 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh @@ -3,6 +3,8 @@ #include "scaled_mm_c2x.cuh" #include "cutlass/float8.h" +#include "scaled_mm_c2x_sm89_fp8_configs.cuh" + /** * This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm * shape. @@ -334,7 +336,551 @@ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out, TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); - uint32_t const m = a.size(0); + uint32_t const m = out.size(0); + uint32_t const n = out.size(1); + uint32_t const k = b.size(0); + + if (m == 1) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_1::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 16) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller< + typename sm89_fp8_config_3::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 32) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_0::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_4::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_4::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_4::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_4::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller< + typename sm89_fp8_config_3::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_3::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller< + typename sm89_fp8_config_3::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_5::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_6::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_7::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 64) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_4::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_1::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller< + typename sm89_fp8_config_8::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_8::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_9::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_8::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 128) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller< + typename sm89_fp8_config_1::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller< + typename sm89_fp8_config_2::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 256) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 512) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 1024) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 2048) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else if (m <= 4096) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } else { // m512 kernels + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass2xGemm>( + out, a, b, std::forward(args)...); + } + uint32_t const mp2 = std::max(static_cast(32), next_pow_2(m)); // next power of 2 diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 82ef54c16daf..64439475fdb5 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -41,6 +41,8 @@ We currently support the following OpenAI APIs: - *Note: `parallel_tool_calls` and `user` parameters are ignored.* - [Embeddings API](#embeddings-api) (`/v1/embeddings`) - Only applicable to [embedding models](../models/pooling_models.md) (`--task embed`). +- [Transcriptions API](#transcriptions-api) (`/v1/audio/transcriptions`) + - Only applicable to Automatic Speech Recognition (ASR) models (OpenAI Whisper) (`--task generate`). In addition, we have the following custom APIs: @@ -296,6 +298,17 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s :end-before: end-chat-embedding-extra-params ::: +(transcriptions-api)= + +### Transcriptions API + +Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription); +you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it. + + + +Code example: + (tokenizer-api)= ### Tokenizer API diff --git a/examples/online_serving/openai_transcription_client.py b/examples/online_serving/openai_transcription_client.py new file mode 100644 index 000000000000..bd3c02a8a95e --- /dev/null +++ b/examples/online_serving/openai_transcription_client.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +from openai import OpenAI + +from vllm.assets.audio import AudioAsset + +mary_had_lamb = AudioAsset('mary_had_lamb').get_local_path() +winning_call = AudioAsset('winning_call').get_local_path() + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) +with open(str(mary_had_lamb), "rb") as f: + transcription = client.audio.transcriptions.create( + file=f, + model="openai/whisper-large-v3", + language="en", + response_format="text", + temperature=0.0) + print("transcription result:", transcription) diff --git a/examples/run_lm_eval.sh b/examples/run_lm_eval.sh new file mode 100644 index 000000000000..db36fc3755a7 --- /dev/null +++ b/examples/run_lm_eval.sh @@ -0,0 +1,2 @@ +MODEL=neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic +lm_eval --model vllm --model_args "pretrained=$MODEL" --tasks gsm8k --batch_size "auto" diff --git a/requirements-common.txt b/requirements-common.txt index cfa02025629f..2c1ac6595ef8 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -8,12 +8,10 @@ py-cpuinfo transformers >= 4.48.2 # Required for Bamba model and Transformers backend. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. -fastapi >= 0.107.0, < 0.113.0; python_version < '3.9' -fastapi >= 0.107.0, != 0.113.*, != 0.114.0; python_version >= '3.9' +fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. aiohttp openai >= 1.52.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support) -uvicorn[standard] -pydantic >= 2.9 # Required for fastapi >= 0.113.0 +pydantic >= 2.9 prometheus_client >= 0.18.0 pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 diff --git a/requirements-test.in b/requirements-test.in index 229d743ec802..ecf874ecc50f 100644 --- a/requirements-test.in +++ b/requirements-test.in @@ -19,6 +19,7 @@ pqdm ray[adag]==2.40.0 sentence-transformers # required for embedding tests soundfile # required for audio tests +jiwer # required for audio tests timm # required for internvl test torch==2.5.1 torchaudio==2.5.1 diff --git a/requirements-test.txt b/requirements-test.txt index e032aac710dd..648a2626c857 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -66,6 +66,7 @@ charset-normalizer==3.4.0 click==8.1.7 # via # black + # jiwer # nltk # ray colorama==0.4.6 @@ -187,6 +188,8 @@ jinja2==3.1.4 # via # datamodel-code-generator # torch +jiwer==3.0.5 + # via -r requirements-test.in jmespath==1.0.1 # via # boto3 @@ -470,6 +473,8 @@ pyyaml==6.0.2 # timm # transformers # vocos +rapidfuzz==3.12.1 + # via jiwer ray[adag]==2.40.0 # via -r requirements-test.in redis==5.2.0 diff --git a/setup.py b/setup.py index a4043c43a7d5..9fa0c1225e17 100755 --- a/setup.py +++ b/setup.py @@ -15,7 +15,6 @@ from packaging.version import Version, parse from setuptools import Extension, find_packages, setup from setuptools.command.build_ext import build_ext -from setuptools_scm import get_version from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME @@ -479,9 +478,7 @@ def get_gaudi_sw_version(): def get_vllm_version() -> str: - version = get_version( - write_to="vllm/_version.py", # TODO: move this to pyproject.toml - ) + version = "0.7.2.1" sep = "+" if "+" not in version else "." # dev versions might contain + diff --git a/tests/accuracy/test_lm_eval_correctness.py b/tests/accuracy/test_lm_eval_correctness.py new file mode 100644 index 000000000000..c3c5f8a04d8d --- /dev/null +++ b/tests/accuracy/test_lm_eval_correctness.py @@ -0,0 +1,150 @@ +import itertools +import os +from pathlib import Path +from typing import TYPE_CHECKING + +import nltk +import numpy +import pandas as pd +import pytest +import yaml + +if TYPE_CHECKING: + import lm_eval as lm_eval_t + +# requires a particular lm-evaluation-harness +# pip install lm_eval==0.4.3 +lm_eval: "lm_eval_t" = pytest.importorskip("lm_eval", + reason="lm_eval required") + +MAX_MODEL_LEN = 4096 +RTOL = 0.040 +TEST_DATA_PATH = os.environ.get( + "LM_EVAL_TEST_DATA_FILE", + "../neuralmagic/lm-eval-configs/models/Meta-Llama-3-8B-Instruct.yaml") +# just show the test data file from the `neuralmagic/lm-eval-configs/models` +# directory. this could be a `model.yaml`, or a `leaderboard/model.yaml` +TEST_DATA_FILE = str(Path(TEST_DATA_PATH)).replace( + str(Path.cwd() / "../neuralmagic/lm-eval-configs/models"), "") + + +def launch_lm_eval(eval_config, tp_size): + model_args = { + "pretrained": eval_config['model_name'], + } + eval_config_model_args = eval_config.get('model_args') + if eval_config_model_args: + model_args.update(eval_config_model_args) + + model_backend = eval_config.get("backend", "vllm") + + if model_backend == "vllm": + model_args.update({ + "tensor_parallel_size": tp_size, + "distributed_executor_backend": "ray", + "max_model_len": MAX_MODEL_LEN + }) + + evaluate_args = { + "model": model_backend, + "model_args": ",".join([f"{k}={v}" for k, v in model_args.items()]), + "tasks": [task["name"] for task in eval_config["tasks"]], + "num_fewshot": eval_config["num_fewshot"], + "batch_size": "auto" + } + if "limit" in eval_config: + evaluate_args["limit"] = eval_config["limit"] + if "fewshot_as_multiturn" in eval_config: + evaluate_args["fewshot_as_multiturn"] = eval_config[ + "fewshot_as_multiturn"] + if "apply_chat_template" in eval_config: + evaluate_args["apply_chat_template"] = eval_config[ + "apply_chat_template"] + + simple_eval_args = ['{}={}'.format(k, v) for k, v in evaluate_args.items()] + print(f"lm_eval.simple_evaluate({', '.join(simple_eval_args)}") + results = lm_eval.simple_evaluate(**evaluate_args) + + return results + + +# pass the TEST_DATA_FILE in as a parameter so that the results +# are uniquely reported to TestMo +@pytest.mark.parametrize("test_data_file", [TEST_DATA_FILE]) +def test_lm_eval_correctness(num_gpus_available, test_data_file): + eval_config = yaml.safe_load( + Path(TEST_DATA_PATH).read_text(encoding="utf-8")) + eval_config_tasks = { + t['name']: {m['name']: m['value'] + for m in t['metrics']} + for t in eval_config["tasks"] + } + # identify unique metrics we wish to report on. + eval_config_metrics = set( + itertools.chain.from_iterable([ + metric.keys() for metric in + [eval_config_tasks[task] for task in eval_config_tasks] + ])) + + # retrieve the ground truth values from the evaluation config + # we transpose the info into a set of records indexed by + # a "task" and "metric". The `dropna()` is necessary to remove extra + # rows where there is no ground truth value for the "task" and "metric" + ground_truth_df = pd.DataFrame.from_records( + eval_config_tasks, index=eval_config_metrics).transpose() + gt_listing_df = ground_truth_df.reset_index(names="task").melt( + id_vars="task", var_name="metric", + value_name="ground_truth").dropna().set_index(["task", "metric"]) + + # the ifeval task requires an additional set of data + if "leaderboard_ifeval" in [task["name"] for task in eval_config["tasks"]]: + nltk.download('punkt_tab') + + # Launch eval requests. + results = launch_lm_eval(eval_config, tp_size=num_gpus_available) + + # process the results into a dataframe that looks like the ground truth + # with records indexed by "task" and "metric", but with the measured value + # for each index. + results_df = pd.DataFrame.from_records( + results["results"], index=eval_config_metrics).transpose() + r_listing_df = (results_df.reset_index(names="task").melt( + id_vars="task", var_name="metric", + value_name="measured").dropna().set_index(["task", "metric"])) + + # present the results + # combine the ground truth and results into a single dataframe + # but eliminate any rows that do not have both values + # (This could happen if the eval_config includes a measure that's not + # generated, or if the LM Evaluation harness generates a measure that + # was not requested by the eval_config.) + comparing_metrics_df = pd.concat( + [gt_listing_df, r_listing_df], + axis="columns").reset_index(names=["task", "metric"]).dropna() + + # Add a column with the relative tolerance level for the task + task_rtol_map = { + t["name"]: t.get("rtol", RTOL) + for t in eval_config["tasks"] + } + comparing_metrics_df.loc[:, "rtol"] = comparing_metrics_df.apply( + lambda metric: task_rtol_map[metric.task], axis=1) + + # and determine if measured is close to ground truth + comparing_metrics_df.loc[:, "isclose"] = comparing_metrics_df.apply( + lambda metric: numpy.isclose( + metric.ground_truth, metric.measured, rtol=metric.rtol), + axis=1) + print("==== LM EVAL RESULT ====\n") + comparing_metrics_df.sort_values(by=["task", "metric"], inplace=True) + print(comparing_metrics_df.to_markdown(index=False)) + + # save the results for later summary + llm_results_md = Path("llmeval_results-" + + TEST_DATA_FILE.replace("/", "-")).with_suffix(".md") + llm_results_md.write_text( + f"## {eval_config['model_name']}\n" + f"{comparing_metrics_df.to_markdown(index=False)}\n") + + # fail if any scores fail to match ground truth. + assert comparing_metrics_df.loc[:, "isclose"].all() diff --git a/tests/entrypoints/openai/correctness/__init__.py b/tests/entrypoints/openai/correctness/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/correctness/test_lmeval.py similarity index 98% rename from tests/entrypoints/openai/test_accuracy.py rename to tests/entrypoints/openai/correctness/test_lmeval.py index df25780cd0f4..ebb2ea4d9d14 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/correctness/test_lmeval.py @@ -13,7 +13,7 @@ from vllm.platforms import current_platform -from ...utils import RemoteOpenAIServer +from ....utils import RemoteOpenAIServer MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct" NUM_CONCURRENT = 500 diff --git a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py new file mode 100644 index 000000000000..19d4735b9dde --- /dev/null +++ b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py @@ -0,0 +1,166 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Evaluate Transcription API correctness by computing Word Error Rate (WER) +on a given ASR dataset. When provided, it will also compare the WER against +a baseline. +This simulates real work usage of the API and makes sure that the frontend and +AsyncLLMEngine are working correctly. +""" +import asyncio +import io +import time +from statistics import mean, median +from typing import List + +import librosa +import pytest +import soundfile +import torch +from datasets import load_dataset +from evaluate import load +from transformers import AutoTokenizer + +from ....utils import RemoteOpenAIServer + + +def to_bytes(y, sr): + buffer = io.BytesIO() + soundfile.write(buffer, y, sr, format="WAV") + buffer.seek(0) + return buffer + + +async def transcribe_audio(client, tokenizer, y, sr): + # Send loaded audio directly instead of loading from disk, + # dont account for that time though + with to_bytes(y, sr) as f: + start_time = time.perf_counter() + transcription = await client.audio.transcriptions.create( + file=f, + model=tokenizer.name_or_path, + language="en", + temperature=0.0, + ) + end_time = time.perf_counter() + # NOTE there's no streaming in transcriptions, can't measure ttft + latency = end_time - start_time + num_output_tokens = len( + tokenizer(transcription.text, add_special_tokens=False).input_ids) + return latency, num_output_tokens, transcription.text + + +async def bound_transcribe(model_name, sem, client, audio, reference): + tokenizer = AutoTokenizer.from_pretrained(model_name) + # Use semaphore to limit concurrent requests. + async with sem: + result = await transcribe_audio(client, tokenizer, *audio) + # Normalize *english* output/reference for evaluation. + out = tokenizer.normalize(result[2]) + ref = tokenizer.normalize(reference) + return result[:2] + (out, ref) + + +async def process_dataset(model, client, data, concurrent_request): + sem = asyncio.Semaphore(concurrent_request) + + # Warmup call as the first `librosa.load` server-side is quite slow. + audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"] + _ = await bound_transcribe(model, sem, client, (audio, sr), "") + + tasks: List[asyncio.Task] = [] + for sample in data: + audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"] + task = asyncio.create_task( + bound_transcribe(model, sem, client, (audio, sr), sample["text"])) + tasks.append(task) + return await asyncio.gather(*tasks) + + +def print_performance_metrics(results, total_time): + latencies = [res[0] for res in results] + total_tokens = sum([res[1] for res in results]) + + total = len(results) + print(f"Total Requests: {total}") + print(f"Successful Requests: {len(latencies)}") + print(f"Average Latency: {mean(latencies):.4f} seconds") + print(f"Median Latency: {median(latencies):.4f} seconds") + perc = sorted(latencies)[int(len(latencies) * 0.95) - 1] + print(f"95th Percentile Latency: {perc:.4f} seconds") + # Throughput + req_throughput = len(latencies) / total_time + print(f"Estimated req_Throughput: {req_throughput:.2f} requests/s") + throughput = total_tokens / total_time + print(f"Estimated Throughput: {throughput:.2f} tok/s") + + +def add_duration(sample): + y, sr = sample['audio']["array"], sample['audio']["sampling_rate"] + sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000 + return sample + + +def load_hf_dataset(dataset_repo: str, split='validation', **hf_kwargs): + ## Load and filter the dataset + dataset = load_dataset(dataset_repo, split=split, **hf_kwargs) + if 'duration_ms' not in dataset[0]: + # compute duration to filter + dataset = dataset.map(add_duration) + + # Whisper max supported duration + dataset = dataset.filter(lambda example: example['duration_ms'] < 30000) + return dataset + + +def run_evaluation(model: str, + client, + dataset, + max_concurrent_reqs: int, + n_examples: int = -1, + print_metrics: bool = True): + if n_examples > 0: + dataset = dataset.select(range(n_examples)) + start = time.perf_counter() + results = asyncio.run( + process_dataset(model, client, dataset, max_concurrent_reqs)) + end = time.perf_counter() + total_time = end - start + print(f"Total Test Time: {total_time:.4f} seconds") + if print_metrics: + print_performance_metrics(results, total_time) + # Compute WER + predictions = [res[2] for res in results] + references = [res[3] for res in results] + wer = load("wer") + wer_score = 100 * wer.compute(references=references, + predictions=predictions) + print("WER:", wer_score) + return wer_score + + +# alternatives "openai/whisper-large-v2", "openai/whisper-large-v3-turbo".. +@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"]) +# Original dataset is 20GB+ in size, hence we use a pre-filtered slice. +@pytest.mark.parametrize( + "dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"]) +# NOTE: Expected WER measured with equivalent hf.transformers args: +# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered. +@pytest.mark.parametrize("expected_wer", [12.744980]) +def test_wer_correctness(model_name, + dataset_repo, + expected_wer, + n_examples=-1, + max_concurrent_request=None): + with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server: + dataset = load_hf_dataset(dataset_repo) + + if not max_concurrent_request: + # No max concurrency + max_concurrent_request = n_examples if n_examples > 0\ + else len(dataset) + + client = remote_server.get_async_client() + wer = run_evaluation(model_name, client, dataset, + max_concurrent_request, n_examples) + if expected_wer: + torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2) diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py new file mode 100644 index 000000000000..5d4a5de4badd --- /dev/null +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 + +# imports for guided decoding tests +import io +import json + +import librosa +import numpy as np +import openai +import pytest +import soundfile as sf + +from vllm.assets.audio import AudioAsset + +from ...utils import RemoteOpenAIServer + + +@pytest.fixture +def mary_had_lamb(): + path = AudioAsset('mary_had_lamb').get_local_path() + with open(str(path), "rb") as f: + yield f + + +@pytest.fixture +def winning_call(): + path = AudioAsset('winning_call').get_local_path() + with open(str(path), "rb") as f: + yield f + + +@pytest.mark.asyncio +async def test_basic_audio(mary_had_lamb): + model_name = "openai/whisper-large-v3-turbo" + server_args = ["--enforce-eager"] + # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb. + prompt = "THE FIRST WORDS I SPOKE" + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + transcription = await client.audio.transcriptions.create( + model=model_name, + file=mary_had_lamb, + language="en", + response_format="text", + temperature=0.0) + out = json.loads(transcription)['text'] + assert "Mary had a little lamb," in out + # This should "force" whisper to continue prompt in all caps + transcription_wprompt = await client.audio.transcriptions.create( + model=model_name, + file=mary_had_lamb, + language="en", + response_format="text", + prompt=prompt, + temperature=0.0) + out_capital = json.loads(transcription_wprompt)['text'] + assert prompt not in out_capital + + +@pytest.mark.asyncio +async def test_bad_requests(mary_had_lamb): + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + + # invalid language + with pytest.raises(openai.BadRequestError): + await client.audio.transcriptions.create(model=model_name, + file=mary_had_lamb, + language="hh", + temperature=0.0) + + # Expect audio too long: repeat the timeseries + mary_had_lamb.seek(0) + audio, sr = librosa.load(mary_had_lamb) + repeated_audio = np.tile(audio, 10) + # Repeated audio to buffer + buffer = io.BytesIO() + sf.write(buffer, repeated_audio, sr, format='WAV') + buffer.seek(0) + with pytest.raises(openai.BadRequestError): + await client.audio.transcriptions.create(model=model_name, + file=buffer, + language="en", + temperature=0.0) + + +@pytest.mark.asyncio +async def test_non_asr_model(winning_call): + # text to text model + model_name = "JackFram/llama-68m" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + res = await client.audio.transcriptions.create(model=model_name, + file=winning_call, + language="en", + temperature=0.0) + assert res.code == 400 and not res.text + assert res.message == "The model does not support Transcriptions API" + + +@pytest.mark.asyncio +async def test_completion_endpoints(): + # text to text model + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + res = await client.chat.completions.create( + model=model_name, + messages=[{ + "role": "system", + "content": "You are a helpful assistant." + }]) + assert res.code == 400 + assert res.message == "The model does not support Chat Completions API" + + res = await client.completions.create(model=model_name, prompt="Hello") + assert res.code == 400 + assert res.message == "The model does not support Completions API" diff --git a/tests/test_config.py b/tests/test_config.py index 2dfae218b47d..3fb83b4c0328 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -17,6 +17,7 @@ ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"), ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "score"), ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"), + ("openai/whisper-small", "transcription", "transcription"), ], ) def test_auto_task(model_id, expected_runner_type, expected_task): diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py index d9e51082e6ca..0203dc092a71 100644 --- a/vllm/assets/audio.py +++ b/vllm/assets/audio.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass +from pathlib import Path from typing import Literal from urllib.parse import urljoin @@ -28,6 +29,10 @@ def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]: s3_prefix=ASSET_DIR) return librosa.load(audio_path, sr=None) + def get_local_path(self) -> Path: + return get_vllm_public_assets(filename=f"{self.name}.ogg", + s3_prefix=ASSET_DIR) + @property def url(self) -> str: return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg") diff --git a/vllm/config.py b/vllm/config.py index 9ba497576124..561569162ed8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -54,17 +54,18 @@ _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", - "score", "reward"] + "score", "reward", "transcription"] _ResolvedTask = Literal["generate", "embed", "classify", "score", "reward", - "draft"] + "draft", "transcription"] -RunnerType = Literal["generate", "pooling", "draft"] +RunnerType = Literal["generate", "pooling", "draft", "transcription"] _RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = { "generate": ["generate"], "pooling": ["embed", "classify", "score", "reward"], "draft": ["draft"], + "transcription": ["transcription"], } _TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = { @@ -483,6 +484,8 @@ def _get_preferred_task( return "embed" if ModelRegistry.is_cross_encoder_model(architectures): return "score" + if ModelRegistry.is_transcription_model(architectures): + return "transcription" suffix_to_preferred_task: List[Tuple[str, _ResolvedTask]] = [ # Other models follow this pattern @@ -515,6 +518,8 @@ def _resolve_task( runner_support: Dict[RunnerType, bool] = { # NOTE: Listed from highest to lowest priority, # in case the model supports multiple of them + "transcription": + ModelRegistry.is_transcription_model(architectures), "generate": ModelRegistry.is_text_generation_model(architectures), "pooling": ModelRegistry.is_pooling_model(architectures), } diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b8f54d6c7804..781cff350529 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -17,10 +17,10 @@ from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union +from typing import Annotated, AsyncIterator, Dict, Optional, Set, Tuple, Union import uvloop -from fastapi import APIRouter, FastAPI, HTTPException, Request +from fastapi import APIRouter, FastAPI, Form, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -62,6 +62,8 @@ ScoreRequest, ScoreResponse, TokenizeRequest, TokenizeResponse, + TranscriptionRequest, + TranscriptionResponse, UnloadLoraAdapterRequest) from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager # yapf: enable @@ -76,6 +78,8 @@ from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) +from vllm.entrypoints.openai.serving_transcription import ( + OpenAIServingTranscription) from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.utils import with_cancellation from vllm.logger import init_logger @@ -319,6 +323,10 @@ def tokenization(request: Request) -> OpenAIServingTokenization: return request.app.state.openai_serving_tokenization +def transcription(request: Request) -> OpenAIServingTranscription: + return request.app.state.openai_serving_transcription + + def engine_client(request: Request) -> EngineClient: return request.app.state.engine_client @@ -511,6 +519,31 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): return await create_score(request, raw_request) +@router.post("/v1/audio/transcriptions") +@with_cancellation +async def create_transcriptions(request: Annotated[TranscriptionRequest, + Form()], + raw_request: Request): + + handler = transcription(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Transcriptions API") + + audio_data = await request.file.read() + generator = await handler.create_transcription(audio_data, request, + raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + + elif isinstance(generator, TranscriptionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + @router.post("/rerank") @with_cancellation async def do_rerank(request: RerankRequest, raw_request: Request): @@ -821,6 +854,12 @@ async def init_app_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, ) + state.openai_serving_transcription = OpenAIServingTranscription( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + ) if model_config.runner_type == "transcription" else None state.task = model_config.task diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 83b841826231..4bd52df2a821 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -8,9 +8,10 @@ from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union import torch +from fastapi import UploadFile from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, ValidationInfo, field_validator, model_validator) -from typing_extensions import Annotated +from typing_extensions import Annotated, TypeAlias from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.logger import init_logger @@ -1426,3 +1427,163 @@ class LoadLoraAdapterRequest(BaseModel): class UnloadLoraAdapterRequest(BaseModel): lora_name: str lora_int_id: Optional[int] = Field(default=None) + + +## Protocols for Audio +AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", + "vtt"] + + +class TranscriptionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/audio/createTranscription + + file: UploadFile + """ + The audio file object (not file name) to transcribe, in one of these + formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. + """ + + model: str + """ID of the model to use. + """ + + language: Optional[str] = None + """The language of the input audio. + + Supplying the input language in + [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format + will improve accuracy and latency. + """ + + prompt: str = Field(default="") + """An optional text to guide the model's style or continue a previous audio + segment. + + The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting) + should match the audio language. + """ + + response_format: AudioResponseFormat = Field(default="json") + """ + The format of the output, in one of these options: `json`, `text`, `srt`, + `verbose_json`, or `vtt`. + """ + + ## TODO (varun) : Support if set to 0, certain thresholds are met !! + temperature: float = Field(default=0.0) + """The sampling temperature, between 0 and 1. + + Higher values like 0.8 will make the output more random, while lower values + like 0.2 will make it more focused / deterministic. If set to 0, the model + will use [log probability](https://en.wikipedia.org/wiki/Log_probability) + to automatically increase the temperature until certain thresholds are hit. + """ + + timestamp_granularities: List[Literal["word", "segment"]] = Field( + alias="timestamp_granularities[]", default=[]) + """The timestamp granularities to populate for this transcription. + + `response_format` must be set `verbose_json` to use timestamp granularities. + Either or both of these options are supported: `word`, or `segment`. Note: + There is no additional latency for segment timestamps, but generating word + timestamps incurs additional latency. + """ + + # Default sampling parameters for transcription requests. + _DEFAULT_SAMPLING_PARAMS: dict = { + "temperature": 0, + } + + def to_sampling_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None) -> SamplingParams: + # TODO(#9845): remove max_tokens when field is removed from OpenAI API + max_tokens = default_max_tokens + + if default_sampling_params is None: + default_sampling_params = {} + # Default parameters + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + + return SamplingParams.from_optional(temperature=temperature, + max_tokens=max_tokens) + + +# Transcription response objects +class TranscriptionResponse(OpenAIBaseModel): + text: str + """The transcribed text.""" + + +class TranscriptionWord(OpenAIBaseModel): + end: float + """End time of the word in seconds.""" + + start: float + """Start time of the word in seconds.""" + + word: str + """The text content of the word.""" + + +class TranscriptionSegment(OpenAIBaseModel): + id: int + """Unique identifier of the segment.""" + + avg_logprob: float + """Average logprob of the segment. + + If the value is lower than -1, consider the logprobs failed. + """ + + compression_ratio: float + """Compression ratio of the segment. + + If the value is greater than 2.4, consider the compression failed. + """ + + end: float + """End time of the segment in seconds.""" + + no_speech_prob: float + """Probability of no speech in the segment. + + If the value is higher than 1.0 and the `avg_logprob` is below -1, consider + this segment silent. + """ + + seek: int + """Seek offset of the segment.""" + + start: float + """Start time of the segment in seconds.""" + + temperature: float + """Temperature parameter used for generating the segment.""" + + text: str + """Text content of the segment.""" + + tokens: List[int] + """Array of token IDs for the text content.""" + + +class TranscriptionResponseVerbose(OpenAIBaseModel): + duration: str + """The duration of the input audio.""" + + language: str + """The language of the input audio.""" + + text: str + """The transcribed text.""" + + segments: Optional[List[TranscriptionSegment]] = None + """Segments of the transcribed text and their corresponding details.""" + + words: Optional[List[TranscriptionWord]] = None + """Extracted words and their corresponding timestamps.""" diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8d39fdcb7483..d4ce4e91251a 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -31,7 +31,8 @@ ErrorResponse, RerankRequest, ScoreRequest, TokenizeChatRequest, - TokenizeCompletionRequest) + TokenizeCompletionRequest, + TranscriptionRequest) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser # yapf: enable @@ -57,7 +58,8 @@ ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, TokenizeChatRequest] -AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest] +AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, + TranscriptionRequest] class TextTokensPrompt(TypedDict): diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py new file mode 100644 index 000000000000..da4930e0e2d8 --- /dev/null +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -0,0 +1,305 @@ +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import io +from typing import AsyncGenerator, Optional, Union, cast + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (ErrorResponse, + RequestResponseMetadata, + TranscriptionRequest, + TranscriptionResponse, + TranscriptionResponseVerbose) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import PromptType +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.utils import PlaceholderModule + +try: + import librosa +except ImportError: + librosa = PlaceholderModule("librosa") # type: ignore[assignment] + +logger = init_logger(__name__) + +# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages#supported-languages +# TODO these configs should live somewhere with the model so we can support +# additional ones + +ISO639_1_SUPPORTED_LANGS = { + "af": "Afrikaans", + "ar": "Arabic", + "hy": "Armenian", + "az": "Azerbaijani", + "be": "Belarusian", + "bs": "Bosnian", + "bg": "Bulgarian", + "ca": "Catalan", + "zh": "Chinese", + "hr": "Croatian", + "cs": "Czech", + "da": "Danish", + "nl": "Dutch", + "en": "English", + "et": "Estonian", + "fi": "Finnish", + "fr": "French", + "gl": "Galician", + "de": "German", + "el": "Greek", + "he": "Hebrew", + "hi": "Hindi", + "hu": "Hungarian", + "is": "Icelandic", + "id": "Indonesian", + "it": "Italian", + "ja": "Japanese", + "kn": "Kannada", + "kk": "Kazakh", + "ko": "Korean", + "lv": "Latvian", + "lt": "Lithuanian", + "mk": "Macedonian", + "ms": "Malay", + "mr": "Marathi", + "mi": "Maori", + "ne": "Nepali", + "no": "Norwegian", + "fa": "Persian", + "pl": "Polish", + "pt": "Portuguese", + "ro": "Romanian", + "ru": "Russian", + "sr": "Serbian", + "sk": "Slovak", + "sl": "Slovenian", + "es": "Spanish", + "sw": "Swahili", + "sv": "Swedish", + "tl": "Tagalog", + "ta": "Tamil", + "th": "Thai", + "tr": "Turkish", + "uk": "Ukrainian", + "ur": "Urdu", + "vi": "Vietnamese", + "cy": "Welsh" +} +ISO639_1_OTHER_LANGS = { + "lo": "Lao", + "jw": "Javanese", + "tk": "Turkmen", + "yi": "Yiddish", + "so": "Somali", + "bn": "Bengali", + "nn": "Norwegian Nynorsk", + "si": "Sinhala", + "yo": "Yoruba", + "sa": "Sanskrit", + "mi": "Māori", + "fo": "Faroese", # codespell:ignore + "mt": "Maltese", + "tg": "Tajik", + "mg": "Malagasy", + "haw": "Hawaiian", + "km": "Khmer", + "br": "Breton", + "ps": "Pashto", + "ln": "Lingala", + "la": "Latin", + "ml": "Malayalam", + "sq": "Albanian", + "su": "Sundanese", + "eu": "Basque", + "ka": "Georgian", + "uz": "Uzbek", + "sn": "Shona", + "ht": "Haitian", + "as": "Assamese", + "mn": "Mongolian", + "te": "Telugu", + "pa": "Panjabi", + "tt": "Tatar", + "gu": "Gujarati", + "oc": "Occitan", + "ha": "Hausa", + "ba": "Bashkir", + "my": "Burmese", + "sd": "Sindhi", + "am": "Amharic", + "lb": "Luxembourgish", + "bo": "Tibetan" +} + +# As per https://platform.openai.com/docs/guides/speech-to-text#overview. +# TODO configurable +MAX_AUDIO_CLIP_FILESIZE_MB = 25 +# TODO get from processor.feature_extractor.chunk_length +MAX_AUDIO_CLIP_DURATION_S = 30 + + +class OpenAIServingTranscription(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + ): + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids) + + diff_sampling_param = self.model_config.get_diff_sampling_param() + if diff_sampling_param: + logger.info( + "Overwriting default completion sampling param with: %s", + diff_sampling_param) + + async def _preprocess_transcription( + self, + request: TranscriptionRequest, + audio_data: bytes, + ) -> PromptType: + # Validate request + # TODO language should be optional and can be guessed. + # For now we default to en. See + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520 + lang_token = f"<|{request.language}|>" if request.language else "<|en|>" + if request.language: + if request.language in ISO639_1_SUPPORTED_LANGS: + pass + elif request.language in ISO639_1_OTHER_LANGS: + logger.warning( + "The selected language %s has limited accuracy with" + " reported WER>=0.5. Results may be less accurate " + "for this choice.", request.language) + else: + raise ValueError( + f"Unsupported language: {request.language}." + "Language should be one of:" + + f" {list(ISO639_1_SUPPORTED_LANGS.values())}" + + f"or {list(ISO639_1_OTHER_LANGS.values())}") + + if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB: + raise ValueError("Maximum file size exceeded.") + + with io.BytesIO(audio_data) as bytes_: + y, sr = librosa.load(bytes_) + if librosa.get_duration(y=y, sr=sr) > MAX_AUDIO_CLIP_DURATION_S: + raise ValueError( + f"Maximum clip duration ({MAX_AUDIO_CLIP_DURATION_S}s) " + "exceeded.") + + prompt = { + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": (y, sr), + }, + }, + "decoder_prompt": + f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}" + } + return cast(PromptType, prompt) + + # TODO (varun) : Make verbose response work ! + async def create_transcription( + self, audio_data: bytes, request: TranscriptionRequest, + raw_request: Request + ) -> Union[TranscriptionResponse, TranscriptionResponseVerbose, + ErrorResponse]: + """Transcription API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/audio/createTranscription + for the API specification. This API mimics the OpenAI transcription API. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + if request.response_format not in ['text', 'json']: + return self.create_error_response( + "Currently only support response_format `text` or `json`") + + # TODO cmpl->transcription? + request_id = f"cmpl-{self._base_request_id(raw_request)}" + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + if lora_request: + return self.create_error_response( + "Currently do not support LoRA for Transcription.") + if prompt_adapter_request: + return self.create_error_response( + "Currently do not support PromptAdapter for Transcription." + ) + + prompt = await self._preprocess_transcription( + request=request, + audio_data=audio_data, + ) + + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + result_generator: Optional[AsyncGenerator[RequestOutput, None]] = None + try: + # TODO(rob): subtract len of tokenized prompt. + default_max_tokens = self.model_config.max_model_len + default_params = self.model_config.get_diff_sampling_param() + sampling_params = request.to_sampling_params( + default_max_tokens, default_params) + + self._log_inputs( + request_id, + prompt['decoder_prompt'], # type: ignore + params=sampling_params, + lora_request=None, + prompt_adapter_request=None) + + result_generator = self.engine_client.generate( + prompt, + sampling_params, + request_id, + ) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + # TODO(rob): figure out a way to pipe streaming in. + # Non-streaming response. + try: + async for op in result_generator: + result = op + return TranscriptionResponse(text=result.outputs[0].text) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 0fc5c4db179c..a0a1b69ad502 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -441,3 +441,30 @@ def supports_cross_encoding( model: Union[Type[object], object], ) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: return is_pooling_model(model) and _supports_cross_encoding(model) + + +@runtime_checkable +class SupportsTranscription(Protocol): + """The interface required for all models that support transcription.""" + + supports_transcription: ClassVar[Literal[True]] = True + + +@overload +def supports_transcription( + model: Type[object]) -> TypeIs[Type[SupportsTranscription]]: + ... + + +@overload +def supports_transcription(model: object) -> TypeIs[SupportsTranscription]: + ... + + +def supports_transcription( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsTranscription]], TypeIs[SupportsTranscription]]: + if isinstance(model, type): + return isinstance(model, SupportsTranscription) + + return isinstance(model, SupportsTranscription) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 3b2a7069efc9..c3f1dd36fc1b 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -22,7 +22,7 @@ from .interfaces import (has_inner_state, is_attention_free, is_hybrid, supports_cross_encoding, supports_multimodal, - supports_pp) + supports_pp, supports_transcription) from .interfaces_base import is_text_generation_model logger = init_logger(__name__) @@ -212,6 +212,7 @@ class _ModelInfo: has_inner_state: bool is_attention_free: bool is_hybrid: bool + supports_transcription: bool @staticmethod def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": @@ -225,7 +226,7 @@ def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), is_hybrid=is_hybrid(model), - ) + supports_transcription=supports_transcription(model)) class _BaseRegisteredModel(ABC): @@ -473,6 +474,13 @@ def is_hybrid_model( model_cls, _ = self.inspect_model_cls(architectures) return model_cls.is_hybrid + def is_transcription_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_transcription + ModelRegistry = _ModelRegistry({ model_arch: diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 0a3011d36101..0b506072094e 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -31,7 +31,7 @@ from vllm.sequence import SequenceData from vllm.transformers_utils.processor import cached_get_processor -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsTranscription from .utils import AutoWeightsLoader, WeightsMapper, make_layers logger = init_logger(__name__) @@ -637,7 +637,8 @@ def input_mapper_for_whisper( @MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper) @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( "audio", get_max_whisper_audio_tokens) -class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal): +class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, + SupportsMultiModal): packed_modules_mapping = { "self_attn.qkv_proj": [ "self_attn.q_proj", diff --git a/vllm/version.py b/vllm/version.py index 70cd0289b441..233e9fe5f7b4 100644 --- a/vllm/version.py +++ b/vllm/version.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 try: - from ._version import __version__, __version_tuple__ + __version__ = "0.7.2.1" + __version_tuple__ = (0, 7, 2, 1) + except Exception as e: import warnings