|
| 1 | +**文件**:`buffer.py` |
| 2 | + |
| 3 | +**核心类**:`Buffer` |
| 4 | + |
| 5 | +**依赖**:`torch`, `deep_ep_cpp` |
| 6 | + |
| 7 | +**目的**:normal模式下dispatch和combine之前的数据预处理。 |
| 8 | + |
| 9 | +# get_dispatch_layout |
| 10 | + |
| 11 | +## 接口功能简述 |
| 12 | + |
| 13 | +根据传入的`topk_idx`计算Normal模式下后续的Dispatch和Combine需要的参数的本地副本 |
| 14 | + |
| 15 | +## 接口定义 |
| 16 | + |
| 17 | +```python |
| 18 | +def get_dispatch_layout( |
| 19 | + self, |
| 20 | + topk_idx: torch.Tensor, |
| 21 | + num_experts: int, |
| 22 | + previous_event: Optional[EventOverlap] = None, |
| 23 | + async_finish: bool = False, |
| 24 | + allocate_on_comm_stream: bool = False, |
| 25 | + ) -> Tuple[ |
| 26 | + torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, EventOverlap |
| 27 | + ]: |
| 28 | +``` |
| 29 | + |
| 30 | +```cpp |
| 31 | +std::tuple< |
| 32 | +torch::Tensor, // num_tokens_per_rank |
| 33 | +std::optional[torch::Tensor](torch::Tensor), // num_tokens_per_rdma_rank (预留字段) |
| 34 | +torch::Tensor, // num_tokens_per_expert |
| 35 | +torch::Tensor, // is_token_in_rank |
| 36 | +std::optional<EventHandle> // output_event (暂未使用) |
| 37 | +> |
| 38 | + |
| 39 | +Buffer::get_dispatch_layout( |
| 40 | +const torch::Tensor& topk_idx, |
| 41 | +int num_experts, |
| 42 | +std::optional<EventHandle>& previous_event, |
| 43 | +bool async, |
| 44 | +bool allocate_on_comm_stream |
| 45 | +) |
| 46 | + |
| 47 | +``` |
| 48 | + |
| 49 | +--- |
| 50 | + |
| 51 | +## 输入参数说明 |
| 52 | + |
| 53 | +| 参数名 | 类型 | 说明 | |
| 54 | +| ------------------------- | --------------------------------------------------- | ---------------------------------------------------------- | |
| 55 | +| `topk_idx` | `torch::Tensor` (`int64`, `[num_tokens, num_topk]`) | 每个 token 的 top-k expert 索引(必须是连续的二维张量)第二维大小取值范围[1, 16] | |
| 56 | +| `num_experts` | `int` | 系统中总的 expert 数量,取值范围[1, 512],且能被 `num_ranks` 整除 | |
| 57 | +| `previous_event` | `std::optional<EventHandle>&` | 异步执行用的前置事件(当前未使用,传入 `std::nullopt`) | |
| 58 | +| `async` | `bool` | 是否启用异步模式(当前未使用) | |
| 59 | +| `allocate_on_comm_stream` | `bool` | 是否在通信流上分配内存(当前未使用) | |
| 60 | + |
| 61 | +### 输入约束 |
| 62 | + |
| 63 | +* num_tokens: 表示batch sequence size,即本卡输入输出的token数量,在输入中体现为topk_idx的第一维。 |
| 64 | + * A2系列双机取值范围:(0, 4096];单机取值范围:(0, 8192]; |
| 65 | + * A3系列取值范围,不开蚂蚁搬家:(0, 8192],开蚂蚁搬家:(0, 32k]; |
| 66 | + |
| 67 | +--- |
| 68 | + |
| 69 | +## 返回值说明 |
| 70 | + |
| 71 | +| 返回值 | 类型 | 说明 | |
| 72 | +| -------------------------- | --------------------------------------------------- | ----------------------------------- | |
| 73 | +| `num_tokens_per_rank` | `torch::Tensor` (`int32`, `[num_ranks]`) | 每个 rank 中被分配的 token 数量 | |
| 74 | +| `num_tokens_per_rdma_rank` | `std::optional<torch::Tensor>` | 保留字段,当前始终为 `std::nullopt` | |
| 75 | +| `num_tokens_per_expert` | `torch::Tensor` (`int32`, `[num_experts]`) | 每个 expert 接收到的 token 数量 | |
| 76 | +| `is_token_in_rank` | `torch::Tensor` (`bool`, `[num_tokens, num_ranks]`) | 指示每个 token 是否属于某个 rank | |
| 77 | +| `output_event` | `std::optional<EventHandle>` | 保留字段,当前为 `std::nullopt` | |
| 78 | + |
| 79 | +--- |
| 80 | + |
| 81 | +## 内部逻辑简述 |
| 82 | + |
| 83 | +1. 将`topk_idx`搬到每个核的UB buffer上,第一遍遍历计算需要的部分参数; |
| 84 | +2. 使用DataCopy将计算结果搬到GM上,利用原子加或者分地址传输的方法聚合各核的数据; |
| 85 | +3. 将部分计算完的GM数据(如前缀和等)搬回UB,第二遍遍历计算剩下的参数; |
| 86 | +4. 计算出的结果tensor搬回GM。 |
| 87 | + |
| 88 | +--- |
| 89 | + |
| 90 | +## 多核策略 |
| 91 | + |
| 92 | +- 直接将`topk_idx`的第0维,即token数目按照核数目进行尽可能平均的划分,如果token数小于核数,则只使用前token数目个核,以此来实现数据并行。 |
| 93 | + |
| 94 | +--- |
| 95 | + |
| 96 | +## 示例用法 |
| 97 | + |
| 98 | +```cpp |
| 99 | +auto topk_idx = torch::randint(0, 256, {4096, 8}, torch::dtype(torch::kInt64).device(torch::kCUDA)); |
| 100 | +int num_experts = 256; |
| 101 | + |
| 102 | +std::optional<EventHandle> dummy_event = std::nullopt; |
| 103 | + |
| 104 | +auto [tokens_per_rank, _, tokens_per_expert, token_in_rank, _] = |
| 105 | + buffer.get_dispatch_layout(topk_idx, num_experts, dummy_event, false, false); |
| 106 | +``` |
| 107 | +
|
| 108 | +--- |
| 109 | +
|
| 110 | +## 注意事项 |
| 111 | +
|
| 112 | +- `topk_idx` 必须是 `int64` 类型并位于 NPU 上; |
| 113 | +- 当前支持的运行卡数`num_ranks`最大值为384; |
| 114 | +- 当前实现不使用 `async`、`previous_event`、`allocate_on_comm_stream` 等参数; |
| 115 | +- 若需要使用 `RDMA`、`异步通信` 或 `事件调度`,需扩展本接口; |
| 116 | +- 若 `num_experts` 不能被 `num_ranks` 整除,会导致逻辑错误; |
| 117 | +- 返回的所有 tensor 默认与输入 tensor 位于相同设备上; |
| 118 | +- A3机器和A2机器上layout实现并不完全相同,但都是计算后续需要的参数,算子中配置了根据环境选择,但仍要确保使用对应机器的算子。 |
| 119 | +
|
| 120 | +--- |
| 121 | +
|
| 122 | +## 扩展建议 |
| 123 | +
|
| 124 | +- 实现 `async` 执行和前置事件依赖(提高流水线并行度); |
| 125 | +- RDMA rank 支持后完善 `num_tokens_per_rdma_rank` 输出。 |
0 commit comments