-
Notifications
You must be signed in to change notification settings - Fork 387
Topology-Aware Expert Load Balancing(TA-ELB) optimazation for EPLB #2351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0ce0ef1
da220a8
5e6edc2
a70044f
9ebb2a3
7c405b2
1427299
ef3f1be
60e12b0
e6ded60
a8de27c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -87,6 +87,8 @@ def init_eplb(self, expert_map_path): | |
|
||
self.eplb_process = self.eplb._launch_process() | ||
|
||
self.compute_moe_load_async = torch.npu.Stream() | ||
|
||
logger.info( | ||
f"[ModelRunner] Launched EPLB process (pid={self.eplb_process.pid})" | ||
) | ||
|
@@ -164,21 +166,20 @@ def compute_and_set_moe_load(self, is_clear=False): | |
|
||
self._gather_buffer = None | ||
if dist.is_initialized(): | ||
self.world_size = dist.get_world_size() | ||
self.device = local_load.device | ||
if self._gather_buffer is None: | ||
shape = (self.world_size, *local_load.shape) | ||
self._gather_buffer = torch.empty(shape, | ||
dtype=local_load.dtype, | ||
device=self.device) | ||
|
||
dist.all_gather_into_tensor(self._gather_buffer, local_load) | ||
|
||
moe_load = self._gather_buffer.permute(1, 0, 2) | ||
self.shared_dict["moe_load"] = moe_load.cpu() | ||
logger.debug( | ||
f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}" | ||
) | ||
with torch.npu.stream(self.compute_moe_load_async): | ||
self.world_size = dist.get_world_size() | ||
self.device = local_load.device | ||
if self._gather_buffer is None: | ||
shape = (self.world_size, *local_load.shape) | ||
self._gather_buffer = torch.empty(shape, | ||
dtype=local_load.dtype, | ||
device=self.device) | ||
|
||
dist.all_gather_into_tensor(self._gather_buffer, local_load) | ||
|
||
moe_load = self._gather_buffer.permute(1, 0, 2) | ||
self.shared_dict["moe_load"] = moe_load.cpu() | ||
logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}") | ||
Comment on lines
+169
to
+182
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic within this |
||
else: | ||
moe_load = local_load.unsqueeze(1) | ||
self.shared_dict["moe_load"] = moe_load.cpu() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
import json | ||
import random | ||
from typing import Dict, List | ||
|
||
import torch | ||
from vllm.distributed import get_ep_group | ||
|
||
|
||
class ExpertLoadBalancer(object): | ||
|
@@ -56,9 +57,9 @@ | |
dtype=torch.int32) | ||
return expert_placement_map | ||
|
||
def generate_log2phy_expert_map(self, layer_id): | ||
def generate_log2phy_expert_map(self, layer_id, rank_id): | ||
concatenated = torch.flatten(self.expert_map_tensor[layer_id]) | ||
rank_expert_to_global = self.generate_index_dicts( | ||
self.expert_map_tensor[layer_id]) | ||
result_dict: Dict[int, List[int]] = {} | ||
for idx, value in enumerate(concatenated): | ||
|
@@ -67,18 +68,95 @@ | |
result_dict[key] = [] | ||
result_dict[key].append(idx) | ||
|
||
log2phy_map = torch.full((self.ranks_num, self.global_expert_num), | ||
-1, | ||
max_num_experts = max(len(locs) for locs in result_dict.values()) | ||
log2phy_map = torch.full((self.global_expert_num, max_num_experts), | ||
0, | ||
dtype=torch.int32) | ||
for rank in range(self.ranks_num): | ||
for key in result_dict: | ||
indices_in_concat = result_dict[key] | ||
if key in rank_expert_to_global[rank]: | ||
log2phy_map[rank][key] = rank_expert_to_global[rank][key] | ||
num_experts = torch.ones(self.global_expert_num, dtype=torch.int32) | ||
# self.update_expert_map(result_dict, log2phy_map, max_num_experts, rank_id) | ||
self.update_expert_loc_map_v1(result_dict, rank_id) | ||
for log_ids, phy_ids in result_dict.items(): | ||
log2phy_map[log_ids, :len(phy_ids)] = torch.tensor(phy_ids) | ||
num_experts[log_ids] = len(phy_ids) | ||
return log2phy_map, num_experts | ||
|
||
def update_expert_map(self, expert_loc, log2phy_map, max_num_dups, rank_id): | ||
ep_size = get_ep_group().world_size | ||
redundancy_shared_expert_num = self.get_global_redundant_expert_num() | ||
n_total_experts = self.global_expert_num + redundancy_shared_expert_num | ||
|
||
for i in range(self.global_expert_num): | ||
same_rank_candidates = [] | ||
same_node_candidates = [] | ||
experts_per_device = n_total_experts // ep_size | ||
all_candidates = [] | ||
phy_list = expert_loc[i] | ||
current_device = rank_id | ||
for phy in phy_list: | ||
phy_device = phy // experts_per_device | ||
if phy_device == current_device: | ||
same_rank_candidates.append(phy) | ||
elif (phy_device // self.ranks_num) == (current_device // self.ranks_num): | ||
same_node_candidates.append(phy) | ||
else: | ||
all_candidates.append(phy) | ||
tmp_expert_loc_map = torch.zeros([max_num_dups], dtype=torch.int32) | ||
|
||
if same_rank_candidates: | ||
expert_loc[i] = same_rank_candidates | ||
elif same_node_candidates: | ||
expert_loc[i] = same_node_candidates | ||
tmp_expert_loc_map[: len(expert_loc[i])] = torch.tensor(expert_loc[i], dtype=torch.int32) | ||
|
||
log2phy_map[i] = tmp_expert_loc_map | ||
|
||
Comment on lines
+83
to
+112
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function |
||
def update_expert_loc_map_v1(self, expert_loc, current_rank): | ||
|
||
device_per_host = 16 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
ep_size = get_ep_group().world_size | ||
current_node, current_rank_in_node = current_rank // device_per_host, current_rank % device_per_host | ||
redundancy_shared_expert_num = self.get_global_redundant_expert_num() | ||
n_total_experts = self.global_expert_num + redundancy_shared_expert_num | ||
experts_per_device = n_total_experts // ep_size | ||
num_hosts = self.ranks_num // device_per_host | ||
for i in range(self.global_expert_num): | ||
same_rank_candidates, same_node_candidates, all_candidates = [], [], [] | ||
phy_list, num_replicas = expert_loc[i], len(expert_loc[i]) | ||
|
||
for phy in phy_list: | ||
phy_device = phy // experts_per_device | ||
if phy_device == current_rank: | ||
same_rank_candidates.append(phy) | ||
elif (phy_device // device_per_host) == (current_rank // device_per_host): | ||
same_node_candidates.append(phy) | ||
else: | ||
chosen_index = random.choice(indices_in_concat) | ||
log2phy_map[rank][key] = chosen_index | ||
return log2phy_map | ||
all_candidates.append(phy) | ||
|
||
is_imbalanced = False | ||
if num_replicas > num_hosts and num_replicas % num_hosts != 0: | ||
replica_per_node = {} | ||
for phy in phy_list: | ||
phy_device = phy // experts_per_device | ||
phy_node = phy_device // device_per_host | ||
local_rank = phy_device % device_per_host | ||
if phy_node not in replica_per_node: | ||
replica_per_node[phy_node] = [] | ||
replica_per_node[phy_node].append(local_rank) | ||
base_replicas_per_host = num_replicas // num_hosts | ||
if len(replica_per_node[current_node]) == base_replicas_per_host: | ||
available_ranks = list(set(range(device_per_host)) - set(replica_per_node[current_node])) | ||
expected_load = round(device_per_host / (base_replicas_per_host + 1)) | ||
if current_rank_in_node in available_ranks: | ||
if available_ranks.index(current_rank_in_node) >= (expected_load - 1) * base_replicas_per_host: | ||
is_imbalanced = True | ||
Comment on lines
+135
to
+151
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic to determine |
||
|
||
if same_rank_candidates: | ||
expert_loc[i] = same_rank_candidates | ||
elif same_node_candidates and not is_imbalanced: | ||
expert_loc[i] = same_node_candidates | ||
|
||
return expert_loc | ||
|
||
|
||
def get_rank_placement_map(self, layer_id, rank_id): | ||
expert_placement_map = self.generate_expert_placement_map() | ||
|
@@ -89,8 +167,8 @@ | |
return rank_local_expert_num, rank_expert_map | ||
|
||
def get_rank_log2phy_map(self, layer_id, rank_id): | ||
layer_log2phy_map = self.generate_log2phy_expert_map(layer_id) | ||
return layer_log2phy_map[rank_id] | ||
layer_log2phy_map = self.generate_log2phy_expert_map(layer_id, rank_id) | ||
return layer_log2phy_map | ||
|
||
def get_global_redundant_expert_num(self): | ||
global_redundant_expert_num = ( | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -1200,7 +1200,11 @@ def __init__( | |||||
expert_load_balancer.get_rank_placement_map( | ||||||
self.moe_instance_id, self.ep_rank)) | ||||||
self.log2phy = expert_load_balancer.get_rank_log2phy_map( | ||||||
self.moe_instance_id, self.ep_rank).npu() | ||||||
self.moe_instance_id, get_ep_group().rank_in_group).npu() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The method
Suggested change
|
||||||
log2phy_map, num_experts = self.log2phy | ||||||
log2phy_map = log2phy_map.npu() | ||||||
num_experts = num_experts.npu() | ||||||
self.log2phy = log2phy_map, num_experts | ||||||
self.global_redundant_expert_num = ( | ||||||
expert_load_balancer.get_global_redundant_expert_num()) | ||||||
else: | ||||||
|
@@ -1431,8 +1435,16 @@ def forward( | |||||
e_hidden_states, expert_token_num, group_list_type = e_hidden_states | ||||||
|
||||||
if self.dynamic_eplb: | ||||||
self.moe_load += expert_token_num if group_list_type else \ | ||||||
torch.cat([expert_token_num[:1], expert_token_num[1:] - expert_token_num[:-1]]) | ||||||
if is_prefill: | ||||||
token_nums = expert_token_num if group_list_type else \ | ||||||
torch.cat([expert_token_num[:1], expert_token_num[1:] - expert_token_num[:-1]]) | ||||||
self.moe_load.add_(token_nums) | ||||||
else: | ||||||
with npu_stream_switch("moe_load_async", 0): | ||||||
npu_wait_tensor(hidden_states, expert_token_num) | ||||||
token_nums = expert_token_num if group_list_type else \ | ||||||
torch.cat([expert_token_num[:1], expert_token_num[1:] - expert_token_num[:-1]]) | ||||||
self.moe_load.add_(token_nums) | ||||||
|
||||||
if not self.enable_prefill_optimizations and fused_moe_state != FusedMoEState.AllGather and not enable_sp: | ||||||
if tp_size > 1: | ||||||
|
@@ -1470,7 +1482,7 @@ def get_map(self): | |||||
return self.expert_map | ||||||
|
||||||
def get_log2phy_map(self): | ||||||
return self.log2phy | ||||||
return self.log2phy[0] | ||||||
|
||||||
def clear_moe_load(self): | ||||||
if self.moe_load is not None: | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition to filter eligible boxes (cards) for expert placement is very complex and difficult to understand. This level of complexity can make the code hard to maintain and debug. Consider refactoring this logic for clarity and simplicity. The core goal seems to be balancing experts across hosts. A simpler condition might be to only allow placing an expert on a host that has the minimum number of instances of that expert.