Skip to content

Commit 3137991

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[BugFix] EPLB + B200 + DeepGEMM : Handle column-major scales tensor (#29162)
Signed-off-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]>
1 parent 57430fc commit 3137991

File tree

4 files changed

+377
-40
lines changed

4 files changed

+377
-40
lines changed

tests/distributed/eplb_utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import os
5+
import random
6+
7+
import torch
8+
import torch.multiprocessing as mp
9+
10+
from vllm.distributed.parallel_state import (
11+
init_distributed_environment,
12+
)
13+
from vllm.utils.system_utils import update_environment_variables
14+
15+
mp.set_start_method("spawn", force=True)
16+
17+
18+
def distributed_run(fn, world_size, *args):
19+
number_of_processes = world_size
20+
processes: list[mp.Process] = []
21+
for i in range(number_of_processes):
22+
env: dict[str, str] = {}
23+
env["RANK"] = str(i)
24+
env["LOCAL_RANK"] = str(i)
25+
env["WORLD_SIZE"] = str(number_of_processes)
26+
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
27+
env["MASTER_ADDR"] = "localhost"
28+
env["MASTER_PORT"] = "12345"
29+
p = mp.Process(target=fn, args=(env, world_size, *args))
30+
processes.append(p)
31+
p.start()
32+
33+
for p in processes:
34+
p.join()
35+
36+
for p in processes:
37+
assert p.exitcode == 0
38+
39+
40+
def set_env_vars_and_device(env: dict[str, str]) -> None:
41+
update_environment_variables(env)
42+
local_rank = os.environ["LOCAL_RANK"]
43+
device = torch.device(f"cuda:{local_rank}")
44+
torch.cuda.set_device(device)
45+
init_distributed_environment()
46+
47+
# Ensure each worker process has the same random seed
48+
random.seed(42)
49+
torch.manual_seed(42)

tests/distributed/test_eplb_execute.py

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,19 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import os
54
import random
65

76
import pytest
87
import torch
98
import torch.distributed
10-
import torch.multiprocessing as mp
119

1210
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
1311
from vllm.distributed.parallel_state import (
1412
ensure_model_parallel_initialized,
1513
get_tp_group,
16-
init_distributed_environment,
1714
)
18-
from vllm.utils.system_utils import update_environment_variables
19-
20-
mp.set_start_method("spawn", force=True)
21-
22-
23-
def distributed_run(fn, world_size, *args):
24-
number_of_processes = world_size
25-
processes: list[mp.Process] = []
26-
for i in range(number_of_processes):
27-
env: dict[str, str] = {}
28-
env["RANK"] = str(i)
29-
env["LOCAL_RANK"] = str(i)
30-
env["WORLD_SIZE"] = str(number_of_processes)
31-
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
32-
env["MASTER_ADDR"] = "localhost"
33-
env["MASTER_PORT"] = "12345"
34-
p = mp.Process(target=fn, args=(env, world_size, *args))
35-
processes.append(p)
36-
p.start()
37-
38-
for p in processes:
39-
p.join()
40-
41-
for p in processes:
42-
assert p.exitcode == 0
43-
44-
45-
def set_env_vars_and_device(env: dict[str, str]) -> None:
46-
update_environment_variables(env)
47-
local_rank = os.environ["LOCAL_RANK"]
48-
device = torch.device(f"cuda:{local_rank}")
49-
torch.cuda.set_device(device)
50-
init_distributed_environment()
51-
52-
# Ensure each worker process has the same random seed
53-
random.seed(42)
54-
torch.manual_seed(42)
15+
16+
from .eplb_utils import distributed_run, set_env_vars_and_device
5517

5618

