Skip to content

Commit 47df7b9

Browse files
gameofdimensionfelix01.yuCopilot
authored
feat: support ascend npu (#651)
* support npu platform * support npu platform * support npu platform * support npu platform * fix cp import * revert * fix destroy group * better error handle Co-authored-by: Copilot <[email protected]> * remove redundant import Co-authored-by: Copilot <[email protected]> --------- Co-authored-by: felix01.yu <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent cb14d5d commit 47df7b9

File tree

13 files changed

+259
-41
lines changed

13 files changed

+259
-41
lines changed

examples/api/run_cache_refresh_flux.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
TaylorSeerCalibratorConfig,
1616
)
1717
import cache_dit
18-
18+
from cache_dit.platforms import current_platform
1919

2020
parser = get_args(parse=False)
2121
parser.add_argument(
@@ -50,7 +50,7 @@
5050
)
5151
),
5252
torch_dtype=torch.bfloat16,
53-
).to("cuda")
53+
).to(current_platform.device_type)
5454

5555
if args.cache:
5656

examples/api/run_cache_refresh_wan_2.2.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414
from utils import get_args, GiB, strify, MemoryTracker
1515
import cache_dit
16-
16+
from cache_dit.platforms import current_platform
1717

1818
parser = get_args(parse=False)
1919
parser.add_argument(
@@ -38,7 +38,7 @@
3838
),
3939
torch_dtype=torch.bfloat16,
4040
# https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement
41-
device_map=("balanced" if (torch.cuda.device_count() > 1 and GiB() <= 48) else None),
41+
device_map=("balanced" if (current_platform.device_count() > 1 and GiB() <= 48) else None),
4242
)
4343

4444
# flow shift should be 3.0 for 480p images, 5.0 for 720p images
@@ -109,12 +109,12 @@
109109

110110
# When device_map is None, we need to explicitly move the model to GPU
111111
# or enable CPU offload to avoid running on CPU
112-
if torch.cuda.device_count() <= 1:
112+
if current_platform.device_count() <= 1:
113113
# Single GPU: use CPU offload for memory efficiency
114114
pipe.enable_model_cpu_offload()
115-
elif torch.cuda.device_count() > 1 and pipe.device.type == "cpu":
115+
elif current_platform.device_count() > 1 and pipe.device.type == "cpu":
116116
# Multi-GPU but model is on CPU (device_map was None): move to default GPU
117-
pipe.to("cuda")
117+
pipe.to(current_platform.device_type)
118118

119119
# Wan currently requires installing diffusers from source
120120
assert isinstance(pipe.vae, AutoencoderKLWan) # enable type check for IDE
@@ -164,7 +164,7 @@ def split_inference_steps(num_inference_steps: int = 30) -> tuple[int, int]:
164164
boundary_timestep = pipe.config.boundary_ratio * pipe.scheduler.config.num_train_timesteps
165165
else:
166166
boundary_timestep = None
167-
pipe.scheduler.set_timesteps(num_inference_steps, device="cuda")
167+
pipe.scheduler.set_timesteps(num_inference_steps, device=current_platform.device_type)
168168
timesteps = pipe.scheduler.timesteps
169169
num_high_noise_steps = 0 # high-noise steps for transformer
170170
for t in timesteps:

examples/api/run_steps_mask.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from diffusers import FluxPipeline, FluxTransformer2DModel
99
from utils import get_args, strify, MemoryTracker
1010
import cache_dit
11+
from cache_dit.platforms import current_platform
1112

1213

1314
parser = get_args(parse=False)
@@ -110,7 +111,7 @@
110111
)
111112
print(f"Applied quantization: {args.quantize_type} to Transformer and Text Encoder 2.")
112113

113-
pipe.to("cuda")
114+
pipe.to(current_platform.device_type)
114115

115116
if args.attn is not None:
116117
if hasattr(pipe.transformer, "set_attention_backend"):

examples/api/run_transformer_only.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from diffusers import FluxPipeline, FluxTransformer2DModel
99
from utils import get_args, strify, MemoryTracker
1010
import cache_dit
11+
from cache_dit.platforms import current_platform
1112

1213

1314
parser = get_args(parse=False)
@@ -31,7 +32,7 @@
3132
)
3233
),
3334
torch_dtype=torch.bfloat16,
34-
).to("cuda")
35+
).to(current_platform.device_type)
3536

