Skip to content

Commit 94096a4

Browse files
authored
[UX] Separate marlin moe config logic from triton moe (#23006)
1 parent a258ad8 commit 94096a4

File tree

2 files changed

+7
-22
lines changed

2 files changed

+7
-22
lines changed

vllm/model_executor/layers/fused_moe/fused_marlin_moe.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""Fused MoE utilities for GPTQ."""
4-
import functools
54
from typing import Optional
65

76
import torch
87

98
import vllm._custom_ops as ops
10-
from vllm.model_executor.layers.fused_moe.fused_moe import (
11-
moe_align_block_size, try_get_optimal_moe_config)
9+
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
1210
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
1311
marlin_make_workspace_new, maybe_warn_marlin_atomic_add)
1412
from vllm.scalar_type import ScalarType, scalar_types
@@ -98,17 +96,11 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
9896
N = w2.shape[1] * 16
9997
topk = topk_ids.shape[1]
10098

101-
get_config_func = functools.partial(
102-
try_get_optimal_moe_config,
103-
w1.shape,
104-
w2.shape,
105-
topk_ids.shape[1],
106-
None,
107-
is_marlin=True,
108-
)
109-
config = get_config_func(M)
110-
111-
block_size_m = config["BLOCK_SIZE_M"]
99+
# M block size selection logic
100+
# TODO: tune this further for specific models
101+
for block_size_m in [8, 16, 32, 48, 64]:
102+
if M * topk / E / block_size_m < 0.9:
103+
break
112104

113105
if global_num_experts == -1:
114106
global_num_experts = E

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,6 @@ def get_default_config(
801801
K: int,
802802
topk: int,
803803
dtype: Optional[str],
804-
is_marlin: bool,
805804
block_shape: Optional[list[int]] = None,
806805
) -> dict[str, int]:
807806
if dtype == "fp8_w8a8" and block_shape is not None:
@@ -832,11 +831,6 @@ def get_default_config(
832831
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
833832
else:
834833
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
835-
elif is_marlin:
836-
for block_size_m in [8, 16, 32, 48, 64]:
837-
if M * topk / E / block_size_m < 0.9:
838-
break
839-
return {"BLOCK_SIZE_M": block_size_m}
840834
elif M <= E:
841835
config = {
842836
"BLOCK_SIZE_M": 16,
@@ -860,7 +854,6 @@ def try_get_optimal_moe_config(
860854
top_k: int,
861855
dtype: Optional[str],
862856
M: int,
863-
is_marlin: bool = False,
864857
block_shape: Optional[list[int]] = None,
865858
) -> dict[str, int]:
866859
from vllm.model_executor.layers.fused_moe import get_config
@@ -883,7 +876,7 @@ def try_get_optimal_moe_config(
883876
else:
884877
# Else use the default config
885878
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
886-
is_marlin, block_shape)
879+
block_shape)
887880
return config
888881

889882

0 commit comments

Comments
 (0)