From 7accbb78b1941770e2f86b8019ea94a0352ac3e2 Mon Sep 17 00:00:00 2001 From: 1145284121 <1145284121@qq.com> Date: Mon, 8 Sep 2025 20:33:52 +0800 Subject: [PATCH 1/2] [Feature] Implement ulysses parallel for multi-gpu inference in CogVideoX1.5-5B-I2V --- cog_inference.py | 45 ++++++++++++++++++++++++++++++- scripts/dist_cog_inference.sh | 12 +++++++++ svg/models/cog/attention.py | 48 +++++++++++++++++++++++++++++++-- svg/models/cog/custom_models.py | 26 ++++++++++++++++++ 4 files changed, 128 insertions(+), 3 deletions(-) create mode 100644 scripts/dist_cog_inference.sh diff --git a/cog_inference.py b/cog_inference.py index f590bd1..a5fd5aa 100644 --- a/cog_inference.py +++ b/cog_inference.py @@ -5,6 +5,7 @@ from svg.models.cog.utils import seed_everything from svg.models.cog.inference import replace_cog_attention, sample_image +import os if __name__ == "__main__": parser = argparse.ArgumentParser(description="A script that sets a random seed.") @@ -29,9 +30,51 @@ required=True, help="Output generated videos" ) - + # Parallel inference parameters + parser.add_argument( + "--use_sequence_parallel", + action="store_true", + help="Enable sequence parallelism for parallel inference" + ) + parser.add_argument( + "--ulysses_degree", + type=int, + default=2, + help="The number of ulysses parallel" + ) + args = parser.parse_args() + if args.use_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import init_distributed_environment, initialize_model_parallel + + # Setup distributed environment + rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + local_rank = int(os.getenv("LOCAL_RANK", 0)) + + assert world_size > 1, f"Sequence parallelism requires world_size > 1, got {world_size}" + assert args.ulysses_degree > 1, "ulysses_degree must be > 1 for sequence parallelism" + assert world_size == args.ulysses_degree, ( + f"Currently only pure Ulysses parallelism is supported. " + f"world_size ({world_size}) must equal ulysses_degree ({args.ulysses_degree})" + ) + + # Initialize PyTorch distributed + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size) + + # Initialize xFuser model parallelism + init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) + initialize_model_parallel( + sequence_parallel_degree=dist.get_world_size(), + ring_degree=1, + ulysses_degree=args.ulysses_degree, + ) + + device = local_rank + seed_everything(args.seed) diff --git a/scripts/dist_cog_inference.sh b/scripts/dist_cog_inference.sh new file mode 100644 index 0000000..7a54ff9 --- /dev/null +++ b/scripts/dist_cog_inference.sh @@ -0,0 +1,12 @@ +# Feel free to change the prompt and the image! +prompt="A bright yellow water taxi glides smoothly across the choppy waters, creating gentle ripples in its wake. The iconic Brooklyn Bridge looms majestically in the background, its intricate web of cables and towering stone arches standing out against the city skyline. The boat, bustling with passengers, offers a lively contrast to the serene, expansive sky dotted with fluffy clouds. As it cruises forward, the vibrant cityscape of New York unfolds, with towering skyscrapers and historic buildings lining the waterfront, capturing the dynamic essence of urban life." +img_path="examples/cog/img/boat.jpg" + +CUDA_VISIBLE_DEVICES=2,3 torchrun --nproc_per_node=2 \ + cog_inference.py \ + --prompt "$prompt" \ + --image_path $img_path \ + --output_path "output-cog-uly2.mp4" \ + --ulysses_degree 2 \ + --use_sequence_parallel + diff --git a/svg/models/cog/attention.py b/svg/models/cog/attention.py index 7b924ad..2f9a03a 100644 --- a/svg/models/cog/attention.py +++ b/svg/models/cog/attention.py @@ -11,6 +11,7 @@ from .placement import sparse_head_placement, hidden_states_placement, ref_sparse_head_placement, ref_hidden_states_placement from .utils import generate_temporal_head_mask_mod, create_block_mask_cached +import torch.distributed as dist try: sys.path.append('svg/kernels/build/') @@ -44,7 +45,11 @@ def rotary_emb(image_rotary_emb, query, key, text_seq_length): key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) return query, key - +try: + from xfuser.core.distributed import get_ulysses_parallel_world_size + from xfuser.model_executor.layers.usp import _ft_c_input_all_to_all, _ft_c_output_all_to_all +except: + pass flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs") torch._dynamo.config.cache_size_limit = 192 * 3 @@ -71,6 +76,11 @@ class CogVideoX_SparseAttn_Processor2_0: def __init__(self, layer_idx): self.layer_idx = layer_idx + + self.use_sp = False + if dist.is_initialized() and get_ulysses_parallel_world_size() > 1: + self.use_sp = True + if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") @@ -232,7 +242,41 @@ def __call__( query, key = rotary_emb(image_rotary_emb, query, key, text_seq_length) # ======================================================================== - hidden_states = self.attention_core_logic(query, key, value, timestep) + if self.use_sp: + # input qkv ulysses all_to_all comm + query_text = query[:,:,:text_seq_length,:] + # Ugly but useful for MMDiT. + # TODO: handle layout inside all_to_all for cleaner code + # for sparse attention,the layout of sequence must be [text_1, text_2, ..., video_1, video_2, ...], + # [text_1, video_1, text_2, video_2, ...] will lead to different attention map + query_latents = query[:,:,text_seq_length:,:] + query_text = _ft_c_input_all_to_all(query_text) + query_latents = _ft_c_input_all_to_all(query_latents) + query = torch.cat([query_text, query_latents], dim=-2) + + key_text = key[:,:,:text_seq_length,:] + key_latents = key[:,:,text_seq_length:,:] + key_text = _ft_c_input_all_to_all(key_text) + key_latents = _ft_c_input_all_to_all(key_latents) + key = torch.cat([key_text, key_latents], dim=-2) + + value_text = value[:,:,:text_seq_length,:] + value_latents = value[:,:,text_seq_length:,:] + value_text = _ft_c_input_all_to_all(value_text) + value_latents = _ft_c_input_all_to_all(value_latents) + value = torch.cat([value_text, value_latents], dim=-2) + + out = self.attention_core_logic(query, key, value, timestep) + + # output o ulysses all_to_all comm + out_text = out[:,:,:get_ulysses_parallel_world_size()*text_seq_length,:] + out_latents = out[:,:,get_ulysses_parallel_world_size()*text_seq_length:,:] + out_text = _ft_c_output_all_to_all(out_text) + out_latents = _ft_c_output_all_to_all(out_latents) + hidden_states = torch.cat([out_text, out_latents], dim=-2) + + else: + hidden_states = self.attention_core_logic(query, key, value, timestep) # ======================================================================== hidden_states = self.get_o(attn, hidden_states, batch_size, head_dim) diff --git a/svg/models/cog/custom_models.py b/svg/models/cog/custom_models.py index 38f3045..08ed757 100644 --- a/svg/models/cog/custom_models.py +++ b/svg/models/cog/custom_models.py @@ -21,6 +21,16 @@ from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXBlock, CogVideoXTransformer3DModel +import torch.distributed as dist + +try: + from xfuser.core.distributed import ( + get_ulysses_parallel_world_size, + get_ulysses_parallel_rank, + get_sp_group + ) +except: + pass logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -121,6 +131,16 @@ def forward( encoder_hidden_states = hidden_states[:, :text_seq_length] hidden_states = hidden_states[:, text_seq_length:] + if dist.is_initialized() and get_ulysses_parallel_world_size() > 1: + encoder_hidden_states = torch.chunk(encoder_hidden_states, get_ulysses_parallel_world_size(), dim=-2)[get_ulysses_parallel_rank()] + # split video latents on dim TS + hidden_states = torch.chunk(hidden_states, get_ulysses_parallel_world_size(), dim=-2)[get_ulysses_parallel_rank()] + # image_rotary_emb should be splited in sequence dim + image_rotary_emb = ( + torch.chunk(image_rotary_emb[0], get_ulysses_parallel_world_size(), dim=-2)[get_ulysses_parallel_rank()], + torch.chunk(image_rotary_emb[1], get_ulysses_parallel_world_size(), dim=-2)[get_ulysses_parallel_rank()], + ) + # 3. Transformer blocks for i, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -156,12 +176,18 @@ def custom_forward(*inputs): # CogVideoX-5B hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) hidden_states = self.norm_final(hidden_states) + if dist.is_initialized() and get_ulysses_parallel_world_size() > 1: + assert text_seq_length % get_ulysses_parallel_world_size() == 0 + text_seq_length = text_seq_length // get_ulysses_parallel_world_size() hidden_states = hidden_states[:, text_seq_length:] # 4. Final block hidden_states = self.norm_out(hidden_states, temb=emb) hidden_states = self.proj_out(hidden_states) + if dist.is_initialized() and get_ulysses_parallel_world_size() > 1: + hidden_states = get_sp_group().all_gather(hidden_states, dim=-2) + # 5. Unpatchify p = self.config.patch_size p_t = self.config.patch_size_t From ee0c23638b6700c11f5d6001244d3f7b032effb8 Mon Sep 17 00:00:00 2001 From: 1145284121 <1145284121@qq.com> Date: Mon, 8 Sep 2025 20:58:57 +0800 Subject: [PATCH 2/2] [Feature] Implement ulysses parallel for multi-gpu inference in HunyuanVideo --- scripts/dist_hyvideo_inference.sh | 27 +++++ svg/models/hyvideo/config.py | 2 +- svg/models/hyvideo/inference.py | 26 ++++- svg/models/hyvideo/modules/custom_models.py | 122 ++++++++++++++++---- svg/models/hyvideo/modules/models.py | 26 ++++- 5 files changed, 171 insertions(+), 32 deletions(-) create mode 100644 scripts/dist_hyvideo_inference.sh diff --git a/scripts/dist_hyvideo_inference.sh b/scripts/dist_hyvideo_inference.sh new file mode 100644 index 0000000..66475ce --- /dev/null +++ b/scripts/dist_hyvideo_inference.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# Description: This script demonstrates multi-gpu video inference using the HunyuanVideo model + +# TFP Values: +# Set the following values to control the percentage of timesteps using dense attention: +# 35% → 0.07, 30% → 0.055, 25% → 0.04, 20% → 0.033, 15% → 0.02, 10% → 0.015 +first_times_fp=0.055 +first_layers_fp=0.025 + +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \ +hyvideo_inference.py \ + --video-size 720 1280 \ + --video-length 129 \ + --infer-steps 10 \ + --seed 0 \ + --prompt "A cat walks on the grass, realistic style." \ + --embedded-cfg-scale 6.0 \ + --flow-shift 7.0 \ + --flow-reverse \ + --output_path ./hunyuan_output_svg_step10_sp4.mp4 \ + --pattern "SVG" \ + --num_sampled_rows 64 \ + --sparsity 0.2 \ + --first_times_fp $first_times_fp \ + --first_layers_fp $first_layers_fp \ + --ulysses-degree 4 \ + --record_attention diff --git a/svg/models/hyvideo/config.py b/svg/models/hyvideo/config.py index 8d12a4e..b9abf2d 100644 --- a/svg/models/hyvideo/config.py +++ b/svg/models/hyvideo/config.py @@ -377,7 +377,7 @@ def add_parallel_args(parser: argparse.ArgumentParser): "--ring-degree", type=int, default=1, - help="Ulysses degree.", + help="Ring Attention degree.", ) return parser diff --git a/svg/models/hyvideo/inference.py b/svg/models/hyvideo/inference.py index 23d5b72..27821cc 100644 --- a/svg/models/hyvideo/inference.py +++ b/svg/models/hyvideo/inference.py @@ -36,7 +36,19 @@ initialize_model_parallel = None init_distributed_environment = None - +""" + NOTE: This parallelization function is not used in the current implementation. + + This function implements sequence parallelism by splitting the input tensor along either + the height or width dimension. While this approach offers the advantage of parallelizing + the patch embedding process, it introduces significant complexity when adapting the data + layout for the SVG flex attention kernel. + + Current approach: We implement parallelism directly in HYVideoDiffusionTransformer.forward() + (located in svg/models/hyvideo/modules/models.py) by splitting along the flattened + (temporal * height * width) dimension after patchify, which provides better compatibility + with the attention mechanism and simpler implementation. +""" def parallelize_transformer(pipe): transformer = pipe.transformer original_forward = transformer.forward @@ -173,7 +185,9 @@ def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs): ring_degree=args.ring_degree, ulysses_degree=args.ulysses_degree, ) - device = torch.device(f"cuda:{os.environ['LOCAL_RANK']}") + local_rank = int(os.getenv("LOCAL_RANK", 0)) + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(local_rank) else: if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" @@ -405,8 +419,8 @@ def __init__( ) self.default_negative_prompt = NEGATIVE_PROMPT - if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1: - parallelize_transformer(self.pipeline) + # if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1: + # parallelize_transformer(self.pipeline) def load_diffusion_pipeline( self, @@ -722,7 +736,9 @@ def predict( n_tokens: {n_tokens} flow_shift: {flow_shift} embedded_guidance_scale: {embedded_guidance_scale}""" - logger.debug(debug_str) + if (not torch.distributed.is_initialized()) \ + or get_sequence_parallel_rank() == 0: + logger.debug(debug_str) # ======================================================================== # Pipeline inference diff --git a/svg/models/hyvideo/modules/custom_models.py b/svg/models/hyvideo/modules/custom_models.py index e669479..2bbc28e 100644 --- a/svg/models/hyvideo/modules/custom_models.py +++ b/svg/models/hyvideo/modules/custom_models.py @@ -18,6 +18,13 @@ from .modulate_layers import ModulateDiT, modulate, apply_gate from .token_refiner import SingleTokenRefiner from .models import MMDoubleStreamBlock, MMSingleStreamBlock +import torch.distributed as dist + +try: + from xfuser.core.distributed import get_ulysses_parallel_world_size + from xfuser.model_executor.layers.usp import _ft_c_input_all_to_all, _ft_c_output_all_to_all +except: + pass try: import sys @@ -175,6 +182,10 @@ def forward( txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) # Run actual attention. + if dist.is_initialized() and get_ulysses_parallel_world_size() > 1: + # all_to_all comm : gather in sequence dim and scatter in num_heads dim + img_q, txt_q, img_k, txt_k, img_v, txt_v = hunyuan_uly_comm_before_attention(img_q, txt_q, img_k, txt_k, img_v, txt_v) + q = torch.cat((img_q, txt_q), dim=1) k = torch.cat((img_k, txt_k), dim=1) v = torch.cat((img_v, txt_v), dim=1) @@ -183,36 +194,29 @@ def forward( ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}" # attention computation start - if not self.hybrid_seq_parallel_attn: - attn = attention( - q, - k, - v, - mode="sparse", - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv, - batch_size=img_k.shape[0], - timestep=timestep, - layer_idx=self.layer_idx - ) - else: - attn = parallel_attention( - self.hybrid_seq_parallel_attn, - q, - k, - v, - img_q_len=img_q.shape[1], - img_kv_len=img_k.shape[1], - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv - ) + attn = attention( + q, + k, + v, + mode="sparse", + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + batch_size=img_k.shape[0], + timestep=timestep, + layer_idx=self.layer_idx + ) # attention computation end img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] + if dist.is_initialized() and get_ulysses_parallel_world_size() > 1: + img_attn, txt_attn = attn[:, : img.shape[1] * get_ulysses_parallel_world_size()], attn[:, img.shape[1]* get_ulysses_parallel_world_size() :] + # all_to_all comm : gather in num_heads dim and scatter in sequence dim + img_attn, txt_attn = hunyuan_uly_comm_after_attention(img_attn, txt_attn) + # Calculate the img bloks. img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate) img = img + apply_gate( @@ -273,6 +277,19 @@ def forward( cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1 ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}" + if dist.is_initialized() and get_ulysses_parallel_world_size() > 1: + # Ulysses comm + img_q, txt_q = q[:, :-txt_len ], q[:, -txt_len:] + img_k, txt_k = k[:, :-txt_len ], k[:, -txt_len:] + img_v, txt_v = v[:, :-txt_len ], v[:, -txt_len:] + + # all_to_all comm : gather in sequence dim and scatter in num_heads dim + img_q, txt_q, img_k, txt_k, img_v, txt_v = hunyuan_uly_comm_before_attention(img_q, txt_q, img_k, txt_k, img_v, txt_v) + + q = torch.cat((img_q, txt_q), dim=1) + k = torch.cat((img_k, txt_k), dim=1) + v = torch.cat((img_v, txt_v), dim=1) + # attention computation start attn = attention( q, @@ -288,12 +305,67 @@ def forward( layer_idx=self.layer_idx ) # attention computation end + if dist.is_initialized() and get_ulysses_parallel_world_size() > 1: + txt_len *= get_ulysses_parallel_world_size() + img_attn, txt_attn = attn[:, :-txt_len ], attn[:, -txt_len:] + # all_to_all comm : gather in num_heads dim and scatter in sequence dim + img_attn, txt_attn = hunyuan_uly_comm_after_attention(img_attn, txt_attn) + attn = torch.cat((img_attn, txt_attn), dim=1) # Compute activation in mlp stream, cat again and run second linear layer. output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) return x + apply_gate(output, gate=mod_gate) +""" + NOTE: The complexity of this parallel implementation stems from the requirement of current + svg kernel that the sequence layout must be [video_1, video_2, ..., text_1, text_2, ...]. + Directly applying all_to_all to the complete MMDiT sequence would result in + [video_1, text_1, video_2, text_2, ...], which would produce different attention maps. + + For efficiency, these rearrange operations should ideally be merged with the + pre_attn_layout and post_attn_layout functions in svg/models/hyvideo/modules/attention.py. + However, explicit rearrange operations are used here to maintain code readability. + + A more elegant implementation could utilize videosys library's all_to_all_with_pad + function, which allows flexible selection of scatter_dim and gather_dim. The current + implementation follows xdit's approach for consistency. +""" +def hunyuan_uly_comm_before_attention(query_video, query_text, key_video, key_text, value_video, value_text): + query_text = rearrange(query_text, "b s h d" " -> b h s d").contiguous() + query_video = rearrange(query_video, "b s h d" " -> b h s d").contiguous() + query_text = _ft_c_input_all_to_all(query_text) + query_video = _ft_c_input_all_to_all(query_video) + query_text = rearrange(query_text, "b h s d -> b s h d") + query_video = rearrange(query_video, "b h s d -> b s h d") + + key_text = rearrange(key_text, "b s h d" " -> b h s d").contiguous() + key_video = rearrange(key_video, "b s h d" " -> b h s d").contiguous() + key_text = _ft_c_input_all_to_all(key_text) + key_video = _ft_c_input_all_to_all(key_video) + key_text = rearrange(key_text, "b h s d -> b s h d") + key_video = rearrange(key_video, "b h s d -> b s h d") + + value_text = rearrange(value_text, "b s h d" " -> b h s d").contiguous() + value_video = rearrange(value_video, "b s h d" " -> b h s d").contiguous() + value_text = _ft_c_input_all_to_all(value_text) + value_video = _ft_c_input_all_to_all(value_video) + value_text = rearrange(value_text, "b h s d -> b s h d") + value_video = rearrange(value_video, "b h s d -> b s h d") + + return query_video, query_text, key_video, key_text, value_video, value_text + + +def hunyuan_uly_comm_after_attention(out_video, out_text): + out_text = rearrange(out_text, "b s (h d)" " -> b h s d",d=128).contiguous() + out_video = rearrange(out_video, "b s (h d)" " -> b h s d",d=128).contiguous() + out_text = _ft_c_output_all_to_all(out_text) + out_video = _ft_c_output_all_to_all(out_video) + out_text = rearrange(out_text, "b h s d" " -> b s (h d)") + out_video = rearrange(out_video, "b h s d" " -> b s (h d)") + + return out_video, out_text + def replace_sparse_forward(): MMDoubleStreamBlock.forward = MMDoubleStreamBlock_Sparse.forward diff --git a/svg/models/hyvideo/modules/models.py b/svg/models/hyvideo/modules/models.py index 88f7c0e..4fcf1df 100644 --- a/svg/models/hyvideo/modules/models.py +++ b/svg/models/hyvideo/modules/models.py @@ -17,7 +17,12 @@ from .mlp_layers import MLP, MLPEmbedder, FinalLayer from .modulate_layers import ModulateDiT, modulate, apply_gate from .token_refiner import SingleTokenRefiner +import torch.distributed as dist +try: + from xfuser.core.distributed import get_ulysses_parallel_world_size, get_ulysses_parallel_rank, get_sp_group +except: + pass class MMDoubleStreamBlock(nn.Module): """ @@ -681,7 +686,21 @@ def forward( cu_seqlens_kv = cu_seqlens_q max_seqlen_q = img_seq_len + txt_seq_len max_seqlen_kv = max_seqlen_q - + # NOTE: In Ulysses sequence parallelism, although sequences are distributed across ranks for processing, + # the attention computation operates on the full sequence dimension through all-to-all communication. + # Therefore, cu_seqlens and max_seqlen for flash attention must be calculated using the full sequence lengths + + if dist.is_initialized() and get_ulysses_parallel_world_size() > 1: + # split both img and txt in sequence dim (temproal * height * width of the latents) + img = torch.chunk(img, get_ulysses_parallel_world_size(),dim=-2)[get_ulysses_parallel_rank()] + txt = torch.chunk(txt, get_ulysses_parallel_world_size(),dim=-2)[get_ulysses_parallel_rank()] + # freqs should be splited in sequence dim + freqs_cos = torch.chunk(freqs_cos, get_ulysses_parallel_world_size(),dim=-2)[get_ulysses_parallel_rank()] + freqs_sin = torch.chunk(freqs_sin, get_ulysses_parallel_world_size(),dim=-2)[get_ulysses_parallel_rank()] + # txt_seq_len in current rank should be divided + txt_seq_len = txt_seq_len // get_ulysses_parallel_world_size() + img_seq_len = img_seq_len // get_ulysses_parallel_world_size() + if hasattr(self, 'sparse_args'): if getattr(self.sparse_args, 'pattern', None) == "SVG": freqs_cos = freqs_cos.to(x.device).to(torch.float32) @@ -727,6 +746,11 @@ def forward( # ---------------------------- Final layer ------------------------------ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + if dist.is_initialized() and get_ulysses_parallel_world_size() > 1: + # gather img in sequence dim (temproal * height * width of the latents) + img = img.contiguous() + img = get_sp_group().all_gather(img, dim=-2) + img = self.unpatchify(img, tt, th, tw) if return_dict: out["x"] = img