Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/api/run_cache_refresh_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
TaylorSeerCalibratorConfig,
)
import cache_dit

from cache_dit.platforms import current_platform

parser = get_args(parse=False)
parser.add_argument(
Expand Down Expand Up @@ -50,7 +50,7 @@
)
),
torch_dtype=torch.bfloat16,
).to("cuda")
).to(current_platform.device_type)

if args.cache:

Expand Down
12 changes: 6 additions & 6 deletions examples/api/run_cache_refresh_wan_2.2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from utils import get_args, GiB, strify, MemoryTracker
import cache_dit

from cache_dit.platforms import current_platform

parser = get_args(parse=False)
parser.add_argument(
Expand All @@ -38,7 +38,7 @@
),
torch_dtype=torch.bfloat16,
# https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement
device_map=("balanced" if (torch.cuda.device_count() > 1 and GiB() <= 48) else None),
device_map=("balanced" if (current_platform.device_count() > 1 and GiB() <= 48) else None),
)

# flow shift should be 3.0 for 480p images, 5.0 for 720p images
Expand Down Expand Up @@ -109,12 +109,12 @@

# When device_map is None, we need to explicitly move the model to GPU
# or enable CPU offload to avoid running on CPU
if torch.cuda.device_count() <= 1:
if current_platform.device_count() <= 1:
# Single GPU: use CPU offload for memory efficiency
pipe.enable_model_cpu_offload()
elif torch.cuda.device_count() > 1 and pipe.device.type == "cpu":
elif current_platform.device_count() > 1 and pipe.device.type == "cpu":
# Multi-GPU but model is on CPU (device_map was None): move to default GPU
pipe.to("cuda")
pipe.to(current_platform.device_type)

# Wan currently requires installing diffusers from source
assert isinstance(pipe.vae, AutoencoderKLWan) # enable type check for IDE
Expand Down Expand Up @@ -164,7 +164,7 @@ def split_inference_steps(num_inference_steps: int = 30) -> tuple[int, int]:
boundary_timestep = pipe.config.boundary_ratio * pipe.scheduler.config.num_train_timesteps
else:
boundary_timestep = None
pipe.scheduler.set_timesteps(num_inference_steps, device="cuda")
pipe.scheduler.set_timesteps(num_inference_steps, device=current_platform.device_type)
timesteps = pipe.scheduler.timesteps
num_high_noise_steps = 0 # high-noise steps for transformer
for t in timesteps:
Expand Down
3 changes: 2 additions & 1 deletion examples/api/run_steps_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from diffusers import FluxPipeline, FluxTransformer2DModel
from utils import get_args, strify, MemoryTracker
import cache_dit
from cache_dit.platforms import current_platform


parser = get_args(parse=False)
Expand Down Expand Up @@ -110,7 +111,7 @@
)
print(f"Applied quantization: {args.quantize_type} to Transformer and Text Encoder 2.")

pipe.to("cuda")
pipe.to(current_platform.device_type)

if args.attn is not None:
if hasattr(pipe.transformer, "set_attention_backend"):
Expand Down
3 changes: 2 additions & 1 deletion examples/api/run_transformer_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from diffusers import FluxPipeline, FluxTransformer2DModel
from utils import get_args, strify, MemoryTracker
import cache_dit
from cache_dit.platforms import current_platform


parser = get_args(parse=False)
Expand All @@ -31,7 +32,7 @@
)
),
torch_dtype=torch.bfloat16,
).to("cuda")
).to(current_platform.device_type)

if args.cache:
from cache_dit import (
Expand Down
4 changes: 3 additions & 1 deletion examples/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,11 @@ def _default_save_path(self) -> Optional[str]:
return None

def summary(self, args: argparse.Namespace) -> str:
from cache_dit.platforms import current_platform

logger.info("🤖 Example Output Summary:")
summary_str = f"- Model: {args.example}\n- Optimization: {self.strify_tag}\n"
device_name = torch.cuda.get_device_name() if torch.cuda.is_available() else "CPU"
device_name = current_platform.get_device_name()
world_size = (
1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size()
)
Expand Down
43 changes: 28 additions & 15 deletions examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,29 @@
TaylorSeerCalibratorConfig,
)

