|
20 | 20 | from vllm.platforms import current_platform |
21 | 21 | from vllm.utils.torch_utils import cuda_device_count_stateless |
22 | 22 |
|
| 23 | +from vllm_omni.platforms import current_omni_platform |
| 24 | + |
23 | 25 | _P = ParamSpec("_P") |
24 | 26 |
|
25 | 27 | if current_platform.is_rocm(): |
@@ -504,8 +506,17 @@ def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: |
504 | 506 | return wrapper |
505 | 507 |
|
506 | 508 |
|
507 | | -class GPUMemoryMonitor: |
508 | | - """Poll global device memory usage via CUDA APIs.""" |
| 509 | +class DeviceMemoryMonitor: |
| 510 | + """Poll global device memory usage.""" |
| 511 | + |
| 512 | + @classmethod |
| 513 | + def instantiate(cls, **kwargs: Any) -> "DeviceMemoryMonitor": |
| 514 | + if current_omni_platform.is_npu(): |
| 515 | + return NPUMemoryMonitor(**kwargs) |
| 516 | + elif current_omni_platform.is_xpu(): |
| 517 | + return XPUMemoryMonitor(**kwargs) |
| 518 | + else: |
| 519 | + return cls(**kwargs) |
509 | 520 |
|
510 | 521 | def __init__(self, device_index: int, interval: float = 0.05): |
511 | 522 | self.device_index = device_index |
@@ -543,3 +554,50 @@ def peak_used_mb(self) -> float: |
543 | 554 |
|
544 | 555 | def __del__(self): |
545 | 556 | self.stop() |
| 557 | + |
| 558 | + |
| 559 | +class NPUMemoryMonitor(DeviceMemoryMonitor): |
| 560 | + def start(self) -> None: |
| 561 | + def monitor_loop() -> None: |
| 562 | + while not self._stop_event.is_set(): |
| 563 | + try: |
| 564 | + with torch.npu.device(self.device_index): |
| 565 | + free_bytes, total_bytes = torch.npu.mem_get_info() |
| 566 | + used_mb = (total_bytes - free_bytes) / (1024**2) |
| 567 | + self._peak_used_mb = max(self._peak_used_mb, used_mb) |
| 568 | + except Exception: |
| 569 | + pass |
| 570 | + time.sleep(self.interval) |
| 571 | + |
| 572 | + self._thread = threading.Thread(target=monitor_loop, daemon=False) |
| 573 | + self._thread.start() |
| 574 | + |
| 575 | + @property |
| 576 | + def peak_used_mb(self) -> float: |
| 577 | + fallback_alloc = torch.npu.max_memory_allocated(device=self.device_index) / (1024**2) |
| 578 | + fallback_reserved = torch.npu.max_memory_reserved(device=self.device_index) / (1024**2) |
| 579 | + return max(self._peak_used_mb, fallback_alloc, fallback_reserved) |
| 580 | + |
| 581 | + |
| 582 | +class XPUMemoryMonitor(DeviceMemoryMonitor): |
| 583 | + def start(self) -> None: |
| 584 | + def monitor_loop() -> None: |
| 585 | + while not self._stop_event.is_set(): |
| 586 | + try: |
| 587 | + with torch.xpu.device(self.device_index): |
| 588 | + free_bytes, total_bytes = torch.xpu.mem_get_info() |
| 589 | + used_mb = (total_bytes - free_bytes) / (1024**2) |
| 590 | + self._peak_used_mb = max(self._peak_used_mb, used_mb) |
| 591 | + except Exception: |
| 592 | + pass |
| 593 | + time.sleep(self.interval) |
| 594 | + |
| 595 | + self._thread = threading.Thread(target=monitor_loop, daemon=False) |
| 596 | + self._thread.start() |
| 597 | + |
| 598 | + @property |
| 599 | + def peak_used_mb(self) -> float: |
| 600 | + torch.xpu.reset_peak_memory_stats(self.device_index) |
| 601 | + fallback_alloc = torch.xpu.max_memory_allocated(device=self.device_index) / (1024**2) |
| 602 | + fallback_reserved = torch.xpu.max_memory_reserved(device=self.device_index) / (1024**2) |
| 603 | + return max(self._peak_used_mb, fallback_alloc, fallback_reserved) |
0 commit comments