From edf03134928630d0f8c395544b8f162a36615d28 Mon Sep 17 00:00:00 2001 From: hyh02297474 Date: Thu, 14 Aug 2025 11:07:43 +0800 Subject: [PATCH 01/13] metrics --- vllm_ascend/eplb/core/eplb_worker.py | 23 +++- vllm_ascend/eplb/eplb_loggers.py | 179 ++++++++++++++++++++++++++ vllm_ascend/eplb/eplb_updator.py | 7 + vllm_ascend/worker/model_runner_v1.py | 10 +- 4 files changed, 215 insertions(+), 4 deletions(-) create mode 100644 vllm_ascend/eplb/eplb_loggers.py diff --git a/vllm_ascend/eplb/core/eplb_worker.py b/vllm_ascend/eplb/core/eplb_worker.py index ad37f35b9a..ddf1244b2b 100644 --- a/vllm_ascend/eplb/core/eplb_worker.py +++ b/vllm_ascend/eplb/core/eplb_worker.py @@ -15,7 +15,7 @@ # This file is a part of the vllm-ascend project. # -from multiprocessing import Process +from multiprocessing import Process, Queue from typing import Any import networkx as nx # type: ignore @@ -39,6 +39,7 @@ def __init__(self, shared_dict, policy_type, enable_d2d: bool = True): self.old_expert_maps = None self.enable_d2d = enable_d2d self.rank_id = dist.get_rank() + self.phy2log = None def do_update(self): # put data in to queue @@ -67,6 +68,7 @@ def do_update(self): self.num_local_experts) _, _, new_placement = self.calculate_rebalance_experts( load_info, old_placement) + self.phy2log = new_placement if not torch.is_tensor(new_placement): new_placement = torch.tensor(new_placement) @@ -383,6 +385,9 @@ def pack_update_info(self, update_info_generator): return list(zip(send_all, recv_all, maps, log2phy_all, layer_ids)) + def get_phy2log(self): + return self.phy2log + class EplbProcess: @@ -404,11 +409,21 @@ def __init__(self, self.planner_q = planner_q self.block_update_q = block_update_q + self.phy2log_q = Queue(maxsize=1) + self.phy2log = None + # Create EplbWorker instance self.worker = EplbWorker(self.shared_dict, self.policy_type, self.enable_d2d) - def worker_process(self, planner_q, block_update_q): + def get_phy2log(self): + if self.phy2log_q.empty(): + return None + else: + self.phy2log = self.phy2log_q.get() + return self.phy2log + + def worker_process(self, planner_q, block_update_q, phy2log_q): """ Subprocess entry: bind to specified NPU, loop waiting for planner_q to wake up, call do_update, then notify main process update is complete. """ @@ -418,12 +433,14 @@ def worker_process(self, planner_q, block_update_q): planner_q.get() packed_update_info = self.worker.do_update() + self.phy2log = self.worker.get_phy2log() while True: if not block_update_q.empty(): continue block_update_q.put(packed_update_info) break + phy2log_q.put(self.phy2log) except Exception as e: logger.warning(f"[EPLB subprocess Exiting due to error: {e}", @@ -435,7 +452,7 @@ def _launch_process(self): Use spawn method to launch subprocess and return (planner_q, block_update_q, proc). """ proc = Process(target=self.worker_process, - args=(self.planner_q, self.block_update_q), + args=(self.planner_q, self.block_update_q, self.phy2log_q), daemon=True) proc.start() diff --git a/vllm_ascend/eplb/eplb_loggers.py b/vllm_ascend/eplb/eplb_loggers.py new file mode 100644 index 0000000000..b4d8cd0983 --- /dev/null +++ b/vllm_ascend/eplb/eplb_loggers.py @@ -0,0 +1,179 @@ +import time +import json +import threading +from typing import Optional + +import torch +import torch_npu +import prometheus_client +import numpy as np + +from vllm.logger import logger +from vllm.distributed.parallel_state import get_ep_group +from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor + +RECORDING_TIME = 10 + + +class EplbStatLogger: + _instance = None + _gauge_cls = prometheus_client.Gauge + _counter_cls = prometheus_client.Counter + + def __init__(self, adaptor: VllmEplbAdaptor, expert_map_path: Optional[str]): + self.rank = get_ep_group().rank + self.layers_num = adaptor.num_moe_layers + self.global_expert_num = adaptor.global_expert_num + self.ep_size = get_ep_group().world_size + + if expert_map_path is None: + # self.phy2log_map = torch.arange(self.global_expert_num).repeat(self.layers_num, 1) + self.phy2log_map = [[i for i in range(self.global_expert_num)] for _ in range(self.layers_num)] + else: + self.phy2log_map = self._expert_file_to_list(expert_map_path) + self.global_expert_num = len(self.phy2log_map[0]) + + self.local_expert_num = self.global_expert_num // self.ep_size + + labelnames_phy_load = ["rank", "layer", "phy_expert_id"] + labelnames_phy2log = ["rank", "layer", "phy_expert_id", "log_expert_id"] + + self.phy_expert = self._counter_cls( + name="vllm:phy_expert_heat", + documentation="Heat of each physical expert per rank", + labelnames=labelnames_phy_load) + + self.phy2log = self._gauge_cls( + name="vllm:phy2log", + documentation="physical expert to logical expert per rank", + labelnames=labelnames_phy2log) + + self.do_record_loop = threading.Thread(target=self.record_loop) + self.moe_load = None + + self.update_load = False + self.update_map = False + + # only init in rank0 + self.all_phy2log = [] + if self.rank == 0: + for layer_id in range(self.layers_num): + for phy_expert_id in range(self.global_expert_num): + self.phy_expert.labels(rank=phy_expert_id // self.local_expert_num, + layer=layer_id, + phy_expert_id=phy_expert_id % self.local_expert_num) + + for layer_id in range(len(self.phy2log_map)): + local_phy2log = [] + for phy_expert_id, log_expert_id in enumerate(self.phy2log_map[layer_id]): + a = self.phy2log.labels(rank=phy_expert_id // self.local_expert_num, + layer=layer_id, + phy_expert_id=phy_expert_id % self.local_expert_num, + log_expert_id=log_expert_id) + a.set(1) + local_phy2log.append(a) + self.all_phy2log.append(local_phy2log) + + self.moe_load = torch.zeros((self.layers_num, self.ep_size, self.local_expert_num)) + + self.lock = threading.Lock() + self.start_loop() + + @staticmethod + def get_instance(): + if EplbStatLogger._instance is None: + raise ValueError( + "ExpertLoadBalancer instance has not been initialized.") + return EplbStatLogger + + @staticmethod + def init_instance(adaptor: VllmEplbAdaptor, expert_map_path: Optional[str]): + """Initialize the singleton instance of ExpertLoadBalancer.""" + EplbStatLogger._instance = EplbStatLogger( + adaptor, expert_map_path) + return EplbStatLogger._instance + + def record(self, moe_load, phy2log_map): + if self.rank != 0: + return + try: + with self.lock: + if moe_load is not None: + torch.add(self.moe_load, moe_load, out=self.moe_load) + self.update_load = True + + if phy2log_map is not None: + self.phy2log_map = phy2log_map + self.update_map = True + except Exception as e: + logger.debug(f"Record moe_load or phy2log error, error result:{e}") + + def record_loop(self): + while True: + try: + if self.update_load: + with self.lock: + self.update_load = False + moe_load = self.moe_load.tolist() + self.moe_load.zero_() + moe_load = np.array(moe_load) + res = np.zeros_like(moe_load) + res[..., 0] = moe_load[..., 0] + res[..., 1:] = moe_load[..., 1:] - moe_load[..., :-1] + res = res.reshape(self.layers_num, -1) + self.record_expert_load(res) + + if self.update_map: + with self.lock: + self.update_map = False + phy2log_map = self.phy2log_map + phy2log_map = np.array(phy2log_map).reshape(self.layers_num, -1) + self.record_phy2log(phy2log_map) + except Exception as e: + logger.debug(f"Record moe_load or phy2log prometheus error, error result:{e}") + time.sleep(RECORDING_TIME) + + def start_loop(self): + self.do_record_loop.start() + + def record_phy2log(self, phy2log_map: list[list[int]]): + for layer_id in range(len(phy2log_map)): + for phy_expert_id, log_expert_id in enumerate(phy2log_map[layer_id]): + self.all_phy2log[layer_id][phy_expert_id].set(0) + + a = self.phy2log.labels( + rank=phy_expert_id // self.local_expert_num, + layer=layer_id, + phy_expert_id=phy_expert_id % self.local_expert_num, + log_expert_id=log_expert_id + ) + a.set(1) + self.all_phy2log[layer_id][phy_expert_id] = a + + def record_expert_load(self, moe_load: list[list[int]]): + for layer_id in range(len(moe_load)): + for phy_expert_id, load in enumerate(moe_load[layer_id]): + self.phy_expert.labels( + rank=phy_expert_id // self.local_expert_num, + layer=layer_id, + phy_expert_id=phy_expert_id % self.local_expert_num, + ).inc(load) + + def _expert_file_to_list(self, expert_map_path: str): + with open(expert_map_path, "r") as f: + data = json.load(f) + + phy2log_data = [] + for layer in data["layer_list"]: + device_data = [] + for device in layer["device_list"]: + device_data += device["device_expert"] + phy2log_data.append(device_data) + return phy2log_data + + def clear(self): + for layer_id in range(self.layers_num): + for phy_expert_id in range(self.global_expert_num): + self.phy_expert.labels(rank=phy_expert_id // self.local_expert_num, + layer=layer_id, + phy_expert_id=phy_expert_id % self.local_expert_num).reset() diff --git a/vllm_ascend/eplb/eplb_updator.py b/vllm_ascend/eplb/eplb_updator.py index 5081d969a0..122d583d57 100644 --- a/vllm_ascend/eplb/eplb_updator.py +++ b/vllm_ascend/eplb/eplb_updator.py @@ -150,8 +150,10 @@ def take_update_info_from_eplb_process(self): self.update_info_all = self.block_update_queue.get() def forward_end(self): + moe_load = None if self.wakeup_eplb_worker_flag(): self.compute_and_set_moe_load(is_clear=True) + moe_load = self.shared_dict.get("moe_load", None) self.wakeup_eplb_worker() if self.update_expert_weight_flag(): @@ -159,6 +161,11 @@ def forward_end(self): self.update_iteration() + if moe_load is not None: + return moe_load, self.eplb.get_phy2log() + else: + return None, None + def compute_and_set_moe_load(self, is_clear=False): local_load = self.adaptor.get_rank_expert_workload() diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 347f9126a5..f9e98c7d5b 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -17,6 +17,7 @@ # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # +import os import copy import gc import math @@ -81,6 +82,7 @@ AscendCommonAttentionMetadata as CommonAttentionMetadata from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor from vllm_ascend.eplb.eplb_updator import EplbUpdator +from vllm_ascend.eplb.eplb_loggers import EplbStatLogger from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler @@ -386,10 +388,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): #EPLB self.dynamic_eplb = ascend_config.dynamic_eplb + self.dynamic_eplb_metrics = os.getenv("DYNAMIC_EXPERT_LOAD_METRICS", False) if self.dynamic_eplb: self.eplb_adaptor: Optional[VllmEplbAdaptor] = None self.is_eplb_warmuped = False self.eplb_updator = EplbUpdator(ascend_config.expert_map_path) + self.ep_loggers: Optional[EplbStatLogger] = None # NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True self.in_profile_run = False @@ -1495,7 +1499,9 @@ def execute_model( " ".join(dr_str)) if self.dynamic_eplb: - self.eplb_updator.forward_end() + moe_load, phy2log_map = self.eplb_updator.forward_end() + if self.dynamic_eplb_metrics and moe_load is not None: + self.ep_loggers.record(moe_load, phy2log_map) return model_runner_output @@ -1836,6 +1842,8 @@ def eplb_warmup(self): if self.dynamic_eplb and not self.is_eplb_warmuped: self.is_eplb_warmuped = True self.eplb_adaptor = VllmEplbAdaptor(model=self.model) + if self.dynamic_eplb_metrics: + self.ep_loggers = EplbStatLogger.init_instance(self.eplb_adaptor, get_ascend_config().expert_map_path) self.eplb_updator.set_adaptor(self.eplb_adaptor) self.eplb_updator.warm_up_eplb() From f9bd6a352a0c3c38157c841a767aac6566e0bd6d Mon Sep 17 00:00:00 2001 From: hyh02297474 Date: Thu, 14 Aug 2025 11:45:10 +0800 Subject: [PATCH 02/13] youhua --- vllm_ascend/eplb/eplb_loggers.py | 5 ++--- vllm_ascend/worker/model_runner_v1.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/eplb/eplb_loggers.py b/vllm_ascend/eplb/eplb_loggers.py index b4d8cd0983..0419d75a3b 100644 --- a/vllm_ascend/eplb/eplb_loggers.py +++ b/vllm_ascend/eplb/eplb_loggers.py @@ -27,7 +27,6 @@ def __init__(self, adaptor: VllmEplbAdaptor, expert_map_path: Optional[str]): self.ep_size = get_ep_group().world_size if expert_map_path is None: - # self.phy2log_map = torch.arange(self.global_expert_num).repeat(self.layers_num, 1) self.phy2log_map = [[i for i in range(self.global_expert_num)] for _ in range(self.layers_num)] else: self.phy2log_map = self._expert_file_to_list(expert_map_path) @@ -48,7 +47,7 @@ def __init__(self, adaptor: VllmEplbAdaptor, expert_map_path: Optional[str]): documentation="physical expert to logical expert per rank", labelnames=labelnames_phy2log) - self.do_record_loop = threading.Thread(target=self.record_loop) + self.do_record_loop = threading.Thread(target=self.record_loop, daemon=True) self.moe_load = None self.update_load = False @@ -83,7 +82,7 @@ def __init__(self, adaptor: VllmEplbAdaptor, expert_map_path: Optional[str]): def get_instance(): if EplbStatLogger._instance is None: raise ValueError( - "ExpertLoadBalancer instance has not been initialized.") + "EplbStatLogger instance has not been initialized.") return EplbStatLogger @staticmethod diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f9e98c7d5b..ef1f7cac7e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -388,7 +388,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): #EPLB self.dynamic_eplb = ascend_config.dynamic_eplb - self.dynamic_eplb_metrics = os.getenv("DYNAMIC_EXPERT_LOAD_METRICS", False) + self.dynamic_eplb_metrics = os.getenv("DYNAMIC_EXPERT_LOAD_METRICS", "0") == "1" if self.dynamic_eplb: self.eplb_adaptor: Optional[VllmEplbAdaptor] = None self.is_eplb_warmuped = False From 5054361373f9c62f127e160141053ad353375c70 Mon Sep 17 00:00:00 2001 From: hyh02297474 Date: Thu, 14 Aug 2025 11:56:19 +0800 Subject: [PATCH 03/13] bug --- vllm_ascend/eplb/eplb_loggers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/eplb/eplb_loggers.py b/vllm_ascend/eplb/eplb_loggers.py index 0419d75a3b..770c34106b 100644 --- a/vllm_ascend/eplb/eplb_loggers.py +++ b/vllm_ascend/eplb/eplb_loggers.py @@ -83,7 +83,7 @@ def get_instance(): if EplbStatLogger._instance is None: raise ValueError( "EplbStatLogger instance has not been initialized.") - return EplbStatLogger + return EplbStatLogger._instance @staticmethod def init_instance(adaptor: VllmEplbAdaptor, expert_map_path: Optional[str]): From d7ee913c6622d06f34d92b760da80a4464d32701 Mon Sep 17 00:00:00 2001 From: hyh02297474 Date: Thu, 14 Aug 2025 14:15:33 +0800 Subject: [PATCH 04/13] bug --- vllm_ascend/eplb/eplb_loggers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_ascend/eplb/eplb_loggers.py b/vllm_ascend/eplb/eplb_loggers.py index 770c34106b..29e38a7136 100644 --- a/vllm_ascend/eplb/eplb_loggers.py +++ b/vllm_ascend/eplb/eplb_loggers.py @@ -4,7 +4,6 @@ from typing import Optional import torch -import torch_npu import prometheus_client import numpy as np From b0274b2d5713e1fdb3878a2a9b128e013640d287 Mon Sep 17 00:00:00 2001 From: hyh02297474 Date: Thu, 14 Aug 2025 15:04:12 +0800 Subject: [PATCH 05/13] geshi --- vllm_ascend/eplb/eplb_loggers.py | 8 ++++---- vllm_ascend/worker/model_runner_v1.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm_ascend/eplb/eplb_loggers.py b/vllm_ascend/eplb/eplb_loggers.py index 29e38a7136..62dd7e2d14 100644 --- a/vllm_ascend/eplb/eplb_loggers.py +++ b/vllm_ascend/eplb/eplb_loggers.py @@ -1,14 +1,14 @@ -import time import json +import time import threading from typing import Optional -import torch -import prometheus_client import numpy as np - +import prometheus_client +import torch from vllm.logger import logger from vllm.distributed.parallel_state import get_ep_group + from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor RECORDING_TIME = 10 diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ef1f7cac7e..514722fd30 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -17,10 +17,10 @@ # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # -import os import copy import gc import math +import os import time import types import weakref @@ -81,8 +81,8 @@ from vllm_ascend.attention.utils import \ AscendCommonAttentionMetadata as CommonAttentionMetadata from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor -from vllm_ascend.eplb.eplb_updator import EplbUpdator from vllm_ascend.eplb.eplb_loggers import EplbStatLogger +from vllm_ascend.eplb.eplb_updator import EplbUpdator from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler From 94d2d9c303b2979a6eaca5458d0a9125eb6c67cd Mon Sep 17 00:00:00 2001 From: hyh02297474 Date: Thu, 14 Aug 2025 15:09:57 +0800 Subject: [PATCH 06/13] geshi --- vllm_ascend/eplb/eplb_loggers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/eplb/eplb_loggers.py b/vllm_ascend/eplb/eplb_loggers.py index 62dd7e2d14..beec3a29bd 100644 --- a/vllm_ascend/eplb/eplb_loggers.py +++ b/vllm_ascend/eplb/eplb_loggers.py @@ -1,13 +1,13 @@ import json -import time import threading +import time from typing import Optional import numpy as np import prometheus_client import torch -from vllm.logger import logger from vllm.distributed.parallel_state import get_ep_group +from vllm.logger import logger from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor From 4fb7aec7a3144412e3ab2773c605a37773bd0d8a Mon Sep 17 00:00:00 2001 From: hyh02297474 Date: Thu, 14 Aug 2025 17:04:02 +0800 Subject: [PATCH 07/13] geshi --- vllm_ascend/eplb/core/eplb_worker.py | 3 +- vllm_ascend/eplb/eplb_loggers.py | 57 +++++++++++++++------------ vllm_ascend/worker/model_runner_v1.py | 7 +++- 3 files changed, 38 insertions(+), 29 deletions(-) diff --git a/vllm_ascend/eplb/core/eplb_worker.py b/vllm_ascend/eplb/core/eplb_worker.py index ddf1244b2b..6242e105d8 100644 --- a/vllm_ascend/eplb/core/eplb_worker.py +++ b/vllm_ascend/eplb/core/eplb_worker.py @@ -452,7 +452,8 @@ def _launch_process(self): Use spawn method to launch subprocess and return (planner_q, block_update_q, proc). """ proc = Process(target=self.worker_process, - args=(self.planner_q, self.block_update_q, self.phy2log_q), + args=(self.planner_q, self.block_update_q, + self.phy2log_q), daemon=True) proc.start() diff --git a/vllm_ascend/eplb/eplb_loggers.py b/vllm_ascend/eplb/eplb_loggers.py index beec3a29bd..cb9469c00c 100644 --- a/vllm_ascend/eplb/eplb_loggers.py +++ b/vllm_ascend/eplb/eplb_loggers.py @@ -26,7 +26,8 @@ def __init__(self, adaptor: VllmEplbAdaptor, expert_map_path: Optional[str]): self.ep_size = get_ep_group().world_size if expert_map_path is None: - self.phy2log_map = [[i for i in range(self.global_expert_num)] for _ in range(self.layers_num)] + self.phy2log_map = [[i for i in range(self.global_expert_num)] + for _ in range(self.layers_num)] else: self.phy2log_map = self._expert_file_to_list(expert_map_path) self.global_expert_num = len(self.phy2log_map[0]) @@ -34,7 +35,9 @@ def __init__(self, adaptor: VllmEplbAdaptor, expert_map_path: Optional[str]): self.local_expert_num = self.global_expert_num // self.ep_size labelnames_phy_load = ["rank", "layer", "phy_expert_id"] - labelnames_phy2log = ["rank", "layer", "phy_expert_id", "log_expert_id"] + labelnames_phy2log = [ + "rank", "layer", "phy_expert_id", "log_expert_id" + ] self.phy_expert = self._counter_cls( name="vllm:phy_expert_heat", @@ -46,7 +49,8 @@ def __init__(self, adaptor: VllmEplbAdaptor, expert_map_path: Optional[str]): documentation="physical expert to logical expert per rank", labelnames=labelnames_phy2log) - self.do_record_loop = threading.Thread(target=self.record_loop, daemon=True) + self.do_record_loop = threading.Thread(target=self.record_loop, + daemon=True) self.moe_load = None self.update_load = False @@ -57,22 +61,26 @@ def __init__(self, adaptor: VllmEplbAdaptor, expert_map_path: Optional[str]): if self.rank == 0: for layer_id in range(self.layers_num): for phy_expert_id in range(self.global_expert_num): - self.phy_expert.labels(rank=phy_expert_id // self.local_expert_num, - layer=layer_id, - phy_expert_id=phy_expert_id % self.local_expert_num) + self.phy_expert.labels( + rank=phy_expert_id // self.local_expert_num, + layer=layer_id, + phy_expert_id=phy_expert_id % self.local_expert_num) for layer_id in range(len(self.phy2log_map)): local_phy2log = [] - for phy_expert_id, log_expert_id in enumerate(self.phy2log_map[layer_id]): - a = self.phy2log.labels(rank=phy_expert_id // self.local_expert_num, - layer=layer_id, - phy_expert_id=phy_expert_id % self.local_expert_num, - log_expert_id=log_expert_id) + for phy_expert_id, log_expert_id in enumerate( + self.phy2log_map[layer_id]): + a = self.phy2log.labels( + rank=phy_expert_id // self.local_expert_num, + layer=layer_id, + phy_expert_id=phy_expert_id % self.local_expert_num, + log_expert_id=log_expert_id) a.set(1) local_phy2log.append(a) self.all_phy2log.append(local_phy2log) - self.moe_load = torch.zeros((self.layers_num, self.ep_size, self.local_expert_num)) + self.moe_load = torch.zeros( + (self.layers_num, self.ep_size, self.local_expert_num)) self.lock = threading.Lock() self.start_loop() @@ -85,10 +93,10 @@ def get_instance(): return EplbStatLogger._instance @staticmethod - def init_instance(adaptor: VllmEplbAdaptor, expert_map_path: Optional[str]): + def init_instance(adaptor: VllmEplbAdaptor, + expert_map_path: Optional[str]): """Initialize the singleton instance of ExpertLoadBalancer.""" - EplbStatLogger._instance = EplbStatLogger( - adaptor, expert_map_path) + EplbStatLogger._instance = EplbStatLogger(adaptor, expert_map_path) return EplbStatLogger._instance def record(self, moe_load, phy2log_map): @@ -125,10 +133,13 @@ def record_loop(self): with self.lock: self.update_map = False phy2log_map = self.phy2log_map - phy2log_map = np.array(phy2log_map).reshape(self.layers_num, -1) + phy2log_map = np.array(phy2log_map).reshape( + self.layers_num, -1) self.record_phy2log(phy2log_map) except Exception as e: - logger.debug(f"Record moe_load or phy2log prometheus error, error result:{e}") + logger.debug( + f"Record moe_load or phy2log prometheus error, error result:{e}" + ) time.sleep(RECORDING_TIME) def start_loop(self): @@ -136,15 +147,15 @@ def start_loop(self): def record_phy2log(self, phy2log_map: list[list[int]]): for layer_id in range(len(phy2log_map)): - for phy_expert_id, log_expert_id in enumerate(phy2log_map[layer_id]): + for phy_expert_id, log_expert_id in enumerate( + phy2log_map[layer_id]): self.all_phy2log[layer_id][phy_expert_id].set(0) a = self.phy2log.labels( rank=phy_expert_id // self.local_expert_num, layer=layer_id, phy_expert_id=phy_expert_id % self.local_expert_num, - log_expert_id=log_expert_id - ) + log_expert_id=log_expert_id) a.set(1) self.all_phy2log[layer_id][phy_expert_id] = a @@ -169,9 +180,3 @@ def _expert_file_to_list(self, expert_map_path: str): phy2log_data.append(device_data) return phy2log_data - def clear(self): - for layer_id in range(self.layers_num): - for phy_expert_id in range(self.global_expert_num): - self.phy_expert.labels(rank=phy_expert_id // self.local_expert_num, - layer=layer_id, - phy_expert_id=phy_expert_id % self.local_expert_num).reset() diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 514722fd30..0bd7ada446 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -388,7 +388,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): #EPLB self.dynamic_eplb = ascend_config.dynamic_eplb - self.dynamic_eplb_metrics = os.getenv("DYNAMIC_EXPERT_LOAD_METRICS", "0") == "1" + self.dynamic_eplb_metrics = os.getenv("DYNAMIC_EXPERT_LOAD_METRICS", + "0") == "1" if self.dynamic_eplb: self.eplb_adaptor: Optional[VllmEplbAdaptor] = None self.is_eplb_warmuped = False @@ -1843,7 +1844,9 @@ def eplb_warmup(self): self.is_eplb_warmuped = True self.eplb_adaptor = VllmEplbAdaptor(model=self.model) if self.dynamic_eplb_metrics: - self.ep_loggers = EplbStatLogger.init_instance(self.eplb_adaptor, get_ascend_config().expert_map_path) + self.ep_loggers = EplbStatLogger.init_instance( + self.eplb_adaptor, + get_ascend_config().expert_map_path) self.eplb_updator.set_adaptor(self.eplb_adaptor) self.eplb_updator.warm_up_eplb() From f421f65cba77f37c9b7b80fd2b351d60dd3ca4ef Mon Sep 17 00:00:00 2001 From: hyh02297474 Date: Thu, 14 Aug 2025 17:11:11 +0800 Subject: [PATCH 08/13] geshi --- vllm_ascend/eplb/eplb_loggers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/eplb/eplb_loggers.py b/vllm_ascend/eplb/eplb_loggers.py index cb9469c00c..de06ca22a8 100644 --- a/vllm_ascend/eplb/eplb_loggers.py +++ b/vllm_ascend/eplb/eplb_loggers.py @@ -19,7 +19,8 @@ class EplbStatLogger: _gauge_cls = prometheus_client.Gauge _counter_cls = prometheus_client.Counter - def __init__(self, adaptor: VllmEplbAdaptor, expert_map_path: Optional[str]): + def __init__(self, adaptor: VllmEplbAdaptor, + expert_map_path: Optional[str]): self.rank = get_ep_group().rank self.layers_num = adaptor.num_moe_layers self.global_expert_num = adaptor.global_expert_num From 0b70e05e6e14c18384c4268d3075176f8747ce59 Mon Sep 17 00:00:00 2001 From: hyh02297474 Date: Thu, 14 Aug 2025 17:15:52 +0800 Subject: [PATCH 09/13] geshi --- vllm_ascend/eplb/eplb_loggers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_ascend/eplb/eplb_loggers.py b/vllm_ascend/eplb/eplb_loggers.py index de06ca22a8..f99750fb3e 100644 --- a/vllm_ascend/eplb/eplb_loggers.py +++ b/vllm_ascend/eplb/eplb_loggers.py @@ -180,4 +180,3 @@ def _expert_file_to_list(self, expert_map_path: str): device_data += device["device_expert"] phy2log_data.append(device_data) return phy2log_data - From 271fc677b17323f95d5b206c01c95b96daa7490e Mon Sep 17 00:00:00 2001 From: hyh02297474 Date: Thu, 14 Aug 2025 17:42:26 +0800 Subject: [PATCH 10/13] geshi --- vllm_ascend/eplb/core/eplb_worker.py | 2 +- vllm_ascend/eplb/eplb_loggers.py | 2 +- vllm_ascend/worker/model_runner_v1.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/eplb/core/eplb_worker.py b/vllm_ascend/eplb/core/eplb_worker.py index 6242e105d8..8ced007dbc 100644 --- a/vllm_ascend/eplb/core/eplb_worker.py +++ b/vllm_ascend/eplb/core/eplb_worker.py @@ -409,7 +409,7 @@ def __init__(self, self.planner_q = planner_q self.block_update_q = block_update_q - self.phy2log_q = Queue(maxsize=1) + self.phy2log_q: Queue = Queue(maxsize=1) self.phy2log = None # Create EplbWorker instance diff --git a/vllm_ascend/eplb/eplb_loggers.py b/vllm_ascend/eplb/eplb_loggers.py index f99750fb3e..f048bbd4ad 100644 --- a/vllm_ascend/eplb/eplb_loggers.py +++ b/vllm_ascend/eplb/eplb_loggers.py @@ -118,7 +118,7 @@ def record(self, moe_load, phy2log_map): def record_loop(self): while True: try: - if self.update_load: + if self.update_load and self.moe_load is not None: with self.lock: self.update_load = False moe_load = self.moe_load.tolist() diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 0bd7ada446..f563d3f57e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1501,7 +1501,7 @@ def execute_model( if self.dynamic_eplb: moe_load, phy2log_map = self.eplb_updator.forward_end() - if self.dynamic_eplb_metrics and moe_load is not None: + if self.ep_loggers is not None and moe_load is not None: self.ep_loggers.record(moe_load, phy2log_map) return model_runner_output From 79acf0dd1d6c050ff475ef59ccf4b8c864df4a91 Mon Sep 17 00:00:00 2001 From: hyh02297474 Date: Thu, 14 Aug 2025 18:53:25 +0800 Subject: [PATCH 11/13] geshi --- vllm_ascend/worker/model_runner_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f563d3f57e..0bd7ada446 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1501,7 +1501,7 @@ def execute_model( if self.dynamic_eplb: moe_load, phy2log_map = self.eplb_updator.forward_end() - if self.ep_loggers is not None and moe_load is not None: + if self.dynamic_eplb_metrics and moe_load is not None: self.ep_loggers.record(moe_load, phy2log_map) return model_runner_output From 5ab63ce55eac92062f8cf25618343f0f7098be62 Mon Sep 17 00:00:00 2001 From: hyh02297474 Date: Thu, 14 Aug 2025 19:17:16 +0800 Subject: [PATCH 12/13] Revert "geshi" This reverts commit 79acf0dd1d6c050ff475ef59ccf4b8c864df4a91. --- vllm_ascend/worker/model_runner_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 0bd7ada446..f563d3f57e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1501,7 +1501,7 @@ def execute_model( if self.dynamic_eplb: moe_load, phy2log_map = self.eplb_updator.forward_end() - if self.dynamic_eplb_metrics and moe_load is not None: + if self.ep_loggers is not None and moe_load is not None: self.ep_loggers.record(moe_load, phy2log_map) return model_runner_output From 7aa09017038ce6364105ada43854c5bd25630172 Mon Sep 17 00:00:00 2001 From: hyh02297474 Date: Thu, 14 Aug 2025 19:17:31 +0800 Subject: [PATCH 13/13] Revert "geshi" This reverts commit 271fc677b17323f95d5b206c01c95b96daa7490e. --- vllm_ascend/eplb/core/eplb_worker.py | 2 +- vllm_ascend/eplb/eplb_loggers.py | 2 +- vllm_ascend/worker/model_runner_v1.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/eplb/core/eplb_worker.py b/vllm_ascend/eplb/core/eplb_worker.py index 8ced007dbc..6242e105d8 100644 --- a/vllm_ascend/eplb/core/eplb_worker.py +++ b/vllm_ascend/eplb/core/eplb_worker.py @@ -409,7 +409,7 @@ def __init__(self, self.planner_q = planner_q self.block_update_q = block_update_q - self.phy2log_q: Queue = Queue(maxsize=1) + self.phy2log_q = Queue(maxsize=1) self.phy2log = None # Create EplbWorker instance diff --git a/vllm_ascend/eplb/eplb_loggers.py b/vllm_ascend/eplb/eplb_loggers.py index f048bbd4ad..f99750fb3e 100644 --- a/vllm_ascend/eplb/eplb_loggers.py +++ b/vllm_ascend/eplb/eplb_loggers.py @@ -118,7 +118,7 @@ def record(self, moe_load, phy2log_map): def record_loop(self): while True: try: - if self.update_load and self.moe_load is not None: + if self.update_load: with self.lock: self.update_load = False moe_load = self.moe_load.tolist() diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f563d3f57e..0bd7ada446 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1501,7 +1501,7 @@ def execute_model( if self.dynamic_eplb: moe_load, phy2log_map = self.eplb_updator.forward_end() - if self.ep_loggers is not None and moe_load is not None: + if self.dynamic_eplb_metrics and moe_load is not None: self.ep_loggers.record(moe_load, phy2log_map) return model_runner_output