@@ -235,6 +235,8 @@ def pre_forward(self, module: nn.Module, *args: Any, **kwargs: Any) -> tuple[tup
235235
236236 def post_forward (self , module : nn .Module , output : Any ) -> Any :
237237 """Shard outputs for split_output=True entries."""
238+ from vllm_omni .diffusion .forward_context import get_forward_context , is_forward_context_available
239+
238240 is_tensor = isinstance (output , torch .Tensor )
239241 is_tensor_list = isinstance (output , (list , tuple )) and all (isinstance (x , torch .Tensor ) for x in output )
240242
@@ -243,6 +245,7 @@ def post_forward(self, module: nn.Module, output: Any) -> Any:
243245 return output
244246
245247 output_list = [output ] if is_tensor else list (output )
248+ actually_sharded = False
246249
247250 for index , spm in self .metadata .items ():
248251 if not isinstance (index , int ):
@@ -252,7 +255,14 @@ def post_forward(self, module: nn.Module, output: Any) -> Any:
252255 if index >= len (output_list ):
253256 raise ValueError (f"Index { index } out of bounds for output of length { len (output_list )} ." )
254257
255- output_list [index ] = self ._prepare_sp_input (output_list [index ], spm , self ._last_args , self ._last_kwargs )
258+ original = output_list [index ]
259+ output_list [index ] = self ._prepare_sp_input (original , spm , self ._last_args , self ._last_kwargs )
260+ if output_list [index ] is not original :
261+ actually_sharded = True
262+
263+ # Mark SP as active only if at least one tensor was actually sharded
264+ if actually_sharded and is_forward_context_available ():
265+ get_forward_context ()._sp_shard_depth += 1
256266
257267 return output_list [0 ] if is_tensor else type (output )(output_list )
258268
@@ -445,6 +455,8 @@ def post_forward(self, module: nn.Module, output: Any) -> Any:
445455 ctx = get_forward_context ()
446456 original_seq_len = ctx .sp_original_seq_len
447457
458+ actually_gathered = False
459+
448460 for i , spm in enumerate (self .metadata ):
449461 if spm is None :
450462 continue
@@ -465,6 +477,12 @@ def post_forward(self, module: nn.Module, output: Any) -> Any:
465477 logger .debug (f"Removed padding: gathered shape { gathered .shape } (original_seq_len={ original_seq_len } )" )
466478
467479 output [i ] = gathered
480+ actually_gathered = True
481+
482+ # Mark SP as inactive only if at least one tensor was actually gathered
483+ if actually_gathered and is_forward_context_available ():
484+ ctx = get_forward_context ()
485+ ctx ._sp_shard_depth = max (0 , ctx ._sp_shard_depth - 1 )
468486
469487 return output [0 ] if is_tensor else type (output )(output )
470488
0 commit comments