Skip to content

Commit 033bad0

Browse files
cthifacebook-github-bot
authored andcommitted
Remove e5m2 from f8f8bf16_rowwise_batched (pytorch#4908)
Summary: Pull Request resolved: pytorch#4908 X-link: facebookresearch/FBGEMM#1934 Fully remove this from f8f8bf16_rowwise_batched, we have already removed most of the kernel instances, this just cleans up the remaining one. Also moved some `TORCH_CHECK` into the main non-templated API level, as that would reduce some binary size (although likely trivial). Reviewed By: jiawenliu64 Differential Revision: D82976652 fbshipit-source-id: 9fe1058a6993942d097d27b272368d172f5ca55f
1 parent 78f7058 commit 033bad0

11 files changed

+29
-164
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,27 @@ at::Tensor dispatch_fp8_rowwise_batched_kernel(
3232
at::Tensor w_scale, // FP32
3333
std::optional<at::Tensor> bias = std::nullopt,
3434
std::optional<at::Tensor> output = std::nullopt) {
35-
const int arch = getDeviceArch();
36-
3735
TORCH_CHECK(
3836
(XQ.dim() == 3 && WQ.dim() == 3),
3937
"FP8 rowwise batched GEMM only supports 3D inputs");
4038
int M, N;
4139
M = XQ.size(1);
4240
N = WQ.size(1);
4341

44-
const bool use_e5m2 = XQ.dtype() == at::kFloat8_e5m2;
45-
if (use_e5m2) {
42+
TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous());
43+
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());
44+
TORCH_CHECK(XQ.dtype() == at::kFloat8_e4m3fn, "XQ must be FP8 e4m3fn");
45+
TORCH_CHECK(WQ.dtype() == at::kFloat8_e4m3fn, "WQ must be FP8 e4m3fn");
46+
TORCH_CHECK(
47+
x_scale.dtype() == at::kFloat && w_scale.dtype() == at::kFloat,
48+
"Scale tensors must be float32.");
49+
if (bias.has_value()) {
4650
TORCH_CHECK(
47-
arch == 9, "f8f8bf16_rowwise_batched only supports FP8 e5m2 on SM90");
48-
return f8f8bf16_rowwise_batched_64_128_128_2_1_1_9_f_e5m2(
49-
XQ, WQ, x_scale, w_scale, bias, output);
51+
bias.value().dtype() == at::kFloat,
52+
"Bias type must be float32 if provided.");
5053
}
5154

55+
const int arch = getDeviceArch();
5256
if (arch == 10) {
5357
if ((M * N <= 4096 * 4096) || (N % 256 > 0 && M % 256 == 0) ||
5458
(M % 256 > 0 && N % 256 > 0) || M >= 1024 && N >= 1024) {

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_128_128_128_1_2_1_10_t.cu

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,8 @@ at::Tensor f8f8bf16_rowwise_batched_128_128_128_1_2_1_10_t(
1818
std::optional<at::Tensor> bias = std::nullopt,
1919
std::optional<at::Tensor> output = std::nullopt) {
2020
// Dispatch this kernel to the correct underlying implementation.
21-
return f8f8bf16_rowwise_batched_wrapper<
22-
128,
23-
128,
24-
128,
25-
1,
26-
2,
27-
1,
28-
10,
29-
true,
30-
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
21+
return f8f8bf16_rowwise_batched_impl<128, 128, 128, 1, 2, 1, 10, true>(
22+
XQ, WQ, x_scale, w_scale, bias, output);
3123
}
3224

3325
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_128_128_128_1_2_1_9_t.cu

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,8 @@ at::Tensor f8f8bf16_rowwise_batched_128_128_128_1_2_1_9_t(
1818
std::optional<at::Tensor> bias = std::nullopt,
1919
std::optional<at::Tensor> output = std::nullopt) {
2020
// Dispatch this kernel to the correct underlying implementation.
21-
return f8f8bf16_rowwise_batched_wrapper<
22-
128,
23-
128,
24-
128,
25-
1,
26-
2,
27-
1,
28-
9,
29-
true,
30-
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
21+
return f8f8bf16_rowwise_batched_impl<128, 128, 128, 1, 2, 1, 9, true>(
22+
XQ, WQ, x_scale, w_scale, bias, output);
3123
}
3224

3325
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_128_128_128_2_1_1_10_t.cu

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,8 @@ at::Tensor f8f8bf16_rowwise_batched_128_128_128_2_1_1_10_t(
1818
std::optional<at::Tensor> bias = std::nullopt,
1919
std::optional<at::Tensor> output = std::nullopt) {
2020
// Dispatch this kernel to the correct underlying implementation.
21-
return f8f8bf16_rowwise_batched_wrapper<
22-
128,
23-
128,
24-
128,
25-
2,
26-
1,
27-
1,
28-
10,
29-
true,
30-
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
21+
return f8f8bf16_rowwise_batched_impl<128, 128, 128, 2, 1, 1, 10, true>(
22+
XQ, WQ, x_scale, w_scale, bias, output);
3123
}
3224

3325
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_128_128_128_2_1_1_9_t.cu

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,8 @@ at::Tensor f8f8bf16_rowwise_batched_128_128_128_2_1_1_9_t(
1818
std::optional<at::Tensor> bias = std::nullopt,
1919
std::optional<at::Tensor> output = std::nullopt) {
2020
// Dispatch this kernel to the correct underlying implementation.
21-
return f8f8bf16_rowwise_batched_wrapper<
22-
128,
23-
128,
24-
128,
25-
2,
26-
1,
27-
1,
28-
9,
29-
true,
30-
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
21+
return f8f8bf16_rowwise_batched_impl<128, 128, 128, 2, 1, 1, 9, true>(
22+
XQ, WQ, x_scale, w_scale, bias, output);
3123
}
3224

3325
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_64_128_128_1_2_1_10_f.cu

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,8 @@ at::Tensor f8f8bf16_rowwise_batched_64_128_128_1_2_1_10_f(
1818
std::optional<at::Tensor> bias = std::nullopt,
1919
std::optional<at::Tensor> output = std::nullopt) {
2020
// Dispatch this kernel to the correct underlying implementation.
21-
return f8f8bf16_rowwise_batched_wrapper<
22-
64,
23-
128,
24-
128,
25-
1,
26-
2,
27-
1,
28-
10,
29-
false,
30-
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
21+
return f8f8bf16_rowwise_batched_impl<64, 128, 128, 1, 2, 1, 10, false>(
22+
XQ, WQ, x_scale, w_scale, bias, output);
3123
}
3224

3325
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_64_128_128_1_2_1_9_f.cu

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,8 @@ at::Tensor f8f8bf16_rowwise_batched_64_128_128_1_2_1_9_f(
1818
std::optional<at::Tensor> bias = std::nullopt,
1919
std::optional<at::Tensor> output = std::nullopt) {
2020
// Dispatch this kernel to the correct underlying implementation.
21-
return f8f8bf16_rowwise_batched_wrapper<
22-
64,
23-
128,
24-
128,
25-
1,
26-
2,
27-
1,
28-
9,
29-
false,
30-
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
21+
return f8f8bf16_rowwise_batched_impl<64, 128, 128, 1, 2, 1, 9, false>(
22+
XQ, WQ, x_scale, w_scale, bias, output);
3123
}
3224

3325
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_64_128_128_2_1_1_10_f.cu

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,8 @@ at::Tensor f8f8bf16_rowwise_batched_64_128_128_2_1_1_10_f(
1818
std::optional<at::Tensor> bias = std::nullopt,
1919
std::optional<at::Tensor> output = std::nullopt) {
2020
// Dispatch this kernel to the correct underlying implementation.
21-
return f8f8bf16_rowwise_batched_wrapper<
22-
64,
23-
128,
24-
128,
25-
2,
26-
1,
27-
1,
28-
10,
29-
false,
30-
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
21+
return f8f8bf16_rowwise_batched_impl<64, 128, 128, 2, 1, 1, 10, false>(
22+
XQ, WQ, x_scale, w_scale, bias, output);
3123
}
3224

3325
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_64_128_128_2_1_1_9_f.cu

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,8 @@ at::Tensor f8f8bf16_rowwise_batched_64_128_128_2_1_1_9_f(
1818
std::optional<at::Tensor> bias = std::nullopt,
1919
std::optional<at::Tensor> output = std::nullopt) {
2020
// Dispatch this kernel to the correct underlying implementation.
21-
return f8f8bf16_rowwise_batched_wrapper<
22-
64,
23-
128,
24-
128,
25-
2,
26-
1,
27-
1,
28-
9,
29-
false,
30-
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
21+
return f8f8bf16_rowwise_batched_impl<64, 128, 128, 2, 1, 1, 9, false>(
22+
XQ, WQ, x_scale, w_scale, bias, output);
3123
}
3224

3325
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_64_128_128_2_1_1_9_f_e5m2.cu

Lines changed: 0 additions & 33 deletions
This file was deleted.

0 commit comments

Comments
 (0)