Skip to content

Commit d0eb60a

Browse files
committed
chore: run pre-commit fixes
Signed-off-by: EdalatiAli <[email protected]>
1 parent 0b2be10 commit d0eb60a

7 files changed

+259
-293
lines changed

csrc/moe/mxfp8_grouped_gemm/es_sm100_mxfp8_blockscaled.cu

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,39 +8,45 @@
88
#include "es_sm100_mxfp8_blockscaled_launcher.cuh"
99

1010
void es_sm100_mxfp8_blockscaled_grouped_mm(
11-
const torch::Tensor& a,
12-
const torch::Tensor& b,
13-
const torch::Tensor& sfa,
14-
const torch::Tensor& sfb,
15-
torch::Tensor& d,
16-
const torch::Tensor& problem_sizes,
17-
const torch::Tensor& expert_offsets,
11+
const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& sfa,
12+
const torch::Tensor& sfb, torch::Tensor& d,
13+
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets,
1814
const torch::Tensor& blockscale_offsets) {
1915
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
2016
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
21-
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)");
22-
TORCH_CHECK(
23-
problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets");
24-
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32");
17+
TORCH_CHECK(problem_sizes.size(1) == 3,
18+
"problem_sizes must have shape (num_experts, 3)");
19+
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
20+
"Number of experts in problem_sizes must match expert_offsets");
21+
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
22+
"problem_sizes must be int32");
2523
TORCH_CHECK(a.dim() == 2, "a must be a 2D tensor of shape (num_tokens, k)");
26-
TORCH_CHECK(b.dim() == 3, "b must be a 3D tensor of shape (num_experts, k, n)");
27-
TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0, "k should align 128");
24+
TORCH_CHECK(b.dim() == 3,
25+
"b must be a 3D tensor of shape (num_experts, k, n)");
26+
TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0,
27+
"k should align 128");
2828
TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128");
2929
TORCH_CHECK(a.strides()[1] == 1, "a must be row major");
3030
TORCH_CHECK(b.strides()[1] == 1, "b must be column major");
3131

3232
auto stream = at::cuda::getCurrentCUDAStream();
3333
if (d.dtype() == torch::kBFloat16) {
34-
expert_specialization::es_sm100_mxfp8_blockscaled_group_mm_dispatch_out_dtype<cutlass::bfloat16_t>(
35-
a, b, sfa, sfb, d, problem_sizes, expert_offsets, blockscale_offsets, stream);
34+
expert_specialization::
35+
es_sm100_mxfp8_blockscaled_group_mm_dispatch_out_dtype<
36+
cutlass::bfloat16_t>(a, b, sfa, sfb, d, problem_sizes,
37+
expert_offsets, blockscale_offsets, stream);
3638
} else if (d.dtype() == torch::kFloat16) {
37-
expert_specialization::es_sm100_mxfp8_blockscaled_group_mm_dispatch_out_dtype<cutlass::half_t>(
38-
a, b, sfa, sfb, d, problem_sizes, expert_offsets, blockscale_offsets, stream);
39+
expert_specialization::
40+
es_sm100_mxfp8_blockscaled_group_mm_dispatch_out_dtype<cutlass::half_t>(
41+
a, b, sfa, sfb, d, problem_sizes, expert_offsets,
42+
blockscale_offsets, stream);
3943
} else {
4044
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
4145
}
4246
#else
43-
TORCH_CHECK(false, "No implemented es_sm100_mxfp8_blockscaled_grouped_mm for current device");
47+
TORCH_CHECK(false,
48+
"No implemented es_sm100_mxfp8_blockscaled_grouped_mm for "
49+
"current device");
4450
#endif
4551
}
4652

csrc/moe/mxfp8_grouped_gemm/es_sm100_mxfp8_blockscaled_functor.cuh

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,10 @@ struct Sm100Mxfp8BlockScaledOffsetFunctor {
3838

3939
Sm100Mxfp8BlockScaledOffsetFunctor() = default;
4040
Sm100Mxfp8BlockScaledOffsetFunctor(
41-
int* _expert_offsets,
42-
int* _blockscale_offsets,
43-
ElementA* _a_base,
44-
ElementB* _b_base,
45-
ElementSF* _sfa_base,
46-
ElementSF* _sfb_base,
47-
ElementD* _d_base,
48-
ElementA** _a_offsets,
49-
ElementB** _b_offsets,
50-
ElementSF** _sfa_offsets,
51-
ElementSF** _sfb_offsets,
52-
ElementD** _d_offsets)
41+
int* _expert_offsets, int* _blockscale_offsets, ElementA* _a_base,
42+
ElementB* _b_base, ElementSF* _sfa_base, ElementSF* _sfb_base,
43+
ElementD* _d_base, ElementA** _a_offsets, ElementB** _b_offsets,
44+
ElementSF** _sfa_offsets, ElementSF** _sfb_offsets, ElementD** _d_offsets)
5345
: expert_offsets{_expert_offsets},
5446
blockscale_offsets{_blockscale_offsets},
5547
a_base(_a_base),
@@ -65,7 +57,8 @@ struct Sm100Mxfp8BlockScaledOffsetFunctor {
6557

6658
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
6759
int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
68-
int64_t blockscale_offset = static_cast<int64_t>(blockscale_offsets[expert_id]);
60+
int64_t blockscale_offset =
61+
static_cast<int64_t>(blockscale_offsets[expert_id]);
6962
int64_t a_stride = expert_offset * k;
7063
int64_t b_stride = expert_id * k * n;
7164
int64_t d_stride = expert_offset * n;
@@ -89,14 +82,17 @@ struct Sm100Mxfp8BlockScaledLayoutFunctor {
8982
LayoutSFB* layout_sfb_base{nullptr};
9083

9184
Sm100Mxfp8BlockScaledLayoutFunctor() = default;
92-
Sm100Mxfp8BlockScaledLayoutFunctor(LayoutSFA* _layout_sfa_base, LayoutSFB* _layout_sfb_base)
85+
Sm100Mxfp8BlockScaledLayoutFunctor(LayoutSFA* _layout_sfa_base,
86+
LayoutSFB* _layout_sfb_base)
9387
: layout_sfa_base(_layout_sfa_base), layout_sfb_base(_layout_sfb_base) {}
9488

9589
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
9690
LayoutSFA* layout_sfa_ptr = layout_sfa_base + expert_id;
9791
LayoutSFB* layout_sfb_ptr = layout_sfb_base + expert_id;
98-
*layout_sfa_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
99-
*layout_sfb_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
92+
*layout_sfa_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(
93+
cute::make_shape(m, n, k, 1));
94+
*layout_sfb_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(
95+
cute::make_shape(m, n, k, 1));
10096
}
10197
};
10298

