Skip to content

Commit 789562c

Browse files
Support CUTLASS NVFP4 (w4a4) for Blackwell Geforce GPUs (SM120) (#21309)
Signed-off-by: LopezCastroRoberto <[email protected]>
1 parent 3f36c32 commit 789562c

File tree

6 files changed

+329
-13
lines changed

6 files changed

+329
-13
lines changed

CMakeLists.txt

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,25 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
529529
endif()
530530
endif()
531531

532+
# The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require
533+
# CUDA 12.8 or later
534+
cuda_archs_loose_intersection(FP4_ARCHS "12.0;12.0a" "${CUDA_ARCHS}")
535+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
536+
set(SRCS
537+
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
538+
"csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu")
539+
set_gencode_flags_for_srcs(
540+
SRCS "${SRCS}"
541+
CUDA_ARCHS "${FP4_ARCHS}")
542+
list(APPEND VLLM_EXT_SRC "${SRCS}")
543+
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
544+
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
545+
else()
546+
message(STATUS "Not building NVFP4 as no compatible archs were found.")
547+
# clear FP4_ARCHS
548+
set(FP4_ARCHS)
549+
endif()
550+
532551
# FP4 Archs and flags
533552
cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}")
534553
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
@@ -541,7 +560,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
541560
SRCS "${SRCS}"
542561
CUDA_ARCHS "${FP4_ARCHS}")
543562
list(APPEND VLLM_EXT_SRC "${SRCS}")
544-
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4=1")
563+
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
545564
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
546565
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
547566
else()

csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ void run_fp4_blockwise_scaled_group_mm(
335335
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
336336
}
337337

338-
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
338+
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
339339
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
340340
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
341341
#endif
@@ -356,7 +356,7 @@ void cutlass_fp4_group_mm(
356356
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
357357
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
358358
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
359-
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
359+
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
360360
// Input validation
361361
CHECK_INPUT(a, FLOAT4_E2M1X2, "a");
362362
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
@@ -398,7 +398,7 @@ void cutlass_fp4_group_mm(
398398
TORCH_CHECK_NOT_IMPLEMENTED(
399399
false,
400400
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
401-
"be compiled with ENABLE_NVFP4 for SM100+ and CUDA "
401+
"be compiled with ENABLE_NVFP4_SM100 for SM100+ and CUDA "
402402
"12.8 or above.");
403403
#endif
404404
}

csrc/quantization/fp4/nvfp4_quant_entry.cu

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616

1717
#include <torch/all.h>
1818

19-
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
20-
void scaled_fp4_quant_sm100a(torch::Tensor const& output,
19+
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
20+
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
21+
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
2122
torch::Tensor const& input,
2223
torch::Tensor const& output_sf,
2324
torch::Tensor const& input_sf);
2425
#endif
2526

26-
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
27+
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
2728
void scaled_fp4_experts_quant_sm100a(
2829
torch::Tensor& output, torch::Tensor& output_scale,
2930
torch::Tensor const& input, torch::Tensor const& input_global_scale,
@@ -33,8 +34,9 @@ void scaled_fp4_experts_quant_sm100a(
3334

3435
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
3536
torch::Tensor& output_sf, torch::Tensor const& input_sf) {
36-
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
37-
return scaled_fp4_quant_sm100a(output, input, output_sf, input_sf);
37+
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
38+
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
39+
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf);
3840
#endif
3941
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel");
4042
}
@@ -44,7 +46,7 @@ void scaled_fp4_experts_quant(
4446
torch::Tensor const& input, torch::Tensor const& input_global_scale,
4547
torch::Tensor const& input_offset_by_experts,
4648
torch::Tensor const& output_scale_offset_by_experts) {
47-
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
49+
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
4850
return scaled_fp4_experts_quant_sm100a(
4951
output, output_scale, input, input_global_scale, input_offset_by_experts,
5052
output_scale_offset_by_experts);

csrc/quantization/fp4/nvfp4_quant_kernels.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input,
332332
int multiProcessorCount,
333333
cudaStream_t stream);
334334

335-
void scaled_fp4_quant_sm100a(torch::Tensor const& output,
335+
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
336336
torch::Tensor const& input,
337337
torch::Tensor const& output_sf,
338338
torch::Tensor const& input_sf) {

csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,30 @@
1616

1717
#include <torch/all.h>
1818

19-
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
19+
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
2020
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
2121
torch::Tensor const& B,
2222
torch::Tensor const& A_sf,
2323
torch::Tensor const& B_sf,
2424
torch::Tensor const& alpha);
2525
#endif
2626

27+
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
28+
void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
29+
torch::Tensor const& B,
30+
torch::Tensor const& A_sf,
31+
torch::Tensor const& B_sf,
32+
torch::Tensor const& alpha);
33+
#endif
34+
2735
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
2836
torch::Tensor const& B, torch::Tensor const& A_sf,
2937
torch::Tensor const& B_sf,
3038
torch::Tensor const& alpha) {
31-
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
39+
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
3240
return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha);
41+
#elif defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
42+
return cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha);
3343
#endif
3444
TORCH_CHECK_NOT_IMPLEMENTED(false,
3545
"No compiled nvfp4 mm kernel, vLLM should "

0 commit comments

Comments
 (0)