Skip to content

Commit afba188

Browse files
committed
Merge remote-tracking branch 'upstream/main' into 0106
* upstream/main: fix little batchsize and int8 quant on ci (sgl-project#302) optimize sinks attention (sgl-project#260) add swiglu_oai_triton (sgl-project#270) update tag to 2026.01.12 (sgl-project#312) feat:add performance compare (sgl-project#311) support add_gemma_rms_norm (sgl-project#310) optimize gdn gating and fused_qkvzba_split_reshape_cat (sgl-project#306) fix layout numTokensPerExpertTensor partial Initialization bug (sgl-project#303) Supplement A2 doc, software and hardware compatibility info (sgl-project#294) Added an environment variable to control whether to enable the Combine Ant Migration feature. (sgl-project#304)
2 parents 2ec9a7b + fca808e commit afba188

File tree

17 files changed

+937
-253
lines changed

17 files changed

+937
-253
lines changed

.github/workflows/pr-test-npu.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ jobs:
7474
HCCL_BUFFSIZE: 1913
7575
run: |
7676
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_low_latency.py
77+
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_low_latency.py --num-tokens=1
78+
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_low_latency.py --num-tokens=2
7779
7880
- name: Run test base fused deep moe
7981
timeout-minutes: 10
@@ -168,6 +170,8 @@ jobs:
168170
HCCL_BUFFSIZE: 1913
169171
run: |
170172
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_low_latency.py
173+
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_low_latency.py --num-tokens=1
174+
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_low_latency.py --num-tokens=2
171175
172176
- name: Run test base fused deep moe
173177
timeout-minutes: 10

config.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
[global]
2-
version = 2025.12.25
2+
version = 2026.01.12

csrc/deepep/deep_ep.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ Buffer::Buffer(int64_t rank, int64_t num_ranks, int64_t num_nvl_bytes, int64_t n
4848
this->shared_expert_rank_num = get_value_from_env("MOE_SHARED_EXPERT_RANK_NUM", 0);
4949
const char *roundEnv = std::getenv("DEEPEP_NORMAL_LONG_SEQ_ROUND");
5050
const char *tokensEnv = std::getenv("DEEPEP_NORMAL_LONG_SEQ_PER_ROUND_TOKENS");
51+
this->combine_enable_long_seq = get_value_from_env("DEEPEP_NORMAL_COMBINE_ENABLE_LONG_SEQ", 0);
5152
bool roundSet = (roundEnv != nullptr);
5253
bool tokensSet = (tokensEnv != nullptr);
5354

@@ -602,6 +603,8 @@ Buffer::intranode_combine(const torch::Tensor &x, const torch::Tensor &topk_idx,
602603
std::optional<torch::Tensor> recv_topk_weights;
603604
std::optional<EventHandle> event;
604605

606+
int32_t round = this->combine_enable_long_seq ? this->round : 1;
607+
int32_t per_round_tokens = this->combine_enable_long_seq ? this->per_round_tokens : MAX_TOKENS_PER_ROUND;
605608
EXEC_NPU_CMD(aclnnCamMoeCombineNormal, recv_x, token_src_info, ep_send_counts, expert_scales, tp_send_counts,
606609
hcom_ep_name, num_ranks, rank, hcom_ep_name, tp_world_size, tp_rankId, moe_expert_number, real_max_bs,
607610
round, per_round_tokens, combined_x, combine_send_cost_stats_out);

csrc/deepep/deep_ep.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ struct Buffer {
2424

2525
int32_t round;
2626
int32_t per_round_tokens;
27+
bool combine_enable_long_seq = false; // Whether to enable the Combine Ant Migration feature
2728

2829
bool low_latency_mode = false;
2930
bool is_padding = false;

csrc/deepep/ops/op_kernel/dispatch_layout.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ class DispatchLayout
127127
SyncFunc<AscendC::HardEvent::MTE3_V>();
128128
Duplicate<T>(numTokensPerRankTensor, 0, numRanks_);
129129
Duplicate<T>(isTokenInRankTensor, 0, tempTokens_ * numRanks_);
130-
Duplicate<T>(numTokensPerExpertTensor, 0, numExperts_);
130+
Duplicate<T>(numTokensPerExpertTensor, 0, numTokensPerExpert32AlignIntLen_ / sizeof(T));
131131
SyncFunc<AscendC::HardEvent::V_S>();
132132
SyncFunc<AscendC::HardEvent::V_MTE3>();
133133
const DataCopyExtParams clearGmParams{1U, numTokensPerExpert32AlignIntLen_, 0U, 0U, 0U};

csrc/deepep/ops/op_kernel/dispatch_layout_a2.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ class DispatchLayoutA2
168168
LocalTensor<T> countExpertTensor = countExpertBuf_.AllocTensor<T>();
169169
Duplicate<T>(countExpertTensor, 0, numExperts_);
170170
Duplicate<T>(numTokensPerRankTensor, 0, numRanks_);
171-
Duplicate<T>(numTokensPerExpertTensor, 0, numExperts_);
172-
Duplicate<T>(prefixCountPerExpertTensor, 0, numExperts_);
171+
Duplicate<T>(numTokensPerExpertTensor, 0, numTokensPerExpert32AlignIntLen_ / sizeof(T));
172+
Duplicate<T>(prefixCountPerExpertTensor, 0, numTokensPerExpert32AlignIntLen_ / sizeof(T));
173173
Duplicate<T>(isTokenInRankTensor, 0, tempTokens_ * numRanks_);
174174
Duplicate<T>(localTokenServerOffsetTensor, 0, localTokenServerOffset32AlignIntLen_ / sizeof(T));
175175
Duplicate<T>(sendTokenIdxTensor, 0, sendTokenIdx32AlignIntLen_ / sizeof(T));

csrc/deepep/ops2/op_kernel/dispatch_layout.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ class DispatchLayout
9898
LocalTensor<T> seenRankTensor = seenRankBuf_.AllocTensor<T>();
9999
LocalTensor<T> sendTokenIdxSmallTensor = sendTokenIdxSmallBuf_.AllocTensor<T>();
100100
Duplicate<T>(numTokensPerRankTensor, 0, numRanks_);
101-
Duplicate<T>(numTokensPerExpertTensor, 0, numExperts_);
101+
Duplicate<T>(numTokensPerExpertTensor, 0, numTokensPerExpert32AlignIntLen_ / sizeof(T));
102102
Duplicate<T>(isTokenInRankTensor, 0, tempTokens_ * numRanks_);
103103
SyncFunc<AscendC::HardEvent::V_S>();
104104

csrc/deepep/ops2/op_kernel/dispatch_layout_a2.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ class DispatchLayoutA2
168168
LocalTensor<T> countExpertTensor = countExpertBuf_.AllocTensor<T>();
169169
Duplicate<T>(countExpertTensor, 0, numExperts_);
170170
Duplicate<T>(numTokensPerRankTensor, 0, numRanks_);
171-
Duplicate<T>(numTokensPerExpertTensor, 0, numExperts_);
172-
Duplicate<T>(prefixCountPerExpertTensor, 0, numExperts_);
171+
Duplicate<T>(numTokensPerExpertTensor, 0, numTokensPerExpert32AlignIntLen_ / sizeof(T));
172+
Duplicate<T>(prefixCountPerExpertTensor, 0, numTokensPerExpert32AlignIntLen_ / sizeof(T));
173173
Duplicate<T>(isTokenInRankTensor, 0, tempTokens_ * numRanks_);
174174
Duplicate<T>(localTokenServerOffsetTensor, 0, localTokenServerOffset32AlignIntLen_ / sizeof(T));
175175
Duplicate<T>(sendTokenIdxTensor, 0, sendTokenIdx32AlignIntLen_ / sizeof(T));

python/deep_ep/A2_DEEPEP_CN.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11

22
A2场景下使用DeepEp说明
33

4+
# 软硬件配套说明
5+
硬件型号支持:Atlas A2 系列产品
6+
平台:aarch64/x86
7+
配套软件
8+
- 驱动 Ascend HDK ≥ 25.3.RC1、CANN ≥ 8.3.RC1
9+
410
# 构建DeepEp包
511
执行工程构建脚本 build.sh
612
```bash
@@ -47,6 +53,12 @@ DeepEp 向上层提供以下核心接口:
4753
export HCCL_BUFFSIZE=1024
4854
```
4955

56+
A2场景下叠加deepep,需**禁用**环境变量`HCCL_OP_EXPANSION_MODE`,否则会出现未知算子错误。
57+
```bash
58+
# A2下需要去除该环境变量
59+
# export HCCL_OP_EXPANSION_MODE=AIV
60+
```
61+
5062
## A2单机
5163

5264
### 框架接入建议
Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,88 @@
1-
# This file contains swiglu for OpenAI models.
2-
# It will be optimized using Triton in the future.
31
import torch
2+
import triton
3+
import triton.language as tl
4+
from sgl_kernel_npu.utils.triton_utils import get_device_properties
45

56

6-
def swiglu_oai(layer, hidden_states):
7+
@triton.jit
8+
def swiglu_oai_kernel(
9+
hidden_states,
10+
gated_output,
11+
gemm1_alpha,
12+
gemm1_clamp_limit,
13+
output_dim: tl.constexpr,
14+
BLOCK_SIZE: tl.constexpr,
15+
MINIBLOCK_SIZE: tl.constexpr,
16+
BS: tl.constexpr,
17+
):
18+
i_block = tl.program_id(0)
19+
20+
for i_miniblock in range(0, BLOCK_SIZE, MINIBLOCK_SIZE):
21+
offset_bs = i_block * BLOCK_SIZE + i_miniblock + tl.arange(0, MINIBLOCK_SIZE)
22+
mask_bs = offset_bs < BS
23+
24+
offset_gate = tl.arange(0, output_dim) * 2
25+
offset_up = tl.arange(0, output_dim) * 2 + 1
26+
27+
gate = tl.load(
28+
hidden_states + offset_bs[:, None] * output_dim * 2 + offset_gate[None, :],
29+
mask=mask_bs[:, None],
30+
)
31+
up = tl.load(
32+
hidden_states + offset_bs[:, None] * output_dim * 2 + offset_up[None, :],
33+
mask=mask_bs[:, None],
34+
)
35+
36+
gate = tl.where(gate > gemm1_clamp_limit, gemm1_clamp_limit, gate)
37+
up = tl.where(up > gemm1_clamp_limit, gemm1_clamp_limit, up)
38+
up = tl.where(up < -gemm1_clamp_limit, -gemm1_clamp_limit, up)
39+
sig = 1.0 / (1.0 + tl.exp(-gate * gemm1_alpha))
40+
glu = gate * sig
41+
out = (up + 1) * glu
42+
43+
tl.store(
44+
gated_output
45+
+ offset_bs[:, None] * output_dim
46+
+ tl.arange(0, output_dim)[None, :],
47+
out,
48+
mask=mask_bs[:, None],
49+
)
50+
51+
52+
def swiglu_oai_triton(
53+
hidden_states,
54+
dim,
55+
gemm1_alpha,
56+
gemm1_clamp_limit,
57+
):
58+
hidden_states = hidden_states.view(-1, dim)
59+
BS = hidden_states.shape[0]
60+
output_dim = dim // 2
61+
gated_output = torch.empty(
62+
(BS, output_dim),
63+
dtype=hidden_states.dtype,
64+
device=hidden_states.device,
65+
)
66+
67+
kernel_num = get_device_properties()[0]
68+
MINIBLOCK_SIZE = 16
69+
BLOCK_SIZE = triton.cdiv(BS, MINIBLOCK_SIZE * kernel_num) * MINIBLOCK_SIZE
70+
BLOCK_NUM = triton.cdiv(BS, BLOCK_SIZE)
71+
72+
swiglu_oai_kernel[(BLOCK_NUM,)](
73+
hidden_states,
74+
gated_output,
75+
gemm1_alpha,
76+
gemm1_clamp_limit,
77+
output_dim,
78+
BLOCK_SIZE,
79+
MINIBLOCK_SIZE,
80+
BS,
81+
)
82+
return gated_output
83+
84+
85+
def swiglu_oai_native(layer, hidden_states):
786
E, N, _ = layer.w13_weight.size()
887
gate_up = hidden_states.view(-1, N)
988
alpha = layer.moe_runner_config.gemm1_alpha
@@ -14,3 +93,12 @@ def swiglu_oai(layer, hidden_states):
1493
glu = gate * torch.sigmoid(gate * alpha)
1594
gated_output = (up + 1) * glu
1695
return gated_output
96+
97+
98+
def swiglu_oai(layer, hidden_states):
99+
return swiglu_oai_triton(
100+
hidden_states,
101+
layer.w13_weight.shape[1],
102+
layer.moe_runner_config.gemm1_alpha,
103+
layer.moe_runner_config.gemm1_clamp_limit,
104+
)

0 commit comments

Comments
 (0)