from cache_dit.platforms import current_platform

logger = init_logger(__name__)


class MemoryTracker:
"""Track peak GPU memory usage during execution."""

def __init__(self, device=None):
self.device = device if device is not None else torch.cuda.current_device()
self.enabled = torch.cuda.is_available()
self.device = device if device is not None else current_platform.current_device()
self.enabled = current_platform.is_accelerator_available()
self.peak_memory = 0

def __enter__(self):
if self.enabled:
torch.cuda.reset_peak_memory_stats(self.device)
torch.cuda.synchronize(self.device)
current_platform.reset_peak_memory_stats(self.device)
current_platform.synchronize(self.device)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self.enabled:
torch.cuda.synchronize(self.device)
self.peak_memory = torch.cuda.max_memory_allocated(self.device)
current_platform.synchronize(self.device)
self.peak_memory = current_platform.max_memory_allocated(self.device)

def get_peak_memory_gb(self):
"""Get peak memory in GB."""
Expand All @@ -54,10 +56,10 @@ def report(self):

def GiB():
try:
if not torch.cuda.is_available():
if not current_platform.is_accelerator_available():
return 0
total_memory_bytes = torch.cuda.get_device_properties(
torch.cuda.current_device(),
total_memory_bytes = current_platform.get_device_properties(
current_platform.current_device(),
).total_memory
total_memory_gib = total_memory_bytes / (1024**3)
return int(total_memory_gib)
Expand Down Expand Up @@ -1346,21 +1348,32 @@ def strify(args, pipe_or_stats):


def get_rank_device():
available = current_platform.is_accelerator_available()
device_type = current_platform.device_type
if dist.is_initialized():
rank = dist.get_rank()
device = torch.device("cuda", rank % torch.cuda.device_count())
device = torch.device(device_type, rank % current_platform.device_count())
return rank, device
return 0, torch.device("cuda" if torch.cuda.is_available() else "cpu")
return 0, torch.device(device_type if available else "cpu")


def maybe_init_distributed(args=None):
from cache_dit.platforms.platform import CpuPlatform

platform_full_backend = current_platform.full_dist_backend
cpu_full_backend = CpuPlatform.full_dist_backend
backend = (
f"{cpu_full_backend},{platform_full_backend}"
if args.ulysses_anything
else current_platform.dist_backend
)
if args is not None:
if args.parallel_type is not None:
dist.init_process_group(
backend="cpu:gloo,cuda:nccl" if args.ulysses_anything else "nccl",
backend=backend,
)
rank, device = get_rank_device()
torch.cuda.set_device(device)
current_platform.set_device(device)
return rank, device
else:
# no distributed needed
Expand All @@ -1370,10 +1383,10 @@ def maybe_init_distributed(args=None):
# always init distributed for other examples
if not dist.is_initialized():
dist.init_process_group(
backend="nccl",
backend=platform_full_backend,
)
rank, device = get_rank_device()
torch.cuda.set_device(device)
current_platform.set_device(device)
return rank, device


Expand Down
30 changes: 23 additions & 7 deletions src/cache_dit/parallelism/attention/_distributed_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,28 @@
import torch.distributed._functional_collectives as fc
import torch.nn.functional as F

from cache_dit.kernels import (
per_token_quant_fp8,
per_token_dequant_fp8,
qkv_permute_quant_fp8,
qkv_dequant_permute_fp8,
)
from cache_dit.platforms import current_platform

try:
from cache_dit.kernels import (
per_token_quant_fp8,
per_token_dequant_fp8,
qkv_permute_quant_fp8,
qkv_dequant_permute_fp8,
)
except ImportError:

def _fp8_kernel_unavailable(*args, **kwargs):
raise RuntimeError(
"FP8 kernels could not be imported (e.g., Triton may not be available on this "
"platform). FP8 async operations are not supported. Please install the required "
"dependencies or disable FP8 mode."
)

per_token_quant_fp8 = _fp8_kernel_unavailable
per_token_dequant_fp8 = _fp8_kernel_unavailable
qkv_permute_quant_fp8 = _fp8_kernel_unavailable
qkv_dequant_permute_fp8 = _fp8_kernel_unavailable
from cache_dit.logger import init_logger

logger = init_logger(__name__)
Expand Down Expand Up @@ -72,7 +88,7 @@ def _gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]:
# HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead
comm_backends = str(dist.get_backend(group=group))
# NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl")
gather_device = "cpu" if "cpu" in comm_backends else torch.device("cuda")
gather_device = "cpu" if "cpu" in comm_backends else current_platform.default_device()
gathered_sizes = [
torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from cache_dit.parallelism.attention import _unified_all_to_all_o_async_fn
from cache_dit.parallelism.attention import _unified_all_to_all_qkv_async_fn
from cache_dit.parallelism.attention import _prepare_ulysses_comm_metadata
from cache_dit.platforms import current_platform

from cache_dit.logger import init_logger

Expand Down Expand Up @@ -112,7 +113,7 @@ def _ulysses_attn_with_async_qkv_proj_zimage_controlnet(

# Apply RoPE
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast("cuda", enabled=False):
with torch.amp.autocast(current_platform.device_type, enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from cache_dit.parallelism.attention import _unified_all_to_all_qkv_async_fn
from cache_dit.parallelism.attention import _prepare_ulysses_comm_metadata
from cache_dit.parallelism.attention import _maybe_patch_find_submodule
from cache_dit.platforms import current_platform

from cache_dit.logger import init_logger

Expand Down Expand Up @@ -193,7 +194,7 @@ def _ulysses_attn_with_async_qkv_proj_zimage(

# Apply RoPE
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast("cuda", enabled=False):
with torch.amp.autocast(current_platform.device_type, enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from cache_dit.logger import init_logger
from cache_dit.parallelism.config import ParallelismConfig
from cache_dit.utils import maybe_empty_cache
from cache_dit.platforms import current_platform

from .tp_plan_registers import TensorParallelismPlanner, TensorParallelismPlannerRegister
from .tp_utils import shard_divisible_attr
Expand Down Expand Up @@ -104,7 +105,7 @@ def parallelize_transformer(
for _, block in transformer.transformer_blocks.named_children():
# moving to cuda speed up the rearrangement process significantly
old_device = next(block.parameters()).device
block.to("cuda")
block.to(current_platform.device_type)
self.rearrange_feedforward_weight(block, tp_size)
block.to(old_device)
shard_divisible_attr(
Expand Down Expand Up @@ -139,7 +140,7 @@ def parallelize_transformer(
for _, block in transformer.single_transformer_blocks.named_children():
# moving to cuda speed up the rearrangement process significantly
old_device = next(block.parameters()).device
block.to("cuda")
block.to(current_platform.device_type)
self.rearrange_singleblock_weight(block, tp_size)
block.to(old_device)
shard_divisible_attr(
Expand Down
39 changes: 39 additions & 0 deletions src/cache_dit/platforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
import importlib
from typing import Any


def resolve_obj_by_qualname(qualname: str) -> Any:
"""
Resolve an object by its fully-qualified class name.
"""
module_name, obj_name = qualname.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, obj_name)


def resolve_current_platform_cls_qualname() -> str:
if torch.cuda.is_available():
return "cache_dit.platforms.platform.CudaPlatform"
try:
import torch_npu # type: ignore # noqa

return "cache_dit.platforms.platform.NPUPlatform"
except ImportError:
return "cache_dit.platforms.platform.CpuPlatform"


_current_platform = None


def __getattr__(name: str):
if name == "current_platform":
global _current_platform
if _current_platform is None:
platform_cls_qualname = resolve_current_platform_cls_qualname()
_current_platform = resolve_obj_by_qualname(platform_cls_qualname)()
return _current_platform
elif name in globals():
return globals()[name]
else:
raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
Loading
Loading