From 624b4a8d67dbb56945123deffc440f1c4b31678e Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Wed, 4 Mar 2026 09:48:19 +0800 Subject: [PATCH 01/11] ud --- .../runtime/compilation/__init__.py | 11 + .../piecewise_cuda_graph_runner.py | 465 ++++++++++++++++++ .../pipelines_core/stages/denoising.py | 74 +++ .../multimodal_gen/runtime/server_args.py | 98 ++++ 4 files changed, 648 insertions(+) create mode 100644 python/sglang/multimodal_gen/runtime/compilation/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py diff --git a/python/sglang/multimodal_gen/runtime/compilation/__init__.py b/python/sglang/multimodal_gen/runtime/compilation/__init__.py new file mode 100644 index 000000000000..01c96735c7fd --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/compilation/__init__.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .piecewise_cuda_graph_runner import ( + DiffusionPiecewiseCudaGraphRunner, + resolve_capture_sizes, +) + +__all__ = [ + "DiffusionPiecewiseCudaGraphRunner", + "resolve_capture_sizes", +] diff --git a/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py b/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py new file mode 100644 index 000000000000..da0ce9474538 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py @@ -0,0 +1,465 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Piecewise CUDA graph runner for diffusion DiT models. + +This runner mirrors LLM piecewise CUDA graph behavior: +1. enable piecewise graph splitting via torch.compile backend hooks +2. bucket requests to capture sizes +3. apply padding to bucket shape +4. capture and replay CUDA graph with stable tensor addresses +""" + +from __future__ import annotations + +import bisect +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Hashable + +import torch + +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.srt.compilation.compilation_config import CompilationConfig +from sglang.srt.compilation.compile import install_torch_compiled +from sglang.srt.compilation.piecewise_context_manager import ( + enable_piecewise_cuda_graph, + enable_piecewise_cuda_graph_compile, + set_pcg_capture_stream, +) + +logger = init_logger(__name__) + +_RUNTIME_VALUE = object() +_GLOBAL_GRAPH_POOL = None + + +def _set_torch_compile_config() -> None: + import torch._dynamo.config + + torch._dynamo.config.accumulated_cache_size_limit = 1024 + if hasattr(torch._dynamo.config, "cache_size_limit"): + torch._dynamo.config.cache_size_limit = 1024 + + +@contextmanager +def _graph_capture_stream(): + if not current_platform.is_cuda_alike(): + yield None + return + + stream = torch.cuda.Stream() + current = torch.cuda.current_stream() + if current != stream: + stream.wait_stream(current) + with torch.cuda.stream(stream): + yield stream + + +def _is_tensor(obj: Any) -> bool: + return isinstance(obj, torch.Tensor) + + +def _join_path(parent: str, key: str) -> str: + if not parent: + return key + return f"{parent}.{key}" + + +def _supports_scalar(obj: Any) -> bool: + return isinstance(obj, (int, float, str, bool)) + + +def _pad_dim_for_path(path: str, tensor: torch.Tensor, raw_seq_len: int) -> int | None: + # Mandatory for all sequence DiTs. + if path.endswith("hidden_states") and tensor.ndim >= 2: + if tensor.shape[1] == raw_seq_len: + return 1 + + # Qwen-Image rotary cache for image tokens. + if path == "freqs_cis.0" and tensor.ndim >= 1: + if tensor.shape[0] == raw_seq_len: + return 0 + + return None + + +def _signature( + obj: Any, + *, + raw_seq_len: int, + static_seq_len: int, + path: str = "", +) -> Hashable: + if _is_tensor(obj): + shape = list(obj.shape) + pad_dim = _pad_dim_for_path(path, obj, raw_seq_len) + if pad_dim is not None: + shape[pad_dim] = static_seq_len + return ( + "tensor", + tuple(shape), + str(obj.dtype), + str(obj.device), + ) + + if isinstance(obj, list): + return ( + "list", + tuple( + _signature( + x, + raw_seq_len=raw_seq_len, + static_seq_len=static_seq_len, + path=_join_path(path, str(i)), + ) + for i, x in enumerate(obj) + ), + ) + if isinstance(obj, tuple): + return ( + "tuple", + tuple( + _signature( + x, + raw_seq_len=raw_seq_len, + static_seq_len=static_seq_len, + path=_join_path(path, str(i)), + ) + for i, x in enumerate(obj) + ), + ) + if isinstance(obj, dict): + items = tuple( + sorted( + ( + k, + _signature( + v, + raw_seq_len=raw_seq_len, + static_seq_len=static_seq_len, + path=_join_path(path, str(k)), + ), + ) + for k, v in obj.items() + ) + ) + return ("dict", items) + if obj is None: + return ("none",) + if _supports_scalar(obj): + return ("scalar", obj) + return ("repr", repr(obj)) + + +@dataclass +class _TensorSlot: + tensor: torch.Tensor + pad_dim: int | None = None + + +def _build_slots( + obj: Any, + *, + raw_seq_len: int, + static_seq_len: int, + path: str = "", +) -> Any: + if _is_tensor(obj): + pad_dim = _pad_dim_for_path(path, obj, raw_seq_len) + if pad_dim is None: + return _TensorSlot(tensor=torch.empty_like(obj), pad_dim=None) + new_shape = list(obj.shape) + new_shape[pad_dim] = static_seq_len + return _TensorSlot( + tensor=torch.empty( + tuple(new_shape), dtype=obj.dtype, device=obj.device + ), + pad_dim=pad_dim, + ) + + if isinstance(obj, list): + return [ + _build_slots( + x, + raw_seq_len=raw_seq_len, + static_seq_len=static_seq_len, + path=_join_path(path, str(i)), + ) + for i, x in enumerate(obj) + ] + if isinstance(obj, tuple): + return tuple( + _build_slots( + x, + raw_seq_len=raw_seq_len, + static_seq_len=static_seq_len, + path=_join_path(path, str(i)), + ) + for i, x in enumerate(obj) + ) + if isinstance(obj, dict): + return { + k: _build_slots( + v, + raw_seq_len=raw_seq_len, + static_seq_len=static_seq_len, + path=_join_path(path, str(k)), + ) + for k, v in obj.items() + } + return _RUNTIME_VALUE + + +def _materialize_call_kwargs(slots: Any, values: Any) -> Any: + if isinstance(slots, _TensorSlot): + src = values + dst = slots.tensor + + if slots.pad_dim is None: + if dst.shape != src.shape: + raise ValueError( + f"Tensor shape changed for CUDA graph replay: {dst.shape} vs {src.shape}" + ) + dst.copy_(src) + return dst + + if src.shape[slots.pad_dim] > dst.shape[slots.pad_dim]: + raise ValueError( + "Input sequence length exceeds padded slot shape: " + f"{src.shape[slots.pad_dim]} > {dst.shape[slots.pad_dim]}" + ) + + dst.zero_() + indices = [slice(None)] * dst.ndim + indices[slots.pad_dim] = slice(0, src.shape[slots.pad_dim]) + dst[tuple(indices)].copy_(src) + return dst + + if slots is _RUNTIME_VALUE: + return values + + if isinstance(slots, list): + return [ + _materialize_call_kwargs(s, v) for s, v in zip(slots, values, strict=True) + ] + if isinstance(slots, tuple): + return tuple( + _materialize_call_kwargs(s, v) for s, v in zip(slots, values, strict=True) + ) + if isinstance(slots, dict): + return {k: _materialize_call_kwargs(slots[k], values[k]) for k in slots.keys()} + + raise TypeError(f"Unsupported slot type: {type(slots)}") + + +def _slice_output_to_raw_seq(output: Any, raw_seq_len: int, static_seq_len: int) -> Any: + if raw_seq_len == static_seq_len: + return output + + if isinstance(output, torch.Tensor): + if output.ndim >= 2 and output.shape[1] == static_seq_len: + return output[:, :raw_seq_len, ...] + return output + if isinstance(output, list): + return [_slice_output_to_raw_seq(x, raw_seq_len, static_seq_len) for x in output] + if isinstance(output, tuple): + return tuple(_slice_output_to_raw_seq(x, raw_seq_len, static_seq_len) for x in output) + if isinstance(output, dict): + return { + k: _slice_output_to_raw_seq(v, raw_seq_len, static_seq_len) + for k, v in output.items() + } + return output + + +def _get_graph_pool(device: torch.device): + global _GLOBAL_GRAPH_POOL + if _GLOBAL_GRAPH_POOL is None: + _GLOBAL_GRAPH_POOL = torch.get_device_module(device).graph_pool_handle() + return _GLOBAL_GRAPH_POOL + + +@dataclass +class _GraphEntry: + signature: Hashable + static_seq_len: int + slots: dict[str, Any] + captured: bool = False + + +class DiffusionPiecewiseCudaGraphRunner: + def __init__( + self, + model: torch.nn.Module, + capture_sizes: list[int], + *, + compiler: str = "eager", + enable_debug: bool = False, + ) -> None: + self.model = model + self.capture_sizes = sorted(set(int(x) for x in capture_sizes)) + self.capture_sizes_set = set(self.capture_sizes) + self.compiler = compiler + self.enable_debug = enable_debug + + self._entries: dict[Hashable, _GraphEntry] = {} + self._installed = False + self._compiled = False + + device = next(model.parameters()).device + self.device = device + self.graph_pool = _get_graph_pool(device) + + self.compile_config = CompilationConfig( + self.capture_sizes, + compiler=self.compiler, + enable_debug_mode=self.enable_debug, + ) + _set_torch_compile_config() + + def _install_compiled(self) -> None: + if self._installed: + return + install_torch_compiled( + self.model, + dynamic_arg_dims={"hidden_states": [1]}, + compile_config=self.compile_config, + graph_pool=self.graph_pool, + fullgraph=True, + ) + self._installed = True + + def can_run( + self, hidden_states: torch.Tensor, seq_len_override: int | None = None + ) -> bool: + if not current_platform.is_cuda_alike(): + return False + if not torch.cuda.is_available(): + return False + if hidden_states.device.type != "cuda": + return False + if hidden_states.ndim < 2: + return False + if seq_len_override is None: + return False + if hidden_states.shape[1] != int(seq_len_override): + return False + if not self.capture_sizes: + return False + if seq_len_override > self.capture_sizes[-1]: + return False + return True + + def _select_static_seq_len(self, seq_len: int) -> int | None: + idx = bisect.bisect_left(self.capture_sizes, seq_len) + if idx >= len(self.capture_sizes): + return None + return self.capture_sizes[idx] + + def _ensure_compiled(self, call_kwargs: dict[str, Any]) -> None: + if self._compiled: + return + with enable_piecewise_cuda_graph(): + with enable_piecewise_cuda_graph_compile(): + _ = self.model(**call_kwargs) + self._compiled = True + + def _capture(self, call_kwargs: dict[str, Any]) -> None: + with enable_piecewise_cuda_graph(): + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.barrier() + + with _graph_capture_stream() as stream: + if stream is not None: + with set_pcg_capture_stream(stream): + self.model(**call_kwargs) + self.model(**call_kwargs) + else: + self.model(**call_kwargs) + self.model(**call_kwargs) + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.barrier() + + def run(self, *, seq_len_override: int | None = None, **kwargs) -> Any | None: + hidden_states = kwargs.get("hidden_states") + if hidden_states is None or not self.can_run( + hidden_states, seq_len_override=seq_len_override + ): + return None + + raw_seq_len = int(seq_len_override) + static_seq_len = self._select_static_seq_len(raw_seq_len) + if static_seq_len is None: + return None + + self._install_compiled() + + sig = ( + static_seq_len, + _signature( + kwargs, + raw_seq_len=raw_seq_len, + static_seq_len=static_seq_len, + path="", + ), + ) + entry = self._entries.get(sig) + if entry is None: + entry = _GraphEntry( + signature=sig, + static_seq_len=static_seq_len, + slots=_build_slots( + kwargs, + raw_seq_len=raw_seq_len, + static_seq_len=static_seq_len, + path="", + ), + ) + self._entries[sig] = entry + logger.info( + "Diffusion PCG init for %s (raw_seq=%d, static_seq=%d)", + self.model.__class__.__name__, + raw_seq_len, + static_seq_len, + ) + + call_kwargs = _materialize_call_kwargs(entry.slots, kwargs) + + if not self._compiled: + self._ensure_compiled(call_kwargs) + + if not entry.captured: + self._capture(call_kwargs) + entry.captured = True + + with enable_piecewise_cuda_graph(): + output = self.model(**call_kwargs) + + return _slice_output_to_raw_seq(output, raw_seq_len, static_seq_len) + + +def resolve_capture_sizes( + *, + seq_len: int, + explicit_sizes: list[int] | None, + max_tokens: int, +) -> list[int]: + if explicit_sizes: + sizes = sorted(set(int(x) for x in explicit_sizes if int(x) > 0)) + if seq_len <= max_tokens and seq_len not in sizes: + sizes.append(int(seq_len)) + return sorted(set(sizes)) + + capture_sizes = ( + list(range(4, 33, 4)) + + list(range(48, 257, 16)) + + list(range(288, 513, 32)) + + list(range(576, 1024 + 1, 64)) + + list(range(1280, 4096 + 1, 256)) + + list(range(4608, int(max_tokens) + 1, 512)) + ) + capture_sizes = [s for s in capture_sizes if s <= int(max_tokens)] + if seq_len <= max_tokens and seq_len not in capture_sizes: + capture_sizes.append(int(seq_len)) + return sorted(set(capture_sizes)) diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py index 844999554d99..5d2f76d3fe05 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py @@ -33,6 +33,10 @@ refresh_context_on_dual_transformer, refresh_context_on_transformer, ) +from sglang.multimodal_gen.runtime.compilation import ( + DiffusionPiecewiseCudaGraphRunner, + resolve_capture_sizes, +) from sglang.multimodal_gen.runtime.distributed import ( cfg_model_parallel_all_reduce, get_local_torch_device, @@ -128,6 +132,9 @@ def __init__( self._cache_dit_enabled = False self._cached_num_steps = None self._is_warmed_up = False + # piecewise cuda graph state + self._pcg_runners: dict[int, DiffusionPiecewiseCudaGraphRunner] = {} + self._pcg_disabled_model_ids: set[int] = set() def _maybe_enable_torch_compile(self, module: object) -> None: """ @@ -1345,6 +1352,28 @@ def _predict_noise( guidance: torch.Tensor, **kwargs, ): + runner = self._get_or_create_pcg_runner(current_model, latent_model_input) + if runner is not None: + try: + output = runner.run( + seq_len_override=int(latent_model_input.shape[1]), + hidden_states=latent_model_input, + timestep=timestep, + guidance=guidance, + **kwargs, + ) + if output is not None: + return output + except Exception as e: + model_id = id(current_model) + if model_id not in self._pcg_disabled_model_ids: + logger.warning( + "Disable diffusion PCG for %s after failure: %s", + current_model.__class__.__name__, + e, + ) + self._pcg_disabled_model_ids.add(model_id) + return current_model( hidden_states=latent_model_input, timestep=timestep, @@ -1352,6 +1381,51 @@ def _predict_noise( **kwargs, ) + def _is_pcg_backend_supported(self) -> bool: + return self.attn_backend.get_enum() in { + AttentionBackendEnum.FA, + AttentionBackendEnum.FA2, + AttentionBackendEnum.TORCH_SDPA, + } + + def _get_or_create_pcg_runner( + self, model: nn.Module, hidden_states: torch.Tensor + ) -> DiffusionPiecewiseCudaGraphRunner | None: + if not self.server_args.enable_piecewise_cuda_graph: + return None + if not self._is_pcg_backend_supported(): + return None + if hidden_states is None or hidden_states.ndim != 3: + return None + + model_id = id(model) + if model_id in self._pcg_disabled_model_ids: + return None + + runner = self._pcg_runners.get(model_id) + if runner is not None: + return runner + + capture_sizes = resolve_capture_sizes( + seq_len=int(hidden_states.shape[1]), + explicit_sizes=self.server_args.piecewise_cuda_graph_tokens, + max_tokens=self.server_args.piecewise_cuda_graph_max_tokens, + ) + runner = DiffusionPiecewiseCudaGraphRunner( + model=model, + capture_sizes=capture_sizes, + compiler=self.server_args.piecewise_cuda_graph_compiler, + enable_debug=self.server_args.enable_piecewise_cuda_graph_debug, + ) + self._pcg_runners[model_id] = runner + logger.info( + "Enable diffusion PCG for %s with %d capture buckets (max=%d)", + model.__class__.__name__, + len(capture_sizes), + max(capture_sizes) if capture_sizes else 0, + ) + return runner + def _predict_noise_with_cfg( self, current_model: nn.Module, diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py index 480efe584245..db35257d8373 100644 --- a/python/sglang/multimodal_gen/runtime/server_args.py +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -270,6 +270,11 @@ class ServerArgs: # Compilation enable_torch_compile: bool = False + enable_piecewise_cuda_graph: bool = False + piecewise_cuda_graph_max_tokens: int = 8192 + piecewise_cuda_graph_tokens: list[int] | None = None + piecewise_cuda_graph_compiler: str = "eager" + enable_piecewise_cuda_graph_debug: bool = False # warmup warmup: bool = False @@ -352,6 +357,7 @@ def _adjust_parameters(self): self._adjust_attention_backend() self._adjust_platform_specific() self._adjust_autocast() + self._adjust_piecewise_cuda_graph() def _validate_parameters(self): """check consistency and raise errors for invalid configs""" @@ -359,6 +365,7 @@ def _validate_parameters(self): self._validate_offload() self._validate_parallelism() self._validate_cfg_parallel() + self._validate_piecewise_cuda_graph() def _adjust_save_paths(self): """Normalize empty-string save paths to None (disabled).""" @@ -463,6 +470,51 @@ def _adjust_warmup(self): "Warmup enabled, the launch time is expected to be longer than usual" ) + def _adjust_piecewise_cuda_graph(self): + if self.piecewise_cuda_graph_tokens is None: + self.piecewise_cuda_graph_tokens = self._generate_piecewise_cuda_graph_tokens() + elif isinstance(self.piecewise_cuda_graph_tokens, str): + self.piecewise_cuda_graph_tokens = [ + int(x.strip()) + for x in self.piecewise_cuda_graph_tokens.split(",") + if x.strip() + ] + + self.piecewise_cuda_graph_tokens = sorted( + set( + int(x) + for x in (self.piecewise_cuda_graph_tokens or []) + if int(x) > 0 and int(x) <= int(self.piecewise_cuda_graph_max_tokens) + ) + ) + + if self.enable_piecewise_cuda_graph and self.enable_torch_compile: + logger.warning( + "Both --enable-piecewise-cuda-graph and --enable-torch-compile are set. " + "Disabling torch.compile and keeping piecewise CUDA graph." + ) + self.enable_torch_compile = False + + if not current_platform.is_cuda_alike() and self.enable_piecewise_cuda_graph: + logger.warning( + "Piecewise CUDA graph is enabled but current platform is not CUDA-like. " + "Feature will be disabled." + ) + self.enable_piecewise_cuda_graph = False + + def _generate_piecewise_cuda_graph_tokens(self) -> list[int]: + capture_sizes = ( + list(range(4, 33, 4)) + + list(range(48, 257, 16)) + + list(range(288, 513, 32)) + + list(range(576, 1024 + 1, 64)) + + list(range(1280, 4096 + 1, 256)) + + list(range(4608, int(self.piecewise_cuda_graph_max_tokens) + 1, 512)) + ) + return [ + s for s in capture_sizes if s <= int(self.piecewise_cuda_graph_max_tokens) + ] + def _adjust_network_ports(self): self.port = self.settle_port(self.port) initial_scheduler_port = self.scheduler_port + ( @@ -738,6 +790,37 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="Use torch.compile to speed up DiT inference." + "However, will likely cause precision drifts. See (https://github.com/pytorch/pytorch/issues/145213)", ) + parser.add_argument( + "--enable-piecewise-cuda-graph", + action=StoreBoolean, + default=ServerArgs.enable_piecewise_cuda_graph, + help="Enable piecewise CUDA graph with automatic graph splitting and padding buckets.", + ) + parser.add_argument( + "--piecewise-cuda-graph-max-tokens", + type=int, + default=ServerArgs.piecewise_cuda_graph_max_tokens, + help="Maximum sequence length bucket for diffusion piecewise CUDA graph.", + ) + parser.add_argument( + "--piecewise-cuda-graph-tokens", + type=int, + nargs="+", + default=ServerArgs.piecewise_cuda_graph_tokens, + help="Optional explicit token buckets for diffusion piecewise CUDA graph.", + ) + parser.add_argument( + "--piecewise-cuda-graph-compiler", + type=str, + default=ServerArgs.piecewise_cuda_graph_compiler, + help="Compiler backend for diffusion piecewise CUDA graph: eager or inductor.", + ) + parser.add_argument( + "--enable-piecewise-cuda-graph-debug", + action=StoreBoolean, + default=ServerArgs.enable_piecewise_cuda_graph_debug, + help="Enable debug checks for diffusion piecewise CUDA graph.", + ) # warmup parser.add_argument( @@ -1197,6 +1280,21 @@ def _validate_cfg_parallel(self): "CFG Parallelism is enabled via `--enable-cfg-parallel`, but num_gpus == 1" ) + def _validate_piecewise_cuda_graph(self): + if self.piecewise_cuda_graph_max_tokens <= 0: + raise ValueError("piecewise_cuda_graph_max_tokens must be positive") + + if self.piecewise_cuda_graph_compiler not in ("eager", "inductor"): + raise ValueError( + "piecewise_cuda_graph_compiler must be one of: eager, inductor" + ) + + if self.piecewise_cuda_graph_tokens is not None: + if len(self.piecewise_cuda_graph_tokens) == 0: + raise ValueError( + "piecewise_cuda_graph_tokens is empty; set a non-empty bucket list or unset it" + ) + def _set_default_attention_backend(self) -> None: """Configure ROCm defaults when users do not specify an attention backend.""" if current_platform.is_rocm(): From e9a46e4268a6ed498ce4a043e9d182b297638d1f Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Wed, 4 Mar 2026 10:08:01 +0800 Subject: [PATCH 02/11] ud --- python/sglang/multimodal_gen/runtime/layers/linear.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/linear.py b/python/sglang/multimodal_gen/runtime/layers/linear.py index f74309a53405..45d7b33913bf 100644 --- a/python/sglang/multimodal_gen/runtime/layers/linear.py +++ b/python/sglang/multimodal_gen/runtime/layers/linear.py @@ -39,6 +39,7 @@ from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger logger = init_logger(__name__) +_AMP_SUPPORTED = current_platform.is_amp_supported() WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", @@ -153,7 +154,7 @@ def apply( ) -> torch.Tensor: output = ( F.linear(x, layer.weight, bias) - if current_platform.is_amp_supported() or bias is None + if _AMP_SUPPORTED or bias is None else F.linear(x, layer.weight, bias.to(x.dtype)) ) # NOTE: this line assumes that we are using amp when using cuda and is needed to account for the fact that amp isn't supported in mps return output From 0ff111e91e9f3918614cc11e4069bb5be7806e4a Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Wed, 4 Mar 2026 10:13:15 +0800 Subject: [PATCH 03/11] ud --- .../runtime/compilation/piecewise_cuda_graph_runner.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py b/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py index da0ce9474538..965c866bf6eb 100644 --- a/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py +++ b/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py @@ -305,6 +305,7 @@ def __init__( self._entries: dict[Hashable, _GraphEntry] = {} self._installed = False self._compiled = False + self._eager_warmup_done = False device = next(model.parameters()).device self.device = device @@ -359,6 +360,12 @@ def _select_static_seq_len(self, seq_len: int) -> int | None: def _ensure_compiled(self, call_kwargs: dict[str, Any]) -> None: if self._compiled: return + # Warm up lazy custom kernels (e.g. JIT kernels that touch filesystem) + # outside torch.compile to avoid Dynamo tracing unsupported Python/C APIs. + if not self._eager_warmup_done: + with torch.no_grad(): + self.model(**call_kwargs) + self._eager_warmup_done = True with enable_piecewise_cuda_graph(): with enable_piecewise_cuda_graph_compile(): _ = self.model(**call_kwargs) From c332d86ee8d4ba313bf261804aebf64f6e8ff2d9 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Wed, 4 Mar 2026 10:44:10 +0800 Subject: [PATCH 04/11] ud --- .../runtime/layers/visual_embedding.py | 70 ++++++++++++++++--- 1 file changed, 61 insertions(+), 9 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/visual_embedding.py b/python/sglang/multimodal_gen/runtime/layers/visual_embedding.py index c8eff7d6f0d8..d51bb2d49cc2 100644 --- a/python/sglang/multimodal_gen/runtime/layers/visual_embedding.py +++ b/python/sglang/multimodal_gen/runtime/layers/visual_embedding.py @@ -27,11 +27,64 @@ from sglang.multimodal_gen.runtime.layers.activation import get_act_fn from sglang.multimodal_gen.runtime.layers.linear import ColumnParallelLinear from sglang.multimodal_gen.runtime.layers.mlp import MLP +from sglang.multimodal_gen.runtime.layers.utils import register_custom_op from sglang.multimodal_gen.runtime.platforms import current_platform _is_cuda = current_platform.is_cuda() +def _timestep_embedding_cuda_fake( + timesteps: torch.Tensor, + num_channels: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 0.0, + scale: float = 1, +) -> torch.Tensor: + return torch.empty( + (timesteps.shape[0], num_channels), + dtype=torch.float32, + device=timesteps.device, + ) + + +if _is_cuda: + + @register_custom_op( + op_name="diffusion_timestep_embedding_cuda", + fake_impl=_timestep_embedding_cuda_fake, + ) + def timestep_embedding_cuda_custom_op( + timesteps: torch.Tensor, + num_channels: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 0.0, + scale: float = 1, + ) -> torch.Tensor: + return timestep_embedding_cuda( + timesteps, + num_channels, + flip_sin_to_cos=flip_sin_to_cos, + downscale_freq_shift=downscale_freq_shift, + scale=scale, + ) +else: + + def timestep_embedding_cuda_custom_op( + timesteps: torch.Tensor, + num_channels: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 0.0, + scale: float = 1, + ) -> torch.Tensor: + return timestep_embedding_cuda( + timesteps, + num_channels, + flip_sin_to_cos=flip_sin_to_cos, + downscale_freq_shift=downscale_freq_shift, + scale=scale, + ) + + class PatchEmbed(nn.Module): """2D Image to Patch Embedding @@ -89,21 +142,20 @@ def forward(self, x): class Timesteps(_Timesteps): def forward(self, timesteps: torch.Tensor) -> torch.Tensor: if _is_cuda: - return timestep_embedding_cuda( - timesteps, - self.num_channels, - flip_sin_to_cos=self.flip_sin_to_cos, - downscale_freq_shift=self.downscale_freq_shift, - scale=self.scale, - ) - else: - return timestep_embedding_diffusers( + return timestep_embedding_cuda_custom_op( timesteps, self.num_channels, flip_sin_to_cos=self.flip_sin_to_cos, downscale_freq_shift=self.downscale_freq_shift, scale=self.scale, ) + return timestep_embedding_diffusers( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) class CombinedTimestepGuidanceTextProjEmbeddings( From 78f902cc0287af966073194ef1b4f7d83fa43746 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Wed, 4 Mar 2026 10:51:06 +0800 Subject: [PATCH 05/11] ud --- python/sglang/multimodal_gen/runtime/layers/layernorm.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/sglang/multimodal_gen/runtime/layers/layernorm.py b/python/sglang/multimodal_gen/runtime/layers/layernorm.py index 888824f60dd9..a728ee9698db 100644 --- a/python/sglang/multimodal_gen/runtime/layers/layernorm.py +++ b/python/sglang/multimodal_gen/runtime/layers/layernorm.py @@ -30,8 +30,15 @@ get_tp_group, ) from sglang.multimodal_gen.runtime.layers.custom_op import CustomOp +from sglang.multimodal_gen.runtime.layers.utils import register_custom_op from sglang.multimodal_gen.runtime.utils.common import get_bool_env_var +if _is_cuda: + fused_inplace_qknorm = register_custom_op( + fused_inplace_qknorm, + mutates_args=["q", "k"], + ) + # Copied and adapted from sglang @CustomOp.register("rms_norm") From 3be0b1bfee2975b9698aadeb5c86f2c52e05317e Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Wed, 4 Mar 2026 11:12:50 +0800 Subject: [PATCH 06/11] ud --- .../runtime/layers/rotary_embedding/utils.py | 93 ++++++++++++------- 1 file changed, 58 insertions(+), 35 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/utils.py b/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/utils.py index 4eaa106e8742..5dbcd0e0ddb7 100644 --- a/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/utils.py +++ b/python/sglang/multimodal_gen/runtime/layers/rotary_embedding/utils.py @@ -5,6 +5,44 @@ import torch from sglang.jit_kernel.diffusion.triton.rotary import apply_rotary_embedding +from sglang.multimodal_gen.runtime.layers.utils import register_custom_op +from sglang.multimodal_gen.runtime.platforms import current_platform + +_is_cuda = current_platform.is_cuda() +_flashinfer_rope_qk_inplace_custom_op = None + +if _is_cuda: + try: + from flashinfer.rope import ( + apply_rope_with_cos_sin_cache_inplace as _flashinfer_apply_rope_with_cos_sin_cache_inplace, + ) + except ImportError: + _flashinfer_apply_rope_with_cos_sin_cache_inplace = None + + if _flashinfer_apply_rope_with_cos_sin_cache_inplace is not None: + + @register_custom_op( + op_name="diffusion_flashinfer_rope_qk_inplace", + mutates_args=["query", "key"], + ) + def flashinfer_rope_qk_inplace_custom_op( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool = False, + ) -> None: + _flashinfer_apply_rope_with_cos_sin_cache_inplace( + positions=positions, + query=query, + key=key, + head_size=head_size, + cos_sin_cache=cos_sin_cache, + is_neox=is_neox, + ) + + _flashinfer_rope_qk_inplace_custom_op = flashinfer_rope_qk_inplace_custom_op def _apply_rotary_emb( @@ -67,32 +105,6 @@ def apply_flashinfer_rope_qk_inplace( if head_size != d: raise ValueError(f"head_size mismatch: inferred {d}, but head_size={head_size}") - try: - from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace - except ImportError: - # Triton fallback for AMD/ROCm where FlashInfer is not available - import warnings - - warnings.warn( - "FlashInfer not available, using Triton fallback for RoPE", - stacklevel=2, - ) - half_size = cos_sin_cache.shape[-1] // 2 - if positions is None: - cos = cos_sin_cache[:seqlen, :half_size].to(q.dtype) - sin = cos_sin_cache[:seqlen, half_size:].to(q.dtype) - cos = cos.unsqueeze(0).expand(bsz, -1, -1).reshape(bsz * seqlen, -1) - sin = sin.unsqueeze(0).expand(bsz, -1, -1).reshape(bsz * seqlen, -1) - else: - positions = positions.to(cos_sin_cache.device).view(-1) - cos = cos_sin_cache[positions, :half_size].to(q.dtype) - sin = cos_sin_cache[positions, half_size:].to(q.dtype) - q_flat = q.reshape(bsz * seqlen, nheads, d) - k_flat = k.reshape(bsz * seqlen, nheads, d) - q_rot = apply_rotary_embedding(q_flat, cos, sin, interleaved=not is_neox) - k_rot = apply_rotary_embedding(k_flat, cos, sin, interleaved=not is_neox) - return q_rot.view(bsz, seqlen, nheads, d), k_rot.view(bsz, seqlen, nheads, d) - if positions is None: pos_1d = torch.arange(seqlen, device=q.device, dtype=torch.long) positions = pos_1d if bsz == 1 else pos_1d.repeat(bsz) @@ -110,12 +122,23 @@ def apply_flashinfer_rope_qk_inplace( q_flat = q.reshape(bsz * seqlen, nheads * d).contiguous() k_flat = k.reshape(bsz * seqlen, nheads * d).contiguous() - apply_rope_with_cos_sin_cache_inplace( - positions=positions, - query=q_flat, - key=k_flat, - head_size=d, - cos_sin_cache=cos_sin_cache, - is_neox=is_neox, - ) - return q_flat.view(bsz, seqlen, nheads, d), k_flat.view(bsz, seqlen, nheads, d) + if _flashinfer_rope_qk_inplace_custom_op is not None: + _flashinfer_rope_qk_inplace_custom_op( + positions=positions, + query=q_flat, + key=k_flat, + head_size=d, + cos_sin_cache=cos_sin_cache, + is_neox=is_neox, + ) + return q_flat.view(bsz, seqlen, nheads, d), k_flat.view(bsz, seqlen, nheads, d) + + half_size = cos_sin_cache.shape[-1] // 2 + positions = positions.to(cos_sin_cache.device).view(-1) + cos = cos_sin_cache[positions, :half_size].to(q.dtype) + sin = cos_sin_cache[positions, half_size:].to(q.dtype) + q_flat = q.reshape(bsz * seqlen, nheads, d) + k_flat = k.reshape(bsz * seqlen, nheads, d) + q_rot = apply_rotary_embedding(q_flat, cos, sin, interleaved=not is_neox) + k_rot = apply_rotary_embedding(k_flat, cos, sin, interleaved=not is_neox) + return q_rot.view(bsz, seqlen, nheads, d), k_rot.view(bsz, seqlen, nheads, d) From abde4344b81e55ecebf2693ffbf7b49f21990b82 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Wed, 4 Mar 2026 11:22:33 +0800 Subject: [PATCH 07/11] ud --- .../layers/attention/backends/flash_attn.py | 37 +++++++++++++++++-- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py index 4fbe4f0c78c6..ba01b8318587 100644 --- a/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py +++ b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py @@ -12,6 +12,7 @@ AttentionBackendEnum, current_platform, ) +from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph try: from sgl_kernel.flash_attn import flash_attn_varlen_func @@ -71,7 +72,6 @@ def flash_attn_varlen_func_fake_out( sinks: Optional[torch.Tensor] = None, ver: int = 4, ) -> torch.Tensor: - assert ver == 4, "only support flash attention v4" q, k, v = [maybe_contiguous(t) for t in (q, k, v)] num_head, head_dim = q.shape[-2:] if cu_seqlens_q is None: @@ -131,7 +131,6 @@ def flash_attn_varlen_func_fake_out_lse( sinks: Optional[torch.Tensor] = None, ver: int = 4, ) -> Tuple[torch.Tensor, torch.Tensor]: - assert ver == 4, "only support flash attention v4" q, k, v = [maybe_contiguous(t) for t in (q, k, v)] num_head, head_dim = q.shape[-2:] if cu_seqlens_q is None: @@ -480,9 +479,10 @@ def forward( q_shape = tuple(query.shape) k_shape = tuple(key.shape) v_shape = tuple(value.shape) + in_piecewise = is_in_piecewise_cuda_graph() use_upstream = _should_use_upstream_flash_attention( - flash_attn_varlen_func_upstream is not None, + (flash_attn_varlen_func_upstream is not None) and (not in_piecewise), self._upstream_heads_ok, q_shape, k_shape, @@ -515,6 +515,37 @@ def forward( # - fa_ver == 3: call python function (can return Tensor or (Tensor, Tensor) depending on flag) # - fa_ver == 4: call custom ops with FIXED return schema if fa_ver == 3: + if in_piecewise: + if return_softmax_lse: + out_tensor, softmax_lse = flash_attn_varlen_func_op_lse( + q=query, + k=key, + v=value, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.softmax_scale, + causal=self.causal, + return_softmax_lse=True, + ver=fa_ver, + ) + return out_tensor, softmax_lse + out_tensor = flash_attn_varlen_func_op( + q=query, + k=key, + v=value, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.softmax_scale, + causal=self.causal, + return_softmax_lse=False, + ver=fa_ver, + ) + return out_tensor + flash_attn_op = flash_attn_func output = flash_attn_op( q=query, From e1d406774eb98c7d4564b68ff92b75f401364d54 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Wed, 4 Mar 2026 11:35:29 +0800 Subject: [PATCH 08/11] ud --- .../piecewise_cuda_graph_runner.py | 8 +- .../layers/attention/backends/flash_attn.py | 35 +++++- .../pipelines_core/stages/denoising.py | 115 ++++++++++++++++++ 3 files changed, 153 insertions(+), 5 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py b/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py index 965c866bf6eb..049a8070e122 100644 --- a/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py +++ b/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py @@ -441,7 +441,13 @@ def run(self, *, seq_len_override: int | None = None, **kwargs) -> Any | None: entry.captured = True with enable_piecewise_cuda_graph(): - output = self.model(**call_kwargs) + if current_platform.is_cuda_alike(): + # Runtime recompilation can trigger late on-demand capture in the + # backend. Ensure a valid capture stream is always available. + with set_pcg_capture_stream(torch.cuda.current_stream()): + output = self.model(**call_kwargs) + else: + output = self.model(**call_kwargs) return _slice_output_to_raw_seq(output, raw_seq_len, static_seq_len) diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py index ba01b8318587..d200ce803940 100644 --- a/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py +++ b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py @@ -1,7 +1,6 @@ # Copied and adapted from: https://github.com/hao-ai-lab/FastVideo # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from functools import lru_cache from typing import Any, List, Optional, Tuple import torch @@ -329,18 +328,29 @@ def flash_attn_varlen_func_op_lse( fa_ver = 3 -@lru_cache(maxsize=128) def _get_cu_seqlens(device_index: int, bsz: int, seqlen: int) -> torch.Tensor: - return torch.arange( + key = (device_index, bsz, seqlen) + cached = _CU_SEQLENS_CACHE.get(key) + if cached is not None: + return cached + + cu_seqlens = torch.arange( 0, (bsz + 1) * seqlen, step=seqlen, device=torch.device("cuda", device_index), dtype=torch.int32, ) + _CU_SEQLENS_CACHE[key] = cu_seqlens + return cu_seqlens + + +_CU_SEQLENS_CACHE: dict[tuple[int, int, int], torch.Tensor] = {} +_UPSTREAM_FLASH_ATTN_CACHE: dict[ + tuple[bool, bool, tuple[int, ...], tuple[int, ...], tuple[int, ...]], bool +] = {} -@lru_cache(maxsize=256) def _should_use_upstream_flash_attention( upstream_available: bool, upstream_heads_ok: bool, @@ -348,10 +358,23 @@ def _should_use_upstream_flash_attention( k_shape: tuple[int, ...], v_shape: tuple[int, ...], ) -> bool: + cache_key = ( + upstream_available, + upstream_heads_ok, + q_shape, + k_shape, + v_shape, + ) + cached = _UPSTREAM_FLASH_ATTN_CACHE.get(cache_key) + if cached is not None: + return cached + if not upstream_available or not upstream_heads_ok: + _UPSTREAM_FLASH_ATTN_CACHE[cache_key] = False return False if len(q_shape) != 4 or len(k_shape) != 4 or len(v_shape) != 4: + _UPSTREAM_FLASH_ATTN_CACHE[cache_key] = False return False bsz, seqlen, nheads_q, d = q_shape @@ -366,11 +389,15 @@ def _should_use_upstream_flash_attention( or d != d_k or d != d_v ): + _UPSTREAM_FLASH_ATTN_CACHE[cache_key] = False return False if nheads_k != nheads_v: + _UPSTREAM_FLASH_ATTN_CACHE[cache_key] = False return False if nheads_k == 0 or (nheads_q % nheads_k) != 0: + _UPSTREAM_FLASH_ATTN_CACHE[cache_key] = False return False + _UPSTREAM_FLASH_ATTN_CACHE[cache_key] = True return True diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py index 5d2f76d3fe05..23dc38d1f98c 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py @@ -1002,6 +1002,21 @@ def forward( seq_len = prepared_vars["seq_len"] guidance = prepared_vars["guidance"] + self._maybe_precapture_piecewise_cuda_graph( + batch=batch, + server_args=server_args, + timesteps=timesteps, + target_dtype=target_dtype, + boundary_timestep=boundary_timestep, + seq_len=seq_len, + reserved_frames_mask=reserved_frames_mask, + guidance=guidance, + latents=latents, + image_kwargs=image_kwargs, + pos_cond_kwargs=pos_cond_kwargs, + neg_cond_kwargs=neg_cond_kwargs, + ) + # Initialize lists for ODE trajectory trajectory_timesteps: list[torch.Tensor] = [] trajectory_latents: list[torch.Tensor] = [] @@ -1585,6 +1600,106 @@ def _predict_noise_with_cfg( ) return noise_pred + def _maybe_precapture_piecewise_cuda_graph( + self, + *, + batch: Req, + server_args: ServerArgs, + timesteps: torch.Tensor, + target_dtype: torch.dtype, + boundary_timestep: float | None, + seq_len: int | None, + reserved_frames_mask: torch.Tensor | None, + guidance: torch.Tensor | None, + latents: torch.Tensor, + image_kwargs: dict[str, Any], + pos_cond_kwargs: dict[str, Any], + neg_cond_kwargs: dict[str, Any], + ) -> None: + if not server_args.enable_piecewise_cuda_graph: + return + if batch.is_warmup: + return + if timesteps is None or timesteps.numel() == 0: + return + + expected_model_count = 2 if self.transformer_2 is not None else 1 + visited_model_ids: set[int] = set() + timesteps_cpu = timesteps.cpu() + original_is_cfg_negative = batch.is_cfg_negative + + logger.info( + "Pre-capturing diffusion PCG before denoising loop (target_models=%d)", + expected_model_count, + ) + + for i, t_host in enumerate(timesteps_cpu): + t_int = int(t_host.item()) + t_device = timesteps[i] + current_model, current_guidance_scale = self._select_and_manage_model( + t_int=t_int, + boundary_timestep=boundary_timestep, + server_args=server_args, + batch=batch, + ) + + model_id = id(current_model) + if model_id in visited_model_ids: + continue + + latent_model_input = latents.to(target_dtype) + if batch.image_latent is not None: + latent_model_input = torch.cat( + [latent_model_input, batch.image_latent], dim=1 + ).to(target_dtype) + + timestep = self.expand_timestep_before_forward( + batch, + server_args, + t_device, + target_dtype, + seq_len, + reserved_frames_mask, + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t_device + ) + + attn_metadata = self._build_attn_metadata( + i, + batch, + server_args, + timestep_value=t_int, + timesteps=timesteps_cpu, + ) + + _ = self._predict_noise_with_cfg( + current_model=current_model, + latent_model_input=latent_model_input, + timestep=timestep, + batch=batch, + timestep_index=i, + attn_metadata=attn_metadata, + target_dtype=target_dtype, + current_guidance_scale=current_guidance_scale, + image_kwargs=image_kwargs, + pos_cond_kwargs=pos_cond_kwargs, + neg_cond_kwargs=neg_cond_kwargs, + server_args=server_args, + guidance=guidance, + latents=latents, + ) + + visited_model_ids.add(model_id) + if len(visited_model_ids) >= expected_model_count: + break + + batch.is_cfg_negative = original_is_cfg_negative + logger.info( + "Pre-capture finished for %d model(s) before formal denoising", + len(visited_model_ids), + ) + def prepare_sta_param(self, batch: Req, server_args: ServerArgs): """ Prepare Sliding Tile Attention (STA) parameters and settings. From d0d2ec952f60746b45a8bfeceea7f79d304c16c8 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Wed, 4 Mar 2026 11:44:28 +0800 Subject: [PATCH 09/11] ud --- .../layers/attention/backends/flash_attn.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py index d200ce803940..de8671fa97a0 100644 --- a/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py +++ b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py @@ -494,14 +494,12 @@ def forward( return_softmax_lse: bool = False, ): attn_metadata: FlashAttentionMetadata = get_forward_context().attn_metadata - if attn_metadata is not None and attn_metadata.max_seqlen_q is None: - attn_metadata.max_seqlen_q = query.shape[1] - attn_metadata.max_seqlen_k = key.shape[1] - max_seqlen_q = attn_metadata.max_seqlen_q - max_seqlen_k = attn_metadata.max_seqlen_k - else: - max_seqlen_q = query.shape[1] - max_seqlen_k = key.shape[1] + # Keep control flow stable for torch.compile/PCG. + max_seqlen_q = query.shape[1] + max_seqlen_k = key.shape[1] + if attn_metadata is not None: + attn_metadata.max_seqlen_q = max_seqlen_q + attn_metadata.max_seqlen_k = max_seqlen_k q_shape = tuple(query.shape) k_shape = tuple(key.shape) From 3a1e70079720764593ac40d7fa17991d956870ee Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Wed, 4 Mar 2026 14:46:22 +0800 Subject: [PATCH 10/11] ud --- .../compilation/piecewise_cuda_graph_runner.py | 13 +++++++++++++ .../srt/compilation/cuda_piecewise_backend.py | 18 +++++++++++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py b/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py index 049a8070e122..6d3c6f90f6d4 100644 --- a/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py +++ b/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py @@ -19,6 +19,7 @@ from sglang.multimodal_gen.runtime.platforms import current_platform from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.srt.compilation.compilation_counter import compilation_counter from sglang.srt.compilation.compilation_config import CompilationConfig from sglang.srt.compilation.compile import install_torch_compiled from sglang.srt.compilation.piecewise_context_manager import ( @@ -372,6 +373,7 @@ def _ensure_compiled(self, call_kwargs: dict[str, Any]) -> None: self._compiled = True def _capture(self, call_kwargs: dict[str, Any]) -> None: + before = compilation_counter.num_cudagraph_captured with enable_piecewise_cuda_graph(): if torch.distributed.is_available() and torch.distributed.is_initialized(): torch.distributed.barrier() @@ -387,6 +389,17 @@ def _capture(self, call_kwargs: dict[str, Any]) -> None: if torch.distributed.is_available() and torch.distributed.is_initialized(): torch.distributed.barrier() + after = compilation_counter.num_cudagraph_captured + if after > before: + logger.info( + "Diffusion PCG captured %d piecewise CUDA graphs", + after - before, + ) + else: + logger.warning( + "Diffusion PCG capture completed but no piecewise CUDA graph was captured; " + "runtime shape may not match configured capture buckets." + ) def run(self, *, seq_len_override: int | None = None, **kwargs) -> Any | None: hidden_states = kwargs.get("hidden_states") diff --git a/python/sglang/srt/compilation/cuda_piecewise_backend.py b/python/sglang/srt/compilation/cuda_piecewise_backend.py index 8ca1e6a43cbc..d001d267632d 100644 --- a/python/sglang/srt/compilation/cuda_piecewise_backend.py +++ b/python/sglang/srt/compilation/cuda_piecewise_backend.py @@ -113,7 +113,23 @@ def __call__(self, *args) -> Any: if len(self.sym_shape_indices) == 0: return self.compiled_graph_for_general_shape(*args) - runtime_shape = args[self.sym_shape_indices[0]] + runtime_shape = None + for idx in self.sym_shape_indices: + candidate = args[idx] + try: + candidate = int(candidate) + except Exception: + pass + if candidate in self.concrete_size_entries: + runtime_shape = candidate + break + if runtime_shape is None: + runtime_shape = args[self.sym_shape_indices[0]] + try: + runtime_shape = int(runtime_shape) + except Exception: + pass + if runtime_shape not in self.concrete_size_entries: # we don't need to do anything for this shape return self.compiled_graph_for_general_shape(*args) From 62ba40233f7c41139c5bdb998288131cd1f502a8 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Wed, 4 Mar 2026 15:07:38 +0800 Subject: [PATCH 11/11] ud --- .../piecewise_cuda_graph_runner.py | 56 +++++++------- .../srt/compilation/cuda_piecewise_backend.py | 76 +++++++++++++++---- .../compilation/piecewise_context_manager.py | 16 ++++ 3 files changed, 108 insertions(+), 40 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py b/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py index 6d3c6f90f6d4..e879c7560872 100644 --- a/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py +++ b/python/sglang/multimodal_gen/runtime/compilation/piecewise_cuda_graph_runner.py @@ -26,6 +26,7 @@ enable_piecewise_cuda_graph, enable_piecewise_cuda_graph_compile, set_pcg_capture_stream, + set_pcg_runtime_shape, ) logger = init_logger(__name__) @@ -358,7 +359,7 @@ def _select_static_seq_len(self, seq_len: int) -> int | None: return None return self.capture_sizes[idx] - def _ensure_compiled(self, call_kwargs: dict[str, Any]) -> None: + def _ensure_compiled(self, call_kwargs: dict[str, Any], runtime_shape: int) -> None: if self._compiled: return # Warm up lazy custom kernels (e.g. JIT kernels that touch filesystem) @@ -368,27 +369,29 @@ def _ensure_compiled(self, call_kwargs: dict[str, Any]) -> None: self.model(**call_kwargs) self._eager_warmup_done = True with enable_piecewise_cuda_graph(): - with enable_piecewise_cuda_graph_compile(): - _ = self.model(**call_kwargs) + with set_pcg_runtime_shape(runtime_shape): + with enable_piecewise_cuda_graph_compile(): + _ = self.model(**call_kwargs) self._compiled = True - def _capture(self, call_kwargs: dict[str, Any]) -> None: + def _capture(self, call_kwargs: dict[str, Any], runtime_shape: int) -> None: before = compilation_counter.num_cudagraph_captured with enable_piecewise_cuda_graph(): - if torch.distributed.is_available() and torch.distributed.is_initialized(): - torch.distributed.barrier() - - with _graph_capture_stream() as stream: - if stream is not None: - with set_pcg_capture_stream(stream): - self.model(**call_kwargs) - self.model(**call_kwargs) - else: - self.model(**call_kwargs) - self.model(**call_kwargs) - - if torch.distributed.is_available() and torch.distributed.is_initialized(): - torch.distributed.barrier() + with set_pcg_runtime_shape(runtime_shape): + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.barrier() + + with _graph_capture_stream() as stream: + if stream is not None: + with set_pcg_capture_stream(stream): + for _ in range(4): + self.model(**call_kwargs) + else: + for _ in range(4): + self.model(**call_kwargs) + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.barrier() after = compilation_counter.num_cudagraph_captured if after > before: logger.info( @@ -447,20 +450,21 @@ def run(self, *, seq_len_override: int | None = None, **kwargs) -> Any | None: call_kwargs = _materialize_call_kwargs(entry.slots, kwargs) if not self._compiled: - self._ensure_compiled(call_kwargs) + self._ensure_compiled(call_kwargs, static_seq_len) if not entry.captured: - self._capture(call_kwargs) + self._capture(call_kwargs, static_seq_len) entry.captured = True with enable_piecewise_cuda_graph(): - if current_platform.is_cuda_alike(): - # Runtime recompilation can trigger late on-demand capture in the - # backend. Ensure a valid capture stream is always available. - with set_pcg_capture_stream(torch.cuda.current_stream()): + with set_pcg_runtime_shape(static_seq_len): + if current_platform.is_cuda_alike(): + # Runtime recompilation can trigger late on-demand capture in the + # backend. Ensure a valid capture stream is always available. + with set_pcg_capture_stream(torch.cuda.current_stream()): + output = self.model(**call_kwargs) + else: output = self.model(**call_kwargs) - else: - output = self.model(**call_kwargs) return _slice_output_to_raw_seq(output, raw_seq_len, static_seq_len) diff --git a/python/sglang/srt/compilation/cuda_piecewise_backend.py b/python/sglang/srt/compilation/cuda_piecewise_backend.py index d001d267632d..e3f2fcf93369 100644 --- a/python/sglang/srt/compilation/cuda_piecewise_backend.py +++ b/python/sglang/srt/compilation/cuda_piecewise_backend.py @@ -13,6 +13,7 @@ from sglang.srt.compilation.compilation_counter import compilation_counter from sglang.srt.compilation.piecewise_context_manager import ( get_pcg_capture_stream, + get_pcg_runtime_shape, is_in_pcg_torch_compile, ) from sglang.srt.compilation.weak_ref_tensor import weak_ref_tensors @@ -104,31 +105,78 @@ def check_for_ending_compilation(self): # save the hash of the inductor graph for the next run self.sglang_backend.compiler_manager.save_to_file() + def _infer_runtime_shape_from_tensors(self, args) -> Optional[int]: + # Prefer sequence-like dims from activation tensors: + # 1) dim-1 for rank>=2 tensors (e.g. [B, S, C] -> S) + # 2) dim-0 for rank>=1 tensors + # 3) any dim as final fallback + candidates: list[int] = [] + + for x in args: + if not isinstance(x, torch.Tensor): + continue + if x.ndim >= 2: + s = int(x.shape[1]) + if s in self.concrete_size_entries: + candidates.append(s) + if candidates: + return max(candidates) + + for x in args: + if not isinstance(x, torch.Tensor): + continue + if x.ndim >= 1: + s = int(x.shape[0]) + if s in self.concrete_size_entries: + candidates.append(s) + if candidates: + return max(candidates) + + for x in args: + if not isinstance(x, torch.Tensor): + continue + for d in x.shape: + s = int(d) + if s in self.concrete_size_entries: + candidates.append(s) + if candidates: + return max(candidates) + return None + def __call__(self, *args) -> Any: if not self.first_run_finished: self.first_run_finished = True self.check_for_ending_compilation() return self.compiled_graph_for_general_shape(*args) - if len(self.sym_shape_indices) == 0: - return self.compiled_graph_for_general_shape(*args) - runtime_shape = None - for idx in self.sym_shape_indices: - candidate = args[idx] + runtime_shape_override = get_pcg_runtime_shape() + if runtime_shape_override is not None: try: - candidate = int(candidate) + runtime_shape_override = int(runtime_shape_override) except Exception: pass - if candidate in self.concrete_size_entries: - runtime_shape = candidate - break + if runtime_shape_override in self.concrete_size_entries: + runtime_shape = runtime_shape_override + + if runtime_shape is None and len(self.sym_shape_indices) > 0: + for idx in self.sym_shape_indices: + candidate = args[idx] + try: + candidate = int(candidate) + except Exception: + pass + if candidate in self.concrete_size_entries: + runtime_shape = candidate + break + if runtime_shape is None: + runtime_shape = args[self.sym_shape_indices[0]] + try: + runtime_shape = int(runtime_shape) + except Exception: + pass if runtime_shape is None: - runtime_shape = args[self.sym_shape_indices[0]] - try: - runtime_shape = int(runtime_shape) - except Exception: - pass + runtime_shape = self._infer_runtime_shape_from_tensors(args) if runtime_shape not in self.concrete_size_entries: # we don't need to do anything for this shape diff --git a/python/sglang/srt/compilation/piecewise_context_manager.py b/python/sglang/srt/compilation/piecewise_context_manager.py index 20a08a9972b9..852cf72775c8 100644 --- a/python/sglang/srt/compilation/piecewise_context_manager.py +++ b/python/sglang/srt/compilation/piecewise_context_manager.py @@ -16,6 +16,7 @@ _in_piecewise_cuda_graph = False _in_pcg_torch_compile = False _pcg_capture_stream = None +_pcg_runtime_shape = None def is_in_piecewise_cuda_graph(): @@ -30,6 +31,10 @@ def get_pcg_capture_stream(): return _pcg_capture_stream +def get_pcg_runtime_shape(): + return _pcg_runtime_shape + + @contextmanager def enable_piecewise_cuda_graph_compile(): global _in_pcg_torch_compile @@ -63,6 +68,17 @@ def set_pcg_capture_stream(stream: torch.cuda.Stream): _pcg_capture_stream = None +@contextmanager +def set_pcg_runtime_shape(runtime_shape: int | None): + global _pcg_runtime_shape + old = _pcg_runtime_shape + _pcg_runtime_shape = runtime_shape + try: + yield + finally: + _pcg_runtime_shape = old + + @dataclass class ForwardContext: def __init__(self):