Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions vllm_omni/diffusion/worker/diffusion_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
"""

import gc
import json
import multiprocessing as mp
import os
from collections.abc import Iterable
from contextlib import AbstractContextManager, nullcontext
from pathlib import Path
from typing import Any

import torch
Expand Down Expand Up @@ -58,6 +60,43 @@ class DiffusionWorker:
delegated to DiffusionModelRunner.
"""

@staticmethod
def predict_resource_usage(od_config: OmniDiffusionConfig) -> dict[str, float]:
import torch
from vllm.utils.mem_utils import GiB_bytes

total_params = 0
try:
model_path = Path(od_config.model)
for cfg_name in ["config.json", "llm_config.json", "diffusion_config.json"]:
cfg_file = model_path / cfg_name
if cfg_file.exists():
with open(cfg_file) as f:
data = json.load(f)
total_params = data.get("num_parameters", 0) or data.get("total_params", 0)
if total_params > 0:
break
except Exception:
pass
if total_params == 0:
m_name = str(od_config.model).lower()
if "bagel" in m_name:
total_params = 13.5e9
elif "flux" in m_name:
total_params = 12.0e9
else:
total_params = 10.0e9
dtype = getattr(od_config, "dtype", torch.bfloat16)
bytes_per_param = 2 if dtype in [torch.bfloat16, torch.float16] else 4
static_gb = (total_params * bytes_per_param) / GiB_bytes
h, w = getattr(od_config, "height", 1024), getattr(od_config, "width", 1024)
dynamic_gb = 2.5 * (h * w / (1024 * 1024))
return {
"static_gb": round(static_gb, 2),
"dynamic_gb": round(dynamic_gb, 2),
"total_gb": round(static_gb + dynamic_gb, 2),
}

def __init__(
self,
local_rank: int,
Expand Down
53 changes: 53 additions & 0 deletions vllm_omni/entrypoints/omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,57 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st

return config_path, stage_configs

def _coordinate_vram_resources(self) -> None:
"""Coordinate VRAM resource reservations across stages based on their type and configs."""
import torch
from vllm.utils.mem_utils import GiB_bytes

from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker

# Device ID
reserved_gb_per_device: dict[int, float] = {}
active_configs = self.stage_configs
if hasattr(self, "_single_stage_id") and self._single_stage_id is not None:
active_configs = [cfg for cfg in self.stage_configs if cfg.stage_id == self._single_stage_id]
for cfg in active_configs:
s_type = getattr(cfg, "stage_type", None)
if s_type == "diffusion":
prediction = DiffusionWorker.predict_resource_usage(cfg.engine_args)
device_str = getattr(cfg.runtime, "devices", "0")
try:
devices = [int(d.strip()) for d in device_str.split(",")]
except (ValueError, AttributeError):
devices = [0]
for d_id in devices:
reserved_gb_per_device[d_id] = reserved_gb_per_device.get(d_id, 0.0) + prediction["total_gb"]
logger.info(
f"[Coordinator] Stage-{cfg.stage_id} ({s_type.capitalize()}) "
f"on devices {devices} predicted budget: {prediction['total_gb']:.2f} GiB"
)
if not torch.cuda.is_available():
return
for cfg in active_configs:
if getattr(cfg, "stage_type", None) == "llm":
# Get the master device ID where the LLM is located
llm_device_str = getattr(cfg.runtime, "devices", "0")
try:
target_device = int(llm_device_str.split(",")[0].strip())
except (ValueError, AttributeError):
target_device = 0
# The physical total of the card containing LLM
physical_vram_gb = torch.cuda.get_device_properties(target_device).total_memory / GiB_bytes
total_reserved_on_this_device = reserved_gb_per_device.get(target_device, 0.0)
if total_reserved_on_this_device > 0:
original_util = cfg.engine_args.get("gpu_memory_utilization", 0.9)
reserved_util_ratio = total_reserved_on_this_device / physical_vram_gb
adjusted_util = min(0.95, original_util + reserved_util_ratio)
cfg.engine_args["gpu_memory_utilization"] = round(adjusted_util, 3)
logger.info(
f"[Coordinator] LLM Stage-{cfg.stage_id} on Device {target_device} dynamic boost: "
f"{original_util} -> {cfg.engine_args['gpu_memory_utilization']} "
f"(Compensating {reserved_util_ratio:.2f} ratio for resource domain isolation)"
)

def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None:
"""Initialize stage list management."""
stage_init_timeout = kwargs.get("stage_init_timeout", 20)
Expand All @@ -316,6 +367,8 @@ def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None:
# Resolve stage configs shared by orchestrator/headless paths.
self.config_path, self.stage_configs = self._resolve_stage_configs(model, kwargs)

self._coordinate_vram_resources()

# Initialize connectors
self.omni_transfer_config, self.connectors = initialize_orchestrator_connectors(
self.config_path, worker_backend=worker_backend, shm_threshold_bytes=shm_threshold_bytes
Expand Down
4 changes: 4 additions & 0 deletions vllm_omni/platforms/cuda/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def get_free_memory(cls, device: torch.device | None = None) -> int:
free, _ = torch.cuda.mem_get_info(device)
return free

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties(device_id).total_memory

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id)
4 changes: 4 additions & 0 deletions vllm_omni/platforms/rocm/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,7 @@ def synchronize(cls) -> None:
def get_free_memory(cls, device: torch.device | None = None) -> int:
free, _ = torch.cuda.mem_get_info(device)
return free

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties(device_id).total_memory