Skip to content

Commit ac397ae

Browse files
authored
[Diffusion][Perf] Remove Redundant Communication Cost by Refining SP Hook Design (#1275)
Signed-off-by: mxuax <mxuax@connect.ust.hk> Signed-off-by: XU Mingshi <91017482+mxuax@users.noreply.github.com>
1 parent d3ea943 commit ac397ae

File tree

3 files changed

+52
-5
lines changed

3 files changed

+52
-5
lines changed

vllm_omni/diffusion/attention/layer.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
1515
from vllm_omni.diffusion.attention.backends.sdpa import SDPABackend
1616
from vllm_omni.diffusion.attention.parallel import build_parallel_attention_strategy
17+
from vllm_omni.diffusion.attention.parallel.base import NoParallelAttention
1718
from vllm_omni.diffusion.attention.parallel.ring import RingParallelAttention
1819
from vllm_omni.diffusion.attention.selector import get_attn_backend
1920
from vllm_omni.diffusion.distributed.parallel_state import get_sp_group
20-
from vllm_omni.diffusion.forward_context import get_forward_context
21+
from vllm_omni.diffusion.forward_context import get_forward_context, is_forward_context_available
2122

2223
logger = init_logger(__name__)
2324

@@ -87,6 +88,21 @@ def __init__(
8788
gather_idx=gather_idx,
8889
use_sync=use_sync,
8990
)
91+
# Fallback strategy when SP is not active (outside sharded regions)
92+
self._no_parallel_strategy = NoParallelAttention()
93+
94+
def _get_active_parallel_strategy(self):
95+
"""Get the parallel strategy based on current SP active state.
96+
97+
Returns NoParallelAttention if we're outside an SP sharded region
98+
(e.g., in noise_refiner/context_refiner before unified_prepare in Z-Image).
99+
This avoids unnecessary SP communication for layers not covered by _sp_plan.
100+
"""
101+
if is_forward_context_available():
102+
ctx = get_forward_context()
103+
if not ctx.sp_active:
104+
return self._no_parallel_strategy
105+
return self.parallel_strategy
90106

91107
def forward(
92108
self,
@@ -95,20 +111,23 @@ def forward(
95111
value: torch.Tensor,
96112
attn_metadata: AttentionMetadata = None,
97113
) -> torch.Tensor:
114+
# Get the appropriate parallel strategy based on SP active state
115+
strategy = self._get_active_parallel_strategy()
116+
98117
# 1. Prepare inputs (Communication / Resharding)
99118
# For Ulysses: AllToAll Q/K/V; Slicing joint_q/k/v
100119
# For Ring: Concat joint_q
101-
query, key, value, attn_metadata, ctx = self.parallel_strategy.pre_attention(query, key, value, attn_metadata)
120+
query, key, value, attn_metadata, ctx = strategy.pre_attention(query, key, value, attn_metadata)
102121

103122
# 2. Kernel Execution (Computation)
104-
if self.use_ring:
123+
if self.use_ring and strategy is not self._no_parallel_strategy:
105124
out = self._run_ring_attention(query, key, value, attn_metadata)
106125
else:
107126
out = self._run_local_attention(query, key, value, attn_metadata)
108127

109128
# 3. Post-processing (Reverse Communication)
110129
# For Ulysses: AllToAll Output, and AllGather Joint Output
111-
out = self.parallel_strategy.post_attention(out, ctx)
130+
out = strategy.post_attention(out, ctx)
112131

113132
return out
114133

vllm_omni/diffusion/forward_context.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,16 @@ class ForwardContext:
3131
# Original sequence length before padding (for removing padding in gather)
3232
sp_original_seq_len: int | None = None
3333

34+
# SP active scope tracking
35+
# Tracks the depth of SP sharding - incremented on shard, decremented on gather
36+
# Used by attention layers to determine if SP communication should be enabled
37+
_sp_shard_depth: int = 0
38+
39+
@property
40+
def sp_active(self) -> bool:
41+
"""Returns True when inside an SP sharded region (between shard and gather)."""
42+
return self._sp_shard_depth > 0
43+
3444
def __post_init__(self):
3545
pass
3646

vllm_omni/diffusion/hooks/sequence_parallel.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ def pre_forward(self, module: nn.Module, *args: Any, **kwargs: Any) -> tuple[tup
235235

236236
def post_forward(self, module: nn.Module, output: Any) -> Any:
237237
"""Shard outputs for split_output=True entries."""
238+
from vllm_omni.diffusion.forward_context import get_forward_context, is_forward_context_available
239+
238240
is_tensor = isinstance(output, torch.Tensor)
239241
is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)
240242

@@ -243,6 +245,7 @@ def post_forward(self, module: nn.Module, output: Any) -> Any:
243245
return output
244246

245247
output_list = [output] if is_tensor else list(output)
248+
actually_sharded = False
246249

247250
for index, spm in self.metadata.items():
248251
if not isinstance(index, int):
@@ -252,7 +255,14 @@ def post_forward(self, module: nn.Module, output: Any) -> Any:
252255
if index >= len(output_list):
253256
raise ValueError(f"Index {index} out of bounds for output of length {len(output_list)}.")
254257

255-
output_list[index] = self._prepare_sp_input(output_list[index], spm, self._last_args, self._last_kwargs)
258+
original = output_list[index]
259+
output_list[index] = self._prepare_sp_input(original, spm, self._last_args, self._last_kwargs)
260+
if output_list[index] is not original:
261+
actually_sharded = True
262+
263+
# Mark SP as active only if at least one tensor was actually sharded
264+
if actually_sharded and is_forward_context_available():
265+
get_forward_context()._sp_shard_depth += 1
256266

257267
return output_list[0] if is_tensor else type(output)(output_list)
258268

@@ -445,6 +455,8 @@ def post_forward(self, module: nn.Module, output: Any) -> Any:
445455
ctx = get_forward_context()
446456
original_seq_len = ctx.sp_original_seq_len
447457

458+
actually_gathered = False
459+
448460
for i, spm in enumerate(self.metadata):
449461
if spm is None:
450462
continue
@@ -465,6 +477,12 @@ def post_forward(self, module: nn.Module, output: Any) -> Any:
465477
logger.debug(f"Removed padding: gathered shape {gathered.shape} (original_seq_len={original_seq_len})")
466478

467479
output[i] = gathered
480+
actually_gathered = True
481+
482+
# Mark SP as inactive only if at least one tensor was actually gathered
483+
if actually_gathered and is_forward_context_available():
484+
ctx = get_forward_context()
485+
ctx._sp_shard_depth = max(0, ctx._sp_shard_depth - 1)
468486

469487
return output[0] if is_tensor else type(output)(output)
470488

0 commit comments

Comments
 (0)