@@ -121,6 +121,9 @@ c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, b
121121 *
122122 * The transpose flags are derived from the layouts of the passed in tensors
123123 *
124+ * If the operands are in packed float4 format, `k`, `lda` and `ldb` are adjusted
125+ * to their unpacked values to match what cuBLAS expects.
126+ *
124127 * @param mat1 First input matrix
125128 * @param mat2 Second input matrix
126129 * @param c Output matrix (result)
@@ -173,6 +176,14 @@ struct cublasCommonArgs {
173176 result_ld = result->stride (transpose_result ? 0 : 1 );
174177 transa = transpose_a ? mata->is_conj () ? ' c' : ' t' : ' n' ;
175178 transb = transpose_b ? matb->is_conj () ? ' c' : ' t' : ' n' ;
179+
180+ // cuBLAS expects unpacked values of `k`, `lda` and `ldb`, adjust for 4x2 packing
181+ // if the gemm operands are in packed float4
182+ if (mat1.dtype () == at::kFloat4_e2m1fn_x2 && mat2.dtype () == at::kFloat4_e2m1fn_x2 ) {
183+ k = k * 2 ;
184+ lda = lda * 2 ;
185+ ldb = ldb * 2 ;
186+ }
176187 }
177188
178189 // Matrix members
@@ -980,7 +991,7 @@ enum class ScalingType : std::uint8_t {
980991 * ---------------------------
981992 * Conditions and corresponding Scaling Types:
982993 *
983- * - If scale tensors are Float8_e8m0fnu:
994+ * - If scale tensors are both ` Float8_e8m0fnu` or `Float8_e4m3fn` :
984995 * - Returns BlockWise (with additional size checks).
985996 *
986997 * - If scale_a.numel() == 1 && scale_b.numel() == 1:
@@ -1001,14 +1012,22 @@ ScalingType get_scaling_type(
10011012 int64_t dim_m,
10021013 int64_t dim_k,
10031014 int64_t dim_n) {
1004- // Check for BlockWise scaling (FP8_E8M0 types)
1005- if (scale_a.scalar_type () == scale_b.scalar_type () &&
1006- scale_a.scalar_type () == at::kFloat8_e8m0fnu ) {
1007- constexpr int64_t BLOCK_SIZE_K = 32 ;
1015+ // Check for BlockWise scaling (FP8_E8M0 and FP8_E4M3 types)
1016+ if ((scale_a.scalar_type () == scale_b.scalar_type ()) &&
1017+ ((scale_a.scalar_type () == at::kFloat8_e8m0fnu ) || (scale_a.scalar_type () == at::kFloat8_e4m3fn ))) {
1018+ const bool is_nvfp4 = scale_a.scalar_type () == at::kFloat8_e4m3fn ;
1019+
1020+ // cuBLAS's mxfp8 gemm: block_size is 1 scale per 32 elements
1021+ // cuBLAS's nvfp4 gemm: block_size is 1 scale per 16 unpacked elements.
1022+ const auto BLOCK_SIZE_K = is_nvfp4 ? 16 : 32 ;
1023+
10081024 constexpr int64_t BLOCK_SIZE_MN = 128 ;
10091025
1026+ // adjust for fp4x2 packing if necessary
1027+ const auto dim_k_unpacked = is_nvfp4 ? dim_k * 2 : dim_k;
1028+
10101029 auto ceil_div = [](auto a, auto b) { return (a + b - 1 ) / b; };
1011- auto num_k_blocks = ceil_div (dim_k , BLOCK_SIZE_K);
1030+ auto num_k_blocks = ceil_div (dim_k_unpacked , BLOCK_SIZE_K);
10121031 auto padded_num_k_blocks = ceil_div (num_k_blocks, 4 ) * 4 ;
10131032
10141033 // TODO: We might want to enforce some structure on the shapes of the scale
@@ -1149,13 +1168,16 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
11491168 mat2.sizes ()[1 ], " ) must be divisible by 16" );
11501169 // Check types
11511170 TORCH_CHECK (!out_dtype || *out_dtype == out.scalar_type (), " out_dtype must match output matrix type" );
1152- TORCH_CHECK (isFloat8Type (mat1.scalar_type ()), " Expected mat1 to be Float8 matrix got " , mat1.scalar_type ());
1153- TORCH_CHECK (isFloat8Type (mat2.scalar_type ()), " Expected mat2 to be Float8 matrix got " , mat2.scalar_type ());
1171+ TORCH_CHECK (isFloat8Type (mat1.scalar_type ()) || mat1. scalar_type () == ScalarType::Float4_e2m1fn_x2 , " Expected mat1 to be Float8 or Float4_x2 matrix got " , mat1.scalar_type ());
1172+ TORCH_CHECK (isFloat8Type (mat2.scalar_type ()) || mat2. scalar_type () == ScalarType::Float4_e2m1fn_x2 , " Expected mat2 to be Float8 or Float4_x2 matrix got " , mat2.scalar_type ());
11541173#ifndef USE_ROCM
11551174 // Type restrictions imposed by CuBLASLt as of CUDA-12.1
11561175 TORCH_CHECK (mat1.scalar_type () != ScalarType::Float8_e5m2 || mat2.scalar_type () != ScalarType::Float8_e5m2,
11571176 " Multiplication of two Float8_e5m2 matrices is not supported" );
11581177#endif
1178+ if (use_fast_accum) {
1179+ TORCH_CHECK (mat1.scalar_type () != ScalarType::Float4_e2m1fn_x2 && mat2.scalar_type () != ScalarType::Float4_e2m1fn_x2, " `use_fast_accum` is not supported when `mat1` or `mat2` tensors have the `Float4_e2m1fn_x2` dtype." );
1180+ }
11591181 if (bias) {
11601182 TORCH_CHECK (out.scalar_type () != kFloat , " Bias is not supported when out_dtype is set to Float32" );
11611183 TORCH_CHECK (bias->scalar_type () == ScalarType::BFloat16 || bias->scalar_type () == ScalarType::Half,
0 commit comments