@@ -110,8 +106,12 @@ struct Sm100Mxfp8BlockScaledStrideFunctor {
110106
StrideD* stride_D_base{nullptr};
111107

112108
Sm100Mxfp8BlockScaledStrideFunctor() = default;
113-
Sm100Mxfp8BlockScaledStrideFunctor(StrideA* _stride_A_base, StrideB* _stride_B_base, StrideD* _stride_D_base)
114-
: stride_A_base(_stride_A_base), stride_B_base(_stride_B_base), stride_D_base(_stride_D_base) {}
109+
Sm100Mxfp8BlockScaledStrideFunctor(StrideA* _stride_A_base,
110+
StrideB* _stride_B_base,
111+
StrideD* _stride_D_base)
112+
: stride_A_base(_stride_A_base),
113+
stride_B_base(_stride_B_base),
114+
stride_D_base(_stride_D_base) {}
115115

116116
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
117117
StrideA* stride_A = stride_A_base + expert_id;
@@ -123,9 +123,11 @@ struct Sm100Mxfp8BlockScaledStrideFunctor {
123123
}
124124
};
125125

126-
template <typename OffsetFunctor, typename LayoutFunctor, typename StrideFunctor>
126+
template <typename OffsetFunctor, typename LayoutFunctor,
127+
typename StrideFunctor>
127128
__global__ void sm100Mxfp8BlockscaledGroupedGemmPreComputeKernel(
128-
int* problem_sizes, OffsetFunctor offset_functor, LayoutFunctor layout_functor, StrideFunctor stride_functor) {
129+
int* problem_sizes, OffsetFunctor offset_functor,
130+
LayoutFunctor layout_functor, StrideFunctor stride_functor) {
129131
int64_t expert_id = static_cast<int64_t>(threadIdx.x);
130132
int m = problem_sizes[expert_id * 3 + 0];
131133
int n = problem_sizes[expert_id * 3 + 1];

csrc/moe/mxfp8_grouped_gemm/es_sm100_mxfp8_blockscaled_group_quant.cu

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,9 @@
88
#include "es_sm100_mxfp8_blockscaled_group_quant.cuh"
99

1010
void es_sm100_mxfp8_blockscaled_grouped_quant(
11-
const torch::Tensor& input,
12-
const torch::Tensor& problem_sizes,
11+
const torch::Tensor& input, const torch::Tensor& problem_sizes,
1312
const torch::Tensor& expert_offsets,
14-
const torch::Tensor& blockscale_offsets,
15-
torch::Tensor& quant_output,
13+
const torch::Tensor& blockscale_offsets, torch::Tensor& quant_output,
1614
torch::Tensor& scale_factor) {
1715
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
1816
TORCH_CHECK(input.dim() == 2, "input must be 2D tensor");
@@ -26,20 +24,25 @@ void es_sm100_mxfp8_blockscaled_grouped_quant(
2624
"expert_offsets must be 1D and have size equal to the number of groups");
2725
TORCH_CHECK(
2826
blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups,
29-
"blockscale_offsets must be 1D and have size equal to the number of groups");
27+
"blockscale_offsets must be 1D and have size equal to the number of "
28+
"groups");
3029

3130
auto stream = at::cuda::getCurrentCUDAStream();
3231
if (input.dtype() == torch::kBFloat16) {
33-
expert_specialization::launch_es_sm100_mxfp8_blockscaled_grouped_quant<__nv_bfloat16>(
34-
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, scale_factor);
32+
expert_specialization::launch_es_sm100_mxfp8_blockscaled_grouped_quant<
33+
__nv_bfloat16>(input, problem_sizes, expert_offsets, blockscale_offsets,
34+
quant_output, scale_factor);
3535
} else if (input.dtype() == torch::kFloat16) {
36-
expert_specialization::launch_es_sm100_mxfp8_blockscaled_grouped_quant<__half>(
37-
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, scale_factor);
36+
expert_specialization::launch_es_sm100_mxfp8_blockscaled_grouped_quant<
37+
__half>(input, problem_sizes, expert_offsets, blockscale_offsets,
38+
quant_output, scale_factor);
3839
} else {
3940
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
4041
}
4142
#else
42-
TORCH_CHECK(false, "No implemented es_sm100_mxfp8_blockscaled_grouped_quant for current device");
43+
TORCH_CHECK(false,
44+
"No implemented es_sm100_mxfp8_blockscaled_grouped_quant for "
45+
"current device");
4346
#endif
4447
}
4548

0 commit comments

Comments
 (0)