-
Notifications
You must be signed in to change notification settings - Fork 56
feat: support ascend npu #651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds support for Ascend NPU hardware by introducing a platform abstraction layer that handles device-specific operations across CPU, CUDA, and NPU platforms. The implementation enables seamless switching between different hardware accelerators without code duplication.
Key Changes:
- Introduced a platform abstraction layer with
CpuPlatform,CudaPlatform, andNPUPlatformclasses that encapsulate device-specific operations - Replaced hardcoded CUDA API calls with platform-agnostic methods throughout the codebase
- Added automatic platform detection based on available hardware (CUDA, NPU via torch_npu, or fallback to CPU)
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| src/cache_dit/platforms/platform.py | Defines three platform classes (CpuPlatform, CudaPlatform, NPUPlatform) that encapsulate device-specific operations like memory management, device synchronization, and distributed backend configuration |
| src/cache_dit/platforms/init.py | Implements platform detection logic that automatically selects the appropriate platform based on available hardware and provides lazy initialization of the current_platform singleton |
| src/cache_dit/utils.py | Updates cache management functions to use platform-agnostic empty_cache() and ipc_collect() methods instead of hardcoded torch.cuda calls |
| src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_flux2.py | Replaces hardcoded "cuda" device strings with current_platform.device_type for device placement during tensor parallelism operations |
| src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_zimage.py | Updates torch.amp.autocast to use current_platform.device_type instead of hardcoded "cuda" |
| src/cache_dit/parallelism/controlnets/context_parallelism/cp_plan_zimage_controlnet.py | Updates torch.amp.autocast to use current_platform.device_type for mixed precision training |
| src/cache_dit/parallelism/attention/_distributed_primitives.py | Wraps kernel imports in try-except for NPU compatibility and updates device selection for distributed communication to use platform abstraction |
| examples/utils.py | Updates MemoryTracker, device initialization, and distributed backend configuration to use platform-agnostic methods |
| examples/base.py | Updates device name retrieval in summary output to use current_platform.get_device_name() |
| examples/api/run_transformer_only.py | Replaces hardcoded "cuda" with current_platform.device_type for model placement |
| examples/api/run_steps_mask.py | Updates pipeline device placement to use platform-agnostic device_type |
| examples/api/run_cache_refresh_wan_2.2.py | Updates device count checks and device placement to use platform abstraction methods |
| examples/api/run_cache_refresh_flux.py | Updates model placement to use current_platform.device_type |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
|
|
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The CpuPlatform class is missing several methods that are used throughout the codebase, including: current_device(), device_count(), set_device(), reset_peak_memory_stats(), max_memory_allocated(), get_device_properties(), synchronize(), empty_cache(), and ipc_collect(). These methods are called in examples/utils.py (e.g., MemoryTracker, GiB, get_rank_device) and other files. Without these methods, the code will raise AttributeError when CpuPlatform is selected and these methods are invoked.
| @staticmethod | |
| def empty_cache(): | |
| """ | |
| CPU backend does not use a global device cache in the same way as CUDA/NPU. | |
| This is a no-op provided for API compatibility. | |
| """ | |
| return None | |
| @staticmethod | |
| def ipc_collect(): | |
| """ | |
| CPU backend does not provide IPC memory management analogous to CUDA/NPU. | |
| This is a no-op provided for API compatibility. | |
| """ | |
| return None | |
| @staticmethod | |
| def synchronize(device=None): | |
| """ | |
| CPU operations are effectively synchronous; nothing to do here. | |
| """ | |
| return None | |
| @staticmethod | |
| def device_count(): | |
| """ | |
| Conceptually, PyTorch exposes a single CPU device. | |
| """ | |
| return 1 | |
| @staticmethod | |
| def current_device(): | |
| """ | |
| For CPU, always return device index 0 for compatibility with CUDA/NPU APIs. | |
| """ | |
| return 0 | |
| @staticmethod | |
| def set_device(device): | |
| """ | |
| Selecting a CPU 'device' has no effect; this is a no-op. | |
| """ | |
| return None | |
| @staticmethod | |
| def reset_peak_memory_stats(device=None): | |
| """ | |
| Peak memory statistics are not tracked via a dedicated CPU API. | |
| This is a no-op provided for API compatibility. | |
| """ | |
| return None | |
| @staticmethod | |
| def max_memory_allocated(device=None): | |
| """ | |
| CPU backend does not report max allocated memory via a unified API. | |
| Return 0 to indicate unsupported metrics while preserving numeric type. | |
| """ | |
| return 0 | |
| @staticmethod | |
| def get_device_properties(device=None): | |
| """ | |
| Detailed device properties are not exposed for CPU in the same form | |
| as CUDA/NPU. Return None to indicate that such information is | |
| unavailable for this backend. | |
| """ | |
| return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gameofdimension Can we use a common abstract class to hold these undefined functions? For example:
from abc import abstractmethod
class BasePlatform:
device_type: str = None
device_control_env_var: str = None
dispatch_key: str = None
dist_backend: str = None
full_dist_backend: str = None
@abstractmethod
@staticmethod
def empty_cache():
raise NotImplementedError("....") Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
DefTruth
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool~ lgtm~
test env
cache
cache+tp
cache+ulysses