Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ __aicore__ inline void CamMoeDistributeDispatchA2Layered<TemplateMC2TypeA2layere
DataCopyPad(tokenTempTensorU8_, xGMTensorU8_[(startTokenId + i) * tokenLenInStruct_],
tokenCopyParamsNoQuant, tokenPadParams);
}
SyncFunc<AscendC::HardEvent::V_MTE2>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment explaining what problems would occur if this synchronization is not added, and why adding this synchronization would prevent these problems.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

// 拷贝topkIds 可省略
DataCopyPad(tokenTempTensorU8_[expOffsetInStruct_], expertIdsGMTensorU8_[(startTokenId + i) * realLenInStruct_],
expCopyParams, expPadParams);
Expand Down
46 changes: 4 additions & 42 deletions csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,37 +26,12 @@ extern "C" __global__ __aicore__ void dispatch_normal_a2(
GM_ADDR recvX, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR expertTokenNumsOut, GM_ADDR epRecvCountOut,
GM_ADDR expandScalesOut, GM_ADDR dispatchWaitRecvCostStatsOut, GM_ADDR workspace, GM_ADDR tiling)
{
// REGISTER_TILING_DEFAULT(NotifyDispatchA2TilingData);
// GET_TILING_DATA_WITH_STRUCT(NotifyDispatchA2TilingData, tilingData, tiling);
REGISTER_TILING_DEFAULT(CamMoeDistributeDispatchA2TilingData);
GET_TILING_DATA_WITH_STRUCT(CamMoeDistributeDispatchA2TilingData, tilingData, tiling);

// hcomm will set magic later in init
uint32_t magic = 1;
GM_ADDR commArgs = nullptr;

// int localRank = tilingData.notifyDispatchInfoA2.localRankId;
// int localRankSize = tilingData.notifyDispatchInfoA2.localRankSize;
// int rank = tilingData.notifyDispatchInfoA2.rankId;
// int rankSize = tilingData.notifyDispatchInfoA2.rankSize;
// int64_t len = tilingData.notifyDispatchInfoA2.sendCount;
// int64_t numTokens = tilingData.notifyDispatchInfoA2.numTokens;
// int64_t topkNum = tilingData.notifyDispatchInfoA2.topkNum;
// int64_t numExperts = tilingData.notifyDispatchInfoA2.numExperts;

// GM_ADDR sendDataInput = sendData;
// GM_ADDR tokenPerExpertDataInput = tokenPerExpertData;
// GM_ADDR sendDataOffsetOutput = sendDataOffset;
// GM_ADDR recvDataOutput = recvData;
// GM_ADDR tokenServerIdxOutput = tokenServerIdx;
// GM_ADDR tokensUniquePerServerOutput = tokensUniquePerServer;
// GM_ADDR epRankTokenCntOutput = epRankTokenCnt;
// GM_ADDR localEpTokenCntOutput = localEpTokenCnt;
// GM_ADDR srcOffsetRankTokenIdxOutput = srcOffsetRankTokenIdx;
// GM_ADDR dstOffsetRankTokenIdxOutput = dstOffsetRankTokenIdx;
// GM_ADDR offsetInnerOutput = offsetInner;
// GM_ADDR countOuterOutput = countOuter;

// fill in unused args
uint32_t extraFlag = 0;
GM_ADDR scale = nullptr;
Expand All @@ -69,28 +44,15 @@ extern "C" __global__ __aicore__ void dispatch_normal_a2(

TPipe pipe;
if (TILING_KEY_IS(2100001000)) {
// NotifyDispatchA2<int> opKernel(rank, rankSize, extraFlag);
// opKernel.Init(KERNELS_ARGS_CALL_A2_ALL2ALL());
// opKernel.Process();
CamMoeDistributeDispatchA2Layered<bfloat16_t, bfloat16_t, false, false, false> op;
op.Init(x, expertIds, scales, expertScales, tokenServerIdx, tokenServerCnt, epRankTokenCnt,
srcOffsetRankTokenIdx, dstOffsetRankTokenIdx, recvX, dynamicScalesOut, expandIdxOut, expertTokenNumsOut,
epRecvCountOut, expandScalesOut, workspace, &pipe, tiling);
op.Process();
} else if (TILING_KEY_IS(2000000000)) {
// NotifyDispatchA2<int> opKernel(rank, rankSize, extraFlag);
// opKernel.Init(KERNELS_ARGS_CALL_A2_ALL2ALL());
// opKernel.Process();
GET_TILING_DATA_WITH_STRUCT(CamMoeDistributeDispatchA2TilingData, tilingData, tiling);
CamMoeDistributeDispatchA2Layered<bfloat16_t, bfloat16_t, false, false, false> op;
op.Init(x, expertIds, scales, expertScales, tokenServerIdx, tokenServerCnt, epRankTokenCnt,
srcOffsetRankTokenIdx, dstOffsetRankTokenIdx, recvX, dynamicScalesOut, expandIdxOut, expertTokenNumsOut,
epRecvCountOut, expandScalesOut, workspace, &pipe, tiling);
op.Process();
} else if (TILING_KEY_IS(2000001000)) {
// NotifyDispatchA2<int> opKernel(rank, rankSize, extraFlag);
// opKernel.Init(KERNELS_ARGS_CALL_A2_ALL2ALL());
// opKernel.Process();
CamMoeDistributeDispatchA2Layered<bfloat16_t, bfloat16_t, false, false, false> op;
} else if (TILING_KEY_IS(2100001002)) {
GET_TILING_DATA_WITH_STRUCT(CamMoeDistributeDispatchA2TilingData, tilingData, tiling);
CamMoeDistributeDispatchA2Layered<bfloat16_t, int8_t, false, true, false> op;
op.Init(x, expertIds, scales, expertScales, tokenServerIdx, tokenServerCnt, epRankTokenCnt,
srcOffsetRankTokenIdx, dstOffsetRankTokenIdx, recvX, dynamicScalesOut, expandIdxOut, expertTokenNumsOut,
epRecvCountOut, expandScalesOut, workspace, &pipe, tiling);
Expand Down
47 changes: 37 additions & 10 deletions tests/python/deepep/test_internode.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,18 @@ def check_layout_a2_data(notify_send_data):
for i in range(num_ranks):
num_tokens_per_rank[i] = (rank_idx == i).sum()
token_sel = (rank_idx == i).max(dim=-1)[0]
count = token_sel.sum().item()
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
tokens[:count] = torch.sort(tokens[:count])[0]
count = token_sel.sum().cpu().item()

# Perform sorting on CPU
token_sel_cpu = token_sel.to(torch.int).cpu()
tokens_cpu = torch.sort(token_sel_cpu, descending=True)[1]
sorted_tokens = torch.sort(tokens_cpu[:count])[0]

# Put the results back into the NPU
tokens = tokens_cpu.to("npu")
tokens[:count] = sorted_tokens.to("npu")

# Ensure the size of arrange matches the count.
token_idx_in_rank[i][tokens[:count]] = torch.arange(
count, dtype=torch.long, device="npu"
)
Expand Down Expand Up @@ -430,14 +439,25 @@ def test_correctness():
combined_x, combined_topk_weights, event = buffer.combine(**combine_args)
check_x = combined_x.float()
ref_x = x_pure_rand if current_x is x_pure_rand else x
assert (
calc_diff(
check_x,
ref_x
* handle[4].masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1),
)
< 5e-5
# Calculate the intermediate values of each item
masked_values = (
handle[4].masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1)
)
scaled_ref_x = ref_x * masked_values

# Calculate the difference
diff_result = calc_diff(check_x, scaled_ref_x)

if args.debug:
print(f"Debug - diff_result: {diff_result}, threshold: {5e-5}")
print(
f"Debug - masked_values (first 10): {masked_values.flatten()[:10]}"
)
print(f"Debug - ref_x (first 10): {ref_x.flatten()[:10]}")
print(f"Debug - scaled_ref_x (first 10): {scaled_ref_x.flatten()[:10]}")
print(f"Debug - check_x (first 10): {check_x.flatten()[:10]}")

assert diff_result < 5e-5

if local_rank == 0:
print(" passed", flush=True)
Expand Down Expand Up @@ -587,6 +607,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
action="store_true",
help="Whether to enable diagnose for testing",
)
parser.add_argument(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The switch for enabling or disabling quantization needs to be added to the test script.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

"--debug",
action="store_true",
default=False,
help="Enable debug logging.",
)

args = parser.parse_args()

num_processes = args.num_processes
Expand Down