Skip to content

Commit 002b35a

Browse files
committed
apply pre-commit
1 parent 2e45ffe commit 002b35a

File tree

4 files changed

+95
-115
lines changed

4 files changed

+95
-115
lines changed

csrc/catcoc/ops/op_kernel/catcoc_allgather_matmul_kernel.hpp

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,21 @@ using namespace AscendC;
3939
using namespace Catcoc;
4040
using namespace sglang::npu_kernel;
4141

42-
4342
template <DataFormatMode dMode, WeightFormatMode wMode>
44-
class CatCocAllgatherMatmul {
43+
class CatCocAllgatherMatmul
44+
{
4545
public:
4646
uint64_t fftsAddr_;
4747
int32_t teamIdx_;
4848

49-
__aicore__ inline void Init(uint64_t fftsAddr, uint64_t teamIdx) {
49+
__aicore__ inline void Init(uint64_t fftsAddr, uint64_t teamIdx)
50+
{
5051
fftsAddr_ = fftsAddr;
5152
teamIdx_ = (int32_t)teamIdx;
5253
}
5354

54-
__aicore__ inline void Process(GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC,
55-
GM_ADDR gmSymmetric, GM_ADDR gmTiling) {
55+
__aicore__ inline void Process(GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC, GM_ADDR gmSymmetric, GM_ADDR gmTiling)
56+
{
5657
// Set FFTS address
5758
AscendC::SetSyncBaseAddr(fftsAddr_);
5859
// Set shmem config
@@ -97,22 +98,23 @@ class CatCocAllgatherMatmul {
9798
*/
9899

99100
// switch cases
100-
using ElementA = typename std::conditional_t<dMode == DataFormatMode::FP16, half,
101-
typename std::conditional_t<dMode == DataFormatMode::BF16, __bf16, float>>;
101+
using ElementA =
102+
typename std::conditional_t<dMode == DataFormatMode::FP16, half,
103+
typename std::conditional_t<dMode == DataFormatMode::BF16, __bf16, float>>;
102104
using ElementB = ElementA;
103105
using ElementC = ElementA;
104106

105107
using LayoutA = Catlass::layout::RowMajor;
106-
using LayoutB = typename std::conditional_t<wMode == WeightFormatMode::WEIGHT_ND,
107-
Catlass::layout::RowMajor, Catlass::layout::zN>;
108+
using LayoutB = typename std::conditional_t<wMode == WeightFormatMode::WEIGHT_ND, Catlass::layout::RowMajor,
109+
Catlass::layout::zN>;
108110
using LayoutC = Catlass::layout::RowMajor;
109111

110112
using AType = Catlass::Gemm::GemmType<ElementA, LayoutA>;
111113
using BType = Catlass::Gemm::GemmType<ElementB, LayoutB>;
112114
using CType = Catlass::Gemm::GemmType<ElementC, LayoutC>;
113115

114116
LayoutA layoutA{m, k};
115-
LayoutB layoutB = LayoutB::template MakeLayout<ElementB>(k, n); // adapted for both nz and nd
117+
LayoutB layoutB = LayoutB::template MakeLayout<ElementB>(k, n); // adapted for both NZ and ND
116118
LayoutC layoutC{m * rankSize, n};
117119
Catlass::GemmCoord problemShape{m, n, k};
118120

@@ -133,15 +135,17 @@ class CatCocAllgatherMatmul {
133135
using RemoteSrcType = AType;
134136
using RemoteDstType = AType;
135137
using CopyDirect = Catcoc::detail::CopyDirect;
136-
using TileRemoteCopy = CommEpilogue::Tile::TileRemoteCopy<ArchTag, RemoteSrcType, RemoteDstType, CopyDirect::Put>;
138+
using TileRemoteCopy =
139+
CommEpilogue::Tile::TileRemoteCopy<ArchTag, RemoteSrcType, RemoteDstType, CopyDirect::Put>;
137140
using TileSchedulerForAllgather = Catlass::Epilogue::Tile::EpilogueIdentityTileSwizzle;
138141

139142
using CommBlockShape = Catlass::MatrixShape<64, UINT_MAX / 2>;
140143
using CommCoreSplit = Catlass::MatrixShape<20, 1>;
141144

142145
constexpr uint32_t UB_STAGES = 2;
143146
using AllGatherTileShape = Catlass::MatrixShape<32, 256>;
144-
using AllGatherDispatch = CommEpilogue::EpilogueAtlasA2CommRemoteCopy<UB_STAGES, Catcoc::detail::CopyMode::Gather>;
147+
using AllGatherDispatch =
148+
CommEpilogue::EpilogueAtlasA2CommRemoteCopy<UB_STAGES, Catcoc::detail::CopyMode::Gather>;
145149
using BlockEpilogueAllGather =
146150
CommEpilogue::Block::CommBlockEpilogue<AllGatherDispatch, RemoteSrcType, RemoteDstType, CommCoreSplit,
147151
CommBlockShape, AllGatherTileShape, TileRemoteCopy,
@@ -166,44 +170,44 @@ class CatCocAllgatherMatmul {
166170
}
167171
};
168172

169-
170173
template <DataFormatMode dMode, WeightFormatMode wMode>
171-
__aicore__ void catcoc_allgather_matmul_impl(uint64_t fftsAddr, uint64_t teamIdx,
172-
GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC,
173-
GM_ADDR gmSymmetric, GM_ADDR gmWorkspace,
174-
GM_ADDR gmTiling) {
174+
__aicore__ void catcoc_allgather_matmul_impl(uint64_t fftsAddr, uint64_t teamIdx, GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC,
175+
GM_ADDR gmSymmetric, GM_ADDR gmWorkspace, GM_ADDR gmTiling)
176+
{
175177
// gmWorkspace is a dummy input for ascendc compile with tiling, catcoc ops use gmSymmetric as actual workspace
176178
CatCocAllgatherMatmul<dMode, wMode> op;
177179
op.Init(fftsAddr, teamIdx);
178180
op.Process(gmA, gmB, gmC, gmSymmetric, gmTiling);
179181
}
180182

181-
182183
extern "C" __global__ __aicore__ void catcoc_allgather_matmul_fp16(uint64_t fftsAddr, uint64_t teamIdx, GM_ADDR gmA,
183184
GM_ADDR gmB, GM_ADDR gmC, GM_ADDR gmSymmetric,
184-
GM_ADDR gmWorkspace, GM_ADDR gmTiling) {
185-
catcoc_allgather_matmul_impl<DataFormatMode::FP16, WeightFormatMode::WEIGHT_ND>
186-
(fftsAddr, teamIdx, gmA, gmB, gmC, gmSymmetric, gmWorkspace, gmTiling);
185+
GM_ADDR gmWorkspace, GM_ADDR gmTiling)
186+
{
187+
catcoc_allgather_matmul_impl<DataFormatMode::FP16, WeightFormatMode::WEIGHT_ND>(fftsAddr, teamIdx, gmA, gmB, gmC,
188+
gmSymmetric, gmWorkspace, gmTiling);
187189
}
188190

189-
190191
extern "C" __global__ __aicore__ void catcoc_allgather_matmul_fp16_wnz(uint64_t fftsAddr, uint64_t teamIdx, GM_ADDR gmA,
191192
GM_ADDR gmB, GM_ADDR gmC, GM_ADDR gmSymmetric,
192-
GM_ADDR gmWorkspace, GM_ADDR gmTiling) {
193-
catcoc_allgather_matmul_impl<DataFormatMode::FP16, WeightFormatMode::WEIGHT_NZ>
194-
(fftsAddr, teamIdx, gmA, gmB, gmC, gmSymmetric, gmWorkspace, gmTiling);
193+
GM_ADDR gmWorkspace, GM_ADDR gmTiling)
194+
{
195+
catcoc_allgather_matmul_impl<DataFormatMode::FP16, WeightFormatMode::WEIGHT_NZ>(fftsAddr, teamIdx, gmA, gmB, gmC,
196+
gmSymmetric, gmWorkspace, gmTiling);
195197
}
196198

197199
extern "C" __global__ __aicore__ void catcoc_allgather_matmul_bf16(uint64_t fftsAddr, uint64_t teamIdx, GM_ADDR gmA,
198200
GM_ADDR gmB, GM_ADDR gmC, GM_ADDR gmSymmetric,
199-
GM_ADDR gmWorkspace, GM_ADDR gmTiling) {
200-
catcoc_allgather_matmul_impl<DataFormatMode::BF16, WeightFormatMode::WEIGHT_ND>
201-
(fftsAddr, teamIdx, gmA, gmB, gmC, gmSymmetric, gmWorkspace, gmTiling);
201+
GM_ADDR gmWorkspace, GM_ADDR gmTiling)
202+
{
203+
catcoc_allgather_matmul_impl<DataFormatMode::BF16, WeightFormatMode::WEIGHT_ND>(fftsAddr, teamIdx, gmA, gmB, gmC,
204+
gmSymmetric, gmWorkspace, gmTiling);
202205
}
203206

204207
extern "C" __global__ __aicore__ void catcoc_allgather_matmul_bf16_wnz(uint64_t fftsAddr, uint64_t teamIdx, GM_ADDR gmA,
205208
GM_ADDR gmB, GM_ADDR gmC, GM_ADDR gmSymmetric,
206-
GM_ADDR gmWorkspace, GM_ADDR gmTiling) {
207-
catcoc_allgather_matmul_impl<DataFormatMode::BF16, WeightFormatMode::WEIGHT_NZ>
208-
(fftsAddr, teamIdx, gmA, gmB, gmC, gmSymmetric, gmWorkspace, gmTiling);
209+
GM_ADDR gmWorkspace, GM_ADDR gmTiling)
210+
{
211+
catcoc_allgather_matmul_impl<DataFormatMode::BF16, WeightFormatMode::WEIGHT_NZ>(fftsAddr, teamIdx, gmA, gmB, gmC,
212+
gmSymmetric, gmWorkspace, gmTiling);
209213
}

csrc/catcoc/ops/op_kernel/catcoc_matmul_allreduce_kernel.hpp

Lines changed: 46 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,21 @@ using namespace AscendC;
3939
using namespace Catcoc;
4040
using namespace sglang::npu_kernel;
4141

42-
4342
template <DataFormatMode dMode, WeightFormatMode wMode>
44-
class CatCocMatmulAllreduce {
43+
class CatCocMatmulAllreduce
44+
{
4545
public:
4646
uint64_t fftsAddr_;
4747
int32_t teamIdx_;
4848

49-
__aicore__ inline void Init(uint64_t fftsAddr, uint64_t teamIdx) {
49+
__aicore__ inline void Init(uint64_t fftsAddr, uint64_t teamIdx)
50+
{
5051
fftsAddr_ = fftsAddr;
5152
teamIdx_ = (int32_t)teamIdx;
5253
}
5354

54-
__aicore__ inline void Process(GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmD,
55-
GM_ADDR gmSymmetric, GM_ADDR gmTiling) {
55+
__aicore__ inline void Process(GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmD, GM_ADDR gmSymmetric, GM_ADDR gmTiling)
56+
{
5657
// Set FFTS address
5758
AscendC::SetSyncBaseAddr(fftsAddr_);
5859
// Set shmem config
@@ -97,15 +98,16 @@ class CatCocMatmulAllreduce {
9798
*/
9899

99100
// switch cases
100-
using ElementA = typename std::conditional_t<dMode == DataFormatMode::FP16, half,
101-
typename std::conditional_t<dMode == DataFormatMode::BF16, __bf16, float>>;
101+
using ElementA =
102+
typename std::conditional_t<dMode == DataFormatMode::FP16, half,
103+
typename std::conditional_t<dMode == DataFormatMode::BF16, __bf16, float>>;
102104
using ElementB = ElementA;
103105
using ElementC = ElementA;
104106
using ElementD = ElementA;
105107

106108
using LayoutA = Catlass::layout::RowMajor;
107-
using LayoutB = typename std::conditional_t<wMode == WeightFormatMode::WEIGHT_ND,
108-
Catlass::layout::RowMajor, Catlass::layout::zN>;
109+
using LayoutB = typename std::conditional_t<wMode == WeightFormatMode::WEIGHT_ND, Catlass::layout::RowMajor,
110+
Catlass::layout::zN>;
109111
using LayoutC = Catlass::layout::RowMajor;
110112
using LayoutD = Catlass::layout::RowMajor;
111113

@@ -115,7 +117,7 @@ class CatCocMatmulAllreduce {
115117
using DType = Catlass::Gemm::GemmType<ElementD, LayoutD>;
116118

117119
LayoutA layoutA{m, k};
118-
LayoutB layoutB = LayoutB::template MakeLayout<ElementB>(k, n); // adapted for both nz and nd
120+
LayoutB layoutB = LayoutB::template MakeLayout<ElementB>(k, n); // adapted for both NZ and ND
119121
LayoutD layoutD{m, n};
120122
Catlass::GemmCoord problemShape{m, n, k};
121123

@@ -125,17 +127,19 @@ class CatCocMatmulAllreduce {
125127
using L1TileShape = Catlass::GemmShape<128, 256, 256>;
126128
using L0TileShape = Catlass::GemmShape<128, 256, 64>;
127129
using BlockMmad =
128-
Catlass::Gemm::Block::BlockMmad<MmadDispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>;
130+
Catlass::Gemm::Block::BlockMmad<MmadDispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>;
129131

130132
constexpr uint32_t SWIZZLE_GROUP_SIZE = 7;
131133
constexpr uint32_t SWIZZLE_DIRECTION = 1;
132-
using BlockMmadScheduler = Catlass::Gemm::Block::GemmIdentityBlockSwizzle<SWIZZLE_GROUP_SIZE, SWIZZLE_DIRECTION>;
134+
using BlockMmadScheduler =
135+
Catlass::Gemm::Block::GemmIdentityBlockSwizzle<SWIZZLE_GROUP_SIZE, SWIZZLE_DIRECTION>;
133136
using BlockEpilogueScheduler = Catcoc::CommEpilogue::Block::BlockCommSwizzle<0>;
134137

135138
using RemoteSrcType = CType;
136139
using RemoteDstType = DType;
137140
using CopyDirect = Catcoc::detail::CopyDirect;
138-
using TileRemoteCopy = CommEpilogue::Tile::TileRemoteCopy<ArchTag, RemoteSrcType, RemoteDstType, CopyDirect::Get>;
141+
using TileRemoteCopy =
142+
CommEpilogue::Tile::TileRemoteCopy<ArchTag, RemoteSrcType, RemoteDstType, CopyDirect::Get>;
139143
using TileScheduler = Catlass::Epilogue::Tile::EpilogueIdentityTileSwizzle;
140144

141145
constexpr uint32_t COMM_BLOCK_ROWS = 8;
@@ -150,27 +154,27 @@ class CatCocMatmulAllreduce {
150154
constexpr uint32_t SCATTER_TILE_COLUMNS = 256;
151155
using EpilogueReduceScatterTileShape = Catlass::MatrixShape<SCATTER_TILE_ROWS, SCATTER_TILE_COLUMNS>;
152156
using EpilogueReduceScatterDispatch =
153-
CommEpilogue::EpilogueAtlasA2CommRemoteCopy<UB_STAGES, Catcoc::detail::CopyMode::Scatter>;
157+
CommEpilogue::EpilogueAtlasA2CommRemoteCopy<UB_STAGES, Catcoc::detail::CopyMode::Scatter>;
154158
using BlockEpilogueReduceScatter =
155-
CommEpilogue::Block::CommBlockEpilogue<EpilogueReduceScatterDispatch, RemoteSrcType, RemoteDstType,
156-
CommCoreSplit, CommBlockShape, EpilogueReduceScatterTileShape,
157-
TileRemoteCopy, TileScheduler>;
159+
CommEpilogue::Block::CommBlockEpilogue<EpilogueReduceScatterDispatch, RemoteSrcType, RemoteDstType,
160+
CommCoreSplit, CommBlockShape, EpilogueReduceScatterTileShape,
161+
TileRemoteCopy, TileScheduler>;
158162

159163
constexpr uint32_t ALLGATHER_TILE_ROWS = 4;
160164
constexpr uint32_t ALLGATHER_TILE_COLUMNS = 256;
161165
using EpilogueAllGatherTileShape = Catlass::MatrixShape<ALLGATHER_TILE_ROWS, ALLGATHER_TILE_COLUMNS>;
162166
using EpilogueAllGatherDispatch =
163-
CommEpilogue::EpilogueAtlasA2CommRemoteCopy<UB_STAGES, Catcoc::detail::CopyMode::Gather>;
167+
CommEpilogue::EpilogueAtlasA2CommRemoteCopy<UB_STAGES, Catcoc::detail::CopyMode::Gather>;
164168
using BlockEpilogueAllGather =
165-
CommEpilogue::Block::CommBlockEpilogue<EpilogueAllGatherDispatch, RemoteSrcType, RemoteDstType, CommCoreSplit,
166-
CommBlockShape, EpilogueAllGatherTileShape, TileRemoteCopy,
167-
TileScheduler>;
169+
CommEpilogue::Block::CommBlockEpilogue<EpilogueAllGatherDispatch, RemoteSrcType, RemoteDstType,
170+
CommCoreSplit, CommBlockShape, EpilogueAllGatherTileShape,
171+
TileRemoteCopy, TileScheduler>;
168172

169173
constexpr uint32_t WORKSPACE_STAGES = 2;
170174
constexpr uint32_t COMM_INTERVAL = 4;
171175
using MatmulAllReduceKernel =
172-
DGemm::Kernel::MatmulAllReduce<BlockMmad, BlockEpilogueReduceScatter, BlockEpilogueAllGather,
173-
BlockMmadScheduler, BlockEpilogueScheduler, WORKSPACE_STAGES>;
176+
DGemm::Kernel::MatmulAllReduce<BlockMmad, BlockEpilogueReduceScatter, BlockEpilogueAllGather,
177+
BlockMmadScheduler, BlockEpilogueScheduler, WORKSPACE_STAGES>;
174178

175179
typename BlockEpilogueReduceScatter::Params reduceScatterParams{};
176180
typename BlockEpilogueAllGather::Params allGatherParams{};
@@ -195,10 +199,9 @@ class CatCocMatmulAllreduce {
195199
};
196200

197201
template <DataFormatMode dMode, WeightFormatMode wMode>
198-
__aicore__ void catcoc_matmul_allreduce_impl(uint64_t fftsAddr, uint64_t teamIdx,
199-
GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmD,
200-
GM_ADDR gmSymmetric, GM_ADDR gmWorkspace,
201-
GM_ADDR gmTiling) {
202+
__aicore__ void catcoc_matmul_allreduce_impl(uint64_t fftsAddr, uint64_t teamIdx, GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmD,
203+
GM_ADDR gmSymmetric, GM_ADDR gmWorkspace, GM_ADDR gmTiling)
204+
{
202205
// gmWorkspace is a dummy input for ascendc compile with tiling, catcoc ops use gmSymmetric as actual workspace
203206
CatCocMatmulAllreduce<dMode, wMode> op;
204207
op.Init(fftsAddr, teamIdx);
@@ -207,28 +210,32 @@ __aicore__ void catcoc_matmul_allreduce_impl(uint64_t fftsAddr, uint64_t teamIdx
207210

208211
extern "C" __global__ __aicore__ void catcoc_matmul_allreduce_fp16(uint64_t fftsAddr, uint64_t teamIdx, GM_ADDR gmA,
209212
GM_ADDR gmB, GM_ADDR gmD, GM_ADDR gmSymmetric,
210-
GM_ADDR gmWorkspace, GM_ADDR gmTiling) {
211-
catcoc_matmul_allreduce_impl<DataFormatMode::FP16, WeightFormatMode::WEIGHT_ND>
212-
(fftsAddr, teamIdx, gmA, gmB, gmD, gmSymmetric, gmWorkspace, gmTiling);
213+
GM_ADDR gmWorkspace, GM_ADDR gmTiling)
214+
{
215+
catcoc_matmul_allreduce_impl<DataFormatMode::FP16, WeightFormatMode::WEIGHT_ND>(fftsAddr, teamIdx, gmA, gmB, gmD,
216+
gmSymmetric, gmWorkspace, gmTiling);
213217
}
214218

215219
extern "C" __global__ __aicore__ void catcoc_matmul_allreduce_fp16_wnz(uint64_t fftsAddr, uint64_t teamIdx, GM_ADDR gmA,
216220
GM_ADDR gmB, GM_ADDR gmD, GM_ADDR gmSymmetric,
217-
GM_ADDR gmWorkspace, GM_ADDR gmTiling) {
218-
catcoc_matmul_allreduce_impl<DataFormatMode::FP16, WeightFormatMode::WEIGHT_NZ>
219-
(fftsAddr, teamIdx, gmA, gmB, gmD, gmSymmetric, gmWorkspace, gmTiling);
221+
GM_ADDR gmWorkspace, GM_ADDR gmTiling)
222+
{
223+
catcoc_matmul_allreduce_impl<DataFormatMode::FP16, WeightFormatMode::WEIGHT_NZ>(fftsAddr, teamIdx, gmA, gmB, gmD,
224+
gmSymmetric, gmWorkspace, gmTiling);
220225
}
221226

222227
extern "C" __global__ __aicore__ void catcoc_matmul_allreduce_bf16(uint64_t fftsAddr, uint64_t teamIdx, GM_ADDR gmA,
223228
GM_ADDR gmB, GM_ADDR gmD, GM_ADDR gmSymmetric,
224-
GM_ADDR gmWorkspace, GM_ADDR gmTiling){
225-
catcoc_matmul_allreduce_impl<DataFormatMode::BF16, WeightFormatMode::WEIGHT_ND>
226-
(fftsAddr, teamIdx, gmA, gmB, gmD, gmSymmetric, gmWorkspace, gmTiling);
229+
GM_ADDR gmWorkspace, GM_ADDR gmTiling)
230+
{
231+
catcoc_matmul_allreduce_impl<DataFormatMode::BF16, WeightFormatMode::WEIGHT_ND>(fftsAddr, teamIdx, gmA, gmB, gmD,
232+
gmSymmetric, gmWorkspace, gmTiling);
227233
}
228234

229235
extern "C" __global__ __aicore__ void catcoc_matmul_allreduce_bf16_wnz(uint64_t fftsAddr, uint64_t teamIdx, GM_ADDR gmA,
230236
GM_ADDR gmB, GM_ADDR gmD, GM_ADDR gmSymmetric,
231-
GM_ADDR gmWorkspace, GM_ADDR gmTiling) {
232-
catcoc_matmul_allreduce_impl<DataFormatMode::BF16, WeightFormatMode::WEIGHT_NZ>
233-
(fftsAddr, teamIdx, gmA, gmB, gmD, gmSymmetric, gmWorkspace, gmTiling);
237+
GM_ADDR gmWorkspace, GM_ADDR gmTiling)
238+
{
239+
catcoc_matmul_allreduce_impl<DataFormatMode::BF16, WeightFormatMode::WEIGHT_NZ>(fftsAddr, teamIdx, gmA, gmB, gmD,
240+
gmSymmetric, gmWorkspace, gmTiling);
234241
}

0 commit comments

Comments
 (0)