Skip to content

Commit d62b0aa

Browse files
authored
feat: add abstract platform (#653)
* feat: add abstract platform * feat: add abstract platform * feat: add abstract platform
1 parent a4371b3 commit d62b0aa

File tree

2 files changed

+88
-6
lines changed

2 files changed

+88
-6
lines changed

src/cache_dit/platforms/__init__.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import torch
22
import importlib
3-
from typing import Any
3+
from typing import TYPE_CHECKING
4+
from .platform import BasePlatform
45

56

6-
def resolve_obj_by_qualname(qualname: str) -> Any:
7+
def resolve_obj_by_qualname(qualname: str) -> BasePlatform:
78
"""
89
Resolve an object by its fully-qualified class name.
910
"""
@@ -23,7 +24,11 @@ def resolve_current_platform_cls_qualname() -> str:
2324
return "cache_dit.platforms.platform.CpuPlatform"
2425

2526

26-
_current_platform = None
27+
_current_platform: BasePlatform = None
28+
29+
30+
if TYPE_CHECKING:
31+
current_platform: BasePlatform
2732

2833

2934
def __getattr__(name: str):
@@ -37,3 +42,16 @@ def __getattr__(name: str):
3742
return globals()[name]
3843
else:
3944
raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
45+
46+
47+
def __setattr__(name: str, value):
48+
if name == "current_platform":
49+
global _current_platform
50+
_current_platform = value
51+
elif name in globals():
52+
globals()[name] = value
53+
else:
54+
raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
55+
56+
57+
__all__ = ["BasePlatform", "current_platform"]

src/cache_dit/platforms/platform.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,72 @@
1+
# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/platforms
12
import torch
3+
from abc import ABC
24

35

4-
class CpuPlatform:
6+
class BasePlatform(ABC):
7+
device_type: str
8+
device_control_env_var: str
9+
dispatch_key: str
10+
dist_backend: str
11+
full_dist_backend: str
12+
13+
@staticmethod
14+
def empty_cache(*args, **kwargs):
15+
raise NotImplementedError
16+
17+
@staticmethod
18+
def ipc_collect(*args, **kwargs):
19+
raise NotImplementedError
20+
21+
@staticmethod
22+
def get_device_name():
23+
raise NotImplementedError
24+
25+
@staticmethod
26+
def device_ctx(*args, **kwargs):
27+
raise NotImplementedError
28+
29+
@staticmethod
30+
def default_device(*args, **kwargs):
31+
raise NotImplementedError
32+
33+
@staticmethod
34+
def synchronize(*args, **kwargs):
35+
raise NotImplementedError
36+
37+
@staticmethod
38+
def device_count(*args, **kwargs):
39+
raise NotImplementedError
40+
41+
@staticmethod
42+
def is_accelerator_available(*args, **kwargs):
43+
raise NotImplementedError
44+
45+
@staticmethod
46+
def current_device(*args, **kwargs):
47+
raise NotImplementedError
48+
49+
@staticmethod
50+
def reset_peak_memory_stats(*args, **kwargs):
51+
raise NotImplementedError
52+
53+
@staticmethod
54+
def max_memory_allocated(*args, **kwargs):
55+
raise NotImplementedError
56+
57+
@staticmethod
58+
def get_device_properties(*args, **kwargs):
59+
raise NotImplementedError
60+
61+
@staticmethod
62+
def set_device(*args, **kwargs):
63+
raise NotImplementedError
64+
65+
66+
class CpuPlatform(BasePlatform):
567
device_type: str = "cpu"
68+
dispatch_key: str = "CPU"
69+
device_control_env_var = "CPU_VISIBLE_MEMORY_NODES"
670
dist_backend: str = "gloo"
771
full_dist_backend: str = "cpu:gloo"
872

@@ -19,7 +83,7 @@ def is_accelerator_available():
1983
return False
2084

2185

22-
class CudaPlatform:
86+
class CudaPlatform(BasePlatform):
2387
device_type: str = "cuda"
2488
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
2589
dispatch_key: str = "CUDA"
@@ -79,7 +143,7 @@ def set_device(device):
79143
return torch.cuda.set_device(device)
80144

81145

82-
class NPUPlatform:
146+
class NPUPlatform(BasePlatform):
83147
device_type: str = "npu"
84148
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
85149
dispatch_key: str = "PrivateUse1"

0 commit comments

Comments
 (0)