diff --git a/python/sglang/multimodal_gen/runtime/compilation/__init__.py b/python/sglang/multimodal_gen/runtime/compilation/__init__.py new file mode 100644 index 000000000000..01c96735c7fd --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/compilation/__init__.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .piecewise_cuda_graph_runner import ( + DiffusionPiecewiseCudaGraphRunner, + resolve_capture_sizes, +) + +__all__ = [ + "DiffusionPiecewiseCudaGraphRunner", + "resolve_capture_sizes", +] diff --git a/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py b/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py new file mode 100644 index 000000000000..e879c7560872 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py @@ -0,0 +1,495 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Piecewise CUDA graph runner for diffusion DiT models. + +This runner mirrors LLM piecewise CUDA graph behavior: +1. enable piecewise graph splitting via torch.compile backend hooks +2. bucket requests to capture sizes +3. apply padding to bucket shape +4. capture and replay CUDA graph with stable tensor addresses +""" + +from __future__ import annotations + +import bisect +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Hashable + +import torch + +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.srt.compilation.compilation_counter import compilation_counter +from sglang.srt.compilation.compilation_config import CompilationConfig +from sglang.srt.compilation.compile import install_torch_compiled +from sglang.srt.compilation.piecewise_context_manager import ( + enable_piecewise_cuda_graph, + enable_piecewise_cuda_graph_compile, + set_pcg_capture_stream, + set_pcg_runtime_shape, +) + +logger = init_logger(__name__) + +_RUNTIME_VALUE = object() +_GLOBAL_GRAPH_POOL = None + + +def _set_torch_compile_config() -> None: + import torch._dynamo.config + + torch._dynamo.config.accumulated_cache_size_limit = 1024 + if hasattr(torch._dynamo.config, "cache_size_limit"): + torch._dynamo.config.cache_size_limit = 1024 + + +@contextmanager +def _graph_capture_stream(): + if not current_platform.is_cuda_alike(): + yield None + return + + stream = torch.cuda.Stream() + current = torch.cuda.current_stream() + if current != stream: + stream.wait_stream(current) + with torch.cuda.stream(stream): + yield stream + + +def _is_tensor(obj: Any) -> bool: + return isinstance(obj, torch.Tensor) + + +def _join_path(parent: str, key: str) -> str: + if not parent: + return key + return f"{parent}.{key}" + + +def _supports_scalar(obj: Any) -> bool: + return isinstance(obj, (int, float, str, bool)) + + +def _pad_dim_for_path(path: str, tensor: torch.Tensor, raw_seq_len: int) -> int | None: + # Mandatory for all sequence DiTs. + if path.endswith("hidden_states") and tensor.ndim >= 2: + if tensor.shape[1] == raw_seq_len: + return 1 + + # Qwen-Image rotary cache for image tokens. + if path == "freqs_cis.0" and tensor.ndim >= 1: + if tensor.shape[0] == raw_seq_len: + return 0 + + return None + + +def _signature( + obj: Any, + *, + raw_seq_len: int, + static_seq_len: int, + path: str = "", +) -> Hashable: + if _is_tensor(obj): + shape = list(obj.shape) + pad_dim = _pad_dim_for_path(path, obj, raw_seq_len) + if pad_dim is not None: + shape[pad_dim] = static_seq_len + return ( + "tensor", + tuple(shape), + str(obj.dtype), + str(obj.device), + ) + + if isinstance(obj, list): + return ( + "list", + tuple( + _signature( + x, + raw_seq_len=raw_seq_len, + static_seq_len=static_seq_len, + path=_join_path(path, str(i)), + ) + for i, x in enumerate(obj) + ), + ) + if isinstance(obj, tuple): + return ( + "tuple", + tuple( + _signature( + x, + raw_seq_len=raw_seq_len, + static_seq_len=static_seq_len, + path=_join_path(path, str(i)), + ) + for i, x in enumerate(obj) + ), + ) + if isinstance(obj, dict): + items = tuple( + sorted( + ( + k, + _signature( + v, + raw_seq_len=raw_seq_len, + static_seq_len=static_seq_len, + path=_join_path(path, str(k)), + ), + ) + for k, v in obj.items() + ) + ) + return ("dict", items) + if obj is None: + return ("none",) + if _supports_scalar(obj): + return ("scalar", obj) + return ("repr", repr(obj)) + + +@dataclass +class _TensorSlot: + tensor: torch.Tensor + pad_dim: int | None = None + + +def _build_slots( + obj: Any, + *, + raw_seq_len: int, + static_seq_len: int, + path: str = "", +) -> Any: + if _is_tensor(obj): + pad_dim = _pad_dim_for_path(path, obj, raw_seq_len) + if pad_dim is None: + return _TensorSlot(tensor=torch.empty_like(obj), pad_dim=None) + new_shape = list(obj.shape) + new_shape[pad_dim] = static_seq_len + return _TensorSlot( + tensor=torch.empty( + tuple(new_shape), dtype=obj.dtype, device=obj.device + ), + pad_dim=pad_dim, + ) + + if isinstance(obj, list): + return [ + _build_slots( + x, + raw_seq_len=raw_seq_len, + static_seq_len=static_seq_len, + path=_join_path(path, str(i)), + ) + for i, x in enumerate(obj) + ] + if isinstance(obj, tuple): + return tuple( + _build_slots( + x, + raw_seq_len=raw_seq_len, + static_seq_len=static_seq_len, + path=_join_path(path, str(i)), + ) + for i, x in enumerate(obj) + ) + if isinstance(obj, dict): + return { + k: _build_slots( + v, + raw_seq_len=raw_seq_len, + static_seq_len=static_seq_len, + path=_join_path(path, str(k)), + ) + for k, v in obj.items() + } + return _RUNTIME_VALUE + + +def _materialize_call_kwargs(slots: Any, values: Any) -> Any: + if isinstance(slots, _TensorSlot): + src = values + dst = slots.tensor + + if slots.pad_dim is None: + if dst.shape != src.shape: + raise ValueError( + f"Tensor shape changed for CUDA graph replay: {dst.shape} vs {src.shape}" + ) + dst.copy_(src) + return dst + + if src.shape[slots.pad_dim] > dst.shape[slots.pad_dim]: + raise ValueError( + "Input sequence length exceeds padded slot shape: " + f"{src.shape[slots.pad_dim]} > {dst.shape[slots.pad_dim]}" + ) + + dst.zero_() + indices = [slice(None)] * dst.ndim + indices[slots.pad_dim] = slice(0, src.shape[slots.pad_dim]) + dst[tuple(indices)].copy_(src) + return dst + + if slots is _RUNTIME_VALUE: + return values + + if isinstance(slots, list): + return [ + _materialize_call_kwargs(s, v) for s, v in zip(slots, values, strict=True) + ] + if isinstance(slots, tuple): + return tuple( + _materialize_call_kwargs(s, v) for s, v in zip(slots, values, strict=True) + ) + if isinstance(slots, dict): + return {k: _materialize_call_kwargs(slots[k], values[k]) for k in slots.keys()} + + raise TypeError(f"Unsupported slot type: {type(slots)}") + + +def _slice_output_to_raw_seq(output: Any, raw_seq_len: int, static_seq_len: int) -> Any: + if raw_seq_len == static_seq_len: + return output + + if isinstance(output, torch.Tensor): + if output.ndim >= 2 and output.shape[1] == static_seq_len: + return output[:, :raw_seq_len, ...] + return output + if isinstance(output, list): + return [_slice_output_to_raw_seq(x, raw_seq_len, static_seq_len) for x in output] + if isinstance(output, tuple): + return tuple(_slice_output_to_raw_seq(x, raw_seq_len, static_seq_len) for x in output) + if isinstance(output, dict): + return { + k: _slice_output_to_raw_seq(v, raw_seq_len, static_seq_len) + for k, v in output.items() + } + return output + + +def _get_graph_pool(device: torch.device): + global _GLOBAL_GRAPH_POOL + if _GLOBAL_GRAPH_POOL is None: + _GLOBAL_GRAPH_POOL = torch.get_device_module(device).graph_pool_handle() + return _GLOBAL_GRAPH_POOL + + +@dataclass +class _GraphEntry: + signature: Hashable + static_seq_len: int + slots: dict[str, Any] + captured: bool = False + + +class DiffusionPiecewiseCudaGraphRunner: + def __init__( + self, + model: torch.nn.Module, + capture_sizes: list[int], + *, + compiler: str = "eager", + enable_debug: bool = False, + ) -> None: + self.model = model + self.capture_sizes = sorted(set(int(x) for x in capture_sizes)) + self.capture_sizes_set = set(self.capture_sizes) + self.compiler = compiler + self.enable_debug = enable_debug + + self._entries: dict[Hashable, _GraphEntry] = {} + self._installed = False + self._compiled = False + self._eager_warmup_done = False + + device = next(model.parameters()).device + self.device = device + self.graph_pool = _get_graph_pool(device) + + self.compile_config = CompilationConfig( + self.capture_sizes, + compiler=self.compiler, + enable_debug_mode=self.enable_debug, + ) + _set_torch_compile_config() + + def _install_compiled(self) -> None: + if self._installed: + return + install_torch_compiled( + self.model, + dynamic_arg_dims={"hidden_states": [1]}, + compile_config=self.compile_config, + graph_pool=self.graph_pool, + fullgraph=True, + ) + self._installed = True + + def can_run( + self, hidden_states: torch.Tensor, seq_len_override: int | None = None + ) -> bool: + if not current_platform.is_cuda_alike(): + return False + if not torch.cuda.is_available(): + return False + if hidden_states.device.type != "cuda": + return False + if hidden_states.ndim < 2: + return False + if seq_len_override is None: + return False + if hidden_states.shape[1] != int(seq_len_override): + return False + if not self.capture_sizes: + return False + if seq_len_override > self.capture_sizes[-1]: + return False + return True + + def _select_static_seq_len(self, seq_len: int) -> int | None: + idx = bisect.bisect_left(self.capture_sizes, seq_len) + if idx >= len(self.capture_sizes): + return None + return self.capture_sizes[idx] + + def _ensure_compiled(self, call_kwargs: dict[str, Any], runtime_shape: int) -> None: + if self._compiled: + return + # Warm up lazy custom kernels (e.g. JIT kernels that touch filesystem) + # outside torch.compile to avoid Dynamo tracing unsupported Python/C APIs. + if not self._eager_warmup_done: + with torch.no_grad(): + self.model(**call_kwargs) + self._eager_warmup_done = True + with enable_piecewise_cuda_graph(): + with set_pcg_runtime_shape(runtime_shape): + with enable_piecewise_cuda_graph_compile(): + _ = self.model(**call_kwargs) + self._compiled = True + + def _capture(self, call_kwargs: dict[str, Any], runtime_shape: int) -> None: + before = compilation_counter.num_cudagraph_captured + with enable_piecewise_cuda_graph(): + with set_pcg_runtime_shape(runtime_shape): + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.barrier() + + with _graph_capture_stream() as stream: + if stream is not None: + with set_pcg_capture_stream(stream): + for _ in range(4): + self.model(**call_kwargs) + else: + for _ in range(4): + self.model(**call_kwargs) + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.barrier() + after = compilation_counter.num_cudagraph_captured + if after > before: + logger.info( + "Diffusion PCG captured %d piecewise CUDA graphs", + after - before, + ) + else: + logger.warning( + "Diffusion PCG capture completed but no piecewise CUDA graph was captured; " + "runtime shape may not match configured capture buckets." + ) + + def run(self, *, seq_len_override: int | None = None, **kwargs) -> Any | None: + hidden_states = kwargs.get("hidden_states") + if hidden_states is None or not self.can_run( + hidden_states, seq_len_override=seq_len_override + ): + return None + + raw_seq_len = int(seq_len_override) + static_seq_len = self._select_static_seq_len(raw_seq_len) + if static_seq_len is None: + return None + + self._install_compiled() + + sig = ( + static_seq_len, + _signature( + kwargs, + raw_seq_len=raw_seq_len, + static_seq_len=static_seq_len, + path="", + ), + ) + entry = self._entries.get(sig) + if entry is None: + entry = _GraphEntry( + signature=sig, + static_seq_len=static_seq_len, + slots=_build_slots( + kwargs, + raw_seq_len=raw_seq_len, + static_seq_len=static_seq_len, + path="", + ), + ) + self._entries[sig] = entry + logger.info( + "Diffusion PCG init for %s (raw_seq=%d, static_seq=%d)", + self.model.__class__.__name__, + raw_seq_len, + static_seq_len, + ) + + call_kwargs = _materialize_call_kwargs(entry.slots, kwargs) + + if not self._compiled: + self._ensure_compiled(call_kwargs, static_seq_len) + + if not entry.captured: + self._capture(call_kwargs, static_seq_len) + entry.captured = True + + with enable_piecewise_cuda_graph(): + with set_pcg_runtime_shape(static_seq_len): + if current_platform.is_cuda_alike(): + # Runtime recompilation can trigger late on-demand capture in the + # backend. Ensure a valid capture stream is always available. + with set_pcg_capture_stream(torch.cuda.current_stream()): + output = self.model(**call_kwargs) + else: + output = self.model(**call_kwargs) + + return _slice_output_to_raw_seq(output, raw_seq_len, static_seq_len) + + +def resolve_capture_sizes( + *, + seq_len: int, + explicit_sizes: list[int] | None, + max_tokens: int, +) -> list[int]: + if explicit_sizes: + sizes = sorted(set(int(x) for x in explicit_sizes if int(x) > 0)) + if seq_len <= max_tokens and seq_len not in sizes: + sizes.append(int(seq_len)) + return sorted(set(sizes)) + + capture_sizes = ( + list(range(4, 33, 4)) + + list(range(48, 257, 16)) + + list(range(288, 513, 32)) + + list(range(576, 1024 + 1, 64)) + + list(range(1280, 4096 + 1, 256)) + + list(range(4608, int(max_tokens) + 1, 512)) + ) + capture_sizes = [s for s in capture_sizes if s <= int(max_tokens)] + if seq_len <= max_tokens and seq_len not in capture_sizes: + capture_sizes.append(int(seq_len)) + return sorted(set(capture_sizes)) diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py index 4fbe4f0c78c6..de8671fa97a0 100644 --- a/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py +++ b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py @@ -1,7 +1,6 @@ # Copied and adapted from: https://github.com/hao-ai-lab/FastVideo # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from functools import lru_cache from typing import Any, List, Optional, Tuple import torch @@ -12,6 +11,7 @@ AttentionBackendEnum, current_platform, ) +from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph try: from sgl_kernel.flash_attn import flash_attn_varlen_func @@ -71,7 +71,6 @@ def flash_attn_varlen_func_fake_out( sinks: Optional[torch.Tensor] = None, ver: int = 4, ) -> torch.Tensor: - assert ver == 4, "only support flash attention v4" q, k, v = [maybe_contiguous(t) for t in (q, k, v)] num_head, head_dim = q.shape[-2:] if cu_seqlens_q is None: @@ -131,7 +130,6 @@ def flash_attn_varlen_func_fake_out_lse( sinks: Optional[torch.Tensor] = None, ver: int = 4, ) -> Tuple[torch.Tensor, torch.Tensor]: - assert ver == 4, "only support flash attention v4" q, k, v = [maybe_contiguous(t) for t in (q, k, v)] num_head, head_dim = q.shape[-2:] if cu_seqlens_q is None: @@ -330,18 +328,29 @@ def flash_attn_varlen_func_op_lse( fa_ver = 3 -@lru_cache(maxsize=128) def _get_cu_seqlens(device_index: int, bsz: int, seqlen: int) -> torch.Tensor: - return torch.arange( + key = (device_index, bsz, seqlen) + cached = _CU_SEQLENS_CACHE.get(key) + if cached is not None: + return cached + + cu_seqlens = torch.arange( 0, (bsz + 1) * seqlen, step=seqlen, device=torch.device("cuda", device_index), dtype=torch.int32, ) + _CU_SEQLENS_CACHE[key] = cu_seqlens + return cu_seqlens + + +_CU_SEQLENS_CACHE: dict[tuple[int, int, int], torch.Tensor] = {} +_UPSTREAM_FLASH_ATTN_CACHE: dict[ + tuple[bool, bool, tuple[int, ...], tuple[int, ...], tuple[int, ...]], bool +] = {} -@lru_cache(maxsize=256) def _should_use_upstream_flash_attention( upstream_available: bool, upstream_heads_ok: bool, @@ -349,10 +358,23 @@ def _should_use_upstream_flash_attention( k_shape: tuple[int, ...], v_shape: tuple[int, ...], ) -> bool: + cache_key = ( + upstream_available, + upstream_heads_ok, + q_shape, + k_shape, + v_shape, + ) + cached = _UPSTREAM_FLASH_ATTN_CACHE.get(cache_key) + if cached is not None: + return cached + if not upstream_available or not upstream_heads_ok: + _UPSTREAM_FLASH_ATTN_CACHE[cache_key] = False return False if len(q_shape) != 4 or len(k_shape) != 4 or len(v_shape) != 4: + _UPSTREAM_FLASH_ATTN_CACHE[cache_key] = False return False bsz, seqlen, nheads_q, d = q_shape @@ -367,11 +389,15 @@ def _should_use_upstream_flash_attention( or d != d_k or d != d_v ): + _UPSTREAM_FLASH_ATTN_CACHE[cache_key] = False return False if nheads_k != nheads_v: + _UPSTREAM_FLASH_ATTN_CACHE[cache_key] = False return False if nheads_k == 0 or (nheads_q % nheads_k) != 0: + _UPSTREAM_FLASH_ATTN_CACHE[cache_key] = False return False + _UPSTREAM_FLASH_ATTN_CACHE[cache_key] = True return True @@ -468,21 +494,20 @@ def forward( return_softmax_lse: bool = False, ): attn_metadata: FlashAttentionMetadata = get_forward_context().attn_metadata - if attn_metadata is not None and attn_metadata.max_seqlen_q is None: - attn_metadata.max_seqlen_q = query.shape[1] - attn_metadata.max_seqlen_k = key.shape[1] - max_seqlen_q = attn_metadata.max_seqlen_q - max_seqlen_k = attn_metadata.max_seqlen_k - else: - max_seqlen_q = query.shape[1] - max_seqlen_k = key.shape[1] + # Keep control flow stable for torch.compile/PCG. + max_seqlen_q = query.shape[1] + max_seqlen_k = key.shape[1] + if attn_metadata is not None: + attn_metadata.max_seqlen_q = max_seqlen_q + attn_metadata.max_seqlen_k = max_seqlen_k q_shape = tuple(query.shape) k_shape = tuple(key.shape) v_shape = tuple(value.shape) + in_piecewise = is_in_piecewise_cuda_graph() use_upstream = _should_use_upstream_flash_attention( - flash_attn_varlen_func_upstream is not None, + (flash_attn_varlen_func_upstream is not None) and (not in_piecewise), self._upstream_heads_ok, q_shape, k_shape, @@ -515,6 +540,37 @@ def forward( # - fa_ver == 3: call python function (can return Tensor or (Tensor, Tensor) depending on flag) # - fa_ver == 4: call custom ops with FIXED return schema if fa_ver == 3: + if in_piecewise: + if return_softmax_lse: + out_tensor, softmax_lse = flash_attn_varlen_func_op_lse( + q=query, + k=key, + v=value, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.softmax_scale, + causal=self.causal, + return_softmax_lse=True, + ver=fa_ver, + ) + return out_tensor, softmax_lse + out_tensor = flash_attn_varlen_func_op( + q=query, + k=key, + v=value, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.softmax_scale, + causal=self.causal, + return_softmax_lse=False, + ver=fa_ver, + ) + return out_tensor + flash_attn_op = flash_attn_func output = flash_attn_op( q=query, diff --git a/python/sglang/multimodal_gen/runtime/layers/layernorm.py b/python/sglang/multimodal_gen/runtime/layers/layernorm.py index 888824f60dd9..a728ee9698db 100644 --- a/python/sglang/multimodal_gen/runtime/layers/layernorm.py +++ b/python/sglang/multimodal_gen/runtime/layers/layernorm.py @@ -30,8 +30,15 @@ get_tp_group, ) from sglang.multimodal_gen.runtime.layers.custom_op import CustomOp +from sglang.multimodal_gen.runtime.layers.utils import register_custom_op from sglang.multimodal_gen.runtime.utils.common import get_bool_env_var +if _is_cuda: + fused_inplace_qknorm = register_custom_op( + fused_inplace_qknorm, + mutates_args=["q", "k"], + ) + # Copied and adapted from sglang @CustomOp.register("rms_norm") diff --git a/python/sglang/multimodal_gen/runtime/layers/linear.py b/python/sglang/multimodal_gen/runtime/layers/linear.py index f74309a53405..45d7b33913bf 100644 --- a/python/sglang/multimodal_gen/runtime/layers/linear.py +++ b/python/sglang/multimodal_gen/runtime/layers/linear.py @@ -39,6 +39,7 @@ from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger logger = init_logger(__name__) +_AMP_SUPPORTED = current_platform.is_amp_supported() WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", @@ -153,7 +154,7 @@ def apply( ) -> torch.Tensor: output = ( F.linear(x, layer.weight, bias) - if current_platform.is_amp_supported() or bias is None + if _AMP_SUPPORTED or bias is None else F.linear(x, layer.weight, bias.to(x.dtype)) ) # NOTE: this line assumes that we are using amp when using cuda and is needed to account for the fact that amp isn't supported in mps return output diff --git a/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/utils.py b/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/utils.py index 4eaa106e8742..5dbcd0e0ddb7 100644 --- a/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/utils.py +++ b/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/utils.py @@ -5,6 +5,44 @@ import torch from sglang.jit_kernel.diffusion.triton.rotary import apply_rotary_embedding +from sglang.multimodal_gen.runtime.layers.utils import register_custom_op +from sglang.multimodal_gen.runtime.platforms import current_platform + +_is_cuda = current_platform.is_cuda() +_flashinfer_rope_qk_inplace_custom_op = None + +if _is_cuda: + try: + from flashinfer.rope import ( + apply_rope_with_cos_sin_cache_inplace as _flashinfer_apply_rope_with_cos_sin_cache_inplace, + ) + except ImportError: + _flashinfer_apply_rope_with_cos_sin_cache_inplace = None + + if _flashinfer_apply_rope_with_cos_sin_cache_inplace is not None: + + @register_custom_op( + op_name="diffusion_flashinfer_rope_qk_inplace", + mutates_args=["query", "key"], + ) + def flashinfer_rope_qk_inplace_custom_op( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool = False, + ) -> None: + _flashinfer_apply_rope_with_cos_sin_cache_inplace( + positions=positions, + query=query, + key=key, + head_size=head_size, + cos_sin_cache=cos_sin_cache, + is_neox=is_neox, + ) + + _flashinfer_rope_qk_inplace_custom_op = flashinfer_rope_qk_inplace_custom_op def _apply_rotary_emb( @@ -67,32 +105,6 @@ def apply_flashinfer_rope_qk_inplace( if head_size != d: raise ValueError(f"head_size mismatch: inferred {d}, but head_size={head_size}") - try: - from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace - except ImportError: - # Triton fallback for AMD/ROCm where FlashInfer is not available - import warnings - - warnings.warn( - "FlashInfer not available, using Triton fallback for RoPE", - stacklevel=2, - ) - half_size = cos_sin_cache.shape[-1] // 2 - if positions is None: - cos = cos_sin_cache[:seqlen, :half_size].to(q.dtype) - sin = cos_sin_cache[:seqlen, half_size:].to(q.dtype) - cos = cos.unsqueeze(0).expand(bsz, -1, -1).reshape(bsz * seqlen, -1) - sin = sin.unsqueeze(0).expand(bsz, -1, -1).reshape(bsz * seqlen, -1) - else: - positions = positions.to(cos_sin_cache.device).view(-1) - cos = cos_sin_cache[positions, :half_size].to(q.dtype) - sin = cos_sin_cache[positions, half_size:].to(q.dtype) - q_flat = q.reshape(bsz * seqlen, nheads, d) - k_flat = k.reshape(bsz * seqlen, nheads, d) - q_rot = apply_rotary_embedding(q_flat, cos, sin, interleaved=not is_neox) - k_rot = apply_rotary_embedding(k_flat, cos, sin, interleaved=not is_neox) - return q_rot.view(bsz, seqlen, nheads, d), k_rot.view(bsz, seqlen, nheads, d) - if positions is None: pos_1d = torch.arange(seqlen, device=q.device, dtype=torch.long) positions = pos_1d if bsz == 1 else pos_1d.repeat(bsz) @@ -110,12 +122,23 @@ def apply_flashinfer_rope_qk_inplace( q_flat = q.reshape(bsz * seqlen, nheads * d).contiguous() k_flat = k.reshape(bsz * seqlen, nheads * d).contiguous() - apply_rope_with_cos_sin_cache_inplace( - positions=positions, - query=q_flat, - key=k_flat, - head_size=d, - cos_sin_cache=cos_sin_cache, - is_neox=is_neox, - ) - return q_flat.view(bsz, seqlen, nheads, d), k_flat.view(bsz, seqlen, nheads, d) + if _flashinfer_rope_qk_inplace_custom_op is not None: + _flashinfer_rope_qk_inplace_custom_op( + positions=positions, + query=q_flat, + key=k_flat, + head_size=d, + cos_sin_cache=cos_sin_cache, + is_neox=is_neox, + ) + return q_flat.view(bsz, seqlen, nheads, d), k_flat.view(bsz, seqlen, nheads, d) + + half_size = cos_sin_cache.shape[-1] // 2 + positions = positions.to(cos_sin_cache.device).view(-1) + cos = cos_sin_cache[positions, :half_size].to(q.dtype) + sin = cos_sin_cache[positions, half_size:].to(q.dtype) + q_flat = q.reshape(bsz * seqlen, nheads, d) + k_flat = k.reshape(bsz * seqlen, nheads, d) + q_rot = apply_rotary_embedding(q_flat, cos, sin, interleaved=not is_neox) + k_rot = apply_rotary_embedding(k_flat, cos, sin, interleaved=not is_neox) + return q_rot.view(bsz, seqlen, nheads, d), k_rot.view(bsz, seqlen, nheads, d) diff --git a/python/sglang/multimodal_gen/runtime/layers/visual_embedding.py b/python/sglang/multimodal_gen/runtime/layers/visual_embedding.py index c8eff7d6f0d8..d51bb2d49cc2 100644 --- a/python/sglang/multimodal_gen/runtime/layers/visual_embedding.py +++ b/python/sglang/multimodal_gen/runtime/layers/visual_embedding.py @@ -27,11 +27,64 @@ from sglang.multimodal_gen.runtime.layers.activation import get_act_fn from sglang.multimodal_gen.runtime.layers.linear import ColumnParallelLinear from sglang.multimodal_gen.runtime.layers.mlp import MLP +from sglang.multimodal_gen.runtime.layers.utils import register_custom_op from sglang.multimodal_gen.runtime.platforms import current_platform _is_cuda = current_platform.is_cuda() +def _timestep_embedding_cuda_fake( + timesteps: torch.Tensor, + num_channels: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 0.0, + scale: float = 1, +) -> torch.Tensor: + return torch.empty( + (timesteps.shape[0], num_channels), + dtype=torch.float32, + device=timesteps.device, + ) + + +if _is_cuda: + + @register_custom_op( + op_name="diffusion_timestep_embedding_cuda", + fake_impl=_timestep_embedding_cuda_fake, + ) + def timestep_embedding_cuda_custom_op( + timesteps: torch.Tensor, + num_channels: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 0.0, + scale: float = 1, + ) -> torch.Tensor: + return timestep_embedding_cuda( + timesteps, + num_channels, + flip_sin_to_cos=flip_sin_to_cos, + downscale_freq_shift=downscale_freq_shift, + scale=scale, + ) +else: + + def timestep_embedding_cuda_custom_op( + timesteps: torch.Tensor, + num_channels: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 0.0, + scale: float = 1, + ) -> torch.Tensor: + return timestep_embedding_cuda( + timesteps, + num_channels, + flip_sin_to_cos=flip_sin_to_cos, + downscale_freq_shift=downscale_freq_shift, + scale=scale, + ) + + class PatchEmbed(nn.Module): """2D Image to Patch Embedding @@ -89,21 +142,20 @@ def forward(self, x): class Timesteps(_Timesteps): def forward(self, timesteps: torch.Tensor) -> torch.Tensor: if _is_cuda: - return timestep_embedding_cuda( - timesteps, - self.num_channels, - flip_sin_to_cos=self.flip_sin_to_cos, - downscale_freq_shift=self.downscale_freq_shift, - scale=self.scale, - ) - else: - return timestep_embedding_diffusers( + return timestep_embedding_cuda_custom_op( timesteps, self.num_channels, flip_sin_to_cos=self.flip_sin_to_cos, downscale_freq_shift=self.downscale_freq_shift, scale=self.scale, ) + return timestep_embedding_diffusers( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) class CombinedTimestepGuidanceTextProjEmbeddings( diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py index 844999554d99..23dc38d1f98c 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py @@ -33,6 +33,10 @@ refresh_context_on_dual_transformer, refresh_context_on_transformer, ) +from sglang.multimodal_gen.runtime.compilation import ( + DiffusionPiecewiseCudaGraphRunner, + resolve_capture_sizes, +) from sglang.multimodal_gen.runtime.distributed import ( cfg_model_parallel_all_reduce, get_local_torch_device, @@ -128,6 +132,9 @@ def __init__( self._cache_dit_enabled = False self._cached_num_steps = None self._is_warmed_up = False + # piecewise cuda graph state + self._pcg_runners: dict[int, DiffusionPiecewiseCudaGraphRunner] = {} + self._pcg_disabled_model_ids: set[int] = set() def _maybe_enable_torch_compile(self, module: object) -> None: """ @@ -995,6 +1002,21 @@ def forward( seq_len = prepared_vars["seq_len"] guidance = prepared_vars["guidance"] + self._maybe_precapture_piecewise_cuda_graph( + batch=batch, + server_args=server_args, + timesteps=timesteps, + target_dtype=target_dtype, + boundary_timestep=boundary_timestep, + seq_len=seq_len, + reserved_frames_mask=reserved_frames_mask, + guidance=guidance, + latents=latents, + image_kwargs=image_kwargs, + pos_cond_kwargs=pos_cond_kwargs, + neg_cond_kwargs=neg_cond_kwargs, + ) + # Initialize lists for ODE trajectory trajectory_timesteps: list[torch.Tensor] = [] trajectory_latents: list[torch.Tensor] = [] @@ -1345,6 +1367,28 @@ def _predict_noise( guidance: torch.Tensor, **kwargs, ): + runner = self._get_or_create_pcg_runner(current_model, latent_model_input) + if runner is not None: + try: + output = runner.run( + seq_len_override=int(latent_model_input.shape[1]), + hidden_states=latent_model_input, + timestep=timestep, + guidance=guidance, + **kwargs, + ) + if output is not None: + return output + except Exception as e: + model_id = id(current_model) + if model_id not in self._pcg_disabled_model_ids: + logger.warning( + "Disable diffusion PCG for %s after failure: %s", + current_model.__class__.__name__, + e, + ) + self._pcg_disabled_model_ids.add(model_id) + return current_model( hidden_states=latent_model_input, timestep=timestep, @@ -1352,6 +1396,51 @@ def _predict_noise( **kwargs, ) + def _is_pcg_backend_supported(self) -> bool: + return self.attn_backend.get_enum() in { + AttentionBackendEnum.FA, + AttentionBackendEnum.FA2, + AttentionBackendEnum.TORCH_SDPA, + } + + def _get_or_create_pcg_runner( + self, model: nn.Module, hidden_states: torch.Tensor + ) -> DiffusionPiecewiseCudaGraphRunner | None: + if not self.server_args.enable_piecewise_cuda_graph: + return None + if not self._is_pcg_backend_supported(): + return None + if hidden_states is None or hidden_states.ndim != 3: + return None + + model_id = id(model) + if model_id in self._pcg_disabled_model_ids: + return None + + runner = self._pcg_runners.get(model_id) + if runner is not None: + return runner + + capture_sizes = resolve_capture_sizes( + seq_len=int(hidden_states.shape[1]), + explicit_sizes=self.server_args.piecewise_cuda_graph_tokens, + max_tokens=self.server_args.piecewise_cuda_graph_max_tokens, + ) + runner = DiffusionPiecewiseCudaGraphRunner( + model=model, + capture_sizes=capture_sizes, + compiler=self.server_args.piecewise_cuda_graph_compiler, + enable_debug=self.server_args.enable_piecewise_cuda_graph_debug, + ) + self._pcg_runners[model_id] = runner + logger.info( + "Enable diffusion PCG for %s with %d capture buckets (max=%d)", + model.__class__.__name__, + len(capture_sizes), + max(capture_sizes) if capture_sizes else 0, + ) + return runner + def _predict_noise_with_cfg( self, current_model: nn.Module, @@ -1511,6 +1600,106 @@ def _predict_noise_with_cfg( ) return noise_pred + def _maybe_precapture_piecewise_cuda_graph( + self, + *, + batch: Req, + server_args: ServerArgs, + timesteps: torch.Tensor, + target_dtype: torch.dtype, + boundary_timestep: float | None, + seq_len: int | None, + reserved_frames_mask: torch.Tensor | None, + guidance: torch.Tensor | None, + latents: torch.Tensor, + image_kwargs: dict[str, Any], + pos_cond_kwargs: dict[str, Any], + neg_cond_kwargs: dict[str, Any], + ) -> None: + if not server_args.enable_piecewise_cuda_graph: + return + if batch.is_warmup: + return + if timesteps is None or timesteps.numel() == 0: + return + + expected_model_count = 2 if self.transformer_2 is not None else 1 + visited_model_ids: set[int] = set() + timesteps_cpu = timesteps.cpu() + original_is_cfg_negative = batch.is_cfg_negative + + logger.info( + "Pre-capturing diffusion PCG before denoising loop (target_models=%d)", + expected_model_count, + ) + + for i, t_host in enumerate(timesteps_cpu): + t_int = int(t_host.item()) + t_device = timesteps[i] + current_model, current_guidance_scale = self._select_and_manage_model( + t_int=t_int, + boundary_timestep=boundary_timestep, + server_args=server_args, + batch=batch, + ) + + model_id = id(current_model) + if model_id in visited_model_ids: + continue + + latent_model_input = latents.to(target_dtype) + if batch.image_latent is not None: + latent_model_input = torch.cat( + [latent_model_input, batch.image_latent], dim=1 + ).to(target_dtype) + + timestep = self.expand_timestep_before_forward( + batch, + server_args, + t_device, + target_dtype, + seq_len, + reserved_frames_mask, + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t_device + ) + + attn_metadata = self._build_attn_metadata( + i, + batch, + server_args, + timestep_value=t_int, + timesteps=timesteps_cpu, + ) + + _ = self._predict_noise_with_cfg( + current_model=current_model, + latent_model_input=latent_model_input, + timestep=timestep, + batch=batch, + timestep_index=i, + attn_metadata=attn_metadata, + target_dtype=target_dtype, + current_guidance_scale=current_guidance_scale, + image_kwargs=image_kwargs, + pos_cond_kwargs=pos_cond_kwargs, + neg_cond_kwargs=neg_cond_kwargs, + server_args=server_args, + guidance=guidance, + latents=latents, + ) + + visited_model_ids.add(model_id) + if len(visited_model_ids) >= expected_model_count: + break + + batch.is_cfg_negative = original_is_cfg_negative + logger.info( + "Pre-capture finished for %d model(s) before formal denoising", + len(visited_model_ids), + ) + def prepare_sta_param(self, batch: Req, server_args: ServerArgs): """ Prepare Sliding Tile Attention (STA) parameters and settings. diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py index 480efe584245..db35257d8373 100644 --- a/python/sglang/multimodal_gen/runtime/server_args.py +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -270,6 +270,11 @@ class ServerArgs: # Compilation enable_torch_compile: bool = False + enable_piecewise_cuda_graph: bool = False + piecewise_cuda_graph_max_tokens: int = 8192 + piecewise_cuda_graph_tokens: list[int] | None = None + piecewise_cuda_graph_compiler: str = "eager" + enable_piecewise_cuda_graph_debug: bool = False # warmup warmup: bool = False @@ -352,6 +357,7 @@ def _adjust_parameters(self): self._adjust_attention_backend() self._adjust_platform_specific() self._adjust_autocast() + self._adjust_piecewise_cuda_graph() def _validate_parameters(self): """check consistency and raise errors for invalid configs""" @@ -359,6 +365,7 @@ def _validate_parameters(self): self._validate_offload() self._validate_parallelism() self._validate_cfg_parallel() + self._validate_piecewise_cuda_graph() def _adjust_save_paths(self): """Normalize empty-string save paths to None (disabled).""" @@ -463,6 +470,51 @@ def _adjust_warmup(self): "Warmup enabled, the launch time is expected to be longer than usual" ) + def _adjust_piecewise_cuda_graph(self): + if self.piecewise_cuda_graph_tokens is None: + self.piecewise_cuda_graph_tokens = self._generate_piecewise_cuda_graph_tokens() + elif isinstance(self.piecewise_cuda_graph_tokens, str): + self.piecewise_cuda_graph_tokens = [ + int(x.strip()) + for x in self.piecewise_cuda_graph_tokens.split(",") + if x.strip() + ] + + self.piecewise_cuda_graph_tokens = sorted( + set( + int(x) + for x in (self.piecewise_cuda_graph_tokens or []) + if int(x) > 0 and int(x) <= int(self.piecewise_cuda_graph_max_tokens) + ) + ) + + if self.enable_piecewise_cuda_graph and self.enable_torch_compile: + logger.warning( + "Both --enable-piecewise-cuda-graph and --enable-torch-compile are set. " + "Disabling torch.compile and keeping piecewise CUDA graph." + ) + self.enable_torch_compile = False + + if not current_platform.is_cuda_alike() and self.enable_piecewise_cuda_graph: + logger.warning( + "Piecewise CUDA graph is enabled but current platform is not CUDA-like. " + "Feature will be disabled." + ) + self.enable_piecewise_cuda_graph = False + + def _generate_piecewise_cuda_graph_tokens(self) -> list[int]: + capture_sizes = ( + list(range(4, 33, 4)) + + list(range(48, 257, 16)) + + list(range(288, 513, 32)) + + list(range(576, 1024 + 1, 64)) + + list(range(1280, 4096 + 1, 256)) + + list(range(4608, int(self.piecewise_cuda_graph_max_tokens) + 1, 512)) + ) + return [ + s for s in capture_sizes if s <= int(self.piecewise_cuda_graph_max_tokens) + ] + def _adjust_network_ports(self): self.port = self.settle_port(self.port) initial_scheduler_port = self.scheduler_port + ( @@ -738,6 +790,37 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="Use torch.compile to speed up DiT inference." + "However, will likely cause precision drifts. See (https://github.com/pytorch/pytorch/issues/145213)", ) + parser.add_argument( + "--enable-piecewise-cuda-graph", + action=StoreBoolean, + default=ServerArgs.enable_piecewise_cuda_graph, + help="Enable piecewise CUDA graph with automatic graph splitting and padding buckets.", + ) + parser.add_argument( + "--piecewise-cuda-graph-max-tokens", + type=int, + default=ServerArgs.piecewise_cuda_graph_max_tokens, + help="Maximum sequence length bucket for diffusion piecewise CUDA graph.", + ) + parser.add_argument( + "--piecewise-cuda-graph-tokens", + type=int, + nargs="+", + default=ServerArgs.piecewise_cuda_graph_tokens, + help="Optional explicit token buckets for diffusion piecewise CUDA graph.", + ) + parser.add_argument( + "--piecewise-cuda-graph-compiler", + type=str, + default=ServerArgs.piecewise_cuda_graph_compiler, + help="Compiler backend for diffusion piecewise CUDA graph: eager or inductor.", + ) + parser.add_argument( + "--enable-piecewise-cuda-graph-debug", + action=StoreBoolean, + default=ServerArgs.enable_piecewise_cuda_graph_debug, + help="Enable debug checks for diffusion piecewise CUDA graph.", + ) # warmup parser.add_argument( @@ -1197,6 +1280,21 @@ def _validate_cfg_parallel(self): "CFG Parallelism is enabled via `--enable-cfg-parallel`, but num_gpus == 1" ) + def _validate_piecewise_cuda_graph(self): + if self.piecewise_cuda_graph_max_tokens <= 0: + raise ValueError("piecewise_cuda_graph_max_tokens must be positive") + + if self.piecewise_cuda_graph_compiler not in ("eager", "inductor"): + raise ValueError( + "piecewise_cuda_graph_compiler must be one of: eager, inductor" + ) + + if self.piecewise_cuda_graph_tokens is not None: + if len(self.piecewise_cuda_graph_tokens) == 0: + raise ValueError( + "piecewise_cuda_graph_tokens is empty; set a non-empty bucket list or unset it" + ) + def _set_default_attention_backend(self) -> None: """Configure ROCm defaults when users do not specify an attention backend.""" if current_platform.is_rocm(): diff --git a/python/sglang/srt/compilation/cuda_piecewise_backend.py b/python/sglang/srt/compilation/cuda_piecewise_backend.py index 8ca1e6a43cbc..e3f2fcf93369 100644 --- a/python/sglang/srt/compilation/cuda_piecewise_backend.py +++ b/python/sglang/srt/compilation/cuda_piecewise_backend.py @@ -13,6 +13,7 @@ from sglang.srt.compilation.compilation_counter import compilation_counter from sglang.srt.compilation.piecewise_context_manager import ( get_pcg_capture_stream, + get_pcg_runtime_shape, is_in_pcg_torch_compile, ) from sglang.srt.compilation.weak_ref_tensor import weak_ref_tensors @@ -104,16 +105,79 @@ def check_for_ending_compilation(self): # save the hash of the inductor graph for the next run self.sglang_backend.compiler_manager.save_to_file() + def _infer_runtime_shape_from_tensors(self, args) -> Optional[int]: + # Prefer sequence-like dims from activation tensors: + # 1) dim-1 for rank>=2 tensors (e.g. [B, S, C] -> S) + # 2) dim-0 for rank>=1 tensors + # 3) any dim as final fallback + candidates: list[int] = [] + + for x in args: + if not isinstance(x, torch.Tensor): + continue + if x.ndim >= 2: + s = int(x.shape[1]) + if s in self.concrete_size_entries: + candidates.append(s) + if candidates: + return max(candidates) + + for x in args: + if not isinstance(x, torch.Tensor): + continue + if x.ndim >= 1: + s = int(x.shape[0]) + if s in self.concrete_size_entries: + candidates.append(s) + if candidates: + return max(candidates) + + for x in args: + if not isinstance(x, torch.Tensor): + continue + for d in x.shape: + s = int(d) + if s in self.concrete_size_entries: + candidates.append(s) + if candidates: + return max(candidates) + return None + def __call__(self, *args) -> Any: if not self.first_run_finished: self.first_run_finished = True self.check_for_ending_compilation() return self.compiled_graph_for_general_shape(*args) - if len(self.sym_shape_indices) == 0: - return self.compiled_graph_for_general_shape(*args) + runtime_shape = None + runtime_shape_override = get_pcg_runtime_shape() + if runtime_shape_override is not None: + try: + runtime_shape_override = int(runtime_shape_override) + except Exception: + pass + if runtime_shape_override in self.concrete_size_entries: + runtime_shape = runtime_shape_override + + if runtime_shape is None and len(self.sym_shape_indices) > 0: + for idx in self.sym_shape_indices: + candidate = args[idx] + try: + candidate = int(candidate) + except Exception: + pass + if candidate in self.concrete_size_entries: + runtime_shape = candidate + break + if runtime_shape is None: + runtime_shape = args[self.sym_shape_indices[0]] + try: + runtime_shape = int(runtime_shape) + except Exception: + pass + if runtime_shape is None: + runtime_shape = self._infer_runtime_shape_from_tensors(args) - runtime_shape = args[self.sym_shape_indices[0]] if runtime_shape not in self.concrete_size_entries: # we don't need to do anything for this shape return self.compiled_graph_for_general_shape(*args) diff --git a/python/sglang/srt/compilation/piecewise_context_manager.py b/python/sglang/srt/compilation/piecewise_context_manager.py index 20a08a9972b9..852cf72775c8 100644 --- a/python/sglang/srt/compilation/piecewise_context_manager.py +++ b/python/sglang/srt/compilation/piecewise_context_manager.py @@ -16,6 +16,7 @@ _in_piecewise_cuda_graph = False _in_pcg_torch_compile = False _pcg_capture_stream = None +_pcg_runtime_shape = None def is_in_piecewise_cuda_graph(): @@ -30,6 +31,10 @@ def get_pcg_capture_stream(): return _pcg_capture_stream +def get_pcg_runtime_shape(): + return _pcg_runtime_shape + + @contextmanager def enable_piecewise_cuda_graph_compile(): global _in_pcg_torch_compile @@ -63,6 +68,17 @@ def set_pcg_capture_stream(stream: torch.cuda.Stream): _pcg_capture_stream = None +@contextmanager +def set_pcg_runtime_shape(runtime_shape: int | None): + global _pcg_runtime_shape + old = _pcg_runtime_shape + _pcg_runtime_shape = runtime_shape + try: + yield + finally: + _pcg_runtime_shape = old + + @dataclass class ForwardContext: def __init__(self):