Skip to content

Conversation

@gameofdimension
Copy link
Member

test env

device Ascend910B2C
torch 2.7.1+cpu
torch_npu 2.7.1.dev20250724
CANN 8.2.RC1

cache

python generate.py qwen_image \
  --model-path /apps/dat/file/llm/model/Qwen-Image --cache --cpu-offload
qwen_image 1024x1024 C0_Q0_DBCache_F1B0_W8I1M0MC3_R0 24_CFG1_T0O0_S30

cache+tp

torchrun --master-port=22222 --nproc_per_node=2 generate.py qwen_image \
  --parallel tp --model-path /apps/dat/file/llm/model/Qwen-Image --cache --cpu-offload
qwen_image 1024x1024 C0_Q0_DBCache_F1B0_W8I1M0MC3_R0 24_CFG1_T0O0_TP2_S31

cache+ulysses

torchrun --master-port=22222 --nproc_per_node=2 generate.py qwen_image \
  --parallel ulysses --model-path /apps/dat/file/llm/model/Qwen-Image --cache --cpu-offload
qwen_image 1024x1024 C0_Q0_DBCache_F1B0_W8I1M0MC3_R0 24_CFG1_T0O0_Ulysses2_S30

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for Ascend NPU hardware by introducing a platform abstraction layer that handles device-specific operations across CPU, CUDA, and NPU platforms. The implementation enables seamless switching between different hardware accelerators without code duplication.

Key Changes:

  • Introduced a platform abstraction layer with CpuPlatform, CudaPlatform, and NPUPlatform classes that encapsulate device-specific operations
  • Replaced hardcoded CUDA API calls with platform-agnostic methods throughout the codebase
  • Added automatic platform detection based on available hardware (CUDA, NPU via torch_npu, or fallback to CPU)

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
src/cache_dit/platforms/platform.py Defines three platform classes (CpuPlatform, CudaPlatform, NPUPlatform) that encapsulate device-specific operations like memory management, device synchronization, and distributed backend configuration
src/cache_dit/platforms/init.py Implements platform detection logic that automatically selects the appropriate platform based on available hardware and provides lazy initialization of the current_platform singleton
src/cache_dit/utils.py Updates cache management functions to use platform-agnostic empty_cache() and ipc_collect() methods instead of hardcoded torch.cuda calls
src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_flux2.py Replaces hardcoded "cuda" device strings with current_platform.device_type for device placement during tensor parallelism operations
src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_zimage.py Updates torch.amp.autocast to use current_platform.device_type instead of hardcoded "cuda"
src/cache_dit/parallelism/controlnets/context_parallelism/cp_plan_zimage_controlnet.py Updates torch.amp.autocast to use current_platform.device_type for mixed precision training
src/cache_dit/parallelism/attention/_distributed_primitives.py Wraps kernel imports in try-except for NPU compatibility and updates device selection for distributed communication to use platform abstraction
examples/utils.py Updates MemoryTracker, device initialization, and distributed backend configuration to use platform-agnostic methods
examples/base.py Updates device name retrieval in summary output to use current_platform.get_device_name()
examples/api/run_transformer_only.py Replaces hardcoded "cuda" with current_platform.device_type for model placement
examples/api/run_steps_mask.py Updates pipeline device placement to use platform-agnostic device_type
examples/api/run_cache_refresh_wan_2.2.py Updates device count checks and device placement to use platform abstraction methods
examples/api/run_cache_refresh_flux.py Updates model placement to use current_platform.device_type

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +20 to +21


Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CpuPlatform class is missing several methods that are used throughout the codebase, including: current_device(), device_count(), set_device(), reset_peak_memory_stats(), max_memory_allocated(), get_device_properties(), synchronize(), empty_cache(), and ipc_collect(). These methods are called in examples/utils.py (e.g., MemoryTracker, GiB, get_rank_device) and other files. Without these methods, the code will raise AttributeError when CpuPlatform is selected and these methods are invoked.

Suggested change
@staticmethod
def empty_cache():
"""
CPU backend does not use a global device cache in the same way as CUDA/NPU.
This is a no-op provided for API compatibility.
"""
return None
@staticmethod
def ipc_collect():
"""
CPU backend does not provide IPC memory management analogous to CUDA/NPU.
This is a no-op provided for API compatibility.
"""
return None
@staticmethod
def synchronize(device=None):
"""
CPU operations are effectively synchronous; nothing to do here.
"""
return None
@staticmethod
def device_count():
"""
Conceptually, PyTorch exposes a single CPU device.
"""
return 1
@staticmethod
def current_device():
"""
For CPU, always return device index 0 for compatibility with CUDA/NPU APIs.
"""
return 0
@staticmethod
def set_device(device):
"""
Selecting a CPU 'device' has no effect; this is a no-op.
"""
return None
@staticmethod
def reset_peak_memory_stats(device=None):
"""
Peak memory statistics are not tracked via a dedicated CPU API.
This is a no-op provided for API compatibility.
"""
return None
@staticmethod
def max_memory_allocated(device=None):
"""
CPU backend does not report max allocated memory via a unified API.
Return 0 to indicate unsupported metrics while preserving numeric type.
"""
return 0
@staticmethod
def get_device_properties(device=None):
"""
Detailed device properties are not exposed for CPU in the same form
as CUDA/NPU. Return None to indicate that such information is
unavailable for this backend.
"""
return None

Copilot uses AI. Check for mistakes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gameofdimension Can we use a common abstract class to hold these undefined functions? For example:

from abc import abstractmethod

class BasePlatform:
    device_type: str = None
    device_control_env_var: str = None
    dispatch_key: str = None
    dist_backend: str = None
    full_dist_backend: str = None
    
    @abstractmethod
    @staticmethod
    def empty_cache():
         raise NotImplementedError("....") 

@DefTruth DefTruth mentioned this pull request Jan 7, 2026
17 tasks
gameofdimension and others added 2 commits January 7, 2026 17:53
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Copy link
Member

@DefTruth DefTruth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool~ lgtm~

@DefTruth DefTruth merged commit 47df7b9 into vipshop:main Jan 7, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants