Skip to content

Commit fdd28c4

Browse files
committed
enable DeviceMemoryMonitor for all platforms
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
1 parent 3b0c21f commit fdd28c4

File tree

4 files changed

+73
-17
lines changed

4 files changed

+73
-17
lines changed

tests/e2e/offline_inference/test_diffusion_cpu_offload.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
77

8-
from tests.utils import GPUMemoryMonitor, hardware_test
8+
from tests.utils import DeviceMemoryMonitor, hardware_test
99
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
1010
from vllm_omni.platforms import current_omni_platform
1111

@@ -21,11 +21,11 @@
2121

2222
def inference(model_name: str, offload: bool = True):
2323
current_omni_platform.empty_cache()
24-
device_index = torch.cuda.current_device()
25-
monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02)
24+
device_index = torch.accelerator.current_device_index()
25+
monitor = DeviceMemoryMonitor.instantiate(device_index=device_index, interval=0.02)
2626
monitor.start()
2727
m = Omni(model=model_name, enable_cpu_offload=offload)
28-
torch.cuda.reset_peak_memory_stats(device=device_index)
28+
torch.accelerator.reset_peak_memory_stats()
2929
height = 256
3030
width = 256
3131

@@ -36,7 +36,7 @@ def inference(model_name: str, offload: bool = True):
3636
width=width,
3737
num_inference_steps=9,
3838
guidance_scale=0.0,
39-
generator=torch.Generator("cuda").manual_seed(42),
39+
generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
4040
),
4141
)
4242
peak = monitor.peak_used_mb

tests/e2e/offline_inference/test_diffusion_layerwise_offload.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
import torch
66
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
77

8-
from tests.utils import GPUMemoryMonitor
8+
from tests.utils import DeviceMemoryMonitor
99
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
10+
from vllm_omni.platforms import current_omni_platform
1011

1112
# ruff: noqa: E402
1213
REPO_ROOT = Path(__file__).resolve().parents[2]
@@ -27,11 +28,9 @@ def run_inference(
2728
layerwise_offload: bool = False,
2829
num_inference_steps: int = 3,
2930
) -> float:
30-
# For now, only support on GPU, so apply torch.cuda operations here
31-
# NPU / ROCm platforms are expected to be detected and skipped this test function
32-
torch.cuda.empty_cache()
33-
device_index = torch.cuda.current_device()
34-
monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02)
31+
current_omni_platform.empty_cache()
32+
device_index = torch.accelerator.current_device_index()
33+
monitor = DeviceMemoryMonitor.instantiate(device_index=device_index, interval=0.02)
3534
monitor.start()
3635

3736
m = Omni(
@@ -41,7 +40,7 @@ def run_inference(
4140
flow_shift=5.0,
4241
)
4342

44-
torch.cuda.reset_peak_memory_stats(device=device_index)
43+
torch.accelerator.reset_peak_memory_stats()
4544

4645
# Refer to tests/e2e/offline_inference/test_t2v_model.py
4746
# Use minimal settings for testing
@@ -54,7 +53,7 @@ def run_inference(
5453
OmniDiffusionSamplingParams(
5554
height=height,
5655
width=width,
57-
generator=torch.Generator("cuda").manual_seed(42),
56+
generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
5857
guidance_scale=1.0,
5958
num_inference_steps=num_inference_steps,
6059
num_frames=num_frames,

tests/e2e/offline_inference/test_zimage_parallelism.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from PIL import Image
2323
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
2424

25-
from tests.utils import GPUMemoryMonitor, hardware_test
25+
from tests.utils import DeviceMemoryMonitor, hardware_test
2626
from vllm_omni import Omni
2727
from vllm_omni.diffusion.data import DiffusionParallelConfig
2828
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
@@ -95,7 +95,7 @@ def _run_zimage_generate(
9595

9696
torch.cuda.empty_cache()
9797
device_index = torch.cuda.current_device()
98-
monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02)
98+
monitor = DeviceMemoryMonitor.instantiate(device_index=device_index, interval=0.02)
9999
monitor.start()
100100
m = Omni(
101101
model=_get_zimage_model(),

tests/utils.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from vllm.platforms import current_platform
2121
from vllm.utils.torch_utils import cuda_device_count_stateless
2222

23+
from vllm_omni.platforms import current_omni_platform
24+
2325
_P = ParamSpec("_P")
2426

2527
if current_platform.is_rocm():
@@ -504,8 +506,17 @@ def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
504506
return wrapper
505507

506508

507-
class GPUMemoryMonitor:
508-
"""Poll global device memory usage via CUDA APIs."""
509+
class DeviceMemoryMonitor:
510+
"""Poll global device memory usage."""
511+
512+
@classmethod
513+
def instantiate(cls, **kwargs: Any) -> "DeviceMemoryMonitor":
514+
if current_omni_platform.is_npu():
515+
return NPUMemoryMonitor(**kwargs)
516+
elif current_omni_platform.is_xpu():
517+
return XPUMemoryMonitor(**kwargs)
518+
else:
519+
return cls(**kwargs)
509520

510521
def __init__(self, device_index: int, interval: float = 0.05):
511522
self.device_index = device_index
@@ -543,3 +554,49 @@ def peak_used_mb(self) -> float:
543554

544555
def __del__(self):
545556
self.stop()
557+
558+
559+
class NPUMemoryMonitor(DeviceMemoryMonitor):
560+
def start(self) -> None:
561+
def monitor_loop() -> None:
562+
while not self._stop_event.is_set():
563+
try:
564+
with torch.npu.device(self.device_index):
565+
free_bytes, total_bytes = torch.npu.mem_get_info()
566+
used_mb = (total_bytes - free_bytes) / (1024**2)
567+
self._peak_used_mb = max(self._peak_used_mb, used_mb)
568+
except Exception:
569+
pass
570+
time.sleep(self.interval)
571+
572+
self._thread = threading.Thread(target=monitor_loop, daemon=False)
573+
self._thread.start()
574+
575+
@property
576+
def peak_used_mb(self) -> float:
577+
fallback_alloc = torch.npu.max_memory_allocated(device=self.device_index) / (1024**2)
578+
fallback_reserved = torch.npu.max_memory_reserved(device=self.device_index) / (1024**2)
579+
return max(self._peak_used_mb, fallback_alloc, fallback_reserved)
580+
581+
582+
class XPUMemoryMonitor(DeviceMemoryMonitor):
583+
def start(self) -> None:
584+
def monitor_loop() -> None:
585+
while not self._stop_event.is_set():
586+
try:
587+
with torch.xpu.device(self.device_index):
588+
free_bytes, total_bytes = torch.xpu.mem_get_info()
589+
used_mb = (total_bytes - free_bytes) / (1024**2)
590+
self._peak_used_mb = max(self._peak_used_mb, used_mb)
591+
except Exception:
592+
pass
593+
time.sleep(self.interval)
594+
595+
self._thread = threading.Thread(target=monitor_loop, daemon=False)
596+
self._thread.start()
597+
598+
@property
599+
def peak_used_mb(self) -> float:
600+
fallback_alloc = torch.xpu.max_memory_allocated(device=self.device_index) / (1024**2)
601+
fallback_reserved = torch.xpu.max_memory_reserved(device=self.device_index) / (1024**2)
602+
return max(self._peak_used_mb, fallback_alloc, fallback_reserved)

0 commit comments

Comments
 (0)