|
10 | 10 | import torch |
11 | 11 |
|
12 | 12 | import vllm.envs as envs |
| 13 | +from vllm.config import get_current_vllm_config |
| 14 | +from vllm.forward_context import get_forward_context, is_forward_context_available |
| 15 | +from vllm.logger import init_logger |
13 | 16 | from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig |
14 | 17 | from vllm.model_executor.layers.fused_moe.utils import ( |
15 | 18 | _resize_cache, |
|
26 | 29 | dbo_yield, |
27 | 30 | ) |
28 | 31 |
|
| 32 | +logger = init_logger(__name__) |
| 33 | + |
29 | 34 | # |
30 | 35 | # This file defines a set of base classes used to make MoE kernels more modular. |
31 | 36 | # The goal is to be able to utilize different communication mechanisms with |
@@ -798,6 +803,42 @@ def _allocate_buffers( |
798 | 803 | buffers = self.shared_buffers[ubatch_idx] |
799 | 804 | workspace_dtype = self.fused_experts.workspace_dtype(out_dtype) |
800 | 805 |
|
| 806 | + # Force worst-case allocation in profiling run for |
| 807 | + # "mk.FusedMoEModularKernel.Standard" formats where this is only bounded |
| 808 | + # by `VLLM_FUSED_MOE_CHUNK_SIZE` and may not be seen during profiling with |
| 809 | + # DP+EP due to the random token routing. |
| 810 | + is_profile_run = ( |
| 811 | + is_forward_context_available() |
| 812 | + and get_forward_context().attn_metadata is None |
| 813 | + ) |
| 814 | + if is_profile_run and self.fused_experts.supports_chunking(): |
| 815 | + parallel_config = get_current_vllm_config().parallel_config |
| 816 | + is_dp_ep = ( |
| 817 | + parallel_config.data_parallel_size > 1 |
| 818 | + and parallel_config.enable_expert_parallel |
| 819 | + ) |
| 820 | + if is_dp_ep: |
| 821 | + max_workspace_13, max_workspace_2, max_fused_out_shape = ( |
| 822 | + self.fused_experts.workspace_shapes( |
| 823 | + envs.VLLM_FUSED_MOE_CHUNK_SIZE, |
| 824 | + N, |
| 825 | + K, |
| 826 | + top_k, |
| 827 | + global_num_experts, |
| 828 | + local_num_experts, |
| 829 | + expert_tokens_meta, |
| 830 | + ) |
| 831 | + ) |
| 832 | + buffers.workspace13.get( |
| 833 | + max_workspace_13, device=device, dtype=workspace_dtype |
| 834 | + ) |
| 835 | + buffers.workspace2.get( |
| 836 | + max_workspace_2, device=device, dtype=workspace_dtype |
| 837 | + ) |
| 838 | + buffers.fused_out.get( |
| 839 | + max_fused_out_shape, device=device, dtype=workspace_dtype |
| 840 | + ) |
| 841 | + |
801 | 842 | # Get intermediate workspace shapes based off the chunked M size. |
802 | 843 | workspace13_shape, workspace2_shape, _ = self.fused_experts.workspace_shapes( |
803 | 844 | M_chunk, |
|
0 commit comments