5719
def create_expert_indices_with_redundancy(
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
# Test that the interaction between EPLB and FusedMoE Layer is okay
5+
6+
from dataclasses import dataclass
7+
8+
import pytest
9+
import torch
10+
11+
from vllm.config import VllmConfig, set_current_vllm_config
12+
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
13+
from vllm.distributed.parallel_state import (
14+
ensure_model_parallel_initialized,
15+
get_tp_group,
16+
)
17+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
18+
19+
from .eplb_utils import distributed_run, set_env_vars_and_device
20+
21+
22+
@dataclass
23+
class TestConfig:
24+
num_layers: int
25+
num_experts: int
26+
num_local_experts: int
27+
num_topk: int
28+
hidden_size: int
29+
intermediate_size: int
30+
weight_dtype: torch.dtype
31+
weight_scale_dtype: torch.dtype | None
32+
column_major_scales: bool
33+
34+
35+
def make_expert_weights(
36+
layer_idx: int,
37+
global_expert_idx: int,
38+
global_num_experts: int,
39+
tensor_shape: tuple[int, ...],
40+
tensor_dtype: torch.dtype,
41+
tensor_device: torch.device,
42+
is_column_major: bool,
43+
) -> torch.Tensor:
44+
assert len(tensor_shape) == 2
45+
46+
if is_column_major:
47+
tensor_shape = (tensor_shape[1], tensor_shape[0])
48+
49+
x = torch.empty(tensor_shape, dtype=tensor_dtype, device=tensor_device)
50+
value_offset = (layer_idx * global_num_experts + global_expert_idx) * x.numel()
51+
x.view(-1).copy_(
52+
torch.arange(
53+
value_offset,
54+
value_offset + x.numel(),
55+
dtype=tensor_dtype,
56+
device=tensor_device,
57+
)
58+
)
59+
60+
if is_column_major:
61+
x = torch.transpose(x, 1, 0)
62+
assert not x.is_contiguous()
63+
return x
64+
65+
66+
def make_fused_moe_layer(
67+
rank: int,
68+
layer_idx: int,
69+
test_config: TestConfig,
70+
) -> FusedMoE:
71+
fml = FusedMoE(
72+
num_experts=test_config.num_experts,
73+
top_k=test_config.num_topk,
74+
hidden_size=test_config.hidden_size,
75+
intermediate_size=test_config.intermediate_size,
76+
prefix=f"dummy_layer_{layer_idx}",
77+
activation="silu",
78+
is_act_and_mul=True,
79+
params_dtype=test_config.weight_dtype,
80+
)
81+
82+
device = torch.device(f"cuda:{rank}")
83+
84+
from functools import partial
85+
86+
_make_expert_weights = partial(
87+
make_expert_weights,
88+
layer_idx=layer_idx,
89+
global_num_experts=test_config.num_experts,
90+
tensor_device=device,
91+
)
92+
93+
assert isinstance(fml.w13_weight.data, torch.Tensor)
94+
assert isinstance(fml.w2_weight.data, torch.Tensor)
95+
fml.w13_weight.data = fml.w13_weight.data.to(device=device)
96+
fml.w2_weight.data = fml.w2_weight.data.to(device=device)
97+
w13_weight = fml.w13_weight.data
98+
w2_weight = fml.w2_weight.data
99+
assert w13_weight.size(0) == test_config.num_local_experts
100+
for i in range(test_config.num_local_experts):
101+
g_i = rank * test_config.num_local_experts + i
102+
w13_weight_e = w13_weight[i]
103+
w2_weight_e = w2_weight[i]
104+
w13_weight_e.copy_(
105+
_make_expert_weights(
106+
global_expert_idx=g_i,
107+
tensor_shape=w13_weight_e.shape,
108+
tensor_dtype=w13_weight_e.dtype,
109+
is_column_major=False,
110+
)
111+
)
112+
w2_weight_e.copy_(
113+
_make_expert_weights(
114+
global_expert_idx=g_i,
115+
tensor_shape=w2_weight_e.shape,
116+
tensor_dtype=w2_weight_e.dtype,
117+
is_column_major=False,
118+
)
119+
)
120+
121+
block_size = 16
122+
123+
def block_quant_scales_shape(
124+
shape: tuple[int, ...], is_column_major: bool
125+
) -> tuple[int, ...]:
126+
assert len(shape) == 3
127+
if not is_column_major:
128+
return (shape[0], shape[1] // block_size, shape[2] // block_size)
129+
else:
130+
return (shape[0], shape[2] // block_size, shape[1] // block_size)
131+
132+
is_column_major = test_config.column_major_scales
133+
w13_weight_scale_inv = torch.empty(
134+
block_quant_scales_shape(w13_weight.shape, is_column_major),
135+
dtype=test_config.weight_dtype,
136+
device=device,
137+
)
138+
w2_weight_scale_inv = torch.empty(
139+
block_quant_scales_shape(w2_weight.shape, is_column_major),
140+
dtype=test_config.weight_dtype,
141+
device=device,
142+
)
143+
144+
for i in range(test_config.num_local_experts):
145+
g_i = rank * test_config.num_local_experts + i
146+
w13_s_e = w13_weight_scale_inv[i]
147+
w2_s_e = w2_weight_scale_inv[i]
148+
w13_s_e.copy_(
149+
_make_expert_weights(
150+
global_expert_idx=g_i,
151+
tensor_shape=w13_s_e.shape,
152+
tensor_dtype=w13_s_e.dtype,
153+
# Fill data in row-major and then
154+
# transpose if test_config requires col-major.
155+
is_column_major=False,
156+
)
157+
)
158+
w2_s_e.copy_(
159+
_make_expert_weights(
160+
global_expert_idx=g_i,
161+
tensor_shape=w2_s_e.shape,
162+
tensor_dtype=w2_s_e.dtype,
163+
is_column_major=False,
164+
)
165+
)
166+
if is_column_major:
167+
w13_weight_scale_inv = torch.transpose(w13_weight_scale_inv, 1, 2)
168+
w2_weight_scale_inv = torch.transpose(w2_weight_scale_inv, 1, 2)
169+
assert not w13_weight_scale_inv.is_contiguous()
170+
assert not w2_weight_scale_inv.is_contiguous()
171+
172+
# Add scales to the parameter list
173+
fml.w13_weight_scale_inv = torch.nn.Parameter(
174+
w13_weight_scale_inv, requires_grad=False
175+
)
176+
fml.w2_weight_scale_inv = torch.nn.Parameter(
177+
w2_weight_scale_inv, requires_grad=False
178+
)
179+
180+
return fml
181+
182+
183+
def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
184+
# Initialize model parallel (using tensor parallel as an entrypoint
185+
# to expert parallel)
186+
set_env_vars_and_device(env)
187+
188+
vllm_config = VllmConfig()
189+
vllm_config.parallel_config.tensor_parallel_size = world_size
190+
vllm_config.parallel_config.enable_expert_parallel = True
191+
192+
with set_current_vllm_config(vllm_config):
193+
ensure_model_parallel_initialized(
194+
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
195+
)
196+
197+
ep_group = get_tp_group().cpu_group
198+
ep_rank = torch.distributed.get_rank()
199+
200+
fml_layers = [
201+
make_fused_moe_layer(ep_rank, layer_idx, test_config)
202+
for layer_idx in range(test_config.num_layers)
203+
]
204+
rank_expert_weights = [fml.get_expert_weights() for fml in fml_layers]
205+
206+
indices = torch.zeros(
207+
test_config.num_layers, test_config.num_experts, dtype=torch.long
208+
)
209+
for lidx in range(test_config.num_layers):
210+
indices[lidx] = torch.Tensor(range(test_config.num_experts))
211+
212+
shuffled_indices = torch.zeros_like(indices)
213+
for lidx in range(test_config.num_layers):
214+
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
215+
216+
rearrange_expert_weights_inplace(
217+
indices,
218+
shuffled_indices,
219+
rank_expert_weights,
220+
ep_group,
221+
is_profile=False,
222+
)
223+
224+
num_local_experts = test_config.num_local_experts
225+
num_global_experts = test_config.num_experts
226+
for lidx, fml in enumerate(fml_layers):
227+
for name, w in fml.named_parameters():
228+
for e in range(num_local_experts):
229+
g_e = shuffled_indices[lidx][ep_rank * num_local_experts + e]
230+
ref = make_expert_weights(
231+
layer_idx=lidx,
232+
global_expert_idx=int(g_e.item()),
233+
global_num_experts=num_global_experts,
234+
tensor_shape=w[e].shape,
235+
tensor_dtype=w[e].dtype,
236+
tensor_device=w[e].device,
237+
is_column_major=not w[e].is_contiguous(),
238+
)
239+
assert w[e].shape == ref.shape and w[e].stride() == ref.stride(), (
240+
f"w[{e}] {w[e].size()} {w[e].stride()} vs "
241+
f"ref {ref.size()} {ref.stride()}"
242+
)
243+
torch.testing.assert_close(w[e], ref)
244+
245+
246+
@pytest.mark.parametrize("world_size", [2])
247+
@pytest.mark.parametrize("num_layers", [4])
248+
@pytest.mark.parametrize("num_experts", [16])
249+
@pytest.mark.parametrize("hidden_size", [256])
250+
@pytest.mark.parametrize("intermediate_size", [256])
251+
@pytest.mark.parametrize("column_major_scales", [True, False])
252+
def test_eplb_fml(
253+
world_size: int,
254+
num_layers: int,
255+
num_experts: int,
256+
hidden_size: int,
257+
intermediate_size: int,
258+
column_major_scales: bool,
259+
):
260+
if torch.cuda.device_count() < world_size:
261+
pytest.skip(f"Need at least {world_size} GPUs to run the test")
262+
263+
num_local_experts = num_experts // world_size
264+
num_topk = 4
265+
# The dtypes are fine as we are essentially just checking data-copies
266+
weight_dtype = torch.bfloat16
267+
weight_scale_dtype = torch.bfloat16
268+
269+
test_config = TestConfig(
270+
num_layers=num_layers,
271+
num_experts=num_experts,
272+
num_local_experts=num_local_experts,
273+
num_topk=num_topk,
274+
hidden_size=hidden_size,
275+
intermediate_size=intermediate_size,
276+
weight_dtype=weight_dtype,
277+
weight_scale_dtype=weight_scale_dtype,
278+
column_major_scales=column_major_scales,
279+
)
280+
281+
distributed_run(
282+
_test_eplb_fml,
283+
world_size,
284+
test_config,
285+
)

0 commit comments

Comments
 (0)