diff --git a/vllm_ascend/eplb/core/eplb_worker.py b/vllm_ascend/eplb/core/eplb_worker.py index ad37f35b9a..6242e105d8 100644 --- a/vllm_ascend/eplb/core/eplb_worker.py +++ b/vllm_ascend/eplb/core/eplb_worker.py @@ -15,7 +15,7 @@ # This file is a part of the vllm-ascend project. # -from multiprocessing import Process +from multiprocessing import Process, Queue from typing import Any import networkx as nx # type: ignore @@ -39,6 +39,7 @@ def __init__(self, shared_dict, policy_type, enable_d2d: bool = True): self.old_expert_maps = None self.enable_d2d = enable_d2d self.rank_id = dist.get_rank() + self.phy2log = None def do_update(self): # put data in to queue @@ -67,6 +68,7 @@ def do_update(self): self.num_local_experts) _, _, new_placement = self.calculate_rebalance_experts( load_info, old_placement) + self.phy2log = new_placement if not torch.is_tensor(new_placement): new_placement = torch.tensor(new_placement) @@ -383,6 +385,9 @@ def pack_update_info(self, update_info_generator): return list(zip(send_all, recv_all, maps, log2phy_all, layer_ids)) + def get_phy2log(self): + return self.phy2log + class EplbProcess: @@ -404,11 +409,21 @@ def __init__(self, self.planner_q = planner_q self.block_update_q = block_update_q + self.phy2log_q = Queue(maxsize=1) + self.phy2log = None + # Create EplbWorker instance self.worker = EplbWorker(self.shared_dict, self.policy_type, self.enable_d2d) - def worker_process(self, planner_q, block_update_q): + def get_phy2log(self): + if self.phy2log_q.empty(): + return None + else: + self.phy2log = self.phy2log_q.get() + return self.phy2log + + def worker_process(self, planner_q, block_update_q, phy2log_q): """ Subprocess entry: bind to specified NPU, loop waiting for planner_q to wake up, call do_update, then notify main process update is complete. """ @@ -418,12 +433,14 @@ def worker_process(self, planner_q, block_update_q): planner_q.get() packed_update_info = self.worker.do_update() + self.phy2log = self.worker.get_phy2log() while True: if not block_update_q.empty(): continue block_update_q.put(packed_update_info) break + phy2log_q.put(self.phy2log) except Exception as e: logger.warning(f"[EPLB subprocess Exiting due to error: {e}", @@ -435,7 +452,8 @@ def _launch_process(self): Use spawn method to launch subprocess and return (planner_q, block_update_q, proc). """ proc = Process(target=self.worker_process, - args=(self.planner_q, self.block_update_q), + args=(self.planner_q, self.block_update_q, + self.phy2log_q), daemon=True) proc.start() diff --git a/vllm_ascend/eplb/eplb_loggers.py b/vllm_ascend/eplb/eplb_loggers.py new file mode 100644 index 0000000000..02f16174a3 --- /dev/null +++ b/vllm_ascend/eplb/eplb_loggers.py @@ -0,0 +1,182 @@ +import json +import threading +import time +from typing import Optional + +import numpy as np +import prometheus_client +import torch +from vllm.distributed.parallel_state import get_ep_group +from vllm.logger import logger + +from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor + +RECORDING_TIME = 10 + + +class EplbStatLogger: + _instance = None + _gauge_cls = prometheus_client.Gauge + _counter_cls = prometheus_client.Counter + + def __init__(self, adaptor: VllmEplbAdaptor, + expert_map_path: Optional[str]): + self.rank = get_ep_group().rank + self.layers_num = adaptor.num_moe_layers + self.global_expert_num = adaptor.global_expert_num + self.ep_size = get_ep_group().world_size + + if expert_map_path is None: + self.phy2log_map = [[i for i in range(self.global_expert_num)] + for _ in range(self.layers_num)] + else: + self.phy2log_map = self._expert_file_to_list(expert_map_path) + self.global_expert_num = len(self.phy2log_map[0]) + + self.local_expert_num = self.global_expert_num // self.ep_size + + labelnames_phy_load = ["rank", "layer", "phy_expert_id"] + labelnames_phy2log = [ + "rank", "layer", "phy_expert_id", "log_expert_id" + ] + + self.phy_expert = self._counter_cls( + name="vllm:phy_expert_heat", + documentation="Heat of each physical expert per rank", + labelnames=labelnames_phy_load) + + self.phy2log = self._gauge_cls( + name="vllm:phy2log", + documentation="physical expert to logical expert per rank", + labelnames=labelnames_phy2log) + + self.do_record_loop = threading.Thread(target=self.record_loop, + daemon=True) + self.moe_load = None + + self.update_load = False + self.update_map = False + + # only init in rank0 + self.all_phy2log = [] + if self.rank == 0: + for layer_id in range(self.layers_num): + for phy_expert_id in range(self.global_expert_num): + self.phy_expert.labels( + rank=phy_expert_id // self.local_expert_num, + layer=layer_id, + phy_expert_id=phy_expert_id % self.local_expert_num) + + for layer_id in range(len(self.phy2log_map)): + local_phy2log = [] + for phy_expert_id, log_expert_id in enumerate( + self.phy2log_map[layer_id]): + a = self.phy2log.labels( + rank=phy_expert_id // self.local_expert_num, + layer=layer_id, + phy_expert_id=phy_expert_id % self.local_expert_num, + log_expert_id=log_expert_id) + a.set(1) + local_phy2log.append(a) + self.all_phy2log.append(local_phy2log) + + self.moe_load = torch.zeros( + (self.layers_num, self.ep_size, self.local_expert_num)) + + self.lock = threading.Lock() + self.start_loop() + + @staticmethod + def get_instance(): + if EplbStatLogger._instance is None: + raise ValueError( + "EplbStatLogger instance has not been initialized.") + return EplbStatLogger._instance + + @staticmethod + def init_instance(adaptor: VllmEplbAdaptor, + expert_map_path: Optional[str]): + """Initialize the singleton instance of ExpertLoadBalancer.""" + EplbStatLogger._instance = EplbStatLogger(adaptor, expert_map_path) + return EplbStatLogger._instance + + def record(self, moe_load, phy2log_map): + if self.rank != 0: + return + try: + with self.lock: + if phy2log_map is not None: + self.phy2log_map = phy2log_map + self.update_map = True + + if moe_load is not None: + torch.add(self.moe_load, moe_load, out=self.moe_load) + self.update_load = True + except Exception as e: + logger.debug(f"Record moe_load or phy2log error, error result:{e}") + + def record_loop(self): + while True: + try: + if self.update_map: + with self.lock: + self.update_map = False + phy2log_map = self.phy2log_map + phy2log_map = np.array(phy2log_map).reshape( + self.layers_num, -1) + self.record_phy2log(phy2log_map) + + if self.update_load: + with self.lock: + self.update_load = False + moe_load = self.moe_load.tolist() + self.moe_load.zero_() + moe_load = np.array(moe_load) + res = np.zeros_like(moe_load) + res[..., 0] = moe_load[..., 0] + res[..., 1:] = moe_load[..., 1:] - moe_load[..., :-1] + res = res.reshape(self.layers_num, -1) + self.record_expert_load(res) + except Exception as e: + logger.debug( + f"Record moe_load or phy2log prometheus error, error result:{e}" + ) + time.sleep(RECORDING_TIME) + + def start_loop(self): + self.do_record_loop.start() + + def record_phy2log(self, phy2log_map: list[list[int]]): + for layer_id in range(len(phy2log_map)): + for phy_expert_id, log_expert_id in enumerate( + phy2log_map[layer_id]): + self.all_phy2log[layer_id][phy_expert_id].set(0) + + a = self.phy2log.labels( + rank=phy_expert_id // self.local_expert_num, + layer=layer_id, + phy_expert_id=phy_expert_id % self.local_expert_num, + log_expert_id=log_expert_id) + a.set(1) + self.all_phy2log[layer_id][phy_expert_id] = a + + def record_expert_load(self, moe_load: list[list[int]]): + for layer_id in range(len(moe_load)): + for phy_expert_id, load in enumerate(moe_load[layer_id]): + self.phy_expert.labels( + rank=phy_expert_id // self.local_expert_num, + layer=layer_id, + phy_expert_id=phy_expert_id % self.local_expert_num, + ).inc(load) + + def _expert_file_to_list(self, expert_map_path: str): + with open(expert_map_path, "r") as f: + data = json.load(f) + + phy2log_data = [] + for layer in data["layer_list"]: + device_data = [] + for device in layer["device_list"]: + device_data += device["device_expert"] + phy2log_data.append(device_data) + return phy2log_data diff --git a/vllm_ascend/eplb/eplb_updator.py b/vllm_ascend/eplb/eplb_updator.py index 5081d969a0..122d583d57 100644 --- a/vllm_ascend/eplb/eplb_updator.py +++ b/vllm_ascend/eplb/eplb_updator.py @@ -150,8 +150,10 @@ def take_update_info_from_eplb_process(self): self.update_info_all = self.block_update_queue.get() def forward_end(self): + moe_load = None if self.wakeup_eplb_worker_flag(): self.compute_and_set_moe_load(is_clear=True) + moe_load = self.shared_dict.get("moe_load", None) self.wakeup_eplb_worker() if self.update_expert_weight_flag(): @@ -159,6 +161,11 @@ def forward_end(self): self.update_iteration() + if moe_load is not None: + return moe_load, self.eplb.get_phy2log() + else: + return None, None + def compute_and_set_moe_load(self, is_clear=False): local_load = self.adaptor.get_rank_expert_workload() diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 347f9126a5..0bd7ada446 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -20,6 +20,7 @@ import copy import gc import math +import os import time import types import weakref @@ -80,6 +81,7 @@ from vllm_ascend.attention.utils import \ AscendCommonAttentionMetadata as CommonAttentionMetadata from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor +from vllm_ascend.eplb.eplb_loggers import EplbStatLogger from vllm_ascend.eplb.eplb_updator import EplbUpdator from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.platform import NPUPlatform @@ -386,10 +388,13 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): #EPLB self.dynamic_eplb = ascend_config.dynamic_eplb + self.dynamic_eplb_metrics = os.getenv("DYNAMIC_EXPERT_LOAD_METRICS", + "0") == "1" if self.dynamic_eplb: self.eplb_adaptor: Optional[VllmEplbAdaptor] = None self.is_eplb_warmuped = False self.eplb_updator = EplbUpdator(ascend_config.expert_map_path) + self.ep_loggers: Optional[EplbStatLogger] = None # NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True self.in_profile_run = False @@ -1495,7 +1500,9 @@ def execute_model( " ".join(dr_str)) if self.dynamic_eplb: - self.eplb_updator.forward_end() + moe_load, phy2log_map = self.eplb_updator.forward_end() + if self.dynamic_eplb_metrics and moe_load is not None: + self.ep_loggers.record(moe_load, phy2log_map) return model_runner_output @@ -1836,6 +1843,10 @@ def eplb_warmup(self): if self.dynamic_eplb and not self.is_eplb_warmuped: self.is_eplb_warmuped = True self.eplb_adaptor = VllmEplbAdaptor(model=self.model) + if self.dynamic_eplb_metrics: + self.ep_loggers = EplbStatLogger.init_instance( + self.eplb_adaptor, + get_ascend_config().expert_map_path) self.eplb_updator.set_adaptor(self.eplb_adaptor) self.eplb_updator.warm_up_eplb()