-
Notifications
You must be signed in to change notification settings - Fork 389
[v0.9.1-dev]dynamic eplb metrics #2364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: v0.9.1-dev
Are you sure you want to change the base?
Changes from 1 commit
edf0313
f9bd6a3
5054361
d7ee913
b0274b2
94d2d9c
4fb7aec
f421f65
0b70e05
271fc67
79acf0d
5ab63ce
7aa0901
87c09d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,179 @@ | ||||||||||||||||||||||
import time | ||||||||||||||||||||||
import json | ||||||||||||||||||||||
import threading | ||||||||||||||||||||||
from typing import Optional | ||||||||||||||||||||||
|
||||||||||||||||||||||
import torch | ||||||||||||||||||||||
import torch_npu | ||||||||||||||||||||||
import prometheus_client | ||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||
|
||||||||||||||||||||||
from vllm.logger import logger | ||||||||||||||||||||||
from vllm.distributed.parallel_state import get_ep_group | ||||||||||||||||||||||
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor | ||||||||||||||||||||||
|
||||||||||||||||||||||
RECORDING_TIME = 10 | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
class EplbStatLogger: | ||||||||||||||||||||||
_instance = None | ||||||||||||||||||||||
_gauge_cls = prometheus_client.Gauge | ||||||||||||||||||||||
_counter_cls = prometheus_client.Counter | ||||||||||||||||||||||
|
||||||||||||||||||||||
def __init__(self, adaptor: VllmEplbAdaptor, expert_map_path: Optional[str]): | ||||||||||||||||||||||
self.rank = get_ep_group().rank | ||||||||||||||||||||||
self.layers_num = adaptor.num_moe_layers | ||||||||||||||||||||||
self.global_expert_num = adaptor.global_expert_num | ||||||||||||||||||||||
self.ep_size = get_ep_group().world_size | ||||||||||||||||||||||
|
||||||||||||||||||||||
if expert_map_path is None: | ||||||||||||||||||||||
# self.phy2log_map = torch.arange(self.global_expert_num).repeat(self.layers_num, 1) | ||||||||||||||||||||||
self.phy2log_map = [[i for i in range(self.global_expert_num)] for _ in range(self.layers_num)] | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
self.phy2log_map = self._expert_file_to_list(expert_map_path) | ||||||||||||||||||||||
self.global_expert_num = len(self.phy2log_map[0]) | ||||||||||||||||||||||
|
||||||||||||||||||||||
self.local_expert_num = self.global_expert_num // self.ep_size | ||||||||||||||||||||||
|
||||||||||||||||||||||
labelnames_phy_load = ["rank", "layer", "phy_expert_id"] | ||||||||||||||||||||||
labelnames_phy2log = ["rank", "layer", "phy_expert_id", "log_expert_id"] | ||||||||||||||||||||||
|
||||||||||||||||||||||
self.phy_expert = self._counter_cls( | ||||||||||||||||||||||
name="vllm:phy_expert_heat", | ||||||||||||||||||||||
documentation="Heat of each physical expert per rank", | ||||||||||||||||||||||
labelnames=labelnames_phy_load) | ||||||||||||||||||||||
|
||||||||||||||||||||||
self.phy2log = self._gauge_cls( | ||||||||||||||||||||||
name="vllm:phy2log", | ||||||||||||||||||||||
documentation="physical expert to logical expert per rank", | ||||||||||||||||||||||
labelnames=labelnames_phy2log) | ||||||||||||||||||||||
|
||||||||||||||||||||||
self.do_record_loop = threading.Thread(target=self.record_loop) | ||||||||||||||||||||||
self.moe_load = None | ||||||||||||||||||||||
|
||||||||||||||||||||||
self.update_load = False | ||||||||||||||||||||||
self.update_map = False | ||||||||||||||||||||||
|
||||||||||||||||||||||
# only init in rank0 | ||||||||||||||||||||||
self.all_phy2log = [] | ||||||||||||||||||||||
if self.rank == 0: | ||||||||||||||||||||||
for layer_id in range(self.layers_num): | ||||||||||||||||||||||
for phy_expert_id in range(self.global_expert_num): | ||||||||||||||||||||||
self.phy_expert.labels(rank=phy_expert_id // self.local_expert_num, | ||||||||||||||||||||||
layer=layer_id, | ||||||||||||||||||||||
phy_expert_id=phy_expert_id % self.local_expert_num) | ||||||||||||||||||||||
|
||||||||||||||||||||||
for layer_id in range(len(self.phy2log_map)): | ||||||||||||||||||||||
local_phy2log = [] | ||||||||||||||||||||||
for phy_expert_id, log_expert_id in enumerate(self.phy2log_map[layer_id]): | ||||||||||||||||||||||
a = self.phy2log.labels(rank=phy_expert_id // self.local_expert_num, | ||||||||||||||||||||||
layer=layer_id, | ||||||||||||||||||||||
phy_expert_id=phy_expert_id % self.local_expert_num, | ||||||||||||||||||||||
log_expert_id=log_expert_id) | ||||||||||||||||||||||
a.set(1) | ||||||||||||||||||||||
local_phy2log.append(a) | ||||||||||||||||||||||
self.all_phy2log.append(local_phy2log) | ||||||||||||||||||||||
|
||||||||||||||||||||||
self.moe_load = torch.zeros((self.layers_num, self.ep_size, self.local_expert_num)) | ||||||||||||||||||||||
|
||||||||||||||||||||||
self.lock = threading.Lock() | ||||||||||||||||||||||
self.start_loop() | ||||||||||||||||||||||
|
||||||||||||||||||||||
@staticmethod | ||||||||||||||||||||||
def get_instance(): | ||||||||||||||||||||||
if EplbStatLogger._instance is None: | ||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||
"ExpertLoadBalancer instance has not been initialized.") | ||||||||||||||||||||||
return EplbStatLogger | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||
|
||||||||||||||||||||||
@staticmethod | ||||||||||||||||||||||
def init_instance(adaptor: VllmEplbAdaptor, expert_map_path: Optional[str]): | ||||||||||||||||||||||
"""Initialize the singleton instance of ExpertLoadBalancer.""" | ||||||||||||||||||||||
EplbStatLogger._instance = EplbStatLogger( | ||||||||||||||||||||||
adaptor, expert_map_path) | ||||||||||||||||||||||
return EplbStatLogger._instance | ||||||||||||||||||||||
|
||||||||||||||||||||||
def record(self, moe_load, phy2log_map): | ||||||||||||||||||||||
if self.rank != 0: | ||||||||||||||||||||||
return | ||||||||||||||||||||||
try: | ||||||||||||||||||||||
with self.lock: | ||||||||||||||||||||||
if moe_load is not None: | ||||||||||||||||||||||
torch.add(self.moe_load, moe_load, out=self.moe_load) | ||||||||||||||||||||||
self.update_load = True | ||||||||||||||||||||||
|
||||||||||||||||||||||
if phy2log_map is not None: | ||||||||||||||||||||||
self.phy2log_map = phy2log_map | ||||||||||||||||||||||
self.update_map = True | ||||||||||||||||||||||
except Exception as e: | ||||||||||||||||||||||
logger.debug(f"Record moe_load or phy2log error, error result:{e}") | ||||||||||||||||||||||
|
||||||||||||||||||||||
def record_loop(self): | ||||||||||||||||||||||
while True: | ||||||||||||||||||||||
try: | ||||||||||||||||||||||
if self.update_load: | ||||||||||||||||||||||
with self.lock: | ||||||||||||||||||||||
self.update_load = False | ||||||||||||||||||||||
moe_load = self.moe_load.tolist() | ||||||||||||||||||||||
self.moe_load.zero_() | ||||||||||||||||||||||
moe_load = np.array(moe_load) | ||||||||||||||||||||||
res = np.zeros_like(moe_load) | ||||||||||||||||||||||
res[..., 0] = moe_load[..., 0] | ||||||||||||||||||||||
res[..., 1:] = moe_load[..., 1:] - moe_load[..., :-1] | ||||||||||||||||||||||
res = res.reshape(self.layers_num, -1) | ||||||||||||||||||||||
self.record_expert_load(res) | ||||||||||||||||||||||
|
||||||||||||||||||||||
if self.update_map: | ||||||||||||||||||||||
with self.lock: | ||||||||||||||||||||||
self.update_map = False | ||||||||||||||||||||||
phy2log_map = self.phy2log_map | ||||||||||||||||||||||
phy2log_map = np.array(phy2log_map).reshape(self.layers_num, -1) | ||||||||||||||||||||||
self.record_phy2log(phy2log_map) | ||||||||||||||||||||||
except Exception as e: | ||||||||||||||||||||||
logger.debug(f"Record moe_load or phy2log prometheus error, error result:{e}") | ||||||||||||||||||||||
time.sleep(RECORDING_TIME) | ||||||||||||||||||||||
|
||||||||||||||||||||||
def start_loop(self): | ||||||||||||||||||||||
self.do_record_loop.start() | ||||||||||||||||||||||
|
||||||||||||||||||||||
def record_phy2log(self, phy2log_map: list[list[int]]): | ||||||||||||||||||||||
for layer_id in range(len(phy2log_map)): | ||||||||||||||||||||||
for phy_expert_id, log_expert_id in enumerate(phy2log_map[layer_id]): | ||||||||||||||||||||||
self.all_phy2log[layer_id][phy_expert_id].set(0) | ||||||||||||||||||||||
|
||||||||||||||||||||||
a = self.phy2log.labels( | ||||||||||||||||||||||
rank=phy_expert_id // self.local_expert_num, | ||||||||||||||||||||||
layer=layer_id, | ||||||||||||||||||||||
phy_expert_id=phy_expert_id % self.local_expert_num, | ||||||||||||||||||||||
log_expert_id=log_expert_id | ||||||||||||||||||||||
) | ||||||||||||||||||||||
a.set(1) | ||||||||||||||||||||||
self.all_phy2log[layer_id][phy_expert_id] = a | ||||||||||||||||||||||
|
||||||||||||||||||||||
def record_expert_load(self, moe_load: list[list[int]]): | ||||||||||||||||||||||
for layer_id in range(len(moe_load)): | ||||||||||||||||||||||
for phy_expert_id, load in enumerate(moe_load[layer_id]): | ||||||||||||||||||||||
self.phy_expert.labels( | ||||||||||||||||||||||
rank=phy_expert_id // self.local_expert_num, | ||||||||||||||||||||||
layer=layer_id, | ||||||||||||||||||||||
phy_expert_id=phy_expert_id % self.local_expert_num, | ||||||||||||||||||||||
).inc(load) | ||||||||||||||||||||||
|
||||||||||||||||||||||
def _expert_file_to_list(self, expert_map_path: str): | ||||||||||||||||||||||
with open(expert_map_path, "r") as f: | ||||||||||||||||||||||
data = json.load(f) | ||||||||||||||||||||||
|
||||||||||||||||||||||
phy2log_data = [] | ||||||||||||||||||||||
for layer in data["layer_list"]: | ||||||||||||||||||||||
device_data = [] | ||||||||||||||||||||||
for device in layer["device_list"]: | ||||||||||||||||||||||
device_data += device["device_expert"] | ||||||||||||||||||||||
phy2log_data.append(device_data) | ||||||||||||||||||||||
return phy2log_data | ||||||||||||||||||||||
|
||||||||||||||||||||||
def clear(self): | ||||||||||||||||||||||
for layer_id in range(self.layers_num): | ||||||||||||||||||||||
for phy_expert_id in range(self.global_expert_num): | ||||||||||||||||||||||
self.phy_expert.labels(rank=phy_expert_id // self.local_expert_num, | ||||||||||||||||||||||
layer=layer_id, | ||||||||||||||||||||||
phy_expert_id=phy_expert_id % self.local_expert_num).reset() |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -17,6 +17,7 @@ | |||||
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py | ||||||
# | ||||||
|
||||||
import os | ||||||
import copy | ||||||
import gc | ||||||
import math | ||||||
|
@@ -81,6 +82,7 @@ | |||||
AscendCommonAttentionMetadata as CommonAttentionMetadata | ||||||
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor | ||||||
from vllm_ascend.eplb.eplb_updator import EplbUpdator | ||||||
from vllm_ascend.eplb.eplb_loggers import EplbStatLogger | ||||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index | ||||||
from vllm_ascend.platform import NPUPlatform | ||||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler | ||||||
|
@@ -386,10 +388,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): | |||||
|
||||||
#EPLB | ||||||
self.dynamic_eplb = ascend_config.dynamic_eplb | ||||||
self.dynamic_eplb_metrics = os.getenv("DYNAMIC_EXPERT_LOAD_METRICS", False) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||
if self.dynamic_eplb: | ||||||
self.eplb_adaptor: Optional[VllmEplbAdaptor] = None | ||||||
self.is_eplb_warmuped = False | ||||||
self.eplb_updator = EplbUpdator(ascend_config.expert_map_path) | ||||||
self.ep_loggers: Optional[EplbStatLogger] = None | ||||||
|
||||||
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True | ||||||
self.in_profile_run = False | ||||||
|
@@ -1495,7 +1499,9 @@ def execute_model( | |||||
" ".join(dr_str)) | ||||||
|
||||||
if self.dynamic_eplb: | ||||||
self.eplb_updator.forward_end() | ||||||
moe_load, phy2log_map = self.eplb_updator.forward_end() | ||||||
if self.dynamic_eplb_metrics and moe_load is not None: | ||||||
self.ep_loggers.record(moe_load, phy2log_map) | ||||||
|
||||||
return model_runner_output | ||||||
|
||||||
|
@@ -1836,6 +1842,8 @@ def eplb_warmup(self): | |||||
if self.dynamic_eplb and not self.is_eplb_warmuped: | ||||||
self.is_eplb_warmuped = True | ||||||
self.eplb_adaptor = VllmEplbAdaptor(model=self.model) | ||||||
if self.dynamic_eplb_metrics: | ||||||
self.ep_loggers = EplbStatLogger.init_instance(self.eplb_adaptor, get_ascend_config().expert_map_path) | ||||||
self.eplb_updator.set_adaptor(self.eplb_adaptor) | ||||||
self.eplb_updator.warm_up_eplb() | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The background thread
do_record_loop
is not configured as a daemon thread. This can prevent the application from shutting down cleanly, as the non-daemon thread will keep the process alive. It's recommended to create it as a daemon thread to ensure the process can exit properly.