Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 6 additions & 42 deletions csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,37 +26,14 @@ extern "C" __global__ __aicore__ void dispatch_normal_a2(
GM_ADDR recvX, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR expertTokenNumsOut, GM_ADDR epRecvCountOut,
GM_ADDR expandScalesOut, GM_ADDR dispatchWaitRecvCostStatsOut, GM_ADDR workspace, GM_ADDR tiling)
{
// REGISTER_TILING_DEFAULT(NotifyDispatchA2TilingData);
// GET_TILING_DATA_WITH_STRUCT(NotifyDispatchA2TilingData, tilingData, tiling);
printf("[dispatch_normal_a2] blockId: %d\n", GetBlockIdx());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This printf statement appears to be for debugging purposes. It should be removed before merging to avoid spamming logs and potential performance degradation in production environments, especially within a kernel.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

solved

REGISTER_TILING_DEFAULT(CamMoeDistributeDispatchA2TilingData);
GET_TILING_DATA_WITH_STRUCT(CamMoeDistributeDispatchA2TilingData, tilingData, tiling);
// GET_TILING_DATA_WITH_STRUCT(CamMoeDistributeDispatchA2TilingData, tilingData, tiling);

// hcomm will set magic later in init
uint32_t magic = 1;
GM_ADDR commArgs = nullptr;

// int localRank = tilingData.notifyDispatchInfoA2.localRankId;
// int localRankSize = tilingData.notifyDispatchInfoA2.localRankSize;
// int rank = tilingData.notifyDispatchInfoA2.rankId;
// int rankSize = tilingData.notifyDispatchInfoA2.rankSize;
// int64_t len = tilingData.notifyDispatchInfoA2.sendCount;
// int64_t numTokens = tilingData.notifyDispatchInfoA2.numTokens;
// int64_t topkNum = tilingData.notifyDispatchInfoA2.topkNum;
// int64_t numExperts = tilingData.notifyDispatchInfoA2.numExperts;

// GM_ADDR sendDataInput = sendData;
// GM_ADDR tokenPerExpertDataInput = tokenPerExpertData;
// GM_ADDR sendDataOffsetOutput = sendDataOffset;
// GM_ADDR recvDataOutput = recvData;
// GM_ADDR tokenServerIdxOutput = tokenServerIdx;
// GM_ADDR tokensUniquePerServerOutput = tokensUniquePerServer;
// GM_ADDR epRankTokenCntOutput = epRankTokenCnt;
// GM_ADDR localEpTokenCntOutput = localEpTokenCnt;
// GM_ADDR srcOffsetRankTokenIdxOutput = srcOffsetRankTokenIdx;
// GM_ADDR dstOffsetRankTokenIdxOutput = dstOffsetRankTokenIdx;
// GM_ADDR offsetInnerOutput = offsetInner;
// GM_ADDR countOuterOutput = countOuter;

// fill in unused args
uint32_t extraFlag = 0;
GM_ADDR scale = nullptr;
Expand All @@ -69,28 +46,15 @@ extern "C" __global__ __aicore__ void dispatch_normal_a2(

TPipe pipe;
if (TILING_KEY_IS(2100001000)) {
// NotifyDispatchA2<int> opKernel(rank, rankSize, extraFlag);
// opKernel.Init(KERNELS_ARGS_CALL_A2_ALL2ALL());
// opKernel.Process();
CamMoeDistributeDispatchA2Layered<bfloat16_t, bfloat16_t, false, false, false> op;
op.Init(x, expertIds, scales, expertScales, tokenServerIdx, tokenServerCnt, epRankTokenCnt,
srcOffsetRankTokenIdx, dstOffsetRankTokenIdx, recvX, dynamicScalesOut, expandIdxOut, expertTokenNumsOut,
epRecvCountOut, expandScalesOut, workspace, &pipe, tiling);
op.Process();
} else if (TILING_KEY_IS(2000000000)) {
// NotifyDispatchA2<int> opKernel(rank, rankSize, extraFlag);
// opKernel.Init(KERNELS_ARGS_CALL_A2_ALL2ALL());
// opKernel.Process();
GET_TILING_DATA_WITH_STRUCT(CamMoeDistributeDispatchA2TilingData, tilingData, tiling);
CamMoeDistributeDispatchA2Layered<bfloat16_t, bfloat16_t, false, false, false> op;
op.Init(x, expertIds, scales, expertScales, tokenServerIdx, tokenServerCnt, epRankTokenCnt,
srcOffsetRankTokenIdx, dstOffsetRankTokenIdx, recvX, dynamicScalesOut, expandIdxOut, expertTokenNumsOut,
epRecvCountOut, expandScalesOut, workspace, &pipe, tiling);
op.Process();
} else if (TILING_KEY_IS(2000001000)) {
// NotifyDispatchA2<int> opKernel(rank, rankSize, extraFlag);
// opKernel.Init(KERNELS_ARGS_CALL_A2_ALL2ALL());
// opKernel.Process();
CamMoeDistributeDispatchA2Layered<bfloat16_t, bfloat16_t, false, false, false> op;
} else if (TILING_KEY_IS(2100001002)) {
GET_TILING_DATA_WITH_STRUCT(CamMoeDistributeDispatchA2TilingData, tilingData, tiling);
CamMoeDistributeDispatchA2Layered<bfloat16_t, int8_t, false, true, false> op;
op.Init(x, expertIds, scales, expertScales, tokenServerIdx, tokenServerCnt, epRankTokenCnt,
srcOffsetRankTokenIdx, dstOffsetRankTokenIdx, recvX, dynamicScalesOut, expandIdxOut, expertTokenNumsOut,
epRecvCountOut, expandScalesOut, workspace, &pipe, tiling);
Expand Down