Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tests/e2e/offline_inference/test_diffusion_cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory

from tests.utils import GPUMemoryMonitor, hardware_test
from tests.utils import DeviceMemoryMonitor, hardware_test
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform

Expand All @@ -21,11 +21,11 @@

def inference(model_name: str, offload: bool = True):
current_omni_platform.empty_cache()
device_index = torch.cuda.current_device()
monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02)
device_index = current_omni_platform.current_device()
monitor = DeviceMemoryMonitor(device_index=device_index, interval=0.02)
monitor.start()
m = Omni(model=model_name, enable_cpu_offload=offload)
torch.cuda.reset_peak_memory_stats(device=device_index)
current_omni_platform.reset_peak_memory_stats()
height = 256
width = 256

Expand All @@ -36,7 +36,7 @@ def inference(model_name: str, offload: bool = True):
width=width,
num_inference_steps=9,
guidance_scale=0.0,
generator=torch.Generator("cuda").manual_seed(42),
generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
),
)
peak = monitor.peak_used_mb
Expand Down
12 changes: 5 additions & 7 deletions tests/e2e/offline_inference/test_diffusion_layerwise_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory

from tests.utils import GPUMemoryMonitor
from tests.utils import DeviceMemoryMonitor
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform

Expand All @@ -28,11 +28,9 @@ def run_inference(
layerwise_offload: bool = False,
num_inference_steps: int = 3,
) -> float:
# For now, only support on GPU, so apply torch.cuda operations here
# NPU / ROCm platforms are expected to be detected and skipped this test function
current_omni_platform.empty_cache()
device_index = torch.cuda.current_device()
monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02)
device_index = current_omni_platform.current_device()
monitor = DeviceMemoryMonitor(device_index=device_index, interval=0.02)
monitor.start()

m = Omni(
Expand All @@ -42,7 +40,7 @@ def run_inference(
flow_shift=5.0,
)

torch.cuda.reset_peak_memory_stats(device=device_index)
current_omni_platform.reset_peak_memory_stats()

# Refer to tests/e2e/offline_inference/test_t2v_model.py
# Use minimal settings for testing
Expand All @@ -55,7 +53,7 @@ def run_inference(
OmniDiffusionSamplingParams(
height=height,
width=width,
generator=torch.Generator("cuda").manual_seed(42),
generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
guidance_scale=1.0,
num_inference_steps=num_inference_steps,
num_frames=num_frames,
Expand Down
16 changes: 8 additions & 8 deletions tests/e2e/offline_inference/test_zimage_parallelism.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: test_zimage_tensor_parallel_tp2 and test_zimage_vae_patch_parallel_tp2 still have torch.cuda.is_available() / torch.cuda.device_count() in their skip guards. Not a blocker since those tests are explicitly CUDA-only, but worth a follow-up if XPU/NPU should run them too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, I have updated cuda to current_platform_omni.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from PIL import Image
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory

from tests.utils import GPUMemoryMonitor, hardware_test
from tests.utils import DeviceMemoryMonitor, hardware_test
from vllm_omni import Omni
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
Expand Down Expand Up @@ -93,9 +93,9 @@ def _run_zimage_generate(
if num_requests < 2:
raise ValueError("num_requests must be >= 2 (1 warmup + >=1 timed)")

torch.cuda.empty_cache()
device_index = torch.cuda.current_device()
monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02)
current_omni_platform.empty_cache()
device_index = current_omni_platform.current_device()
monitor = DeviceMemoryMonitor(device_index=device_index, interval=0.02)
monitor.start()
m = Omni(
model=_get_zimage_model(),
Expand Down Expand Up @@ -161,8 +161,8 @@ def _run_zimage_generate(
def test_zimage_tensor_parallel_tp2(tmp_path: Path):
if current_omni_platform.is_npu() or current_omni_platform.is_rocm():
pytest.skip("Z-Image TP e2e test is only supported on CUDA for now.")
if not torch.cuda.is_available() or torch.cuda.device_count() < 2:
pytest.skip("Z-Image TP=2 requires >= 2 CUDA devices.")
if not current_omni_platform.is_available() or current_omni_platform.device_count() < 2:
pytest.skip("Z-Image TP=2 requires >= 2 devices.")

enforce_eager = False

Expand Down Expand Up @@ -223,8 +223,8 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path):
def test_zimage_vae_patch_parallel_tp2(tmp_path: Path):
if current_omni_platform.is_npu() or current_omni_platform.is_rocm():
pytest.skip("Z-Image VAE patch parallel e2e test is only supported on CUDA for now.")
if not torch.cuda.is_available() or torch.cuda.device_count() < 2:
pytest.skip("Z-Image VAE patch parallel TP=2 requires >= 2 CUDA devices.")
if not current_omni_platform.is_available() or current_omni_platform.device_count() < 2:
pytest.skip("Z-Image VAE patch parallel TP=2 requires >= 2 devices.")

enforce_eager = False

Expand Down
14 changes: 8 additions & 6 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from vllm.platforms import current_platform
from vllm.utils.torch_utils import cuda_device_count_stateless

from vllm_omni.platforms import current_omni_platform

_P = ParamSpec("_P")

if current_platform.is_rocm():
Expand Down Expand Up @@ -504,8 +506,8 @@ def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
return wrapper


class GPUMemoryMonitor:
"""Poll global device memory usage via CUDA APIs."""
class DeviceMemoryMonitor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Consider making base class GPU-specific

Since the base DeviceMemoryMonitor class uses current_omni_platform APIs that may not work uniformly across all platforms, consider either:

  1. Making it handle only GPU/CUDA cases explicitly
  2. Making it abstract and requiring all platforms to subclass
  3. Adding platform capability checks before using platform-specific APIs

This would make the design more robust.

"""Poll global device memory usage."""

def __init__(self, device_index: int, interval: float = 0.05):
self.device_index = device_index
Expand All @@ -518,8 +520,8 @@ def start(self) -> None:
def monitor_loop() -> None:
while not self._stop_event.is_set():
try:
with torch.cuda.device(self.device_index):
free_bytes, total_bytes = torch.cuda.mem_get_info()
with current_omni_platform.device(self.device_index):
free_bytes, total_bytes = current_omni_platform.mem_get_info()
used_mb = (total_bytes - free_bytes) / (1024**2)
self._peak_used_mb = max(self._peak_used_mb, used_mb)
except Exception:
Expand All @@ -537,8 +539,8 @@ def stop(self) -> None:

@property
def peak_used_mb(self) -> float:
fallback_alloc = torch.cuda.max_memory_allocated(device=self.device_index) / (1024**2)
fallback_reserved = torch.cuda.max_memory_reserved(device=self.device_index) / (1024**2)
fallback_alloc = current_omni_platform.max_memory_allocated(device=self.device_index) / (1024**2)
fallback_reserved = current_omni_platform.max_memory_reserved(device=self.device_index) / (1024**2)
return max(self._peak_used_mb, fallback_alloc, fallback_reserved)

def __del__(self):
Expand Down