3637
if args.cache:
3738
from cache_dit import (

examples/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,11 @@ def _default_save_path(self) -> Optional[str]:
292292
return None
293293

294294
def summary(self, args: argparse.Namespace) -> str:
295+
from cache_dit.platforms import current_platform
296+
295297
logger.info("🤖 Example Output Summary:")
296298
summary_str = f"- Model: {args.example}\n- Optimization: {self.strify_tag}\n"
297-
device_name = torch.cuda.get_device_name() if torch.cuda.is_available() else "CPU"
299+
device_name = current_platform.get_device_name()
298300
world_size = (
299301
1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size()
300302
)

examples/utils.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,29 @@
1717
TaylorSeerCalibratorConfig,
1818
)
1919

20+
from cache_dit.platforms import current_platform
21+
2022
logger = init_logger(__name__)
2123

2224

2325
class MemoryTracker:
2426
"""Track peak GPU memory usage during execution."""
2527

2628
def __init__(self, device=None):
27-
self.device = device if device is not None else torch.cuda.current_device()
28-
self.enabled = torch.cuda.is_available()
29+
self.device = device if device is not None else current_platform.current_device()
30+
self.enabled = current_platform.is_accelerator_available()
2931
self.peak_memory = 0
3032

3133
def __enter__(self):
3234
if self.enabled:
33-
torch.cuda.reset_peak_memory_stats(self.device)
34-
torch.cuda.synchronize(self.device)
35+
current_platform.reset_peak_memory_stats(self.device)
36+
current_platform.synchronize(self.device)
3537
return self
3638

3739
def __exit__(self, exc_type, exc_val, exc_tb):
3840
if self.enabled:
39-
torch.cuda.synchronize(self.device)
40-
self.peak_memory = torch.cuda.max_memory_allocated(self.device)
41+
current_platform.synchronize(self.device)
42+
self.peak_memory = current_platform.max_memory_allocated(self.device)
4143

4244
def get_peak_memory_gb(self):
4345
"""Get peak memory in GB."""
@@ -54,10 +56,10 @@ def report(self):
5456

5557
def GiB():
5658
try:
57-
if not torch.cuda.is_available():
59+
if not current_platform.is_accelerator_available():
5860
return 0
59-
total_memory_bytes = torch.cuda.get_device_properties(
60-
torch.cuda.current_device(),
61+
total_memory_bytes = current_platform.get_device_properties(
62+
current_platform.current_device(),
6163
).total_memory
6264
total_memory_gib = total_memory_bytes / (1024**3)
6365
return int(total_memory_gib)
@@ -1346,21 +1348,32 @@ def strify(args, pipe_or_stats):
13461348

13471349

13481350
def get_rank_device():
1351+
available = current_platform.is_accelerator_available()
1352+
device_type = current_platform.device_type
13491353
if dist.is_initialized():
13501354
rank = dist.get_rank()
1351-
device = torch.device("cuda", rank % torch.cuda.device_count())
1355+
device = torch.device(device_type, rank % current_platform.device_count())
13521356
return rank, device
1353-
return 0, torch.device("cuda" if torch.cuda.is_available() else "cpu")
1357+
return 0, torch.device(device_type if available else "cpu")
13541358

13551359

13561360
def maybe_init_distributed(args=None):
1361+
from cache_dit.platforms.platform import CpuPlatform
1362+
1363+
platform_full_backend = current_platform.full_dist_backend
1364+
cpu_full_backend = CpuPlatform.full_dist_backend
1365+
backend = (
1366+
f"{cpu_full_backend},{platform_full_backend}"
1367+
if args.ulysses_anything
1368+
else current_platform.dist_backend
1369+
)
13571370
if args is not None:
13581371
if args.parallel_type is not None:
13591372
dist.init_process_group(
1360-
backend="cpu:gloo,cuda:nccl" if args.ulysses_anything else "nccl",
1373+
backend=backend,
13611374
)
13621375
rank, device = get_rank_device()
1363-
torch.cuda.set_device(device)
1376+
current_platform.set_device(device)
13641377
return rank, device
13651378
else:
13661379
# no distributed needed
@@ -1370,10 +1383,10 @@ def maybe_init_distributed(args=None):
13701383
# always init distributed for other examples
13711384
if not dist.is_initialized():
13721385
dist.init_process_group(
1373-
backend="nccl",
1386+
backend=platform_full_backend,
13741387
)
13751388
rank, device = get_rank_device()
1376-
torch.cuda.set_device(device)
1389+
current_platform.set_device(device)
13771390
return rank, device
13781391

13791392

src/cache_dit/parallelism/attention/_distributed_primitives.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,28 @@
66
import torch.distributed._functional_collectives as fc
77
import torch.nn.functional as F
88

9-
from cache_dit.kernels import (
10-
per_token_quant_fp8,
11-
per_token_dequant_fp8,
12-
qkv_permute_quant_fp8,
13-
qkv_dequant_permute_fp8,
14-
)
9+
from cache_dit.platforms import current_platform
10+
11+
try:
12+
from cache_dit.kernels import (
13+
per_token_quant_fp8,
14+
per_token_dequant_fp8,
15+
qkv_permute_quant_fp8,
16+
qkv_dequant_permute_fp8,
17+
)
18+
except ImportError:
19+
20+
def _fp8_kernel_unavailable(*args, **kwargs):
21+
raise RuntimeError(
22+
"FP8 kernels could not be imported (e.g., Triton may not be available on this "
23+
"platform). FP8 async operations are not supported. Please install the required "
24+
"dependencies or disable FP8 mode."
25+
)
26+
27+
per_token_quant_fp8 = _fp8_kernel_unavailable
28+
per_token_dequant_fp8 = _fp8_kernel_unavailable
29+
qkv_permute_quant_fp8 = _fp8_kernel_unavailable
30+
qkv_dequant_permute_fp8 = _fp8_kernel_unavailable
1531
from cache_dit.logger import init_logger
1632

1733
logger = init_logger(__name__)
@@ -72,7 +88,7 @@ def _gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]:
7288
# HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead
7389
comm_backends = str(dist.get_backend(group=group))
7490
# NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl")
75-
gather_device = "cpu" if "cpu" in comm_backends else torch.device("cuda")
91+
gather_device = "cpu" if "cpu" in comm_backends else current_platform.default_device()
7692
gathered_sizes = [
7793
torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)
7894
]

