Skip to content

Commit 396bf17

Browse files
committed
enable DeviceMemoryMonitor for all platforms
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
1 parent 82e1bf2 commit 396bf17

File tree

4 files changed

+70
-12
lines changed

4 files changed

+70
-12
lines changed

tests/e2e/offline_inference/test_diffusion_cpu_offload.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
77

8-
from tests.utils import GPUMemoryMonitor, hardware_test
8+
from tests.utils import DeviceMemoryMonitor, hardware_test
99
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
1010
from vllm_omni.platforms import current_omni_platform
1111

@@ -21,11 +21,11 @@
2121

2222
def inference(model_name: str, offload: bool = True):
2323
current_omni_platform.empty_cache()
24-
device_index = torch.cuda.current_device()
25-
monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02)
24+
device_index = torch.accelerator.current_device_index()
25+
monitor = DeviceMemoryMonitor.instantiate(device_index=device_index, interval=0.02)
2626
monitor.start()
2727
m = Omni(model=model_name, enable_cpu_offload=offload)
28-
torch.cuda.reset_peak_memory_stats(device=device_index)
28+
torch.accelerator.reset_peak_memory_stats()
2929
height = 256
3030
width = 256
3131

@@ -36,7 +36,7 @@ def inference(model_name: str, offload: bool = True):
3636
width=width,
3737
num_inference_steps=9,
3838
guidance_scale=0.0,
39-
generator=torch.Generator("cuda").manual_seed(42),
39+
generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
4040
),
4141
)
4242
peak = monitor.peak_used_mb

tests/e2e/offline_inference/test_diffusion_layerwise_offload.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
77

8-
from tests.utils import GPUMemoryMonitor
8+
from tests.utils import DeviceMemoryMonitor
99
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
1010
from vllm_omni.platforms import current_omni_platform
1111

@@ -32,7 +32,7 @@ def run_inference(
3232
# NPU / ROCm platforms are expected to be detected and skipped this test function
3333
torch.cuda.empty_cache()
3434
device_index = torch.cuda.current_device()
35-
monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02)
35+
monitor = DeviceMemoryMonitor.instantiate(device_index=device_index, interval=0.02)
3636
monitor.start()
3737

3838
m = Omni(
@@ -55,7 +55,7 @@ def run_inference(
5555
OmniDiffusionSamplingParams(
5656
height=height,
5757
width=width,
58-
generator=torch.Generator("cuda").manual_seed(42),
58+
generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
5959
guidance_scale=1.0,
6060
num_inference_steps=num_inference_steps,
6161
num_frames=num_frames,

tests/e2e/offline_inference/test_zimage_parallelism.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from PIL import Image
2323
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
2424

25-
from tests.utils import GPUMemoryMonitor, hardware_test
25+
from tests.utils import DeviceMemoryMonitor, hardware_test
2626
from vllm_omni import Omni
2727
from vllm_omni.diffusion.data import DiffusionParallelConfig
2828
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
@@ -95,7 +95,7 @@ def _run_zimage_generate(
9595

9696
torch.cuda.empty_cache()
9797
device_index = torch.cuda.current_device()
98-
monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02)
98+
monitor = DeviceMemoryMonitor.instantiate(device_index=device_index, interval=0.02)
9999
monitor.start()
100100
m = Omni(
101101
model=_get_zimage_model(),

tests/utils.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from vllm.platforms import current_platform
2121
from vllm.utils.torch_utils import cuda_device_count_stateless
2222

23+
from vllm_omni.platforms import current_omni_platform
24+
2325
_P = ParamSpec("_P")
2426

2527
if current_platform.is_rocm():
@@ -504,8 +506,17 @@ def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
504506
return wrapper
505507

506508

507-
class GPUMemoryMonitor:
508-
"""Poll global device memory usage via CUDA APIs."""
509+
class DeviceMemoryMonitor:
510+
"""Poll global device memory usage."""
511+
512+
@classmethod
513+
def instantiate(cls, **kwargs: Any) -> "DeviceMemoryMonitor":
514+
if current_omni_platform.is_npu():
515+
return NPUMemoryMonitor(**kwargs)
516+
elif current_omni_platform.is_xpu():
517+
return XPUMemoryMonitor(**kwargs)
518+
else:
519+
return cls(**kwargs)
509520

510521
def __init__(self, device_index: int, interval: float = 0.05):
511522
self.device_index = device_index
@@ -543,3 +554,50 @@ def peak_used_mb(self) -> float:
543554

544555
def __del__(self):
545556
self.stop()
557+
558+
559+
class NPUMemoryMonitor(DeviceMemoryMonitor):
560+
def start(self) -> None:
561+
def monitor_loop() -> None:
562+
while not self._stop_event.is_set():
563+
try:
564+
with torch.npu.device(self.device_index):
565+
free_bytes, total_bytes = torch.npu.mem_get_info()
566+
used_mb = (total_bytes - free_bytes) / (1024**2)
567+
self._peak_used_mb = max(self._peak_used_mb, used_mb)
568+
except Exception:
569+
pass
570+
time.sleep(self.interval)
571+
572+
self._thread = threading.Thread(target=monitor_loop, daemon=False)
573+
self._thread.start()
574+
575+
@property
576+
def peak_used_mb(self) -> float:
577+
fallback_alloc = torch.npu.max_memory_allocated(device=self.device_index) / (1024**2)
578+
fallback_reserved = torch.npu.max_memory_reserved(device=self.device_index) / (1024**2)
579+
return max(self._peak_used_mb, fallback_alloc, fallback_reserved)
580+
581+
582+
class XPUMemoryMonitor(DeviceMemoryMonitor):
583+
def start(self) -> None:
584+
def monitor_loop() -> None:
585+
while not self._stop_event.is_set():
586+
try:
587+
with torch.xpu.device(self.device_index):
588+
free_bytes, total_bytes = torch.xpu.mem_get_info()
589+
used_mb = (total_bytes - free_bytes) / (1024**2)
590+
self._peak_used_mb = max(self._peak_used_mb, used_mb)
591+
except Exception:
592+
pass
593+
time.sleep(self.interval)
594+
595+
self._thread = threading.Thread(target=monitor_loop, daemon=False)
596+
self._thread.start()
597+
598+
@property
599+
def peak_used_mb(self) -> float:
600+
torch.xpu.reset_peak_memory_stats(self.device_index)
601+
fallback_alloc = torch.xpu.max_memory_allocated(device=self.device_index) / (1024**2)
602+
fallback_reserved = torch.xpu.max_memory_reserved(device=self.device_index) / (1024**2)
603+
return max(self._peak_used_mb, fallback_alloc, fallback_reserved)

0 commit comments

Comments
 (0)