diff --git a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py index bf36b78ff4..d06ccb6ae3 100644 --- a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py +++ b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py @@ -5,7 +5,7 @@ import torch from vllm.distributed.parallel_state import cleanup_dist_env_and_memory -from tests.utils import GPUMemoryMonitor, hardware_test +from tests.utils import DeviceMemoryMonitor, hardware_test from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.platforms import current_omni_platform @@ -21,11 +21,11 @@ def inference(model_name: str, offload: bool = True): current_omni_platform.empty_cache() - device_index = torch.cuda.current_device() - monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02) + device_index = current_omni_platform.current_device() + monitor = DeviceMemoryMonitor(device_index=device_index, interval=0.02) monitor.start() m = Omni(model=model_name, enable_cpu_offload=offload) - torch.cuda.reset_peak_memory_stats(device=device_index) + current_omni_platform.reset_peak_memory_stats() height = 256 width = 256 @@ -36,7 +36,7 @@ def inference(model_name: str, offload: bool = True): width=width, num_inference_steps=9, guidance_scale=0.0, - generator=torch.Generator("cuda").manual_seed(42), + generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42), ), ) peak = monitor.peak_used_mb diff --git a/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py b/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py index 998e6232ec..ad9955cc2b 100644 --- a/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py +++ b/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py @@ -5,7 +5,7 @@ import torch from vllm.distributed.parallel_state import cleanup_dist_env_and_memory -from tests.utils import GPUMemoryMonitor +from tests.utils import DeviceMemoryMonitor from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.platforms import current_omni_platform @@ -28,11 +28,9 @@ def run_inference( layerwise_offload: bool = False, num_inference_steps: int = 3, ) -> float: - # For now, only support on GPU, so apply torch.cuda operations here - # NPU / ROCm platforms are expected to be detected and skipped this test function current_omni_platform.empty_cache() - device_index = torch.cuda.current_device() - monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02) + device_index = current_omni_platform.current_device() + monitor = DeviceMemoryMonitor(device_index=device_index, interval=0.02) monitor.start() m = Omni( @@ -42,7 +40,7 @@ def run_inference( flow_shift=5.0, ) - torch.cuda.reset_peak_memory_stats(device=device_index) + current_omni_platform.reset_peak_memory_stats() # Refer to tests/e2e/offline_inference/test_t2v_model.py # Use minimal settings for testing @@ -55,7 +53,7 @@ def run_inference( OmniDiffusionSamplingParams( height=height, width=width, - generator=torch.Generator("cuda").manual_seed(42), + generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42), guidance_scale=1.0, num_inference_steps=num_inference_steps, num_frames=num_frames, diff --git a/tests/e2e/offline_inference/test_zimage_parallelism.py b/tests/e2e/offline_inference/test_zimage_parallelism.py index c43a5d2240..47de955aa7 100644 --- a/tests/e2e/offline_inference/test_zimage_parallelism.py +++ b/tests/e2e/offline_inference/test_zimage_parallelism.py @@ -22,7 +22,7 @@ from PIL import Image from vllm.distributed.parallel_state import cleanup_dist_env_and_memory -from tests.utils import GPUMemoryMonitor, hardware_test +from tests.utils import DeviceMemoryMonitor, hardware_test from vllm_omni import Omni from vllm_omni.diffusion.data import DiffusionParallelConfig from vllm_omni.inputs.data import OmniDiffusionSamplingParams @@ -93,9 +93,9 @@ def _run_zimage_generate( if num_requests < 2: raise ValueError("num_requests must be >= 2 (1 warmup + >=1 timed)") - torch.cuda.empty_cache() - device_index = torch.cuda.current_device() - monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02) + current_omni_platform.empty_cache() + device_index = current_omni_platform.current_device() + monitor = DeviceMemoryMonitor(device_index=device_index, interval=0.02) monitor.start() m = Omni( model=_get_zimage_model(), @@ -161,8 +161,8 @@ def _run_zimage_generate( def test_zimage_tensor_parallel_tp2(tmp_path: Path): if current_omni_platform.is_npu() or current_omni_platform.is_rocm(): pytest.skip("Z-Image TP e2e test is only supported on CUDA for now.") - if not torch.cuda.is_available() or torch.cuda.device_count() < 2: - pytest.skip("Z-Image TP=2 requires >= 2 CUDA devices.") + if not current_omni_platform.is_available() or current_omni_platform.device_count() < 2: + pytest.skip("Z-Image TP=2 requires >= 2 devices.") enforce_eager = False @@ -223,8 +223,8 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path): def test_zimage_vae_patch_parallel_tp2(tmp_path: Path): if current_omni_platform.is_npu() or current_omni_platform.is_rocm(): pytest.skip("Z-Image VAE patch parallel e2e test is only supported on CUDA for now.") - if not torch.cuda.is_available() or torch.cuda.device_count() < 2: - pytest.skip("Z-Image VAE patch parallel TP=2 requires >= 2 CUDA devices.") + if not current_omni_platform.is_available() or current_omni_platform.device_count() < 2: + pytest.skip("Z-Image VAE patch parallel TP=2 requires >= 2 devices.") enforce_eager = False diff --git a/tests/utils.py b/tests/utils.py index 3219821b6f..e4e3e14c59 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,6 +20,8 @@ from vllm.platforms import current_platform from vllm.utils.torch_utils import cuda_device_count_stateless +from vllm_omni.platforms import current_omni_platform + _P = ParamSpec("_P") if current_platform.is_rocm(): @@ -504,8 +506,8 @@ def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: return wrapper -class GPUMemoryMonitor: - """Poll global device memory usage via CUDA APIs.""" +class DeviceMemoryMonitor: + """Poll global device memory usage.""" def __init__(self, device_index: int, interval: float = 0.05): self.device_index = device_index @@ -518,8 +520,8 @@ def start(self) -> None: def monitor_loop() -> None: while not self._stop_event.is_set(): try: - with torch.cuda.device(self.device_index): - free_bytes, total_bytes = torch.cuda.mem_get_info() + with current_omni_platform.device(self.device_index): + free_bytes, total_bytes = current_omni_platform.mem_get_info() used_mb = (total_bytes - free_bytes) / (1024**2) self._peak_used_mb = max(self._peak_used_mb, used_mb) except Exception: @@ -537,8 +539,8 @@ def stop(self) -> None: @property def peak_used_mb(self) -> float: - fallback_alloc = torch.cuda.max_memory_allocated(device=self.device_index) / (1024**2) - fallback_reserved = torch.cuda.max_memory_reserved(device=self.device_index) / (1024**2) + fallback_alloc = current_omni_platform.max_memory_allocated(device=self.device_index) / (1024**2) + fallback_reserved = current_omni_platform.max_memory_reserved(device=self.device_index) / (1024**2) return max(self._peak_used_mb, fallback_alloc, fallback_reserved) def __del__(self):