11#include < cstdint>
2- #include < c10/util/typeid.h>
32#include < c10/util/Exception.h>
43#include < c10/core/Scalar.h>
54#include < c10/core/ScalarType.h>
@@ -96,33 +95,11 @@ c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, b
9695}
9796
9897struct cublasCommonArgs {
99- cublasCommonArgs (
100- const Tensor& mat1,
101- const Tensor& mat2,
102- Tensor& c,
103- const c10::optional<Tensor>& scale_a = c10::nullopt ,
104- const c10::optional<Tensor>& scale_b = c10::nullopt ,
105- const c10::optional<Tensor>& scale_result = c10::nullopt ) {
98+ cublasCommonArgs (const Tensor& mat1, const Tensor& mat2, Tensor& c) {
10699 bool transpose_result = false , transpose_mat1 = false , transpose_mat2 = false ;
107100 result = prepare_matrix_for_cublas (c, transpose_result);
108101 mata = prepare_matrix_for_cublas (transpose_result ? mat2 : mat1, transpose_mat1, transpose_result);
109102 matb = prepare_matrix_for_cublas (transpose_result ? mat1 : mat2, transpose_mat2, transpose_result);
110-
111- // Handle scale tensors if provided
112- if (scale_a && scale_b) {
113- // By default since we return in row-major we run the gemm
114- // as B.T @ A.T, check transpose_result to determine if we flip the scales
115- scale_mata_ptr = transpose_result ? scale_b->data_ptr () : scale_a->data_ptr ();
116- scale_mata_dtype = transpose_result ? scale_b->scalar_type () : scale_a->scalar_type ();
117- scale_matb_ptr = transpose_result ? scale_a->data_ptr () : scale_b->data_ptr ();
118- scale_matb_dtype = transpose_result ? scale_a->scalar_type () : scale_b->scalar_type ();
119- }
120-
121- if (scale_result) {
122- scale_result_ptr = scale_result->data_ptr ();
123- scale_result_dtype = scale_result->scalar_type ();
124- }
125-
126103 auto mat1_sizes = mat1.sizes ();
127104 auto mat2_sizes = mat2.sizes ();
128105 if (transpose_result) {
@@ -138,23 +115,13 @@ struct cublasCommonArgs {
138115 lda = mata->stride ((transpose_mat1 == transpose_result) ? 1 : 0 );
139116 ldb = matb->stride ((transpose_mat2 == transpose_result) ? 1 : 0 );
140117 result_ld = result->stride (transpose_result ? 0 : 1 );
141- transa = transpose_mat1 ? mata->is_conj () ? ' c' : ' t' : ' n' ;
142- transb = transpose_mat2 ? matb->is_conj () ? ' c' : ' t' : ' n' ;
118+ transa = transpose_mat1 ? mata->is_conj () ? ' c' : ' t' : ' n' ;
119+ transb = transpose_mat2 ? matb->is_conj () ? ' c' : ' t' : ' n' ;
143120 }
144-
145- // Matrix members
146121 char transa, transb;
147122 int64_t m, n, k;
148123 int64_t lda, ldb, result_ld;
149124 c10::MaybeOwned<Tensor> mata, matb, result;
150-
151- // Scale members
152- void * scale_mata_ptr = nullptr ;
153- void * scale_matb_ptr = nullptr ;
154- void * scale_result_ptr = nullptr ;
155- c10::optional<c10::ScalarType> scale_mata_dtype;
156- c10::optional<c10::ScalarType> scale_matb_dtype;
157- c10::optional<c10::ScalarType> scale_result_dtype;
158125};
159126} // namespace
160127
@@ -936,24 +903,20 @@ static bool _scaled_mm_is_fnuz() {
936903
937904namespace {
938905
939- enum class ScalingType : std:: uint8_t {
906+ enum class ScalingType {
940907 TensorWise,
941908 RowWise,
942- BlockWise,
943909 Error
944910};
945911/*
946912 * Scaling Type Determination:
947913 * ---------------------------
948914 * Conditions and corresponding Scaling Types:
949915 *
950- * - If scale tensors are Float8_e8m0fnu:
951- * - Returns BlockWise (with additional size checks).
952- *
953916 * - If scale_a.numel() == 1 && scale_b.numel() == 1:
954917 * - Returns TensorWise.
955918 *
956- * - Else if scale_a.dim() == 2 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n:
919+ * - Else if scale_a.dim() == 1 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n:
957920 * - Returns RowWise.
958921 *
959922 * - Otherwise:
@@ -966,40 +929,7 @@ ScalingType get_scaling_type(
966929 const at::Tensor& scale_a,
967930 const at::Tensor& scale_b,
968931 int64_t dim_m,
969- int64_t dim_k,
970932 int64_t dim_n) {
971- // Check for BlockWise scaling (FP8_E8M0 types)
972- if (scale_a.scalar_type () == scale_b.scalar_type () &&
973- scale_a.scalar_type () == at::kFloat8_e8m0fnu ) {
974- constexpr int64_t BLOCK_SIZE_K = 32 ;
975- constexpr int64_t BLOCK_SIZE_MN = 128 ;
976-
977- auto ceil_div = [](auto a, auto b) { return (a + b - 1 ) / b; };
978- auto num_k_blocks = ceil_div (dim_k, BLOCK_SIZE_K);
979- auto padded_num_k_blocks = ceil_div (num_k_blocks, 4 ) * 4 ;
980-
981- // TODO: We might want to enforce some structure on the shapes of the scale
982- // tensors
983-
984- // Check expected sizes for block-wise scaling
985- auto expected_a_size =
986- BLOCK_SIZE_MN * ceil_div (dim_m, BLOCK_SIZE_MN) * padded_num_k_blocks;
987- auto expected_b_size =
988- BLOCK_SIZE_MN * ceil_div (dim_n, BLOCK_SIZE_MN) * padded_num_k_blocks;
989-
990- TORCH_CHECK (scale_a.numel () == expected_a_size,
991- " For BlockWise scaling: Expected scale_a size to be " ,
992- expected_a_size, " but got " , scale_a.numel ());
993- TORCH_CHECK (scale_b.numel () == expected_b_size,
994- " For BlockWise scaling: Expected scale_b size to be " ,
995- expected_b_size, " but got " , scale_b.numel ());
996-
997- TORCH_CHECK (
998- scale_a.is_contiguous () && scale_b.is_contiguous (),
999- " For BlockWise scaling: Both scale_a and scale_b must be contiguous" );
1000-
1001- return ScalingType::BlockWise;
1002- }
1003933 // Both Per-Tensor and Row-wise scaling expect fp32 tensors
1004934 TORCH_CHECK (
1005935 scale_a.scalar_type () == kFloat && scale_b.scalar_type () == kFloat ,
@@ -1097,7 +1027,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
10971027 mat1.sizes ()[0 ], " x" , mat1.sizes ()[1 ], " and " , mat2.sizes ()[0 ], " x" , mat2.sizes ()[1 ], " )" );
10981028
10991029 // Check what type of scaling we are doing based on inputs
1100- ScalingType scaling_choice = get_scaling_type (scale_a, scale_b, mat1.size (0 ), mat1. size ( 1 ), mat2.size (1 ));
1030+ ScalingType scaling_choice = get_scaling_type (scale_a, scale_b, mat1.size (0 ), mat2.size (1 ));
11011031 TORCH_INTERNAL_ASSERT (scaling_choice != ScalingType::Error, " Scaling type not supported" );
11021032
11031033 TORCH_CHECK (!scale_result || (scale_result->numel () == 1 && scale_result->scalar_type () == kFloat ),
@@ -1190,7 +1120,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
11901120 }
11911121#endif
11921122
1193- cublasCommonArgs args (mat1, mat2, out, scale_a, scale_b, scale_result );
1123+ cublasCommonArgs args (mat1, mat2, out);
11941124 const auto out_dtype_ = args.result ->scalar_type ();
11951125 TORCH_CHECK (args.transa == ' t' && args.transb == ' n' , " Only multiplication of row-major and column-major matrices is supported by cuBLASLt" );
11961126
@@ -1300,27 +1230,25 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
13001230 }
13011231 else
13021232#endif
1303- {
1233+ {
13041234 at::cuda::blas::scaled_gemm (
13051235 args.transa ,
13061236 args.transb ,
13071237 args.m ,
13081238 args.n ,
13091239 args.k ,
13101240 args.mata ->data_ptr (),
1311- args. scale_mata_ptr ,
1241+ scale_a. data_ptr () ,
13121242 args.lda ,
13131243 args.mata ->scalar_type (),
1314- args.scale_mata_dtype .value (),
13151244 args.matb ->data_ptr (),
1316- args. scale_matb_ptr ,
1245+ scale_b. data_ptr () ,
13171246 args.ldb ,
13181247 args.matb ->scalar_type (),
1319- args.scale_matb_dtype .value (),
13201248 bias ? bias->data_ptr (): nullptr ,
13211249 bias ? bias->scalar_type () : isFloat8Type (out_dtype_) ? at::ScalarType::Half : out_dtype_,
13221250 args.result ->data_ptr (),
1323- args. scale_result_ptr ,
1251+ scale_result ? scale_result-> data_ptr () : nullptr ,
13241252 args.result_ld ,
13251253 out_dtype_,
13261254 use_fast_accum,
0 commit comments