|
2 | 2 | import functools
|
3 | 3 | import json
|
4 | 4 | import os
|
5 |
| -from typing import Any, Dict, Optional |
| 5 | +from typing import Any, Dict, Optional, Tuple |
6 | 6 |
|
7 | 7 | import torch
|
8 | 8 | import triton
|
@@ -137,7 +137,7 @@ def fused_moe_kernel(
|
137 | 137 |
|
138 | 138 | def moe_align_block_size(
|
139 | 139 | topk_ids: torch.Tensor, block_size: int,
|
140 |
| - num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor): |
| 140 | + num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
141 | 141 | """
|
142 | 142 | Aligns the token distribution across experts to be compatible with block size for matrix multiplication.
|
143 | 143 |
|
@@ -185,7 +185,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
185 | 185 | sorted_token_ids: torch.Tensor,
|
186 | 186 | expert_ids: torch.Tensor,
|
187 | 187 | num_tokens_post_padded: torch.Tensor,
|
188 |
| - mul_routed_weight: bool, top_k: int, config: dict): |
| 188 | + mul_routed_weight: bool, top_k: int, |
| 189 | + config: Dict[str, Any]) -> None: |
189 | 190 | assert topk_weights.stride(1) == 1
|
190 | 191 | assert sorted_token_ids.stride(0) == 1
|
191 | 192 |
|
|
0 commit comments