@@ -38,18 +38,10 @@ struct Sm100Mxfp8BlockScaledOffsetFunctor {
3838
3939 Sm100Mxfp8BlockScaledOffsetFunctor () = default ;
4040 Sm100Mxfp8BlockScaledOffsetFunctor (
41- int * _expert_offsets,
42- int * _blockscale_offsets,
43- ElementA* _a_base,
44- ElementB* _b_base,
45- ElementSF* _sfa_base,
46- ElementSF* _sfb_base,
47- ElementD* _d_base,
48- ElementA** _a_offsets,
49- ElementB** _b_offsets,
50- ElementSF** _sfa_offsets,
51- ElementSF** _sfb_offsets,
52- ElementD** _d_offsets)
41+ int * _expert_offsets, int * _blockscale_offsets, ElementA* _a_base,
42+ ElementB* _b_base, ElementSF* _sfa_base, ElementSF* _sfb_base,
43+ ElementD* _d_base, ElementA** _a_offsets, ElementB** _b_offsets,
44+ ElementSF** _sfa_offsets, ElementSF** _sfb_offsets, ElementD** _d_offsets)
5345 : expert_offsets{_expert_offsets},
5446 blockscale_offsets{_blockscale_offsets},
5547 a_base (_a_base),
@@ -65,7 +57,8 @@ struct Sm100Mxfp8BlockScaledOffsetFunctor {
6557
6658 void CUTE_DEVICE operator ()(int64_t expert_id, int m, int n, int k) {
6759 int64_t expert_offset = static_cast <int64_t >(expert_offsets[expert_id]);
68- int64_t blockscale_offset = static_cast <int64_t >(blockscale_offsets[expert_id]);
60+ int64_t blockscale_offset =
61+ static_cast <int64_t >(blockscale_offsets[expert_id]);
6962 int64_t a_stride = expert_offset * k;
7063 int64_t b_stride = expert_id * k * n;
7164 int64_t d_stride = expert_offset * n;
@@ -89,14 +82,17 @@ struct Sm100Mxfp8BlockScaledLayoutFunctor {
8982 LayoutSFB* layout_sfb_base{nullptr };
9083
9184 Sm100Mxfp8BlockScaledLayoutFunctor () = default ;
92- Sm100Mxfp8BlockScaledLayoutFunctor (LayoutSFA* _layout_sfa_base, LayoutSFB* _layout_sfb_base)
85+ Sm100Mxfp8BlockScaledLayoutFunctor (LayoutSFA* _layout_sfa_base,
86+ LayoutSFB* _layout_sfb_base)
9387 : layout_sfa_base(_layout_sfa_base), layout_sfb_base(_layout_sfb_base) {}
9488
9589 void CUTE_DEVICE operator ()(int64_t expert_id, int m, int n, int k) {
9690 LayoutSFA* layout_sfa_ptr = layout_sfa_base + expert_id;
9791 LayoutSFB* layout_sfb_ptr = layout_sfb_base + expert_id;
98- *layout_sfa_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA (cute::make_shape (m, n, k, 1 ));
99- *layout_sfb_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB (cute::make_shape (m, n, k, 1 ));
92+ *layout_sfa_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA (
93+ cute::make_shape (m, n, k, 1 ));
94+ *layout_sfb_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB (
95+ cute::make_shape (m, n, k, 1 ));
10096 }
10197};
10298
@@ -110,8 +106,12 @@ struct Sm100Mxfp8BlockScaledStrideFunctor {
110106 StrideD* stride_D_base{nullptr };
111107
112108 Sm100Mxfp8BlockScaledStrideFunctor () = default ;
113- Sm100Mxfp8BlockScaledStrideFunctor (StrideA* _stride_A_base, StrideB* _stride_B_base, StrideD* _stride_D_base)
114- : stride_A_base(_stride_A_base), stride_B_base(_stride_B_base), stride_D_base(_stride_D_base) {}
109+ Sm100Mxfp8BlockScaledStrideFunctor (StrideA* _stride_A_base,
110+ StrideB* _stride_B_base,
111+ StrideD* _stride_D_base)
112+ : stride_A_base(_stride_A_base),
113+ stride_B_base (_stride_B_base),
114+ stride_D_base(_stride_D_base) {}
115115
116116 void CUTE_DEVICE operator ()(int64_t expert_id, int m, int n, int k) {
117117 StrideA* stride_A = stride_A_base + expert_id;
@@ -123,9 +123,11 @@ struct Sm100Mxfp8BlockScaledStrideFunctor {
123123 }
124124};
125125
126- template <typename OffsetFunctor, typename LayoutFunctor, typename StrideFunctor>
126+ template <typename OffsetFunctor, typename LayoutFunctor,
127+ typename StrideFunctor>
127128__global__ void sm100Mxfp8BlockscaledGroupedGemmPreComputeKernel (
128- int * problem_sizes, OffsetFunctor offset_functor, LayoutFunctor layout_functor, StrideFunctor stride_functor) {
129+ int * problem_sizes, OffsetFunctor offset_functor,
130+ LayoutFunctor layout_functor, StrideFunctor stride_functor) {
129131 int64_t expert_id = static_cast <int64_t >(threadIdx .x );
130132 int m = problem_sizes[expert_id * 3 + 0 ];
131133 int n = problem_sizes[expert_id * 3 + 1 ];
0 commit comments