Skip to content

Commit c82c141

Browse files
Revert "torch._scaled_mm with MXFP8 (pytorch#147548)"
This reverts commit e34c15a. Reverted pytorch#147548 on behalf of https://github.com/wdvr due to failing internal build - discussed with author ([comment](pytorch#147548 (comment)))
1 parent 0633f63 commit c82c141

File tree

7 files changed

+16
-463
lines changed

7 files changed

+16
-463
lines changed

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#include <c10/macros/Export.h>
1515
#include <c10/util/env.h>
1616
#include <c10/util/irange.h>
17-
#include <c10/core/ScalarType.h>
1817

1918
#ifdef USE_ROCM
2019
#include <hipblaslt/hipblaslt-ext.hpp>
@@ -1504,12 +1503,10 @@ void scaled_gemm(
15041503
const void* mat1_scale_ptr,
15051504
int64_t mat1_ld,
15061505
ScalarType mat1_dtype,
1507-
ScalarType mat1_scale_dtype,
15081506
const void* mat2_ptr,
15091507
const void* mat2_scale_ptr,
15101508
int64_t mat2_ld,
15111509
ScalarType mat2_dtype,
1512-
ScalarType mat2_scale_dtype,
15131510
const void* bias_ptr,
15141511
ScalarType bias_dtype,
15151512
void* result_ptr,
@@ -1537,8 +1534,10 @@ void scaled_gemm(
15371534
// rowwise isn't supported using cublaslt or older hipblaslt
15381535
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
15391536
#endif
1540-
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
1541-
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
1537+
{
1538+
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
1539+
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
1540+
}
15421541
if (result_scale_ptr != nullptr) {
15431542
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
15441543
}
@@ -1561,15 +1560,6 @@ void scaled_gemm(
15611560
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
15621561
}
15631562

1564-
if (mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) {
1565-
#if CUDA_VERSION >= 12080
1566-
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
1567-
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
1568-
#else
1569-
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 and above");
1570-
#endif // CUDA_VERSION >= 12080
1571-
}
1572-
15731563
auto stream = c10::cuda::getCurrentCUDAStream();
15741564
size_t workspaceSize = 0;
15751565
auto workspace_ptr = _getWorkspace(workspaceSize);

aten/src/ATen/cuda/CUDABlas.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,10 @@ void scaled_gemm(
130130
const void* mat1_scale_ptr,
131131
int64_t mat1_ld,
132132
ScalarType mat1_dtype,
133-
ScalarType mat1_scale_dtype,
134133
const void* mat2_ptr,
135134
const void* mat2_scale_ptr,
136135
int64_t mat2_ld,
137136
ScalarType mat2_dtype,
138-
ScalarType mat2_scale_dtype,
139137
const void* bias_ptr,
140138
ScalarType bias_dtype,
141139
void* result_ptr,

aten/src/ATen/cuda/tunable/GemmCommon.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#pragma once
1111

1212
#include <string>
13-
#include <c10/core/ScalarType.h>
1413

1514
#include <ATen/cuda/tunable/TunableOp.h>
1615
#include <ATen/cuda/CUDABlas.h>
@@ -425,12 +424,10 @@ struct ScaledGemmParams : OpParams {
425424
const void* a_scale_ptr{};
426425
int64_t lda{};
427426
ScalarType a_dtype{};
428-
ScalarType a_scale_dtype{};
429427
const void* b{};
430428
const void* b_scale_ptr{};
431429
int64_t ldb{};
432430
ScalarType b_dtype{};
433-
ScalarType b_scale_dtype{};
434431
const void* bias_ptr{};
435432
ScalarType bias_dtype{};
436433
void* c{};

aten/src/ATen/cuda/tunable/TunableGemm.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,10 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
9595
params->a_scale_ptr,
9696
params->lda,
9797
params->a_dtype,
98-
params->a_scale_dtype,
9998
params->b,
10099
params->b_scale_ptr,
101100
params->ldb,
102101
params->b_dtype,
103-
params->b_scale_dtype,
104102
params->bias_ptr,
105103
params->bias_dtype,
106104
params->c,

aten/src/ATen/native/cuda/Blas.cpp

Lines changed: 11 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include <cstdint>
2-
#include <c10/util/typeid.h>
32
#include <c10/util/Exception.h>
43
#include <c10/core/Scalar.h>
54
#include <c10/core/ScalarType.h>
@@ -96,33 +95,11 @@ c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, b
9695
}
9796

9897
struct cublasCommonArgs {
99-
cublasCommonArgs(
100-
const Tensor& mat1,
101-
const Tensor& mat2,
102-
Tensor& c,
103-
const c10::optional<Tensor>& scale_a = c10::nullopt,
104-
const c10::optional<Tensor>& scale_b = c10::nullopt,
105-
const c10::optional<Tensor>& scale_result = c10::nullopt) {
98+
cublasCommonArgs(const Tensor& mat1, const Tensor& mat2, Tensor& c) {
10699
bool transpose_result = false, transpose_mat1 = false, transpose_mat2 = false;
107100
result = prepare_matrix_for_cublas(c, transpose_result);
108101
mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_mat1, transpose_result);
109102
matb = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_mat2, transpose_result);
110-
111-
// Handle scale tensors if provided
112-
if (scale_a && scale_b) {
113-
// By default since we return in row-major we run the gemm
114-
// as B.T @ A.T, check transpose_result to determine if we flip the scales
115-
scale_mata_ptr = transpose_result ? scale_b->data_ptr() : scale_a->data_ptr();
116-
scale_mata_dtype = transpose_result ? scale_b->scalar_type() : scale_a->scalar_type();
117-
scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr();
118-
scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type();
119-
}
120-
121-
if (scale_result) {
122-
scale_result_ptr = scale_result->data_ptr();
123-
scale_result_dtype = scale_result->scalar_type();
124-
}
125-
126103
auto mat1_sizes = mat1.sizes();
127104
auto mat2_sizes = mat2.sizes();
128105
if (transpose_result) {
@@ -138,23 +115,13 @@ struct cublasCommonArgs {
138115
lda = mata->stride((transpose_mat1 == transpose_result) ? 1 : 0);
139116
ldb = matb->stride((transpose_mat2 == transpose_result) ? 1 : 0);
140117
result_ld = result->stride(transpose_result ? 0 : 1);
141-
transa = transpose_mat1 ? mata->is_conj() ? 'c' : 't' : 'n';
142-
transb = transpose_mat2 ? matb->is_conj() ? 'c' : 't' : 'n';
118+
transa = transpose_mat1 ? mata->is_conj() ? 'c' : 't' : 'n';
119+
transb = transpose_mat2 ? matb->is_conj() ? 'c' : 't' : 'n';
143120
}
144-
145-
// Matrix members
146121
char transa, transb;
147122
int64_t m, n, k;
148123
int64_t lda, ldb, result_ld;
149124
c10::MaybeOwned<Tensor> mata, matb, result;
150-
151-
// Scale members
152-
void* scale_mata_ptr = nullptr;
153-
void* scale_matb_ptr = nullptr;
154-
void* scale_result_ptr = nullptr;
155-
c10::optional<c10::ScalarType> scale_mata_dtype;
156-
c10::optional<c10::ScalarType> scale_matb_dtype;
157-
c10::optional<c10::ScalarType> scale_result_dtype;
158125
};
159126
} // namespace
160127

@@ -936,24 +903,20 @@ static bool _scaled_mm_is_fnuz() {
936903

937904
namespace{
938905

939-
enum class ScalingType : std::uint8_t {
906+
enum class ScalingType {
940907
TensorWise,
941908
RowWise,
942-
BlockWise,
943909
Error
944910
};
945911
/*
946912
* Scaling Type Determination:
947913
* ---------------------------
948914
* Conditions and corresponding Scaling Types:
949915
*
950-
* - If scale tensors are Float8_e8m0fnu:
951-
* - Returns BlockWise (with additional size checks).
952-
*
953916
* - If scale_a.numel() == 1 && scale_b.numel() == 1:
954917
* - Returns TensorWise.
955918
*
956-
* - Else if scale_a.dim() == 2 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n:
919+
* - Else if scale_a.dim() == 1 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n:
957920
* - Returns RowWise.
958921
*
959922
* - Otherwise:
@@ -966,40 +929,7 @@ ScalingType get_scaling_type(
966929
const at::Tensor& scale_a,
967930
const at::Tensor& scale_b,
968931
int64_t dim_m,
969-
int64_t dim_k,
970932
int64_t dim_n) {
971-
// Check for BlockWise scaling (FP8_E8M0 types)
972-
if (scale_a.scalar_type() == scale_b.scalar_type() &&
973-
scale_a.scalar_type() == at::kFloat8_e8m0fnu) {
974-
constexpr int64_t BLOCK_SIZE_K = 32;
975-
constexpr int64_t BLOCK_SIZE_MN = 128;
976-
977-
auto ceil_div = [](auto a, auto b) { return (a + b - 1) / b; };
978-
auto num_k_blocks = ceil_div(dim_k, BLOCK_SIZE_K);
979-
auto padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4;
980-
981-
// TODO: We might want to enforce some structure on the shapes of the scale
982-
// tensors
983-
984-
// Check expected sizes for block-wise scaling
985-
auto expected_a_size =
986-
BLOCK_SIZE_MN * ceil_div(dim_m, BLOCK_SIZE_MN) * padded_num_k_blocks;
987-
auto expected_b_size =
988-
BLOCK_SIZE_MN * ceil_div(dim_n, BLOCK_SIZE_MN) * padded_num_k_blocks;
989-
990-
TORCH_CHECK(scale_a.numel() == expected_a_size,
991-
"For BlockWise scaling: Expected scale_a size to be ",
992-
expected_a_size, " but got ", scale_a.numel());
993-
TORCH_CHECK(scale_b.numel() == expected_b_size,
994-
"For BlockWise scaling: Expected scale_b size to be ",
995-
expected_b_size, " but got ", scale_b.numel());
996-
997-
TORCH_CHECK(
998-
scale_a.is_contiguous() && scale_b.is_contiguous(),
999-
"For BlockWise scaling: Both scale_a and scale_b must be contiguous");
1000-
1001-
return ScalingType::BlockWise;
1002-
}
1003933
// Both Per-Tensor and Row-wise scaling expect fp32 tensors
1004934
TORCH_CHECK(
1005935
scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
@@ -1097,7 +1027,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
10971027
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
10981028

10991029
// Check what type of scaling we are doing based on inputs
1100-
ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat1.size(1), mat2.size(1));
1030+
ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat2.size(1));
11011031
TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported");
11021032

11031033
TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
@@ -1190,7 +1120,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
11901120
}
11911121
#endif
11921122

1193-
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result);
1123+
cublasCommonArgs args(mat1, mat2, out);
11941124
const auto out_dtype_ = args.result->scalar_type();
11951125
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
11961126

@@ -1300,27 +1230,25 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
13001230
}
13011231
else
13021232
#endif
1303-
{
1233+
{
13041234
at::cuda::blas::scaled_gemm(
13051235
args.transa,
13061236
args.transb,
13071237
args.m,
13081238
args.n,
13091239
args.k,
13101240
args.mata->data_ptr(),
1311-
args.scale_mata_ptr,
1241+
scale_a.data_ptr(),
13121242
args.lda,
13131243
args.mata->scalar_type(),
1314-
args.scale_mata_dtype.value(),
13151244
args.matb->data_ptr(),
1316-
args.scale_matb_ptr,
1245+
scale_b.data_ptr(),
13171246
args.ldb,
13181247
args.matb->scalar_type(),
1319-
args.scale_matb_dtype.value(),
13201248
bias ? bias->data_ptr(): nullptr,
13211249
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
13221250
args.result->data_ptr(),
1323-
args.scale_result_ptr,
1251+
scale_result ? scale_result->data_ptr() : nullptr,
13241252
args.result_ld,
13251253
out_dtype_,
13261254
use_fast_accum,

0 commit comments

Comments
 (0)