Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions vllm_omni/diffusion/worker/diffusion_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
"""

import gc
import json
import multiprocessing as mp
import os
from collections.abc import Iterable
from contextlib import AbstractContextManager, nullcontext
from pathlib import Path
from typing import Any

import torch
Expand Down Expand Up @@ -58,6 +60,43 @@ class DiffusionWorker:
delegated to DiffusionModelRunner.
"""

@staticmethod
def predict_resource_usage(od_config: OmniDiffusionConfig) -> dict[str, float]:
import torch
from vllm.utils.mem_utils import GiB_bytes

total_params = 0
try:
model_path = Path(od_config.model)
for cfg_name in ["config.json", "llm_config.json", "diffusion_config.json"]:
cfg_file = model_path / cfg_name
if cfg_file.exists():
with open(cfg_file) as f:
data = json.load(f)
total_params = data.get("num_parameters", 0) or data.get("total_params", 0)
if total_params > 0:
break
except Exception:
pass
if total_params == 0:
m_name = str(od_config.model).lower()
if "bagel" in m_name:
total_params = 13.5e9
elif "flux" in m_name:
total_params = 12.0e9
else:
total_params = 10.0e9
dtype = getattr(od_config, "dtype", torch.bfloat16)
bytes_per_param = 2 if dtype in [torch.bfloat16, torch.float16] else 4
static_gb = (total_params * bytes_per_param) / GiB_bytes
h, w = getattr(od_config, "height", 1024), getattr(od_config, "width", 1024)
dynamic_gb = 2.5 * (h * w / (1024 * 1024))
return {
"static_gb": round(static_gb, 2),
"dynamic_gb": round(dynamic_gb, 2),
"total_gb": round(static_gb + dynamic_gb, 2),
}

def __init__(
self,
local_rank: int,
Expand Down
37 changes: 37 additions & 0 deletions vllm_omni/entrypoints/omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,41 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st

return config_path, stage_configs

def _coordinate_vram_resources(self) -> None:
"""Coordinate VRAM resource reservations across stages based on their type and configs."""
import torch
from vllm.utils.mem_utils import GiB_bytes

total_reserved_gb = 0.0
for cfg in self.stage_configs:
s_type = getattr(cfg, "stage_type", None)
if s_type == "diffusion":
# Currently only Diffusion implements the predict_resource_usage interface.
from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker

prediction = DiffusionWorker.predict_resource_usage(cfg.engine_args)
total_reserved_gb += prediction["total_gb"]
logger.info(
f"[Coordinator] Stage-{cfg.stage_id} ({s_type.capitalize()}) "
f"predicted budget: {prediction['total_gb']:.2f} GiB"
)
# Extended generation and other logic
if not torch.cuda.is_available():
return
physical_vram_gb = torch.cuda.get_device_properties(0).total_memory / GiB_bytes
for cfg in self.stage_configs:
if getattr(cfg, "stage_type", None) == "llm":
original_util = cfg.engine_args.get("gpu_memory_utilization", 0.9)
reserved_util_ratio = total_reserved_gb / physical_vram_gb
# (Physical_Used + Logical_KV_Buffer) / Total_VRAM
adjusted_util = min(0.95, original_util + reserved_util_ratio)
cfg.engine_args["gpu_memory_utilization"] = round(adjusted_util, 3)
logger.info(
f"[Coordinator] LLM Stage-{cfg.stage_id} dynamic boost: "
f"{original_util} -> {cfg.engine_args['gpu_memory_utilization']} "
f"(Compensating {reserved_util_ratio:.2f} ratio for cross-modal isolation)"
)

def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None:
"""Initialize stage list management."""
stage_init_timeout = kwargs.get("stage_init_timeout", 20)
Expand All @@ -316,6 +351,8 @@ def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None:
# Resolve stage configs shared by orchestrator/headless paths.
self.config_path, self.stage_configs = self._resolve_stage_configs(model, kwargs)

self._coordinate_vram_resources()

# Initialize connectors
self.omni_transfer_config, self.connectors = initialize_orchestrator_connectors(
self.config_path, worker_backend=worker_backend, shm_threshold_bytes=shm_threshold_bytes
Expand Down
4 changes: 4 additions & 0 deletions vllm_omni/platforms/cuda/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def get_free_memory(cls, device: torch.device | None = None) -> int:
free, _ = torch.cuda.mem_get_info(device)
return free

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties(device_id).total_memory

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id)
4 changes: 4 additions & 0 deletions vllm_omni/platforms/rocm/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,7 @@ def synchronize(cls) -> None:
def get_free_memory(cls, device: torch.device | None = None) -> int:
free, _ = torch.cuda.mem_get_info(device)
return free

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties(device_id).total_memory