Skip to content

Commit d34a733

Browse files
committed
remove torch.accelerator apis
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
1 parent fdd28c4 commit d34a733

File tree

7 files changed

+44
-4
lines changed

7 files changed

+44
-4
lines changed

tests/e2e/offline_inference/test_diffusion_cpu_offload.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121

2222
def inference(model_name: str, offload: bool = True):
2323
current_omni_platform.empty_cache()
24-
device_index = torch.accelerator.current_device_index()
24+
device_index = current_omni_platform.current_device_index()
2525
monitor = DeviceMemoryMonitor.instantiate(device_index=device_index, interval=0.02)
2626
monitor.start()
2727
m = Omni(model=model_name, enable_cpu_offload=offload)
28-
torch.accelerator.reset_peak_memory_stats()
28+
current_omni_platform.reset_peak_memory_stats()
2929
height = 256
3030
width = 256
3131

tests/e2e/offline_inference/test_diffusion_layerwise_offload.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def run_inference(
2929
num_inference_steps: int = 3,
3030
) -> float:
3131
current_omni_platform.empty_cache()
32-
device_index = torch.accelerator.current_device_index()
32+
device_index = current_omni_platform.current_device_index()
3333
monitor = DeviceMemoryMonitor.instantiate(device_index=device_index, interval=0.02)
3434
monitor.start()
3535

@@ -40,7 +40,7 @@ def run_inference(
4040
flow_shift=5.0,
4141
)
4242

43-
torch.accelerator.reset_peak_memory_stats()
43+
current_omni_platform.reset_peak_memory_stats()
4444

4545
# Refer to tests/e2e/offline_inference/test_t2v_model.py
4646
# Use minimal settings for testing

vllm_omni/platforms/cuda/platform.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,11 @@ def get_free_memory(cls, device: torch.device | None = None) -> int:
115115
@classmethod
116116
def get_device_name(cls, device_id: int = 0) -> str:
117117
return torch.cuda.get_device_name(device_id)
118+
119+
@classmethod
120+
def reset_peak_memory_stats(cls, device: torch.device | None = None) -> None:
121+
torch.xpu.reset_peak_memory_stats(device)
122+
123+
@classmethod
124+
def current_device_index(cls) -> int:
125+
return torch.xpu.current_device()

vllm_omni/platforms/interface.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,14 @@ def synchronize(cls) -> None:
9898
def get_free_memory(cls, device: torch.device | None = None) -> int:
9999
raise NotImplementedError
100100

101+
@classmethod
102+
def reset_peak_memory_stats(cls, device: torch.device | None = None) -> None:
103+
raise NotImplementedError
104+
105+
@classmethod
106+
def current_device_index(cls) -> int:
107+
raise NotImplementedError
108+
101109

102110
class UnspecifiedOmniPlatform(OmniPlatform):
103111
_omni_enum = OmniPlatformEnum.UNSPECIFIED

vllm_omni/platforms/npu/platform.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,14 @@ def get_free_memory(cls, device: torch.device | None = None) -> int:
8282
free, _ = torch.npu.mem_get_info(device)
8383
return free
8484

85+
@classmethod
86+
def reset_peak_memory_stats(cls, device: torch.device | None = None) -> None:
87+
torch.xpu.reset_peak_memory_stats(device)
88+
89+
@classmethod
90+
def current_device_index(cls) -> int:
91+
return torch.xpu.current_device()
92+
8593
@classmethod
8694
def get_device_total_memory(cls, device_id: int = 0) -> int:
8795
device_props = torch.npu.get_device_properties(device_id)

vllm_omni/platforms/rocm/platform.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,11 @@ def synchronize(cls) -> None:
9999
def get_free_memory(cls, device: torch.device | None = None) -> int:
100100
free, _ = torch.cuda.mem_get_info(device)
101101
return free
102+
103+
@classmethod
104+
def reset_peak_memory_stats(cls, device: torch.device | None = None) -> None:
105+
torch.xpu.reset_peak_memory_stats(device)
106+
107+
@classmethod
108+
def current_device_index(cls) -> int:
109+
return torch.xpu.current_device()

vllm_omni/platforms/xpu/platform.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ def get_device_version(cls) -> str | None:
7171
def synchronize(cls) -> None:
7272
torch.xpu.synchronize()
7373

74+
@classmethod
75+
def reset_peak_memory_stats(cls, device: torch.device | None = None) -> None:
76+
torch.xpu.reset_peak_memory_stats(device)
77+
78+
@classmethod
79+
def current_device_index(cls) -> int:
80+
return torch.xpu.current_device()
81+
7482
@classmethod
7583
def get_free_memory(cls, device: torch.device | None = None) -> int:
7684
if device is None:

0 commit comments

Comments
 (0)