Skip to content

Commit 6c98042

Browse files
committed
enable DeviceMemoryMonitor for all platforms
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
1 parent 82e1bf2 commit 6c98042

File tree

4 files changed

+72
-17
lines changed

4 files changed

+72
-17
lines changed

tests/e2e/offline_inference/test_diffusion_cpu_offload.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
77

8-
from tests.utils import GPUMemoryMonitor, hardware_test
8+
from tests.utils import DeviceMemoryMonitor, hardware_test
99
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
1010
from vllm_omni.platforms import current_omni_platform
1111

@@ -21,11 +21,11 @@
2121

2222
def inference(model_name: str, offload: bool = True):
2323
current_omni_platform.empty_cache()
24-
device_index = torch.cuda.current_device()
25-
monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02)
24+
device_index = torch.accelerator.current_device_index()
25+
monitor = DeviceMemoryMonitor.instantiate(device_index=device_index, interval=0.02)
2626
monitor.start()
2727
m = Omni(model=model_name, enable_cpu_offload=offload)
28-
torch.cuda.reset_peak_memory_stats(device=device_index)
28+
torch.accelerator.reset_peak_memory_stats()
2929
height = 256
3030
width = 256
3131

@@ -36,7 +36,7 @@ def inference(model_name: str, offload: bool = True):
3636
width=width,
3737
num_inference_steps=9,
3838
guidance_scale=0.0,
39-
generator=torch.Generator("cuda").manual_seed(42),
39+
generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
4040
),
4141
)
4242
peak = monitor.peak_used_mb

tests/e2e/offline_inference/test_diffusion_layerwise_offload.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
77

8-
from tests.utils import GPUMemoryMonitor
8+
from tests.utils import DeviceMemoryMonitor
99
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
1010
from vllm_omni.platforms import current_omni_platform
1111

@@ -28,11 +28,9 @@ def run_inference(
2828
layerwise_offload: bool = False,
2929
num_inference_steps: int = 3,
3030
) -> float:
31-
# For now, only support on GPU, so apply torch.cuda operations here
32-
# NPU / ROCm platforms are expected to be detected and skipped this test function
33-
torch.cuda.empty_cache()
34-
device_index = torch.cuda.current_device()
35-
monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02)
31+
torch.accelerator.empty_cache()
32+
device_index = torch.accelerator.current_device_index()
33+
monitor = DeviceMemoryMonitor.instantiate(device_index=device_index, interval=0.02)
3634
monitor.start()
3735

3836
m = Omni(
@@ -42,7 +40,7 @@ def run_inference(
4240
flow_shift=5.0,
4341
)
4442

45-
torch.cuda.reset_peak_memory_stats(device=device_index)
43+
torch.accelerator.reset_peak_memory_stats()
4644

4745
# Refer to tests/e2e/offline_inference/test_t2v_model.py
4846
# Use minimal settings for testing
@@ -55,7 +53,7 @@ def run_inference(
5553
OmniDiffusionSamplingParams(
5654
height=height,
5755
width=width,
58-
generator=torch.Generator("cuda").manual_seed(42),
56+
generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
5957
guidance_scale=1.0,
6058
num_inference_steps=num_inference_steps,
6159
num_frames=num_frames,

tests/e2e/offline_inference/test_zimage_parallelism.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from PIL import Image
2323
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
2424

25-
from tests.utils import GPUMemoryMonitor, hardware_test
25+
from tests.utils import DeviceMemoryMonitor, hardware_test
2626
from vllm_omni import Omni
2727
from vllm_omni.diffusion.data import DiffusionParallelConfig
2828
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
@@ -95,7 +95,7 @@ def _run_zimage_generate(
9595

9696
torch.cuda.empty_cache()
9797
device_index = torch.cuda.current_device()
98-
monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02)
98+
monitor = DeviceMemoryMonitor.instantiate(device_index=device_index, interval=0.02)
9999
monitor.start()
100100
m = Omni(
101101
model=_get_zimage_model(),

tests/utils.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from vllm.platforms import current_platform
2121
from vllm.utils.torch_utils import cuda_device_count_stateless
2222

23+
from vllm_omni.platforms import current_omni_platform
24+
2325
_P = ParamSpec("_P")
2426

2527
if current_platform.is_rocm():
@@ -504,8 +506,17 @@ def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
504506
return wrapper
505507

506508

507-
class GPUMemoryMonitor:
508-
"""Poll global device memory usage via CUDA APIs."""
509+
class DeviceMemoryMonitor:
510+
"""Poll global device memory usage."""
511+
512+
@classmethod
513+
def instantiate(cls, **kwargs: Any) -> "DeviceMemoryMonitor":
514+
if current_omni_platform.is_npu():
515+
return NPUMemoryMonitor(**kwargs)
516+
elif current_omni_platform.is_xpu():
517+
return XPUMemoryMonitor(**kwargs)
518+
else:
519+
return cls(**kwargs)
509520

510521
def __init__(self, device_index: int, interval: float = 0.05):
511522
self.device_index = device_index
@@ -543,3 +554,49 @@ def peak_used_mb(self) -> float:
543554

544555
def __del__(self):
545556
self.stop()
557+
558+
559+
class NPUMemoryMonitor(DeviceMemoryMonitor):
560+
def start(self) -> None:
561+
def monitor_loop() -> None:
562+
while not self._stop_event.is_set():
563+
try:
564+
with torch.npu.device(self.device_index):
565+
free_bytes, total_bytes = torch.npu.mem_get_info()
566+
used_mb = (total_bytes - free_bytes) / (1024**2)
567+
self._peak_used_mb = max(self._peak_used_mb, used_mb)
568+
except Exception:
569+
pass
570+
time.sleep(self.interval)
571+
572+
self._thread = threading.Thread(target=monitor_loop, daemon=False)
573+
self._thread.start()
574+
575+
@property
576+
def peak_used_mb(self) -> float:
577+
fallback_alloc = torch.npu.max_memory_allocated(device=self.device_index) / (1024**2)
578+
fallback_reserved = torch.npu.max_memory_reserved(device=self.device_index) / (1024**2)
579+
return max(self._peak_used_mb, fallback_alloc, fallback_reserved)
580+
581+
582+
class XPUMemoryMonitor(DeviceMemoryMonitor):
583+
def start(self) -> None:
584+
def monitor_loop() -> None:
585+
while not self._stop_event.is_set():
586+
try:
587+
with torch.xpu.device(self.device_index):
588+
free_bytes, total_bytes = torch.xpu.mem_get_info()
589+
used_mb = (total_bytes - free_bytes) / (1024**2)
590+
self._peak_used_mb = max(self._peak_used_mb, used_mb)
591+
except Exception:
592+
pass
593+
time.sleep(self.interval)
594+
595+
self._thread = threading.Thread(target=monitor_loop, daemon=False)
596+
self._thread.start()
597+
598+
@property
599+
def peak_used_mb(self) -> float:
600+
fallback_alloc = torch.xpu.max_memory_allocated(device=self.device_index) / (1024**2)
601+
fallback_reserved = torch.xpu.max_memory_reserved(device=self.device_index) / (1024**2)
602+
return max(self._peak_used_mb, fallback_alloc, fallback_reserved)

0 commit comments

Comments
 (0)