diff --git a/vllm_ascend/eplb/core/policy/policy_dynamic_ep.py b/vllm_ascend/eplb/core/policy/policy_dynamic_ep.py index 5efbfe6b9c..76f58d1d94 100644 --- a/vllm_ascend/eplb/core/policy/policy_dynamic_ep.py +++ b/vllm_ascend/eplb/core/policy/policy_dynamic_ep.py @@ -46,11 +46,11 @@ def add_redundant(current_expert_table, expert_workload, return workload_new @staticmethod - # Split hot (high-load) experts into redundant experts - def original_compute_balanced_pack_redundancy(origin_weights, card_num, - num_redundancy_expert): + # Split hotspot experts into redundant experts + def original_compute_balanced_pack_redundancy(origin_weights, card_num, num_redundancy_expert, card_per_host=16): # Step 1: Sort the items by weight in descending order (we are sorting by weight now) # Sort based on the second element (the second value of each tuple) + host_num = card_num // card_per_host route_expert_num = len(origin_weights) route_expert_redundancy: list[list[int]] = [ [] for _ in range(route_expert_num) @@ -78,6 +78,7 @@ def original_compute_balanced_pack_redundancy(origin_weights, card_num, box_weights = [0] * card_num # To store the total weight of each box box_counts = [0] * card_num # To store the number of items in each box index = 0 + expert_in_hosts = [[0] * host_num for _ in range(route_expert_num)] for i in range(route_expert_num): redundancy_num = len(route_expert_redundancy[i]) for _ in range(redundancy_num): @@ -90,17 +91,26 @@ def original_compute_balanced_pack_redundancy(origin_weights, card_num, boxes_weights[index].append(cur_weight) box_weights[index] += cur_weight box_counts[index] += 1 - index += 1 + # consider per host balance for redundant experts + expert_in_hosts[i][index // card_per_host] += 1 + index = (index + 1 + card_per_host) % card_num sorted_indices = np.argsort([t[1] for t in origin_weights], kind='stable')[::-1] origin_weights = [origin_weights[idx] for idx in sorted_indices] + check_pair = [[(i-1) % host_num, (i+1) % host_num] for i in range(host_num)] # Step 4: Distribute items into boxes based on weight for item_id, weight in origin_weights: # Find the box with the least items but not full min_box_index = -1 for i in range(card_num): - if item_id in boxes[i]: + host_id = i // card_per_host + max_count = max(expert_in_hosts[item_id]) + min_count = min(expert_in_hosts[item_id]) + if (expert_in_hosts[item_id][host_id] > expert_in_hosts[item_id][check_pair[host_id][0]] + or expert_in_hosts[item_id][host_id] > expert_in_hosts[item_id][check_pair[host_id][1]] + or (max_count != min_count and expert_in_hosts[item_id][host_id] != min_count) + or item_id in boxes[i]): continue # Only choose boxes that still have space (box_counts[i] < items_per_box) if box_counts[i] < items_per_box or (box_counts[i] @@ -363,7 +373,8 @@ def rebalance_experts(self, current_expert_table, expert_workload): # Obtain the globally balanced placement strategy for each layer result, layer_deployment = self.original_compute_balanced_pack_redundancy( - weights, num_npus, num_redundancy_expert) + weights, num_npus, num_redundancy_expert, self.config.num_die_per_host + ) global_deployment[layer] = layer_deployment max_heat_per_layer_after[layer] = max( diff --git a/vllm_ascend/eplb/eplb_updator.py b/vllm_ascend/eplb/eplb_updator.py index 5081d969a0..5dc9c7f1e8 100644 --- a/vllm_ascend/eplb/eplb_updator.py +++ b/vllm_ascend/eplb/eplb_updator.py @@ -87,6 +87,8 @@ def init_eplb(self, expert_map_path): self.eplb_process = self.eplb._launch_process() + self.compute_moe_load_async = torch.npu.Stream() + logger.info( f"[ModelRunner] Launched EPLB process (pid={self.eplb_process.pid})" ) @@ -164,21 +166,20 @@ def compute_and_set_moe_load(self, is_clear=False): self._gather_buffer = None if dist.is_initialized(): - self.world_size = dist.get_world_size() - self.device = local_load.device - if self._gather_buffer is None: - shape = (self.world_size, *local_load.shape) - self._gather_buffer = torch.empty(shape, - dtype=local_load.dtype, - device=self.device) - - dist.all_gather_into_tensor(self._gather_buffer, local_load) - - moe_load = self._gather_buffer.permute(1, 0, 2) - self.shared_dict["moe_load"] = moe_load.cpu() - logger.debug( - f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}" - ) + with torch.npu.stream(self.compute_moe_load_async): + self.world_size = dist.get_world_size() + self.device = local_load.device + if self._gather_buffer is None: + shape = (self.world_size, *local_load.shape) + self._gather_buffer = torch.empty(shape, + dtype=local_load.dtype, + device=self.device) + + dist.all_gather_into_tensor(self._gather_buffer, local_load) + + moe_load = self._gather_buffer.permute(1, 0, 2) + self.shared_dict["moe_load"] = moe_load.cpu() + logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}") else: moe_load = local_load.unsqueeze(1) self.shared_dict["moe_load"] = moe_load.cpu() diff --git a/vllm_ascend/ops/expert_load_balancer.py b/vllm_ascend/ops/expert_load_balancer.py index c6eec64a36..bd789f4559 100644 --- a/vllm_ascend/ops/expert_load_balancer.py +++ b/vllm_ascend/ops/expert_load_balancer.py @@ -3,6 +3,7 @@ from typing import Dict, List import torch +from vllm.distributed import get_ep_group class ExpertLoadBalancer(object): @@ -56,7 +57,7 @@ def generate_expert_placement_map(self): dtype=torch.int32) return expert_placement_map - def generate_log2phy_expert_map(self, layer_id): + def generate_log2phy_expert_map(self, layer_id, rank_id): concatenated = torch.flatten(self.expert_map_tensor[layer_id]) rank_expert_to_global = self.generate_index_dicts( self.expert_map_tensor[layer_id]) @@ -67,18 +68,95 @@ def generate_log2phy_expert_map(self, layer_id): result_dict[key] = [] result_dict[key].append(idx) - log2phy_map = torch.full((self.ranks_num, self.global_expert_num), - -1, + max_num_experts = max(len(locs) for locs in result_dict.values()) + log2phy_map = torch.full((self.global_expert_num, max_num_experts), + 0, dtype=torch.int32) - for rank in range(self.ranks_num): - for key in result_dict: - indices_in_concat = result_dict[key] - if key in rank_expert_to_global[rank]: - log2phy_map[rank][key] = rank_expert_to_global[rank][key] + num_experts = torch.ones(self.global_expert_num, dtype=torch.int32) + # self.update_expert_map(result_dict, log2phy_map, max_num_experts, rank_id) + self.update_expert_loc_map_v1(result_dict, rank_id) + for log_ids, phy_ids in result_dict.items(): + log2phy_map[log_ids, :len(phy_ids)] = torch.tensor(phy_ids) + num_experts[log_ids] = len(phy_ids) + return log2phy_map, num_experts + + def update_expert_map(self, expert_loc, log2phy_map, max_num_dups, rank_id): + ep_size = get_ep_group().world_size + redundancy_shared_expert_num = self.get_global_redundant_expert_num() + n_total_experts = self.global_expert_num + redundancy_shared_expert_num + + for i in range(self.global_expert_num): + same_rank_candidates = [] + same_node_candidates = [] + experts_per_device = n_total_experts // ep_size + all_candidates = [] + phy_list = expert_loc[i] + current_device = rank_id + for phy in phy_list: + phy_device = phy // experts_per_device + if phy_device == current_device: + same_rank_candidates.append(phy) + elif (phy_device // self.ranks_num) == (current_device // self.ranks_num): + same_node_candidates.append(phy) + else: + all_candidates.append(phy) + tmp_expert_loc_map = torch.zeros([max_num_dups], dtype=torch.int32) + + if same_rank_candidates: + expert_loc[i] = same_rank_candidates + elif same_node_candidates: + expert_loc[i] = same_node_candidates + tmp_expert_loc_map[: len(expert_loc[i])] = torch.tensor(expert_loc[i], dtype=torch.int32) + + log2phy_map[i] = tmp_expert_loc_map + + def update_expert_loc_map_v1(self, expert_loc, current_rank): + + device_per_host = 16 + ep_size = get_ep_group().world_size + current_node, current_rank_in_node = current_rank // device_per_host, current_rank % device_per_host + redundancy_shared_expert_num = self.get_global_redundant_expert_num() + n_total_experts = self.global_expert_num + redundancy_shared_expert_num + experts_per_device = n_total_experts // ep_size + num_hosts = self.ranks_num // device_per_host + for i in range(self.global_expert_num): + same_rank_candidates, same_node_candidates, all_candidates = [], [], [] + phy_list, num_replicas = expert_loc[i], len(expert_loc[i]) + + for phy in phy_list: + phy_device = phy // experts_per_device + if phy_device == current_rank: + same_rank_candidates.append(phy) + elif (phy_device // device_per_host) == (current_rank // device_per_host): + same_node_candidates.append(phy) else: - chosen_index = random.choice(indices_in_concat) - log2phy_map[rank][key] = chosen_index - return log2phy_map + all_candidates.append(phy) + + is_imbalanced = False + if num_replicas > num_hosts and num_replicas % num_hosts != 0: + replica_per_node = {} + for phy in phy_list: + phy_device = phy // experts_per_device + phy_node = phy_device // device_per_host + local_rank = phy_device % device_per_host + if phy_node not in replica_per_node: + replica_per_node[phy_node] = [] + replica_per_node[phy_node].append(local_rank) + base_replicas_per_host = num_replicas // num_hosts + if len(replica_per_node[current_node]) == base_replicas_per_host: + available_ranks = list(set(range(device_per_host)) - set(replica_per_node[current_node])) + expected_load = round(device_per_host / (base_replicas_per_host + 1)) + if current_rank_in_node in available_ranks: + if available_ranks.index(current_rank_in_node) >= (expected_load - 1) * base_replicas_per_host: + is_imbalanced = True + + if same_rank_candidates: + expert_loc[i] = same_rank_candidates + elif same_node_candidates and not is_imbalanced: + expert_loc[i] = same_node_candidates + + return expert_loc + def get_rank_placement_map(self, layer_id, rank_id): expert_placement_map = self.generate_expert_placement_map() @@ -89,8 +167,8 @@ def get_rank_placement_map(self, layer_id, rank_id): return rank_local_expert_num, rank_expert_map def get_rank_log2phy_map(self, layer_id, rank_id): - layer_log2phy_map = self.generate_log2phy_expert_map(layer_id) - return layer_log2phy_map[rank_id] + layer_log2phy_map = self.generate_log2phy_expert_map(layer_id, rank_id) + return layer_log2phy_map def get_global_redundant_expert_num(self): global_redundant_expert_num = ( diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 39ee9acbda..978925cbbe 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1200,7 +1200,11 @@ def __init__( expert_load_balancer.get_rank_placement_map( self.moe_instance_id, self.ep_rank)) self.log2phy = expert_load_balancer.get_rank_log2phy_map( - self.moe_instance_id, self.ep_rank).npu() + self.moe_instance_id, get_ep_group().rank_in_group).npu() + log2phy_map, num_experts = self.log2phy + log2phy_map = log2phy_map.npu() + num_experts = num_experts.npu() + self.log2phy = log2phy_map, num_experts self.global_redundant_expert_num = ( expert_load_balancer.get_global_redundant_expert_num()) else: @@ -1431,8 +1435,16 @@ def forward( e_hidden_states, expert_token_num, group_list_type = e_hidden_states if self.dynamic_eplb: - self.moe_load += expert_token_num if group_list_type else \ - torch.cat([expert_token_num[:1], expert_token_num[1:] - expert_token_num[:-1]]) + if is_prefill: + token_nums = expert_token_num if group_list_type else \ + torch.cat([expert_token_num[:1], expert_token_num[1:] - expert_token_num[:-1]]) + self.moe_load.add_(token_nums) + else: + with npu_stream_switch("moe_load_async", 0): + npu_wait_tensor(hidden_states, expert_token_num) + token_nums = expert_token_num if group_list_type else \ + torch.cat([expert_token_num[:1], expert_token_num[1:] - expert_token_num[:-1]]) + self.moe_load.add_(token_nums) if not self.enable_prefill_optimizations and fused_moe_state != FusedMoEState.AllGather and not enable_sp: if tp_size > 1: @@ -1470,7 +1482,7 @@ def get_map(self): return self.expert_map def get_log2phy_map(self): - return self.log2phy + return self.log2phy[0] def clear_moe_load(self): if self.moe_load is not None: diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 0c9436d28f..c8d2defbfb 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -220,11 +220,13 @@ def fused_experts_with_mc2( hidden_states_for_share: Optional[Any] = None, dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, + token_selector: torch.Tensor = None, ) -> Union[Tuple[torch.Tensor, torch.Tensor, int], Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, int]]: assert mc2_mask is not None if log2phy is not None: - topk_ids = log2phy[topk_ids] + log2phy_map, num_experts = log2phy + topk_ids = log2phy_map[topk_ids, token_selector[: topk_ids.shape[0]] % num_experts[topk_ids]] quant_mode = 2 ep_group = get_mc2_group() ep_rank_id = ep_group.rank_in_group @@ -387,7 +389,11 @@ def fused_prefill_experts_with_mc2( hidden_states_for_share: Optional[Any] = None, dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, + token_selector: torch.Tensor = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if log2phy is not None: + log2phy_map, num_experts = log2phy + topk_ids = log2phy_map[topk_ids, token_selector[: topk_ids.shape[0]] % num_experts[topk_ids]] assert mc2_mask is not None max_num_chunks = get_forward_context().max_num_chunks @@ -496,9 +502,11 @@ def fused_experts_with_all2all(hidden_states: torch.Tensor, log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None): + w2_scale_bias: torch.Tensor = None, + token_selector: torch.Tensor = None,): if log2phy is not None: - topk_ids = log2phy[topk_ids] + log2phy_map, num_experts = log2phy + topk_ids = log2phy_map[topk_ids, token_selector[: topk_ids.shape[0]] % num_experts[topk_ids]] original_shape = hidden_states.shape if len(original_shape) == 3: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) @@ -829,6 +837,11 @@ def __init__(self): ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + self.max_token_nums = vllm_config.scheduler_config.max_num_batched_tokens + self.token_selector = torch.arange(0, self.max_token_nums, dtype=torch.int32).view(-1, 1).npu() self.enable_weight_nz_layout = ascend_config.enable_weight_nz_layout try: @@ -978,6 +991,7 @@ def apply( log2phy=log2phy, global_redundant_expert_num=global_redundant_expert_num, shared_experts=shared_experts, + token_selector=self.token_selector, is_torchair=self.torchair_graph_enabled, hidden_states_for_share=shared_gate_up, dynamic_scale_for_share=shared_dequant_scale, @@ -998,6 +1012,7 @@ def apply( global_redundant_expert_num=global_redundant_expert_num, shared_experts=shared_experts, is_torchair=self.torchair_graph_enabled, + token_selector=self.token_selector, hidden_states_for_share=shared_gate_up, dynamic_scale_for_share=shared_dequant_scale, mc2_mask=kwargs.get("mc2_mask", None)) @@ -1029,6 +1044,7 @@ def apply( ep_group=self.ep_group, log2phy=log2phy, global_redundant_expert_num=global_redundant_expert_num, + token_selector=self.token_selector, ) def process_weights_after_loading(self, layer):