src/cache_dit/parallelism/controlnets/context_parallelism/cp_plan_zimage_controlnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from cache_dit.parallelism.attention import _unified_all_to_all_o_async_fn
3030
from cache_dit.parallelism.attention import _unified_all_to_all_qkv_async_fn
3131
from cache_dit.parallelism.attention import _prepare_ulysses_comm_metadata
32+
from cache_dit.platforms import current_platform
3233

3334
from cache_dit.logger import init_logger
3435

@@ -112,7 +113,7 @@ def _ulysses_attn_with_async_qkv_proj_zimage_controlnet(
112113

113114
# Apply RoPE
114115
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
115-
with torch.amp.autocast("cuda", enabled=False):
116+
with torch.amp.autocast(current_platform.device_type, enabled=False):
116117
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
117118
freqs_cis = freqs_cis.unsqueeze(2)
118119
x_out = torch.view_as_real(x * freqs_cis).flatten(3)

src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_zimage.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from cache_dit.parallelism.attention import _unified_all_to_all_qkv_async_fn
3131
from cache_dit.parallelism.attention import _prepare_ulysses_comm_metadata
3232
from cache_dit.parallelism.attention import _maybe_patch_find_submodule
33+
from cache_dit.platforms import current_platform
3334

3435
from cache_dit.logger import init_logger
3536

@@ -193,7 +194,7 @@ def _ulysses_attn_with_async_qkv_proj_zimage(
193194

194195
# Apply RoPE
195196
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
196-
with torch.amp.autocast("cuda", enabled=False):
197+
with torch.amp.autocast(current_platform.device_type, enabled=False):
197198
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
198199
freqs_cis = freqs_cis.unsqueeze(2)
199200
x_out = torch.view_as_real(x * freqs_cis).flatten(3)

src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_flux2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from cache_dit.logger import init_logger
1313
from cache_dit.parallelism.config import ParallelismConfig
1414
from cache_dit.utils import maybe_empty_cache
15+
from cache_dit.platforms import current_platform
1516

1617
from .tp_plan_registers import TensorParallelismPlanner, TensorParallelismPlannerRegister
1718
from .tp_utils import shard_divisible_attr
@@ -104,7 +105,7 @@ def parallelize_transformer(
104105
for _, block in transformer.transformer_blocks.named_children():
105106
# moving to cuda speed up the rearrangement process significantly
106107
old_device = next(block.parameters()).device
107-
block.to("cuda")
108+
block.to(current_platform.device_type)
108109
self.rearrange_feedforward_weight(block, tp_size)
109110
block.to(old_device)
110111
shard_divisible_attr(
@@ -139,7 +140,7 @@ def parallelize_transformer(
139140
for _, block in transformer.single_transformer_blocks.named_children():
140141
# moving to cuda speed up the rearrangement process significantly
141142
old_device = next(block.parameters()).device
142-
block.to("cuda")
143+
block.to(current_platform.device_type)
143144
self.rearrange_singleblock_weight(block, tp_size)
144145
block.to(old_device)
145146
shard_divisible_attr(

0 commit comments

Comments
 (0)