Skip to content

Commit 82216dc

Browse files
minosfuturetlrmchlsmthgemini-code-assist[bot]
authored
[Misc] Support routing logic simulation (#21990)
Signed-off-by: Ming Yang <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 3706618 commit 82216dc

File tree

4 files changed

+481
-0
lines changed

4 files changed

+481
-0
lines changed

tests/test_routing_simulator.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
#!/usr/bin/env python3
2+
# SPDX-License-Identifier: Apache-2.0
3+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
4+
"""
5+
Test script for the token-to-expert routing simulator.
6+
7+
This script demonstrates how to use the routing simulator to test
8+
different routing strategies and analyze their performance, including
9+
integration tests with FusedMoE layer.
10+
"""
11+
12+
import pytest
13+
import torch
14+
15+
from vllm.model_executor.layers.fused_moe.routing_simulator import (
16+
DistributionBasedRouting, RoutingSimulator)
17+
18+
19+
@pytest.fixture
20+
def device():
21+
"""Fixture to provide the appropriate device for testing."""
22+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
23+
24+
25+
@pytest.mark.parametrize("num_tokens", [1, 16, 256])
26+
@pytest.mark.parametrize("hidden_size", [64, 1024])
27+
@pytest.mark.parametrize("num_experts", [16, 128])
28+
@pytest.mark.parametrize("top_k", [1, 4])
29+
def test_basic_functionality(
30+
num_tokens: int,
31+
hidden_size: int,
32+
num_experts: int,
33+
top_k: int,
34+
device,
35+
):
36+
"""Test basic functionality of the routing simulator."""
37+
# Test each routing strategy
38+
strategies = RoutingSimulator.get_available_strategies()
39+
40+
hidden_states = torch.randn(num_tokens, hidden_size, device=device)
41+
router_logits = torch.randn(num_tokens, num_experts, device=device)
42+
43+
for strategy in strategies:
44+
# Simulate routing
45+
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
46+
hidden_states=hidden_states,
47+
router_logits=router_logits,
48+
strategy_name=strategy,
49+
top_k=top_k,
50+
)
51+
52+
# Check output shapes
53+
assert topk_weights.shape == (
54+
num_tokens,
55+
top_k,
56+
), f"Wrong weights shape for {strategy}"
57+
assert topk_ids.shape == (
58+
num_tokens,
59+
top_k,
60+
), f"Wrong ids shape for {strategy}"
61+
62+
# Check that expert IDs are valid
63+
assert (topk_ids.min()
64+
>= 0), f"Invalid expert ID (negative) for {strategy}"
65+
assert (topk_ids.max()
66+
< num_experts), f"Invalid expert ID (too large) for {strategy}"
67+
68+
69+
def test_routing_strategy_integration(monkeypatch, device):
70+
"""Test that the routing strategy environment variable works with
71+
FusedMoE."""
72+
pytest.importorskip("vllm.model_executor.layers.fused_moe.layer")
73+
74+
import vllm.envs as envs
75+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
76+
77+
# Test parameters
78+
num_tokens = 32
79+
hidden_size = 16
80+
num_experts = 4
81+
top_k = 2
82+
83+
# Create test data
84+
hidden_states = torch.randn(num_tokens, hidden_size, device=device)
85+
router_logits = torch.randn(num_tokens, num_experts, device=device)
86+
87+
# Test different routing strategies
88+
strategies = RoutingSimulator.get_available_strategies()
89+
90+
for strategy in strategies:
91+
# Set environment variable
92+
env_name = "VLLM_MOE_ROUTING_SIMULATION_STRATEGY"
93+
monkeypatch.setenv(env_name, strategy)
94+
95+
# Force reload of environment variable
96+
envs.environment_variables[env_name] = lambda s=strategy: s
97+
98+
# Test the select_experts method
99+
topk_weights, topk_ids = FusedMoE.select_experts(
100+
hidden_states=hidden_states,
101+
router_logits=router_logits,
102+
top_k=top_k,
103+
use_grouped_topk=False,
104+
renormalize=True,
105+
indices_type=torch.long)
106+
107+
# Verify output shapes
108+
assert topk_weights.shape == (
109+
num_tokens, top_k), f"Wrong weights shape for {strategy}"
110+
assert topk_ids.shape == (num_tokens,
111+
top_k), f"Wrong ids shape for {strategy}"
112+
113+
# Verify expert IDs are valid
114+
assert topk_ids.min(
115+
) >= 0, f"Invalid expert ID (negative) for {strategy}"
116+
assert topk_ids.max(
117+
) < num_experts, f"Invalid expert ID (too large) for {strategy}"
118+
119+
120+
def test_distribution_based_routing_with_custom_strategy():
121+
"""Test registering and using DistributionBasedRouting with custom
122+
parameters."""
123+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
124+
125+
# Register custom distribution-based strategy
126+
custom_strategy = DistributionBasedRouting(distribution="normal",
127+
mean=2.0,
128+
std=0.5)
129+
RoutingSimulator.register_strategy("custom_normal", custom_strategy)
130+
131+
# Test data
132+
num_tokens = 60
133+
hidden_size = 48
134+
num_experts = 6
135+
top_k = 3
136+
137+
hidden_states = torch.randn(num_tokens, hidden_size, device=device)
138+
router_logits = torch.randn(num_tokens, num_experts, device=device)
139+
140+
# Use the custom strategy
141+
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
142+
hidden_states=hidden_states,
143+
router_logits=router_logits,
144+
strategy_name="custom_normal",
145+
top_k=top_k)
146+
147+
# Check output shapes
148+
assert topk_weights.shape == (num_tokens, top_k)
149+
assert topk_ids.shape == (num_tokens, top_k)
150+
151+
# Check that expert IDs are valid
152+
assert topk_ids.min() >= 0
153+
assert topk_ids.max() < num_experts
154+
155+
156+
def test_instance_compatibility():
157+
"""Test that static methods work correctly."""
158+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
159+
160+
# Test static method directly
161+
hidden_states = torch.randn(10, 8, device=device)
162+
router_logits = torch.randn(10, 4, device=device)
163+
164+
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
165+
hidden_states=hidden_states,
166+
router_logits=router_logits,
167+
strategy_name="uniform_random",
168+
top_k=2)
169+
170+
assert topk_weights.shape == (10, 2)
171+
assert topk_ids.shape == (10, 2)

