Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions vllm_ascend/eplb/core/policy/policy_dynamic_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def add_redundant(current_expert_table, expert_workload,
return workload_new

@staticmethod
# Split hot (high-load) experts into redundant experts
def original_compute_balanced_pack_redundancy(origin_weights, card_num,
num_redundancy_expert):
# Split hotspot experts into redundant experts
def original_compute_balanced_pack_redundancy(origin_weights, card_num, num_redundancy_expert, card_per_host=16):
# Step 1: Sort the items by weight in descending order (we are sorting by weight now)
# Sort based on the second element (the second value of each tuple)
host_num = card_num // card_per_host
route_expert_num = len(origin_weights)
route_expert_redundancy: list[list[int]] = [
[] for _ in range(route_expert_num)
Expand Down Expand Up @@ -78,6 +78,7 @@ def original_compute_balanced_pack_redundancy(origin_weights, card_num,
box_weights = [0] * card_num # To store the total weight of each box
box_counts = [0] * card_num # To store the number of items in each box
index = 0
expert_in_hosts = [[0] * host_num for _ in range(route_expert_num)]
for i in range(route_expert_num):
redundancy_num = len(route_expert_redundancy[i])
for _ in range(redundancy_num):
Expand All @@ -90,17 +91,26 @@ def original_compute_balanced_pack_redundancy(origin_weights, card_num,
boxes_weights[index].append(cur_weight)
box_weights[index] += cur_weight
box_counts[index] += 1
index += 1
# consider per host balance for redundant experts
expert_in_hosts[i][index // card_per_host] += 1
index = (index + 1 + card_per_host) % card_num

sorted_indices = np.argsort([t[1] for t in origin_weights],
kind='stable')[::-1]
origin_weights = [origin_weights[idx] for idx in sorted_indices]
check_pair = [[(i-1) % host_num, (i+1) % host_num] for i in range(host_num)]
# Step 4: Distribute items into boxes based on weight
for item_id, weight in origin_weights:
# Find the box with the least items but not full
min_box_index = -1
for i in range(card_num):
if item_id in boxes[i]:
host_id = i // card_per_host
max_count = max(expert_in_hosts[item_id])
min_count = min(expert_in_hosts[item_id])
if (expert_in_hosts[item_id][host_id] > expert_in_hosts[item_id][check_pair[host_id][0]]
or expert_in_hosts[item_id][host_id] > expert_in_hosts[item_id][check_pair[host_id][1]]
or (max_count != min_count and expert_in_hosts[item_id][host_id] != min_count)
or item_id in boxes[i]):
Comment on lines +110 to +113
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The condition to filter eligible boxes (cards) for expert placement is very complex and difficult to understand. This level of complexity can make the code hard to maintain and debug. Consider refactoring this logic for clarity and simplicity. The core goal seems to be balancing experts across hosts. A simpler condition might be to only allow placing an expert on a host that has the minimum number of instances of that expert.

continue
# Only choose boxes that still have space (box_counts[i] < items_per_box)
if box_counts[i] < items_per_box or (box_counts[i]
Expand Down Expand Up @@ -363,7 +373,8 @@ def rebalance_experts(self, current_expert_table, expert_workload):

# Obtain the globally balanced placement strategy for each layer
result, layer_deployment = self.original_compute_balanced_pack_redundancy(
weights, num_npus, num_redundancy_expert)
weights, num_npus, num_redundancy_expert, self.config.num_die_per_host
)

global_deployment[layer] = layer_deployment
max_heat_per_layer_after[layer] = max(
Expand Down
31 changes: 16 additions & 15 deletions vllm_ascend/eplb/eplb_updator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def init_eplb(self, expert_map_path):

self.eplb_process = self.eplb._launch_process()

self.compute_moe_load_async = torch.npu.Stream()

logger.info(
f"[ModelRunner] Launched EPLB process (pid={self.eplb_process.pid})"
)
Expand Down Expand Up @@ -164,21 +166,20 @@ def compute_and_set_moe_load(self, is_clear=False):

self._gather_buffer = None
if dist.is_initialized():
self.world_size = dist.get_world_size()
self.device = local_load.device
if self._gather_buffer is None:
shape = (self.world_size, *local_load.shape)
self._gather_buffer = torch.empty(shape,
dtype=local_load.dtype,
device=self.device)

dist.all_gather_into_tensor(self._gather_buffer, local_load)

moe_load = self._gather_buffer.permute(1, 0, 2)
self.shared_dict["moe_load"] = moe_load.cpu()
logger.debug(
f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}"
)
with torch.npu.stream(self.compute_moe_load_async):
self.world_size = dist.get_world_size()
self.device = local_load.device
if self._gather_buffer is None:
shape = (self.world_size, *local_load.shape)
self._gather_buffer = torch.empty(shape,
dtype=local_load.dtype,
device=self.device)

dist.all_gather_into_tensor(self._gather_buffer, local_load)

moe_load = self._gather_buffer.permute(1, 0, 2)
self.shared_dict["moe_load"] = moe_load.cpu()
logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}")
Comment on lines +169 to +182
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic within this with block relies on self._gather_buffer for caching, but it is reset to None on every call to compute_and_set_moe_load at line 167. This leads to inefficient re-allocation of the buffer because the if self._gather_buffer is None: check at line 172 will always be true. To fix this, self._gather_buffer should be initialized as an instance attribute in init_eplb and not reset within this function.

else:
moe_load = local_load.unsqueeze(1)
self.shared_dict["moe_load"] = moe_load.cpu()
Expand Down
104 changes: 91 additions & 13 deletions vllm_ascend/ops/expert_load_balancer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import random

Check failure on line 2 in vllm_ascend/ops/expert_load_balancer.py

View workflow job for this annotation

GitHub Actions / lint (3.11, v0.9.1)

Ruff (F401)

vllm_ascend/ops/expert_load_balancer.py:2:8: F401 `random` imported but unused
from typing import Dict, List

import torch
from vllm.distributed import get_ep_group


class ExpertLoadBalancer(object):
Expand Down Expand Up @@ -56,9 +57,9 @@
dtype=torch.int32)
return expert_placement_map

def generate_log2phy_expert_map(self, layer_id):
def generate_log2phy_expert_map(self, layer_id, rank_id):
concatenated = torch.flatten(self.expert_map_tensor[layer_id])
rank_expert_to_global = self.generate_index_dicts(

Check failure on line 62 in vllm_ascend/ops/expert_load_balancer.py

View workflow job for this annotation

GitHub Actions / lint (3.11, v0.9.1)

Ruff (F841)

vllm_ascend/ops/expert_load_balancer.py:62:9: F841 Local variable `rank_expert_to_global` is assigned to but never used
self.expert_map_tensor[layer_id])
result_dict: Dict[int, List[int]] = {}
for idx, value in enumerate(concatenated):
Expand All @@ -67,18 +68,95 @@
result_dict[key] = []
result_dict[key].append(idx)

log2phy_map = torch.full((self.ranks_num, self.global_expert_num),
-1,
max_num_experts = max(len(locs) for locs in result_dict.values())
log2phy_map = torch.full((self.global_expert_num, max_num_experts),
0,
dtype=torch.int32)
for rank in range(self.ranks_num):
for key in result_dict:
indices_in_concat = result_dict[key]
if key in rank_expert_to_global[rank]:
log2phy_map[rank][key] = rank_expert_to_global[rank][key]
num_experts = torch.ones(self.global_expert_num, dtype=torch.int32)
# self.update_expert_map(result_dict, log2phy_map, max_num_experts, rank_id)
self.update_expert_loc_map_v1(result_dict, rank_id)
for log_ids, phy_ids in result_dict.items():
log2phy_map[log_ids, :len(phy_ids)] = torch.tensor(phy_ids)
num_experts[log_ids] = len(phy_ids)
return log2phy_map, num_experts

def update_expert_map(self, expert_loc, log2phy_map, max_num_dups, rank_id):
ep_size = get_ep_group().world_size
redundancy_shared_expert_num = self.get_global_redundant_expert_num()
n_total_experts = self.global_expert_num + redundancy_shared_expert_num

for i in range(self.global_expert_num):
same_rank_candidates = []
same_node_candidates = []
experts_per_device = n_total_experts // ep_size
all_candidates = []
phy_list = expert_loc[i]
current_device = rank_id
for phy in phy_list:
phy_device = phy // experts_per_device
if phy_device == current_device:
same_rank_candidates.append(phy)
elif (phy_device // self.ranks_num) == (current_device // self.ranks_num):
same_node_candidates.append(phy)
else:
all_candidates.append(phy)
tmp_expert_loc_map = torch.zeros([max_num_dups], dtype=torch.int32)

if same_rank_candidates:
expert_loc[i] = same_rank_candidates
elif same_node_candidates:
expert_loc[i] = same_node_candidates
tmp_expert_loc_map[: len(expert_loc[i])] = torch.tensor(expert_loc[i], dtype=torch.int32)

log2phy_map[i] = tmp_expert_loc_map

Comment on lines +83 to +112
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function update_expert_map appears to be dead code, as its only call site is commented out. It also seems to contain a bug where self.ranks_num is used to determine the node, which is likely incorrect. This unused and potentially buggy code should be removed to improve maintainability and reduce clutter.

def update_expert_loc_map_v1(self, expert_loc, current_rank):

device_per_host = 16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The device_per_host is hardcoded to 16. This makes the logic less flexible and difficult to adapt to different hardware configurations. This value should be passed in as a parameter or read from a configuration to improve portability.

ep_size = get_ep_group().world_size
current_node, current_rank_in_node = current_rank // device_per_host, current_rank % device_per_host
redundancy_shared_expert_num = self.get_global_redundant_expert_num()
n_total_experts = self.global_expert_num + redundancy_shared_expert_num
experts_per_device = n_total_experts // ep_size
num_hosts = self.ranks_num // device_per_host
for i in range(self.global_expert_num):
same_rank_candidates, same_node_candidates, all_candidates = [], [], []
phy_list, num_replicas = expert_loc[i], len(expert_loc[i])

for phy in phy_list:
phy_device = phy // experts_per_device
if phy_device == current_rank:
same_rank_candidates.append(phy)
elif (phy_device // device_per_host) == (current_rank // device_per_host):
same_node_candidates.append(phy)
else:
chosen_index = random.choice(indices_in_concat)
log2phy_map[rank][key] = chosen_index
return log2phy_map
all_candidates.append(phy)

is_imbalanced = False
if num_replicas > num_hosts and num_replicas % num_hosts != 0:
replica_per_node = {}
for phy in phy_list:
phy_device = phy // experts_per_device
phy_node = phy_device // device_per_host
local_rank = phy_device % device_per_host
if phy_node not in replica_per_node:
replica_per_node[phy_node] = []
replica_per_node[phy_node].append(local_rank)
base_replicas_per_host = num_replicas // num_hosts
if len(replica_per_node[current_node]) == base_replicas_per_host:
available_ranks = list(set(range(device_per_host)) - set(replica_per_node[current_node]))
expected_load = round(device_per_host / (base_replicas_per_host + 1))
if current_rank_in_node in available_ranks:
if available_ranks.index(current_rank_in_node) >= (expected_load - 1) * base_replicas_per_host:
is_imbalanced = True
Comment on lines +135 to +151
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic to determine is_imbalanced is very complex and difficult to understand, which impacts maintainability. The purpose of these calculations is not clear from the code alone. Please consider refactoring this section for simplicity and adding comments to explain the heuristic being implemented for handling imbalanced expert replica distributions.


if same_rank_candidates:
expert_loc[i] = same_rank_candidates
elif same_node_candidates and not is_imbalanced:
expert_loc[i] = same_node_candidates

return expert_loc


def get_rank_placement_map(self, layer_id, rank_id):
expert_placement_map = self.generate_expert_placement_map()
Expand All @@ -89,8 +167,8 @@
return rank_local_expert_num, rank_expert_map

def get_rank_log2phy_map(self, layer_id, rank_id):
layer_log2phy_map = self.generate_log2phy_expert_map(layer_id)
return layer_log2phy_map[rank_id]
layer_log2phy_map = self.generate_log2phy_expert_map(layer_id, rank_id)
return layer_log2phy_map

def get_global_redundant_expert_num(self):
global_redundant_expert_num = (
Expand Down
20 changes: 16 additions & 4 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,7 +1200,11 @@ def __init__(
expert_load_balancer.get_rank_placement_map(
self.moe_instance_id, self.ep_rank))
self.log2phy = expert_load_balancer.get_rank_log2phy_map(
self.moe_instance_id, self.ep_rank).npu()
self.moe_instance_id, get_ep_group().rank_in_group).npu()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The method get_rank_log2phy_map now returns a tuple of tensors (log2phy_map, num_experts). Calling .npu() on this tuple will raise an AttributeError. The .npu() call should be removed from this line. The subsequent lines correctly move the individual tensors from the tuple to the NPU device.

Suggested change
self.moe_instance_id, get_ep_group().rank_in_group).npu()
self.moe_instance_id, get_ep_group().rank_in_group)

log2phy_map, num_experts = self.log2phy
log2phy_map = log2phy_map.npu()
num_experts = num_experts.npu()
self.log2phy = log2phy_map, num_experts
self.global_redundant_expert_num = (
expert_load_balancer.get_global_redundant_expert_num())
else:
Expand Down Expand Up @@ -1431,8 +1435,16 @@ def forward(
e_hidden_states, expert_token_num, group_list_type = e_hidden_states

if self.dynamic_eplb:
self.moe_load += expert_token_num if group_list_type else \
torch.cat([expert_token_num[:1], expert_token_num[1:] - expert_token_num[:-1]])
if is_prefill:
token_nums = expert_token_num if group_list_type else \
torch.cat([expert_token_num[:1], expert_token_num[1:] - expert_token_num[:-1]])
self.moe_load.add_(token_nums)
else:
with npu_stream_switch("moe_load_async", 0):
npu_wait_tensor(hidden_states, expert_token_num)
token_nums = expert_token_num if group_list_type else \
torch.cat([expert_token_num[:1], expert_token_num[1:] - expert_token_num[:-1]])
self.moe_load.add_(token_nums)

if not self.enable_prefill_optimizations and fused_moe_state != FusedMoEState.AllGather and not enable_sp:
if tp_size > 1:
Expand Down Expand Up @@ -1470,7 +1482,7 @@ def get_map(self):
return self.expert_map

def get_log2phy_map(self):
return self.log2phy
return self.log2phy[0]

def clear_moe_load(self):
if self.moe_load is not None:
Expand Down
22 changes: 19 additions & 3 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,13 @@ def fused_experts_with_mc2(
hidden_states_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
token_selector: torch.Tensor = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor, int], Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, int]]:
assert mc2_mask is not None
if log2phy is not None:
topk_ids = log2phy[topk_ids]
log2phy_map, num_experts = log2phy
topk_ids = log2phy_map[topk_ids, token_selector[: topk_ids.shape[0]] % num_experts[topk_ids]]
quant_mode = 2
ep_group = get_mc2_group()
ep_rank_id = ep_group.rank_in_group
Expand Down Expand Up @@ -387,7 +389,11 @@ def fused_prefill_experts_with_mc2(
hidden_states_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
token_selector: torch.Tensor = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if log2phy is not None:
log2phy_map, num_experts = log2phy
topk_ids = log2phy_map[topk_ids, token_selector[: topk_ids.shape[0]] % num_experts[topk_ids]]
assert mc2_mask is not None
max_num_chunks = get_forward_context().max_num_chunks

Expand Down Expand Up @@ -496,9 +502,11 @@ def fused_experts_with_all2all(hidden_states: torch.Tensor,
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None):
w2_scale_bias: torch.Tensor = None,
token_selector: torch.Tensor = None,):
if log2phy is not None:
topk_ids = log2phy[topk_ids]
log2phy_map, num_experts = log2phy
topk_ids = log2phy_map[topk_ids, token_selector[: topk_ids.shape[0]] % num_experts[topk_ids]]
original_shape = hidden_states.shape
if len(original_shape) == 3:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
Expand Down Expand Up @@ -829,6 +837,11 @@ def __init__(self):

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled

from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
self.max_token_nums = vllm_config.scheduler_config.max_num_batched_tokens
self.token_selector = torch.arange(0, self.max_token_nums, dtype=torch.int32).view(-1, 1).npu()
self.enable_weight_nz_layout = ascend_config.enable_weight_nz_layout

try:
Expand Down Expand Up @@ -978,6 +991,7 @@ def apply(
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
token_selector=self.token_selector,
is_torchair=self.torchair_graph_enabled,
hidden_states_for_share=shared_gate_up,
dynamic_scale_for_share=shared_dequant_scale,
Expand All @@ -998,6 +1012,7 @@ def apply(
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
is_torchair=self.torchair_graph_enabled,
token_selector=self.token_selector,
hidden_states_for_share=shared_gate_up,
dynamic_scale_for_share=shared_dequant_scale,
mc2_mask=kwargs.get("mc2_mask", None))
Expand Down Expand Up @@ -1029,6 +1044,7 @@ def apply(
ep_group=self.ep_group,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
token_selector=self.token_selector,
)

def process_weights_after_loading(self, layer):
Expand Down
Loading