diff --git a/csrc/deepep/ops2/op_kernel/a2/cam_moe_distribute_dispatch_a2_layered.h b/csrc/deepep/ops2/op_kernel/a2/cam_moe_distribute_dispatch_a2_layered.h index da113b627..8621a2c1a 100644 --- a/csrc/deepep/ops2/op_kernel/a2/cam_moe_distribute_dispatch_a2_layered.h +++ b/csrc/deepep/ops2/op_kernel/a2/cam_moe_distribute_dispatch_a2_layered.h @@ -431,6 +431,8 @@ __aicore__ inline void CamMoeDistributeDispatchA2Layered(); // 拷贝topkIds 可省略 DataCopyPad(tokenTempTensorU8_[expOffsetInStruct_], expertIdsGMTensorU8_[(startTokenId + i) * realLenInStruct_], expCopyParams, expPadParams); diff --git a/csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp b/csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp index a6d5be494..da4938cfc 100644 --- a/csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp +++ b/csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp @@ -26,37 +26,12 @@ extern "C" __global__ __aicore__ void dispatch_normal_a2( GM_ADDR recvX, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR expertTokenNumsOut, GM_ADDR epRecvCountOut, GM_ADDR expandScalesOut, GM_ADDR dispatchWaitRecvCostStatsOut, GM_ADDR workspace, GM_ADDR tiling) { - // REGISTER_TILING_DEFAULT(NotifyDispatchA2TilingData); - // GET_TILING_DATA_WITH_STRUCT(NotifyDispatchA2TilingData, tilingData, tiling); REGISTER_TILING_DEFAULT(CamMoeDistributeDispatchA2TilingData); - GET_TILING_DATA_WITH_STRUCT(CamMoeDistributeDispatchA2TilingData, tilingData, tiling); // hcomm will set magic later in init uint32_t magic = 1; GM_ADDR commArgs = nullptr; - // int localRank = tilingData.notifyDispatchInfoA2.localRankId; - // int localRankSize = tilingData.notifyDispatchInfoA2.localRankSize; - // int rank = tilingData.notifyDispatchInfoA2.rankId; - // int rankSize = tilingData.notifyDispatchInfoA2.rankSize; - // int64_t len = tilingData.notifyDispatchInfoA2.sendCount; - // int64_t numTokens = tilingData.notifyDispatchInfoA2.numTokens; - // int64_t topkNum = tilingData.notifyDispatchInfoA2.topkNum; - // int64_t numExperts = tilingData.notifyDispatchInfoA2.numExperts; - - // GM_ADDR sendDataInput = sendData; - // GM_ADDR tokenPerExpertDataInput = tokenPerExpertData; - // GM_ADDR sendDataOffsetOutput = sendDataOffset; - // GM_ADDR recvDataOutput = recvData; - // GM_ADDR tokenServerIdxOutput = tokenServerIdx; - // GM_ADDR tokensUniquePerServerOutput = tokensUniquePerServer; - // GM_ADDR epRankTokenCntOutput = epRankTokenCnt; - // GM_ADDR localEpTokenCntOutput = localEpTokenCnt; - // GM_ADDR srcOffsetRankTokenIdxOutput = srcOffsetRankTokenIdx; - // GM_ADDR dstOffsetRankTokenIdxOutput = dstOffsetRankTokenIdx; - // GM_ADDR offsetInnerOutput = offsetInner; - // GM_ADDR countOuterOutput = countOuter; - // fill in unused args uint32_t extraFlag = 0; GM_ADDR scale = nullptr; @@ -69,28 +44,15 @@ extern "C" __global__ __aicore__ void dispatch_normal_a2( TPipe pipe; if (TILING_KEY_IS(2100001000)) { - // NotifyDispatchA2 opKernel(rank, rankSize, extraFlag); - // opKernel.Init(KERNELS_ARGS_CALL_A2_ALL2ALL()); - // opKernel.Process(); - CamMoeDistributeDispatchA2Layered op; - op.Init(x, expertIds, scales, expertScales, tokenServerIdx, tokenServerCnt, epRankTokenCnt, - srcOffsetRankTokenIdx, dstOffsetRankTokenIdx, recvX, dynamicScalesOut, expandIdxOut, expertTokenNumsOut, - epRecvCountOut, expandScalesOut, workspace, &pipe, tiling); - op.Process(); - } else if (TILING_KEY_IS(2000000000)) { - // NotifyDispatchA2 opKernel(rank, rankSize, extraFlag); - // opKernel.Init(KERNELS_ARGS_CALL_A2_ALL2ALL()); - // opKernel.Process(); + GET_TILING_DATA_WITH_STRUCT(CamMoeDistributeDispatchA2TilingData, tilingData, tiling); CamMoeDistributeDispatchA2Layered op; op.Init(x, expertIds, scales, expertScales, tokenServerIdx, tokenServerCnt, epRankTokenCnt, srcOffsetRankTokenIdx, dstOffsetRankTokenIdx, recvX, dynamicScalesOut, expandIdxOut, expertTokenNumsOut, epRecvCountOut, expandScalesOut, workspace, &pipe, tiling); op.Process(); - } else if (TILING_KEY_IS(2000001000)) { - // NotifyDispatchA2 opKernel(rank, rankSize, extraFlag); - // opKernel.Init(KERNELS_ARGS_CALL_A2_ALL2ALL()); - // opKernel.Process(); - CamMoeDistributeDispatchA2Layered op; + } else if (TILING_KEY_IS(2100001002)) { + GET_TILING_DATA_WITH_STRUCT(CamMoeDistributeDispatchA2TilingData, tilingData, tiling); + CamMoeDistributeDispatchA2Layered op; op.Init(x, expertIds, scales, expertScales, tokenServerIdx, tokenServerCnt, epRankTokenCnt, srcOffsetRankTokenIdx, dstOffsetRankTokenIdx, recvX, dynamicScalesOut, expandIdxOut, expertTokenNumsOut, epRecvCountOut, expandScalesOut, workspace, &pipe, tiling); diff --git a/tests/python/deepep/test_internode.py b/tests/python/deepep/test_internode.py index a3bbc6699..38c34a0c1 100644 --- a/tests/python/deepep/test_internode.py +++ b/tests/python/deepep/test_internode.py @@ -45,10 +45,20 @@ def test_main( assert num_tokens <= MAX_BATCH_SIZE if local_rank == 0: print( - f"[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}, active_ranks={args.active_ranks}", + f"[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}, active_ranks={args.active_ranks}, use_int8_quant={args.use_int8_quant}", flush=True, ) + # Set environment variables to allow the operator to read them + os.environ["DEEP_NORMAL_MODE_USE_INT8_QUANT"] = "1" if args.use_int8_quant else "0" + + # Check quantification mode + USE_INT8_QUANT = args.use_int8_quant + print( + f"[Config] Quantization mode: {'INT8' if USE_INT8_QUANT else 'BF16'}", + flush=True, + ) + experts_per_rank = num_experts // num_ranks if args.active_ranks: @@ -243,9 +253,18 @@ def check_layout_a2_data(notify_send_data): for i in range(num_ranks): num_tokens_per_rank[i] = (rank_idx == i).sum() token_sel = (rank_idx == i).max(dim=-1)[0] - count = token_sel.sum().item() - tokens = torch.sort(token_sel.to(torch.int), descending=True)[1] - tokens[:count] = torch.sort(tokens[:count])[0] + count = token_sel.sum().cpu().item() + + # Perform sorting on CPU + token_sel_cpu = token_sel.to(torch.int).cpu() + tokens_cpu = torch.sort(token_sel_cpu, descending=True)[1] + sorted_tokens = torch.sort(tokens_cpu[:count])[0] + + # Put the results back into the NPU + tokens = tokens_cpu.to("npu") + tokens[:count] = sorted_tokens.to("npu") + + # Ensure the size of arrange matches the count. token_idx_in_rank[i][tokens[:count]] = torch.arange( count, dtype=torch.long, device="npu" ) @@ -282,7 +301,7 @@ def check_layout_a2_data(notify_send_data): ), f"Assertion num_tokens_per_rank failed on rank {rank}: Expected {num_tokens_per_rank}, Actual {ref_num_tokens_per_rank}" assert torch.allclose( ref_num_tokens_per_expert, num_tokens_per_expert - ), f"Assertion num_tokens_per_expert failed on rank {rank}: Expected {num_tokens_per_expert}, Actual {ref_num_tokens_per_expert}" + ), f"Assertion num_tokens_per_expert failed on rank {rank}: Expected {num_tokens_per_expert}, Actual {num_tokens_per_expert}" assert torch.allclose( ref_is_token_in_rank, is_token_in_rank ), f"Assertion is_token_in_rank failed on rank {rank}: Expected {is_token_in_rank}, Actual {ref_is_token_in_rank}" @@ -378,8 +397,9 @@ def test_diagnose( def test_correctness(): for current_x in filter(lambda elem: elem is not None, (x_pure_rand, x)): if local_rank == 0: + quant_mode_str = "INT8" if USE_INT8_QUANT else "BF16" print( - f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, with top-k {num_topk} ...', + f"[testing] Running with {quant_mode_str} quantization, with top-k {num_topk} ...", flush=True, ) # Test dispatch @@ -430,14 +450,25 @@ def test_correctness(): combined_x, combined_topk_weights, event = buffer.combine(**combine_args) check_x = combined_x.float() ref_x = x_pure_rand if current_x is x_pure_rand else x - assert ( - calc_diff( - check_x, - ref_x - * handle[4].masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), - ) - < 5e-5 + # Calculate the intermediate values of each item + masked_values = ( + handle[4].masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1) ) + scaled_ref_x = ref_x * masked_values + + # Calculate the difference + diff_result = calc_diff(check_x, scaled_ref_x) + + if args.debug: + print(f"Debug - diff_result: {diff_result}, threshold: {5e-5}") + print( + f"Debug - masked_values (first 10): {masked_values.flatten()[:10]}" + ) + print(f"Debug - ref_x (first 10): {ref_x.flatten()[:10]}") + print(f"Debug - scaled_ref_x (first 10): {scaled_ref_x.flatten()[:10]}") + print(f"Debug - check_x (first 10): {check_x.flatten()[:10]}") + + assert diff_result < 5e-5 if local_rank == 0: print(" passed", flush=True) @@ -446,7 +477,9 @@ def test_correctness(): def test_tuning(): # Tune dispatch performance - fp8_factor = (1 + 4 / 128) / 2 + quant_factor = ( + (1 + 4 / 128) / 2 if USE_INT8_QUANT else 1.0 + ) # INT8: (1 + 4 / 128) / 2 ≈ 51% bandwidth, BF16: 100% bandwidth config = deep_ep.Config(24, 8, buffer_size) dispatch_args = { @@ -470,13 +503,13 @@ def test_tuning(): # Tune dispatch performance for current_x in filter(lambda elem: elem is not None, (x,)): recv_bytes = ( - (dispatch_bf16_recv_bytes * fp8_factor) - if isinstance(current_x, tuple) + (dispatch_bf16_recv_bytes * quant_factor) + if USE_INT8_QUANT # INT8 量化时使用压缩因子 else dispatch_bf16_recv_bytes ) rdma_send_bytes = ( - (dispatch_bf16_rdma_send_bytes * fp8_factor) - if isinstance(current_x, tuple) + (dispatch_bf16_rdma_send_bytes * quant_factor) + if USE_INT8_QUANT # INT8 量化时使用压缩因子 else dispatch_bf16_rdma_send_bytes ) @@ -495,8 +528,9 @@ def test_tuning(): ("DispatchNormalA2", "NotifyDispatchA2"), ) if local_rank == 0: + quant_mode_str = "INT8" if USE_INT8_QUANT else "BF16" print( - f'[tuning] Dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}) {recv_bytes / 1e9 / t:.2f} GB/s (HCCS), ' + f"[tuning] Dispatch ({quant_mode_str}) {recv_bytes / 1e9 / t:.2f} GB/s (HCCS), " f"{rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), avg_t: {t * 1e6:.2f} us, notify_t: {notify_t * 1e6:.2f} us", flush=True, ) @@ -587,6 +621,19 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): action="store_true", help="Whether to enable diagnose for testing", ) + parser.add_argument( + "--use-int8-quant", + action="store_true", + default=False, + help="Enable internal INT8 quantization instead of BF16 (default: BF16)", + ) + parser.add_argument( + "--debug", + action="store_true", + default=False, + help="Enable debug logging.", + ) + args = parser.parse_args() num_processes = args.num_processes