@@ -39,20 +39,21 @@ using namespace AscendC;
3939using namespace Catcoc ;
4040using namespace sglang ::npu_kernel;
4141
42-
4342template <DataFormatMode dMode, WeightFormatMode wMode>
44- class CatCocMatmulAllreduce {
43+ class CatCocMatmulAllreduce
44+ {
4545public:
4646 uint64_t fftsAddr_;
4747 int32_t teamIdx_;
4848
49- __aicore__ inline void Init (uint64_t fftsAddr, uint64_t teamIdx) {
49+ __aicore__ inline void Init (uint64_t fftsAddr, uint64_t teamIdx)
50+ {
5051 fftsAddr_ = fftsAddr;
5152 teamIdx_ = (int32_t )teamIdx;
5253 }
5354
54- __aicore__ inline void Process (GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmD,
55- GM_ADDR gmSymmetric, GM_ADDR gmTiling) {
55+ __aicore__ inline void Process (GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmD, GM_ADDR gmSymmetric, GM_ADDR gmTiling)
56+ {
5657 // Set FFTS address
5758 AscendC::SetSyncBaseAddr (fftsAddr_);
5859 // Set shmem config
@@ -97,15 +98,16 @@ class CatCocMatmulAllreduce {
9798 */
9899
99100 // switch cases
100- using ElementA = typename std::conditional_t <dMode == DataFormatMode::FP16, half,
101- typename std::conditional_t <dMode == DataFormatMode::BF16, __bf16, float >>;
101+ using ElementA =
102+ typename std::conditional_t <dMode == DataFormatMode::FP16, half,
103+ typename std::conditional_t <dMode == DataFormatMode::BF16, __bf16, float >>;
102104 using ElementB = ElementA;
103105 using ElementC = ElementA;
104106 using ElementD = ElementA;
105107
106108 using LayoutA = Catlass::layout::RowMajor;
107- using LayoutB = typename std::conditional_t <wMode == WeightFormatMode::WEIGHT_ND,
108- Catlass::layout::RowMajor, Catlass::layout::zN>;
109+ using LayoutB = typename std::conditional_t <wMode == WeightFormatMode::WEIGHT_ND, Catlass::layout::RowMajor,
110+ Catlass::layout::zN>;
109111 using LayoutC = Catlass::layout::RowMajor;
110112 using LayoutD = Catlass::layout::RowMajor;
111113
@@ -115,7 +117,7 @@ class CatCocMatmulAllreduce {
115117 using DType = Catlass::Gemm::GemmType<ElementD, LayoutD>;
116118
117119 LayoutA layoutA{m, k};
118- LayoutB layoutB = LayoutB::template MakeLayout<ElementB>(k, n); // adapted for both nz and nd
120+ LayoutB layoutB = LayoutB::template MakeLayout<ElementB>(k, n); // adapted for both NZ and ND
119121 LayoutD layoutD{m, n};
120122 Catlass::GemmCoord problemShape{m, n, k};
121123
@@ -125,17 +127,19 @@ class CatCocMatmulAllreduce {
125127 using L1TileShape = Catlass::GemmShape<128 , 256 , 256 >;
126128 using L0TileShape = Catlass::GemmShape<128 , 256 , 64 >;
127129 using BlockMmad =
128- Catlass::Gemm::Block::BlockMmad<MmadDispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>;
130+ Catlass::Gemm::Block::BlockMmad<MmadDispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>;
129131
130132 constexpr uint32_t SWIZZLE_GROUP_SIZE = 7 ;
131133 constexpr uint32_t SWIZZLE_DIRECTION = 1 ;
132- using BlockMmadScheduler = Catlass::Gemm::Block::GemmIdentityBlockSwizzle<SWIZZLE_GROUP_SIZE, SWIZZLE_DIRECTION>;
134+ using BlockMmadScheduler =
135+ Catlass::Gemm::Block::GemmIdentityBlockSwizzle<SWIZZLE_GROUP_SIZE, SWIZZLE_DIRECTION>;
133136 using BlockEpilogueScheduler = Catcoc::CommEpilogue::Block::BlockCommSwizzle<0 >;
134137
135138 using RemoteSrcType = CType;
136139 using RemoteDstType = DType;
137140 using CopyDirect = Catcoc::detail::CopyDirect;
138- using TileRemoteCopy = CommEpilogue::Tile::TileRemoteCopy<ArchTag, RemoteSrcType, RemoteDstType, CopyDirect::Get>;
141+ using TileRemoteCopy =
142+ CommEpilogue::Tile::TileRemoteCopy<ArchTag, RemoteSrcType, RemoteDstType, CopyDirect::Get>;
139143 using TileScheduler = Catlass::Epilogue::Tile::EpilogueIdentityTileSwizzle;
140144
141145 constexpr uint32_t COMM_BLOCK_ROWS = 8 ;
@@ -150,27 +154,27 @@ class CatCocMatmulAllreduce {
150154 constexpr uint32_t SCATTER_TILE_COLUMNS = 256 ;
151155 using EpilogueReduceScatterTileShape = Catlass::MatrixShape<SCATTER_TILE_ROWS, SCATTER_TILE_COLUMNS>;
152156 using EpilogueReduceScatterDispatch =
153- CommEpilogue::EpilogueAtlasA2CommRemoteCopy<UB_STAGES, Catcoc::detail::CopyMode::Scatter>;
157+ CommEpilogue::EpilogueAtlasA2CommRemoteCopy<UB_STAGES, Catcoc::detail::CopyMode::Scatter>;
154158 using BlockEpilogueReduceScatter =
155- CommEpilogue::Block::CommBlockEpilogue<EpilogueReduceScatterDispatch, RemoteSrcType, RemoteDstType,
156- CommCoreSplit, CommBlockShape, EpilogueReduceScatterTileShape,
157- TileRemoteCopy, TileScheduler>;
159+ CommEpilogue::Block::CommBlockEpilogue<EpilogueReduceScatterDispatch, RemoteSrcType, RemoteDstType,
160+ CommCoreSplit, CommBlockShape, EpilogueReduceScatterTileShape,
161+ TileRemoteCopy, TileScheduler>;
158162
159163 constexpr uint32_t ALLGATHER_TILE_ROWS = 4 ;
160164 constexpr uint32_t ALLGATHER_TILE_COLUMNS = 256 ;
161165 using EpilogueAllGatherTileShape = Catlass::MatrixShape<ALLGATHER_TILE_ROWS, ALLGATHER_TILE_COLUMNS>;
162166 using EpilogueAllGatherDispatch =
163- CommEpilogue::EpilogueAtlasA2CommRemoteCopy<UB_STAGES, Catcoc::detail::CopyMode::Gather>;
167+ CommEpilogue::EpilogueAtlasA2CommRemoteCopy<UB_STAGES, Catcoc::detail::CopyMode::Gather>;
164168 using BlockEpilogueAllGather =
165- CommEpilogue::Block::CommBlockEpilogue<EpilogueAllGatherDispatch, RemoteSrcType, RemoteDstType, CommCoreSplit ,
166- CommBlockShape, EpilogueAllGatherTileShape, TileRemoteCopy ,
167- TileScheduler>;
169+ CommEpilogue::Block::CommBlockEpilogue<EpilogueAllGatherDispatch, RemoteSrcType, RemoteDstType,
170+ CommCoreSplit, CommBlockShape, EpilogueAllGatherTileShape ,
171+ TileRemoteCopy, TileScheduler>;
168172
169173 constexpr uint32_t WORKSPACE_STAGES = 2 ;
170174 constexpr uint32_t COMM_INTERVAL = 4 ;
171175 using MatmulAllReduceKernel =
172- DGemm::Kernel::MatmulAllReduce<BlockMmad, BlockEpilogueReduceScatter, BlockEpilogueAllGather,
173- BlockMmadScheduler, BlockEpilogueScheduler, WORKSPACE_STAGES>;
176+ DGemm::Kernel::MatmulAllReduce<BlockMmad, BlockEpilogueReduceScatter, BlockEpilogueAllGather,
177+ BlockMmadScheduler, BlockEpilogueScheduler, WORKSPACE_STAGES>;
174178
175179 typename BlockEpilogueReduceScatter::Params reduceScatterParams{};
176180 typename BlockEpilogueAllGather::Params allGatherParams{};
@@ -195,10 +199,9 @@ class CatCocMatmulAllreduce {
195199};
196200
197201template <DataFormatMode dMode, WeightFormatMode wMode>
198- __aicore__ void catcoc_matmul_allreduce_impl (uint64_t fftsAddr, uint64_t teamIdx,
199- GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmD,
200- GM_ADDR gmSymmetric, GM_ADDR gmWorkspace,
201- GM_ADDR gmTiling) {
202+ __aicore__ void catcoc_matmul_allreduce_impl (uint64_t fftsAddr, uint64_t teamIdx, GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmD,
203+ GM_ADDR gmSymmetric, GM_ADDR gmWorkspace, GM_ADDR gmTiling)
204+ {
202205 // gmWorkspace is a dummy input for ascendc compile with tiling, catcoc ops use gmSymmetric as actual workspace
203206 CatCocMatmulAllreduce<dMode, wMode> op;
204207 op.Init (fftsAddr, teamIdx);
@@ -207,28 +210,32 @@ __aicore__ void catcoc_matmul_allreduce_impl(uint64_t fftsAddr, uint64_t teamIdx
207210
208211extern " C" __global__ __aicore__ void catcoc_matmul_allreduce_fp16 (uint64_t fftsAddr, uint64_t teamIdx, GM_ADDR gmA,
209212 GM_ADDR gmB, GM_ADDR gmD, GM_ADDR gmSymmetric,
210- GM_ADDR gmWorkspace, GM_ADDR gmTiling) {
211- catcoc_matmul_allreduce_impl<DataFormatMode::FP16, WeightFormatMode::WEIGHT_ND>
212- (fftsAddr, teamIdx, gmA, gmB, gmD, gmSymmetric, gmWorkspace, gmTiling);
213+ GM_ADDR gmWorkspace, GM_ADDR gmTiling)
214+ {
215+ catcoc_matmul_allreduce_impl<DataFormatMode::FP16, WeightFormatMode::WEIGHT_ND>(fftsAddr, teamIdx, gmA, gmB, gmD,
216+ gmSymmetric, gmWorkspace, gmTiling);
213217}
214218
215219extern " C" __global__ __aicore__ void catcoc_matmul_allreduce_fp16_wnz (uint64_t fftsAddr, uint64_t teamIdx, GM_ADDR gmA,
216220 GM_ADDR gmB, GM_ADDR gmD, GM_ADDR gmSymmetric,
217- GM_ADDR gmWorkspace, GM_ADDR gmTiling) {
218- catcoc_matmul_allreduce_impl<DataFormatMode::FP16, WeightFormatMode::WEIGHT_NZ>
219- (fftsAddr, teamIdx, gmA, gmB, gmD, gmSymmetric, gmWorkspace, gmTiling);
221+ GM_ADDR gmWorkspace, GM_ADDR gmTiling)
222+ {
223+ catcoc_matmul_allreduce_impl<DataFormatMode::FP16, WeightFormatMode::WEIGHT_NZ>(fftsAddr, teamIdx, gmA, gmB, gmD,
224+ gmSymmetric, gmWorkspace, gmTiling);
220225}
221226
222227extern " C" __global__ __aicore__ void catcoc_matmul_allreduce_bf16 (uint64_t fftsAddr, uint64_t teamIdx, GM_ADDR gmA,
223228 GM_ADDR gmB, GM_ADDR gmD, GM_ADDR gmSymmetric,
224- GM_ADDR gmWorkspace, GM_ADDR gmTiling){
225- catcoc_matmul_allreduce_impl<DataFormatMode::BF16, WeightFormatMode::WEIGHT_ND>
226- (fftsAddr, teamIdx, gmA, gmB, gmD, gmSymmetric, gmWorkspace, gmTiling);
229+ GM_ADDR gmWorkspace, GM_ADDR gmTiling)
230+ {
231+ catcoc_matmul_allreduce_impl<DataFormatMode::BF16, WeightFormatMode::WEIGHT_ND>(fftsAddr, teamIdx, gmA, gmB, gmD,
232+ gmSymmetric, gmWorkspace, gmTiling);
227233}
228234
229235extern " C" __global__ __aicore__ void catcoc_matmul_allreduce_bf16_wnz (uint64_t fftsAddr, uint64_t teamIdx, GM_ADDR gmA,
230236 GM_ADDR gmB, GM_ADDR gmD, GM_ADDR gmSymmetric,
231- GM_ADDR gmWorkspace, GM_ADDR gmTiling) {
232- catcoc_matmul_allreduce_impl<DataFormatMode::BF16, WeightFormatMode::WEIGHT_NZ>
233- (fftsAddr, teamIdx, gmA, gmB, gmD, gmSymmetric, gmWorkspace, gmTiling);
237+ GM_ADDR gmWorkspace, GM_ADDR gmTiling)
238+ {
239+ catcoc_matmul_allreduce_impl<DataFormatMode::BF16, WeightFormatMode::WEIGHT_NZ>(fftsAddr, teamIdx, gmA, gmB, gmD,
240+ gmSymmetric, gmWorkspace, gmTiling);
234241}
0 commit comments