From 5d3ba92062e79dcee6d9d174025503348aa51890 Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Sun, 4 Jan 2026 16:57:39 +0800 Subject: [PATCH 1/9] dispatch fp8 --- .../ops2/op_kernel/dispatch_normal_a2.cpp | 48 +++---------------- 1 file changed, 6 insertions(+), 42 deletions(-) diff --git a/csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp b/csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp index a6d5be494..0fafe0079 100644 --- a/csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp +++ b/csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp @@ -26,37 +26,14 @@ extern "C" __global__ __aicore__ void dispatch_normal_a2( GM_ADDR recvX, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR expertTokenNumsOut, GM_ADDR epRecvCountOut, GM_ADDR expandScalesOut, GM_ADDR dispatchWaitRecvCostStatsOut, GM_ADDR workspace, GM_ADDR tiling) { - // REGISTER_TILING_DEFAULT(NotifyDispatchA2TilingData); - // GET_TILING_DATA_WITH_STRUCT(NotifyDispatchA2TilingData, tilingData, tiling); + printf("[dispatch_normal_a2] blockId: %d\n", GetBlockIdx()); REGISTER_TILING_DEFAULT(CamMoeDistributeDispatchA2TilingData); - GET_TILING_DATA_WITH_STRUCT(CamMoeDistributeDispatchA2TilingData, tilingData, tiling); + // GET_TILING_DATA_WITH_STRUCT(CamMoeDistributeDispatchA2TilingData, tilingData, tiling); // hcomm will set magic later in init uint32_t magic = 1; GM_ADDR commArgs = nullptr; - // int localRank = tilingData.notifyDispatchInfoA2.localRankId; - // int localRankSize = tilingData.notifyDispatchInfoA2.localRankSize; - // int rank = tilingData.notifyDispatchInfoA2.rankId; - // int rankSize = tilingData.notifyDispatchInfoA2.rankSize; - // int64_t len = tilingData.notifyDispatchInfoA2.sendCount; - // int64_t numTokens = tilingData.notifyDispatchInfoA2.numTokens; - // int64_t topkNum = tilingData.notifyDispatchInfoA2.topkNum; - // int64_t numExperts = tilingData.notifyDispatchInfoA2.numExperts; - - // GM_ADDR sendDataInput = sendData; - // GM_ADDR tokenPerExpertDataInput = tokenPerExpertData; - // GM_ADDR sendDataOffsetOutput = sendDataOffset; - // GM_ADDR recvDataOutput = recvData; - // GM_ADDR tokenServerIdxOutput = tokenServerIdx; - // GM_ADDR tokensUniquePerServerOutput = tokensUniquePerServer; - // GM_ADDR epRankTokenCntOutput = epRankTokenCnt; - // GM_ADDR localEpTokenCntOutput = localEpTokenCnt; - // GM_ADDR srcOffsetRankTokenIdxOutput = srcOffsetRankTokenIdx; - // GM_ADDR dstOffsetRankTokenIdxOutput = dstOffsetRankTokenIdx; - // GM_ADDR offsetInnerOutput = offsetInner; - // GM_ADDR countOuterOutput = countOuter; - // fill in unused args uint32_t extraFlag = 0; GM_ADDR scale = nullptr; @@ -69,28 +46,15 @@ extern "C" __global__ __aicore__ void dispatch_normal_a2( TPipe pipe; if (TILING_KEY_IS(2100001000)) { - // NotifyDispatchA2 opKernel(rank, rankSize, extraFlag); - // opKernel.Init(KERNELS_ARGS_CALL_A2_ALL2ALL()); - // opKernel.Process(); - CamMoeDistributeDispatchA2Layered op; - op.Init(x, expertIds, scales, expertScales, tokenServerIdx, tokenServerCnt, epRankTokenCnt, - srcOffsetRankTokenIdx, dstOffsetRankTokenIdx, recvX, dynamicScalesOut, expandIdxOut, expertTokenNumsOut, - epRecvCountOut, expandScalesOut, workspace, &pipe, tiling); - op.Process(); - } else if (TILING_KEY_IS(2000000000)) { - // NotifyDispatchA2 opKernel(rank, rankSize, extraFlag); - // opKernel.Init(KERNELS_ARGS_CALL_A2_ALL2ALL()); - // opKernel.Process(); + GET_TILING_DATA_WITH_STRUCT(CamMoeDistributeDispatchA2TilingData, tilingData, tiling); CamMoeDistributeDispatchA2Layered op; op.Init(x, expertIds, scales, expertScales, tokenServerIdx, tokenServerCnt, epRankTokenCnt, srcOffsetRankTokenIdx, dstOffsetRankTokenIdx, recvX, dynamicScalesOut, expandIdxOut, expertTokenNumsOut, epRecvCountOut, expandScalesOut, workspace, &pipe, tiling); op.Process(); - } else if (TILING_KEY_IS(2000001000)) { - // NotifyDispatchA2 opKernel(rank, rankSize, extraFlag); - // opKernel.Init(KERNELS_ARGS_CALL_A2_ALL2ALL()); - // opKernel.Process(); - CamMoeDistributeDispatchA2Layered op; + } else if (TILING_KEY_IS(2100001002)) { + GET_TILING_DATA_WITH_STRUCT(CamMoeDistributeDispatchA2TilingData, tilingData, tiling); + CamMoeDistributeDispatchA2Layered op; op.Init(x, expertIds, scales, expertScales, tokenServerIdx, tokenServerCnt, epRankTokenCnt, srcOffsetRankTokenIdx, dstOffsetRankTokenIdx, recvX, dynamicScalesOut, expandIdxOut, expertTokenNumsOut, epRecvCountOut, expandScalesOut, workspace, &pipe, tiling); From 6f37be90026b837bc0687e4f6553b547d557ff63 Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Mon, 5 Jan 2026 09:33:48 +0800 Subject: [PATCH 2/9] fix --- csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp b/csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp index 0fafe0079..f93db1071 100644 --- a/csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp +++ b/csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp @@ -26,7 +26,6 @@ extern "C" __global__ __aicore__ void dispatch_normal_a2( GM_ADDR recvX, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR expertTokenNumsOut, GM_ADDR epRecvCountOut, GM_ADDR expandScalesOut, GM_ADDR dispatchWaitRecvCostStatsOut, GM_ADDR workspace, GM_ADDR tiling) { - printf("[dispatch_normal_a2] blockId: %d\n", GetBlockIdx()); REGISTER_TILING_DEFAULT(CamMoeDistributeDispatchA2TilingData); // GET_TILING_DATA_WITH_STRUCT(CamMoeDistributeDispatchA2TilingData, tilingData, tiling); From 36136c3fb7d35bf5fb46bb153a22fc51cf327c56 Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Mon, 5 Jan 2026 14:46:09 +0800 Subject: [PATCH 3/9] fix --- csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp b/csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp index f93db1071..da4938cfc 100644 --- a/csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp +++ b/csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp @@ -27,7 +27,6 @@ extern "C" __global__ __aicore__ void dispatch_normal_a2( GM_ADDR expandScalesOut, GM_ADDR dispatchWaitRecvCostStatsOut, GM_ADDR workspace, GM_ADDR tiling) { REGISTER_TILING_DEFAULT(CamMoeDistributeDispatchA2TilingData); - // GET_TILING_DATA_WITH_STRUCT(CamMoeDistributeDispatchA2TilingData, tilingData, tiling); // hcomm will set magic later in init uint32_t magic = 1; From da5a0e3380f89ad6e8b85ad32657edd87c51dcf7 Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Sat, 17 Jan 2026 16:37:05 +0800 Subject: [PATCH 4/9] SyncFunc(); --- .../ops2/op_kernel/a2/cam_moe_distribute_dispatch_a2_layered.h | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/deepep/ops2/op_kernel/a2/cam_moe_distribute_dispatch_a2_layered.h b/csrc/deepep/ops2/op_kernel/a2/cam_moe_distribute_dispatch_a2_layered.h index da113b627..404ec22bf 100644 --- a/csrc/deepep/ops2/op_kernel/a2/cam_moe_distribute_dispatch_a2_layered.h +++ b/csrc/deepep/ops2/op_kernel/a2/cam_moe_distribute_dispatch_a2_layered.h @@ -431,6 +431,7 @@ __aicore__ inline void CamMoeDistributeDispatchA2Layered(); // 拷贝topkIds 可省略 DataCopyPad(tokenTempTensorU8_[expOffsetInStruct_], expertIdsGMTensorU8_[(startTokenId + i) * realLenInStruct_], expCopyParams, expPadParams); From 6abb089e3c09f46e9f720716b8acf209ef97c75a Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Sat, 17 Jan 2026 17:13:05 +0800 Subject: [PATCH 5/9] add debug mode --- tests/python/deepep/test_internode.py | 49 ++++++++++++++++++++------- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/tests/python/deepep/test_internode.py b/tests/python/deepep/test_internode.py index a3bbc6699..8b90f367f 100644 --- a/tests/python/deepep/test_internode.py +++ b/tests/python/deepep/test_internode.py @@ -243,9 +243,18 @@ def check_layout_a2_data(notify_send_data): for i in range(num_ranks): num_tokens_per_rank[i] = (rank_idx == i).sum() token_sel = (rank_idx == i).max(dim=-1)[0] - count = token_sel.sum().item() - tokens = torch.sort(token_sel.to(torch.int), descending=True)[1] - tokens[:count] = torch.sort(tokens[:count])[0] + count = token_sel.sum().cpu().item() + + # Perform sorting on CPU + token_sel_cpu = token_sel.to(torch.int).cpu() + tokens_cpu = torch.sort(token_sel_cpu, descending=True)[1] + sorted_tokens = torch.sort(tokens_cpu[:count])[0] + + # Put the results back into the NPU + tokens = tokens_cpu.to("npu") + tokens[:count] = sorted_tokens.to("npu") + + # Ensure the size of arrange matches the count. token_idx_in_rank[i][tokens[:count]] = torch.arange( count, dtype=torch.long, device="npu" ) @@ -430,14 +439,22 @@ def test_correctness(): combined_x, combined_topk_weights, event = buffer.combine(**combine_args) check_x = combined_x.float() ref_x = x_pure_rand if current_x is x_pure_rand else x - assert ( - calc_diff( - check_x, - ref_x - * handle[4].masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), - ) - < 5e-5 - ) + # Calculate the intermediate values of each item + masked_values = handle[4].masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1) + scaled_ref_x = ref_x * masked_values + + + # Calculate the difference + diff_result = calc_diff(check_x, scaled_ref_x) + + if args.debug: + print(f"Debug - diff_result: {diff_result}, threshold: {5e-5}") + print(f"Debug - masked_values (first 10): {masked_values.flatten()[:10]}") + print(f"Debug - ref_x (first 10): {ref_x.flatten()[:10]}") + print(f"Debug - scaled_ref_x (first 10): {scaled_ref_x.flatten()[:10]}") + print(f"Debug - check_x (first 10): {check_x.flatten()[:10]}") + + assert diff_result < 5e-5 if local_rank == 0: print(" passed", flush=True) @@ -519,7 +536,8 @@ def test_tuning(): print("", flush=True) test_correctness() - test_tuning() + #test_correctness_with_saved_data() # 测试 BF16 + #test_tuning() # Diagnose test if enable_diagnose: @@ -587,6 +605,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): action="store_true", help="Whether to enable diagnose for testing", ) + parser.add_argument( + "--debug", + action="store_true", + default=False, + help="Enable debug logging.", + ) + args = parser.parse_args() num_processes = args.num_processes From c3169a663a12b4e5bbbfd5525893c324d4fee89d Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Sat, 17 Jan 2026 17:13:56 +0800 Subject: [PATCH 6/9] fix --- tests/python/deepep/test_internode.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python/deepep/test_internode.py b/tests/python/deepep/test_internode.py index 8b90f367f..61b1fb7b9 100644 --- a/tests/python/deepep/test_internode.py +++ b/tests/python/deepep/test_internode.py @@ -536,8 +536,7 @@ def test_tuning(): print("", flush=True) test_correctness() - #test_correctness_with_saved_data() # 测试 BF16 - #test_tuning() + test_tuning() # Diagnose test if enable_diagnose: From c4d7af21a4859bc02e5b1a5ba89ef2bea03b95aa Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Sat, 17 Jan 2026 17:29:43 +0800 Subject: [PATCH 7/9] fix lint --- tests/python/deepep/test_internode.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/python/deepep/test_internode.py b/tests/python/deepep/test_internode.py index 61b1fb7b9..c1dc9883c 100644 --- a/tests/python/deepep/test_internode.py +++ b/tests/python/deepep/test_internode.py @@ -440,16 +440,19 @@ def test_correctness(): check_x = combined_x.float() ref_x = x_pure_rand if current_x is x_pure_rand else x # Calculate the intermediate values of each item - masked_values = handle[4].masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1) + masked_values = ( + handle[4].masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1) + ) scaled_ref_x = ref_x * masked_values - # Calculate the difference diff_result = calc_diff(check_x, scaled_ref_x) if args.debug: print(f"Debug - diff_result: {diff_result}, threshold: {5e-5}") - print(f"Debug - masked_values (first 10): {masked_values.flatten()[:10]}") + print( + f"Debug - masked_values (first 10): {masked_values.flatten()[:10]}" + ) print(f"Debug - ref_x (first 10): {ref_x.flatten()[:10]}") print(f"Debug - scaled_ref_x (first 10): {scaled_ref_x.flatten()[:10]}") print(f"Debug - check_x (first 10): {check_x.flatten()[:10]}") From 30105a95b5e1e43e23a9437c742d01d32d7803de Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Mon, 19 Jan 2026 14:20:44 +0800 Subject: [PATCH 8/9] fix commit --- .../cam_moe_distribute_dispatch_a2_layered.h | 1 + tests/python/deepep/test_internode.py | 38 ++++++++++++++----- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/csrc/deepep/ops2/op_kernel/a2/cam_moe_distribute_dispatch_a2_layered.h b/csrc/deepep/ops2/op_kernel/a2/cam_moe_distribute_dispatch_a2_layered.h index 404ec22bf..8621a2c1a 100644 --- a/csrc/deepep/ops2/op_kernel/a2/cam_moe_distribute_dispatch_a2_layered.h +++ b/csrc/deepep/ops2/op_kernel/a2/cam_moe_distribute_dispatch_a2_layered.h @@ -431,6 +431,7 @@ __aicore__ inline void CamMoeDistributeDispatchA2Layered(); // 拷贝topkIds 可省略 DataCopyPad(tokenTempTensorU8_[expOffsetInStruct_], expertIdsGMTensorU8_[(startTokenId + i) * realLenInStruct_], diff --git a/tests/python/deepep/test_internode.py b/tests/python/deepep/test_internode.py index c1dc9883c..7e2c0110a 100644 --- a/tests/python/deepep/test_internode.py +++ b/tests/python/deepep/test_internode.py @@ -45,10 +45,20 @@ def test_main( assert num_tokens <= MAX_BATCH_SIZE if local_rank == 0: print( - f"[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}, active_ranks={args.active_ranks}", + f"[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}, active_ranks={args.active_ranks}, use_int8_quant={args.use_int8_quant}", flush=True, ) + # Set environment variables to allow the operator to read them + os.environ["DEEP_NORMAL_MODE_USE_INT8_QUANT"] = "1" if args.use_int8_quant else "0" + + # Check quantification mode + USE_INT8_QUANT = args.use_int8_quant + print( + f"[Config] Quantization mode: {'INT8' if USE_INT8_QUANT else 'BF16'}", + flush=True, + ) + experts_per_rank = num_experts // num_ranks if args.active_ranks: @@ -291,7 +301,7 @@ def check_layout_a2_data(notify_send_data): ), f"Assertion num_tokens_per_rank failed on rank {rank}: Expected {num_tokens_per_rank}, Actual {ref_num_tokens_per_rank}" assert torch.allclose( ref_num_tokens_per_expert, num_tokens_per_expert - ), f"Assertion num_tokens_per_expert failed on rank {rank}: Expected {num_tokens_per_expert}, Actual {ref_num_tokens_per_expert}" + ), f"Assertion num_tokens_per_expert failed on rank {rank}: Expected {num_tokens_per_expert}, Actual {num_tokens_per_expert}" assert torch.allclose( ref_is_token_in_rank, is_token_in_rank ), f"Assertion is_token_in_rank failed on rank {rank}: Expected {is_token_in_rank}, Actual {ref_is_token_in_rank}" @@ -387,8 +397,9 @@ def test_diagnose( def test_correctness(): for current_x in filter(lambda elem: elem is not None, (x_pure_rand, x)): if local_rank == 0: + quant_mode_str = "INT8" if USE_INT8_QUANT else "BF16" print( - f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, with top-k {num_topk} ...', + f"[testing] Running with {quant_mode_str} quantization, with top-k {num_topk} ...", flush=True, ) # Test dispatch @@ -466,7 +477,9 @@ def test_correctness(): def test_tuning(): # Tune dispatch performance - fp8_factor = (1 + 4 / 128) / 2 + quant_factor = ( + 0.5 if USE_INT8_QUANT else 1.0 + ) # INT8: 50% bandwidth, BF16: 100% bandwidth config = deep_ep.Config(24, 8, buffer_size) dispatch_args = { @@ -490,13 +503,13 @@ def test_tuning(): # Tune dispatch performance for current_x in filter(lambda elem: elem is not None, (x,)): recv_bytes = ( - (dispatch_bf16_recv_bytes * fp8_factor) - if isinstance(current_x, tuple) + (dispatch_bf16_recv_bytes * quant_factor) + if USE_INT8_QUANT # INT8 量化时使用压缩因子 else dispatch_bf16_recv_bytes ) rdma_send_bytes = ( - (dispatch_bf16_rdma_send_bytes * fp8_factor) - if isinstance(current_x, tuple) + (dispatch_bf16_rdma_send_bytes * quant_factor) + if USE_INT8_QUANT # INT8 量化时使用压缩因子 else dispatch_bf16_rdma_send_bytes ) @@ -515,8 +528,9 @@ def test_tuning(): ("DispatchNormalA2", "NotifyDispatchA2"), ) if local_rank == 0: + quant_mode_str = "INT8" if USE_INT8_QUANT else "BF16" print( - f'[tuning] Dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}) {recv_bytes / 1e9 / t:.2f} GB/s (HCCS), ' + f"[tuning] Dispatch ({quant_mode_str}) {recv_bytes / 1e9 / t:.2f} GB/s (HCCS), " f"{rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), avg_t: {t * 1e6:.2f} us, notify_t: {notify_t * 1e6:.2f} us", flush=True, ) @@ -607,6 +621,12 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): action="store_true", help="Whether to enable diagnose for testing", ) + parser.add_argument( + "--use-int8-quant", + action="store_true", + default=False, + help="Enable internal INT8 quantization instead of BF16 (default: BF16)", + ) parser.add_argument( "--debug", action="store_true", From bc5a3b12237e07083f583a48dd10e9a3107903e3 Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Mon, 19 Jan 2026 14:35:40 +0800 Subject: [PATCH 9/9] fix --- tests/python/deepep/test_internode.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/deepep/test_internode.py b/tests/python/deepep/test_internode.py index 7e2c0110a..38c34a0c1 100644 --- a/tests/python/deepep/test_internode.py +++ b/tests/python/deepep/test_internode.py @@ -478,8 +478,8 @@ def test_correctness(): def test_tuning(): # Tune dispatch performance quant_factor = ( - 0.5 if USE_INT8_QUANT else 1.0 - ) # INT8: 50% bandwidth, BF16: 100% bandwidth + (1 + 4 / 128) / 2 if USE_INT8_QUANT else 1.0 + ) # INT8: (1 + 4 / 128) / 2 ≈ 51% bandwidth, BF16: 100% bandwidth config = deep_ep.Config(24, 8, buffer_size) dispatch_args = {