|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +# Test that the interaction between EPLB and FusedMoE Layer is okay |
| 5 | + |
| 6 | +from dataclasses import dataclass |
| 7 | + |
| 8 | +import pytest |
| 9 | +import torch |
| 10 | + |
| 11 | +from vllm.config import VllmConfig, set_current_vllm_config |
| 12 | +from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace |
| 13 | +from vllm.distributed.parallel_state import ( |
| 14 | + ensure_model_parallel_initialized, |
| 15 | + get_tp_group, |
| 16 | +) |
| 17 | +from vllm.model_executor.layers.fused_moe.layer import FusedMoE |
| 18 | + |
| 19 | +from .eplb_utils import distributed_run, set_env_vars_and_device |
| 20 | + |
| 21 | + |
| 22 | +@dataclass |
| 23 | +class TestConfig: |
| 24 | + num_layers: int |
| 25 | + num_experts: int |
| 26 | + num_local_experts: int |
| 27 | + num_topk: int |
| 28 | + hidden_size: int |
| 29 | + intermediate_size: int |
| 30 | + weight_dtype: torch.dtype |
| 31 | + weight_scale_dtype: torch.dtype | None |
| 32 | + column_major_scales: bool |
| 33 | + |
| 34 | + |
| 35 | +def make_expert_weights( |
| 36 | + layer_idx: int, |
| 37 | + global_expert_idx: int, |
| 38 | + global_num_experts: int, |
| 39 | + tensor_shape: tuple[int, ...], |
| 40 | + tensor_dtype: torch.dtype, |
| 41 | + tensor_device: torch.device, |
| 42 | + is_column_major: bool, |
| 43 | +) -> torch.Tensor: |
| 44 | + assert len(tensor_shape) == 2 |
| 45 | + |
| 46 | + if is_column_major: |
| 47 | + tensor_shape = (tensor_shape[1], tensor_shape[0]) |
| 48 | + |
| 49 | + x = torch.empty(tensor_shape, dtype=tensor_dtype, device=tensor_device) |
| 50 | + value_offset = (layer_idx * global_num_experts + global_expert_idx) * x.numel() |
| 51 | + x.view(-1).copy_( |
| 52 | + torch.arange( |
| 53 | + value_offset, |
| 54 | + value_offset + x.numel(), |
| 55 | + dtype=tensor_dtype, |
| 56 | + device=tensor_device, |
| 57 | + ) |
| 58 | + ) |
| 59 | + |
| 60 | + if is_column_major: |
| 61 | + x = torch.transpose(x, 1, 0) |
| 62 | + assert not x.is_contiguous() |
| 63 | + return x |
| 64 | + |
| 65 | + |
| 66 | +def make_fused_moe_layer( |
| 67 | + rank: int, |
| 68 | + layer_idx: int, |
| 69 | + test_config: TestConfig, |
| 70 | +) -> FusedMoE: |
| 71 | + fml = FusedMoE( |
| 72 | + num_experts=test_config.num_experts, |
| 73 | + top_k=test_config.num_topk, |
| 74 | + hidden_size=test_config.hidden_size, |
| 75 | + intermediate_size=test_config.intermediate_size, |
| 76 | + prefix=f"dummy_layer_{layer_idx}", |
| 77 | + activation="silu", |
| 78 | + is_act_and_mul=True, |
| 79 | + params_dtype=test_config.weight_dtype, |
| 80 | + ) |
| 81 | + |
| 82 | + device = torch.device(f"cuda:{rank}") |
| 83 | + |
| 84 | + from functools import partial |
| 85 | + |
| 86 | + _make_expert_weights = partial( |
| 87 | + make_expert_weights, |
| 88 | + layer_idx=layer_idx, |
| 89 | + global_num_experts=test_config.num_experts, |
| 90 | + tensor_device=device, |
| 91 | + ) |
| 92 | + |
| 93 | + assert isinstance(fml.w13_weight.data, torch.Tensor) |
| 94 | + assert isinstance(fml.w2_weight.data, torch.Tensor) |
| 95 | + fml.w13_weight.data = fml.w13_weight.data.to(device=device) |
| 96 | + fml.w2_weight.data = fml.w2_weight.data.to(device=device) |
| 97 | + w13_weight = fml.w13_weight.data |
| 98 | + w2_weight = fml.w2_weight.data |
| 99 | + assert w13_weight.size(0) == test_config.num_local_experts |
| 100 | + for i in range(test_config.num_local_experts): |
| 101 | + g_i = rank * test_config.num_local_experts + i |
| 102 | + w13_weight_e = w13_weight[i] |
| 103 | + w2_weight_e = w2_weight[i] |
| 104 | + w13_weight_e.copy_( |
| 105 | + _make_expert_weights( |
| 106 | + global_expert_idx=g_i, |
| 107 | + tensor_shape=w13_weight_e.shape, |
| 108 | + tensor_dtype=w13_weight_e.dtype, |
| 109 | + is_column_major=False, |
| 110 | + ) |
| 111 | + ) |
| 112 | + w2_weight_e.copy_( |
| 113 | + _make_expert_weights( |
| 114 | + global_expert_idx=g_i, |
| 115 | + tensor_shape=w2_weight_e.shape, |
| 116 | + tensor_dtype=w2_weight_e.dtype, |
| 117 | + is_column_major=False, |
| 118 | + ) |
| 119 | + ) |
| 120 | + |
| 121 | + block_size = 16 |
| 122 | + |
| 123 | + def block_quant_scales_shape( |
| 124 | + shape: tuple[int, ...], is_column_major: bool |
| 125 | + ) -> tuple[int, ...]: |
| 126 | + assert len(shape) == 3 |
| 127 | + if not is_column_major: |
| 128 | + return (shape[0], shape[1] // block_size, shape[2] // block_size) |
| 129 | + else: |
| 130 | + return (shape[0], shape[2] // block_size, shape[1] // block_size) |
| 131 | + |
| 132 | + is_column_major = test_config.column_major_scales |
| 133 | + w13_weight_scale_inv = torch.empty( |
| 134 | + block_quant_scales_shape(w13_weight.shape, is_column_major), |
| 135 | + dtype=test_config.weight_dtype, |
| 136 | + device=device, |
| 137 | + ) |
| 138 | + w2_weight_scale_inv = torch.empty( |
| 139 | + block_quant_scales_shape(w2_weight.shape, is_column_major), |
| 140 | + dtype=test_config.weight_dtype, |
| 141 | + device=device, |
| 142 | + ) |
| 143 | + |
| 144 | + for i in range(test_config.num_local_experts): |
| 145 | + g_i = rank * test_config.num_local_experts + i |
| 146 | + w13_s_e = w13_weight_scale_inv[i] |
| 147 | + w2_s_e = w2_weight_scale_inv[i] |
| 148 | + w13_s_e.copy_( |
| 149 | + _make_expert_weights( |
| 150 | + global_expert_idx=g_i, |
| 151 | + tensor_shape=w13_s_e.shape, |
| 152 | + tensor_dtype=w13_s_e.dtype, |
| 153 | + # Fill data in row-major and then |
| 154 | + # transpose if test_config requires col-major. |
| 155 | + is_column_major=False, |
| 156 | + ) |
| 157 | + ) |
| 158 | + w2_s_e.copy_( |
| 159 | + _make_expert_weights( |
| 160 | + global_expert_idx=g_i, |
| 161 | + tensor_shape=w2_s_e.shape, |
| 162 | + tensor_dtype=w2_s_e.dtype, |
| 163 | + is_column_major=False, |
| 164 | + ) |
| 165 | + ) |
| 166 | + if is_column_major: |
| 167 | + w13_weight_scale_inv = torch.transpose(w13_weight_scale_inv, 1, 2) |
| 168 | + w2_weight_scale_inv = torch.transpose(w2_weight_scale_inv, 1, 2) |
| 169 | + assert not w13_weight_scale_inv.is_contiguous() |
| 170 | + assert not w2_weight_scale_inv.is_contiguous() |
| 171 | + |
| 172 | + # Add scales to the parameter list |
| 173 | + fml.w13_weight_scale_inv = torch.nn.Parameter( |
| 174 | + w13_weight_scale_inv, requires_grad=False |
| 175 | + ) |
| 176 | + fml.w2_weight_scale_inv = torch.nn.Parameter( |
| 177 | + w2_weight_scale_inv, requires_grad=False |
| 178 | + ) |
| 179 | + |
| 180 | + return fml |
| 181 | + |
| 182 | + |
| 183 | +def _test_eplb_fml(env, world_size: int, test_config: TestConfig): |
| 184 | + # Initialize model parallel (using tensor parallel as an entrypoint |
| 185 | + # to expert parallel) |
| 186 | + set_env_vars_and_device(env) |
| 187 | + |
| 188 | + vllm_config = VllmConfig() |
| 189 | + vllm_config.parallel_config.tensor_parallel_size = world_size |
| 190 | + vllm_config.parallel_config.enable_expert_parallel = True |
| 191 | + |
| 192 | + with set_current_vllm_config(vllm_config): |
| 193 | + ensure_model_parallel_initialized( |
| 194 | + tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 |
| 195 | + ) |
| 196 | + |
| 197 | + ep_group = get_tp_group().cpu_group |
| 198 | + ep_rank = torch.distributed.get_rank() |
| 199 | + |
| 200 | + fml_layers = [ |
| 201 | + make_fused_moe_layer(ep_rank, layer_idx, test_config) |
| 202 | + for layer_idx in range(test_config.num_layers) |
| 203 | + ] |
| 204 | + rank_expert_weights = [fml.get_expert_weights() for fml in fml_layers] |
| 205 | + |
| 206 | + indices = torch.zeros( |
| 207 | + test_config.num_layers, test_config.num_experts, dtype=torch.long |
| 208 | + ) |
| 209 | + for lidx in range(test_config.num_layers): |
| 210 | + indices[lidx] = torch.Tensor(range(test_config.num_experts)) |
| 211 | + |
| 212 | + shuffled_indices = torch.zeros_like(indices) |
| 213 | + for lidx in range(test_config.num_layers): |
| 214 | + shuffled_indices[lidx] = torch.randperm(test_config.num_experts) |
| 215 | + |
| 216 | + rearrange_expert_weights_inplace( |
| 217 | + indices, |
| 218 | + shuffled_indices, |
| 219 | + rank_expert_weights, |
| 220 | + ep_group, |
| 221 | + is_profile=False, |
| 222 | + ) |
| 223 | + |
| 224 | + num_local_experts = test_config.num_local_experts |
| 225 | + num_global_experts = test_config.num_experts |
| 226 | + for lidx, fml in enumerate(fml_layers): |
| 227 | + for name, w in fml.named_parameters(): |
| 228 | + for e in range(num_local_experts): |
| 229 | + g_e = shuffled_indices[lidx][ep_rank * num_local_experts + e] |
| 230 | + ref = make_expert_weights( |
| 231 | + layer_idx=lidx, |
| 232 | + global_expert_idx=int(g_e.item()), |
| 233 | + global_num_experts=num_global_experts, |
| 234 | + tensor_shape=w[e].shape, |
| 235 | + tensor_dtype=w[e].dtype, |
| 236 | + tensor_device=w[e].device, |
| 237 | + is_column_major=not w[e].is_contiguous(), |
| 238 | + ) |
| 239 | + assert w[e].shape == ref.shape and w[e].stride() == ref.stride(), ( |
| 240 | + f"w[{e}] {w[e].size()} {w[e].stride()} vs " |
| 241 | + f"ref {ref.size()} {ref.stride()}" |
| 242 | + ) |
| 243 | + torch.testing.assert_close(w[e], ref) |
| 244 | + |
| 245 | + |
| 246 | +@pytest.mark.parametrize("world_size", [2]) |
| 247 | +@pytest.mark.parametrize("num_layers", [4]) |
| 248 | +@pytest.mark.parametrize("num_experts", [16]) |
| 249 | +@pytest.mark.parametrize("hidden_size", [256]) |
| 250 | +@pytest.mark.parametrize("intermediate_size", [256]) |
| 251 | +@pytest.mark.parametrize("column_major_scales", [True, False]) |
| 252 | +def test_eplb_fml( |
| 253 | + world_size: int, |
| 254 | + num_layers: int, |
| 255 | + num_experts: int, |
| 256 | + hidden_size: int, |
| 257 | + intermediate_size: int, |
| 258 | + column_major_scales: bool, |
| 259 | +): |
| 260 | + if torch.cuda.device_count() < world_size: |
| 261 | + pytest.skip(f"Need at least {world_size} GPUs to run the test") |
| 262 | + |
| 263 | + num_local_experts = num_experts // world_size |
| 264 | + num_topk = 4 |
| 265 | + # The dtypes are fine as we are essentially just checking data-copies |
| 266 | + weight_dtype = torch.bfloat16 |
| 267 | + weight_scale_dtype = torch.bfloat16 |
| 268 | + |
| 269 | + test_config = TestConfig( |
| 270 | + num_layers=num_layers, |
| 271 | + num_experts=num_experts, |
| 272 | + num_local_experts=num_local_experts, |
| 273 | + num_topk=num_topk, |
| 274 | + hidden_size=hidden_size, |
| 275 | + intermediate_size=intermediate_size, |
| 276 | + weight_dtype=weight_dtype, |
| 277 | + weight_scale_dtype=weight_scale_dtype, |
| 278 | + column_major_scales=column_major_scales, |
| 279 | + ) |
| 280 | + |
| 281 | + distributed_run( |
| 282 | + _test_eplb_fml, |
| 283 | + world_size, |
| 284 | + test_config, |
| 285 | + ) |
0 commit comments