1+ # Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/platforms
12import torch
3+ from abc import ABC
24
35
4- class CpuPlatform :
6+ class BasePlatform (ABC ):
7+ device_type : str
8+ device_control_env_var : str
9+ dispatch_key : str
10+ dist_backend : str
11+ full_dist_backend : str
12+
13+ @staticmethod
14+ def empty_cache (* args , ** kwargs ):
15+ raise NotImplementedError
16+
17+ @staticmethod
18+ def ipc_collect (* args , ** kwargs ):
19+ raise NotImplementedError
20+
21+ @staticmethod
22+ def get_device_name ():
23+ raise NotImplementedError
24+
25+ @staticmethod
26+ def device_ctx (* args , ** kwargs ):
27+ raise NotImplementedError
28+
29+ @staticmethod
30+ def default_device (* args , ** kwargs ):
31+ raise NotImplementedError
32+
33+ @staticmethod
34+ def synchronize (* args , ** kwargs ):
35+ raise NotImplementedError
36+
37+ @staticmethod
38+ def device_count (* args , ** kwargs ):
39+ raise NotImplementedError
40+
41+ @staticmethod
42+ def is_accelerator_available (* args , ** kwargs ):
43+ raise NotImplementedError
44+
45+ @staticmethod
46+ def current_device (* args , ** kwargs ):
47+ raise NotImplementedError
48+
49+ @staticmethod
50+ def reset_peak_memory_stats (* args , ** kwargs ):
51+ raise NotImplementedError
52+
53+ @staticmethod
54+ def max_memory_allocated (* args , ** kwargs ):
55+ raise NotImplementedError
56+
57+ @staticmethod
58+ def get_device_properties (* args , ** kwargs ):
59+ raise NotImplementedError
60+
61+ @staticmethod
62+ def set_device (* args , ** kwargs ):
63+ raise NotImplementedError
64+
65+
66+ class CpuPlatform (BasePlatform ):
567 device_type : str = "cpu"
68+ dispatch_key : str = "CPU"
69+ device_control_env_var = "CPU_VISIBLE_MEMORY_NODES"
670 dist_backend : str = "gloo"
771 full_dist_backend : str = "cpu:gloo"
872
@@ -19,7 +83,7 @@ def is_accelerator_available():
1983 return False
2084
2185
22- class CudaPlatform :
86+ class CudaPlatform ( BasePlatform ) :
2387 device_type : str = "cuda"
2488 device_control_env_var : str = "CUDA_VISIBLE_DEVICES"
2589 dispatch_key : str = "CUDA"
@@ -79,7 +143,7 @@ def set_device(device):
79143 return torch .cuda .set_device (device )
80144
81145
82- class NPUPlatform :
146+ class NPUPlatform ( BasePlatform ) :
83147 device_type : str = "npu"
84148 device_control_env_var : str = "ASCEND_RT_VISIBLE_DEVICES"
85149 dispatch_key : str = "PrivateUse1"
0 commit comments