Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions src/cache_dit/platforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
import importlib
from typing import Any
from typing import TYPE_CHECKING
from .platform import BasePlatform


def resolve_obj_by_qualname(qualname: str) -> Any:
def resolve_obj_by_qualname(qualname: str) -> BasePlatform:
"""
Resolve an object by its fully-qualified class name.
"""
Expand All @@ -23,7 +24,11 @@ def resolve_current_platform_cls_qualname() -> str:
return "cache_dit.platforms.platform.CpuPlatform"


_current_platform = None
_current_platform: BasePlatform = None


if TYPE_CHECKING:
current_platform: BasePlatform


def __getattr__(name: str):
Expand All @@ -37,3 +42,16 @@ def __getattr__(name: str):
return globals()[name]
else:
raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")


def __setattr__(name: str, value):
if name == "current_platform":
global _current_platform
_current_platform = value
elif name in globals():
globals()[name] = value
else:
raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")


__all__ = ["BasePlatform", "current_platform"]
70 changes: 67 additions & 3 deletions src/cache_dit/platforms/platform.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,72 @@
# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/platforms
import torch
from abc import ABC


class CpuPlatform:
class BasePlatform(ABC):
device_type: str
device_control_env_var: str
dispatch_key: str
dist_backend: str
full_dist_backend: str

@staticmethod
def empty_cache(*args, **kwargs):
raise NotImplementedError

@staticmethod
def ipc_collect(*args, **kwargs):
raise NotImplementedError

@staticmethod
def get_device_name():
raise NotImplementedError

@staticmethod
def device_ctx(*args, **kwargs):
raise NotImplementedError

@staticmethod
def default_device(*args, **kwargs):
raise NotImplementedError

@staticmethod
def synchronize(*args, **kwargs):
raise NotImplementedError

@staticmethod
def device_count(*args, **kwargs):
raise NotImplementedError

@staticmethod
def is_accelerator_available(*args, **kwargs):
raise NotImplementedError

@staticmethod
def current_device(*args, **kwargs):
raise NotImplementedError

@staticmethod
def reset_peak_memory_stats(*args, **kwargs):
raise NotImplementedError

@staticmethod
def max_memory_allocated(*args, **kwargs):
raise NotImplementedError

@staticmethod
def get_device_properties(*args, **kwargs):
raise NotImplementedError

@staticmethod
def set_device(*args, **kwargs):
raise NotImplementedError


class CpuPlatform(BasePlatform):
device_type: str = "cpu"
dispatch_key: str = "CPU"
device_control_env_var = "CPU_VISIBLE_MEMORY_NODES"
dist_backend: str = "gloo"
full_dist_backend: str = "cpu:gloo"

Expand All @@ -19,7 +83,7 @@ def is_accelerator_available():
return False


class CudaPlatform:
class CudaPlatform(BasePlatform):
device_type: str = "cuda"
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
dispatch_key: str = "CUDA"
Expand Down Expand Up @@ -79,7 +143,7 @@ def set_device(device):
return torch.cuda.set_device(device)


class NPUPlatform:
class NPUPlatform(BasePlatform):
device_type: str = "npu"
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
dispatch_key: str = "PrivateUse1"
Expand Down
Loading