vllm/envs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,15 @@ def get_vllm_port() -> Optional[int]:
989989
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE":
990990
lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")),
991991

992+
# MoE routing strategy selector.
993+
# See `RoutingSimulator.get_available_strategies()` # for available
994+
# strategies.
995+
# Cutstom routing strategies can be registered by
996+
# RoutingSimulator.register_strategy()
997+
# Note: custom strategies may not produce correct model outputs
998+
"VLLM_MOE_ROUTING_SIMULATION_STRATEGY":
999+
lambda: os.environ.get("VLLM_MOE_ROUTING_SIMULATION_STRATEGY", "").lower(),
1000+
9921001
# Regex timeout for use by the vLLM tool parsing plugins.
9931002
"VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS":
9941003
lambda: int(os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1")),

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
2929
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
3030
is_rocm_aiter_moe_enabled)
31+
from vllm.model_executor.layers.fused_moe.routing_simulator import (
32+
RoutingSimulator)
3133
from vllm.model_executor.layers.quantization.base_config import (
3234
QuantizationConfig, QuantizeMethodBase)
3335
from vllm.model_executor.utils import set_weight_attrs
@@ -1362,6 +1364,16 @@ def select_experts(
13621364
"""
13631365
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
13641366

1367+
# Check if we should use a routing simulation strategy
1368+
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
1369+
if routing_strategy != "":
1370+
return RoutingSimulator.simulate_routing(
1371+
hidden_states=hidden_states,
1372+
router_logits=router_logits,
1373+
strategy_name=routing_strategy,
1374+
top_k=top_k,
1375+
indices_type=indices_type)
1376+
13651377
# DeepSeekv2 uses grouped_top_k
13661378
if use_grouped_topk:
13671379
assert topk_group is not None

0 commit comments

Comments
 (0)