Skip to content

[v0.9.1-dev]dynamic eplb metrics #2364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: v0.9.1-dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions vllm_ascend/eplb/core/eplb_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# This file is a part of the vllm-ascend project.
#

from multiprocessing import Process
from multiprocessing import Process, Queue
from typing import Any

import networkx as nx # type: ignore
Expand All @@ -39,6 +39,7 @@ def __init__(self, shared_dict, policy_type, enable_d2d: bool = True):
self.old_expert_maps = None
self.enable_d2d = enable_d2d
self.rank_id = dist.get_rank()
self.phy2log = None

def do_update(self):
# put data in to queue
Expand Down Expand Up @@ -67,6 +68,7 @@ def do_update(self):
self.num_local_experts)
_, _, new_placement = self.calculate_rebalance_experts(
load_info, old_placement)
self.phy2log = new_placement

if not torch.is_tensor(new_placement):
new_placement = torch.tensor(new_placement)
Expand Down Expand Up @@ -383,6 +385,9 @@ def pack_update_info(self, update_info_generator):

return list(zip(send_all, recv_all, maps, log2phy_all, layer_ids))

def get_phy2log(self):
return self.phy2log


class EplbProcess:

Expand All @@ -404,11 +409,21 @@ def __init__(self,
self.planner_q = planner_q
self.block_update_q = block_update_q

self.phy2log_q = Queue(maxsize=1)
self.phy2log = None

# Create EplbWorker instance
self.worker = EplbWorker(self.shared_dict, self.policy_type,
self.enable_d2d)

def worker_process(self, planner_q, block_update_q):
def get_phy2log(self):
if self.phy2log_q.empty():
return None
else:
self.phy2log = self.phy2log_q.get()
return self.phy2log

def worker_process(self, planner_q, block_update_q, phy2log_q):
"""
Subprocess entry: bind to specified NPU, loop waiting for planner_q to wake up, call do_update, then notify main process update is complete.
"""
Expand All @@ -418,12 +433,14 @@ def worker_process(self, planner_q, block_update_q):
planner_q.get()

packed_update_info = self.worker.do_update()
self.phy2log = self.worker.get_phy2log()

while True:
if not block_update_q.empty():
continue
block_update_q.put(packed_update_info)
break
phy2log_q.put(self.phy2log)

except Exception as e:
logger.warning(f"[EPLB subprocess Exiting due to error: {e}",
Expand All @@ -435,7 +452,8 @@ def _launch_process(self):
Use spawn method to launch subprocess and return (planner_q, block_update_q, proc).
"""
proc = Process(target=self.worker_process,
args=(self.planner_q, self.block_update_q),
args=(self.planner_q, self.block_update_q,
self.phy2log_q),
daemon=True)

proc.start()
Expand Down
182 changes: 182 additions & 0 deletions vllm_ascend/eplb/eplb_loggers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import json
import threading
import time
from typing import Optional

import numpy as np
import prometheus_client
import torch
from vllm.distributed.parallel_state import get_ep_group
from vllm.logger import logger

from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor

RECORDING_TIME = 10


class EplbStatLogger:
_instance = None
_gauge_cls = prometheus_client.Gauge
_counter_cls = prometheus_client.Counter

def __init__(self, adaptor: VllmEplbAdaptor,
expert_map_path: Optional[str]):
self.rank = get_ep_group().rank
self.layers_num = adaptor.num_moe_layers
self.global_expert_num = adaptor.global_expert_num
self.ep_size = get_ep_group().world_size

if expert_map_path is None:
self.phy2log_map = [[i for i in range(self.global_expert_num)]
for _ in range(self.layers_num)]
else:
self.phy2log_map = self._expert_file_to_list(expert_map_path)
self.global_expert_num = len(self.phy2log_map[0])

self.local_expert_num = self.global_expert_num // self.ep_size

labelnames_phy_load = ["rank", "layer", "phy_expert_id"]
labelnames_phy2log = [
"rank", "layer", "phy_expert_id", "log_expert_id"
]

self.phy_expert = self._counter_cls(
name="vllm:phy_expert_heat",
documentation="Heat of each physical expert per rank",
labelnames=labelnames_phy_load)

self.phy2log = self._gauge_cls(
name="vllm:phy2log",
documentation="physical expert to logical expert per rank",
labelnames=labelnames_phy2log)

self.do_record_loop = threading.Thread(target=self.record_loop,
daemon=True)
self.moe_load = None

self.update_load = False
self.update_map = False

# only init in rank0
self.all_phy2log = []
if self.rank == 0:
for layer_id in range(self.layers_num):
for phy_expert_id in range(self.global_expert_num):
self.phy_expert.labels(
rank=phy_expert_id // self.local_expert_num,
layer=layer_id,
phy_expert_id=phy_expert_id % self.local_expert_num)

for layer_id in range(len(self.phy2log_map)):
local_phy2log = []
for phy_expert_id, log_expert_id in enumerate(
self.phy2log_map[layer_id]):
a = self.phy2log.labels(
rank=phy_expert_id // self.local_expert_num,
layer=layer_id,
phy_expert_id=phy_expert_id % self.local_expert_num,
log_expert_id=log_expert_id)
a.set(1)
local_phy2log.append(a)
self.all_phy2log.append(local_phy2log)

self.moe_load = torch.zeros(
(self.layers_num, self.ep_size, self.local_expert_num))

self.lock = threading.Lock()
self.start_loop()

@staticmethod
def get_instance():
if EplbStatLogger._instance is None:
raise ValueError(
"EplbStatLogger instance has not been initialized.")
return EplbStatLogger._instance

@staticmethod
def init_instance(adaptor: VllmEplbAdaptor,
expert_map_path: Optional[str]):
"""Initialize the singleton instance of ExpertLoadBalancer."""
EplbStatLogger._instance = EplbStatLogger(adaptor, expert_map_path)
return EplbStatLogger._instance

def record(self, moe_load, phy2log_map):
if self.rank != 0:
return
try:
with self.lock:
if moe_load is not None:
torch.add(self.moe_load, moe_load, out=self.moe_load)
self.update_load = True

if phy2log_map is not None:
self.phy2log_map = phy2log_map
self.update_map = True
except Exception as e:
logger.debug(f"Record moe_load or phy2log error, error result:{e}")

def record_loop(self):
while True:
try:
if self.update_load:
with self.lock:
self.update_load = False
moe_load = self.moe_load.tolist()
self.moe_load.zero_()
moe_load = np.array(moe_load)
res = np.zeros_like(moe_load)
res[..., 0] = moe_load[..., 0]
res[..., 1:] = moe_load[..., 1:] - moe_load[..., :-1]
res = res.reshape(self.layers_num, -1)
self.record_expert_load(res)

if self.update_map:
with self.lock:
self.update_map = False
phy2log_map = self.phy2log_map
phy2log_map = np.array(phy2log_map).reshape(
self.layers_num, -1)
self.record_phy2log(phy2log_map)
except Exception as e:
logger.debug(
f"Record moe_load or phy2log prometheus error, error result:{e}"
)
time.sleep(RECORDING_TIME)

def start_loop(self):
self.do_record_loop.start()

def record_phy2log(self, phy2log_map: list[list[int]]):
for layer_id in range(len(phy2log_map)):
for phy_expert_id, log_expert_id in enumerate(
phy2log_map[layer_id]):
self.all_phy2log[layer_id][phy_expert_id].set(0)

a = self.phy2log.labels(
rank=phy_expert_id // self.local_expert_num,
layer=layer_id,
phy_expert_id=phy_expert_id % self.local_expert_num,
log_expert_id=log_expert_id)
a.set(1)
self.all_phy2log[layer_id][phy_expert_id] = a

def record_expert_load(self, moe_load: list[list[int]]):
for layer_id in range(len(moe_load)):
for phy_expert_id, load in enumerate(moe_load[layer_id]):
self.phy_expert.labels(
rank=phy_expert_id // self.local_expert_num,
layer=layer_id,
phy_expert_id=phy_expert_id % self.local_expert_num,
).inc(load)

def _expert_file_to_list(self, expert_map_path: str):
with open(expert_map_path, "r") as f:
data = json.load(f)

phy2log_data = []
for layer in data["layer_list"]:
device_data = []
for device in layer["device_list"]:
device_data += device["device_expert"]
phy2log_data.append(device_data)
return phy2log_data
7 changes: 7 additions & 0 deletions vllm_ascend/eplb/eplb_updator.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,22 @@ def take_update_info_from_eplb_process(self):
self.update_info_all = self.block_update_queue.get()

def forward_end(self):
moe_load = None
if self.wakeup_eplb_worker_flag():
self.compute_and_set_moe_load(is_clear=True)
moe_load = self.shared_dict.get("moe_load", None)
self.wakeup_eplb_worker()

if self.update_expert_weight_flag():
self.eplb_loader.update_expert_map_and_weight(self.reqs)

self.update_iteration()

if moe_load is not None:
return moe_load, self.eplb.get_phy2log()
else:
return None, None

def compute_and_set_moe_load(self, is_clear=False):
local_load = self.adaptor.get_rank_expert_workload()

Expand Down
13 changes: 12 additions & 1 deletion vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import copy
import gc
import math
import os
import time
import types
import weakref
Expand Down Expand Up @@ -80,6 +81,7 @@
from vllm_ascend.attention.utils import \
AscendCommonAttentionMetadata as CommonAttentionMetadata
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
from vllm_ascend.eplb.eplb_loggers import EplbStatLogger
from vllm_ascend.eplb.eplb_updator import EplbUpdator
from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.platform import NPUPlatform
Expand Down Expand Up @@ -386,10 +388,13 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):

#EPLB
self.dynamic_eplb = ascend_config.dynamic_eplb
self.dynamic_eplb_metrics = os.getenv("DYNAMIC_EXPERT_LOAD_METRICS",
"0") == "1"
if self.dynamic_eplb:
self.eplb_adaptor: Optional[VllmEplbAdaptor] = None
self.is_eplb_warmuped = False
self.eplb_updator = EplbUpdator(ascend_config.expert_map_path)
self.ep_loggers: Optional[EplbStatLogger] = None

# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
self.in_profile_run = False
Expand Down Expand Up @@ -1495,7 +1500,9 @@ def execute_model(
" ".join(dr_str))

if self.dynamic_eplb:
self.eplb_updator.forward_end()
moe_load, phy2log_map = self.eplb_updator.forward_end()
if self.dynamic_eplb_metrics and moe_load is not None:
self.ep_loggers.record(moe_load, phy2log_map)

return model_runner_output

Expand Down Expand Up @@ -1836,6 +1843,10 @@ def eplb_warmup(self):
if self.dynamic_eplb and not self.is_eplb_warmuped:
self.is_eplb_warmuped = True
self.eplb_adaptor = VllmEplbAdaptor(model=self.model)
if self.dynamic_eplb_metrics:
self.ep_loggers = EplbStatLogger.init_instance(
self.eplb_adaptor,
get_ascend_config().expert_map_path)
self.eplb_updator.set_adaptor(self.eplb_adaptor)
self.eplb_updator.warm_up_eplb()

Expand Down
Loading