Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion cog_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from svg.models.cog.utils import seed_everything
from svg.models.cog.inference import replace_cog_attention, sample_image
import os

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="A script that sets a random seed.")
Expand All @@ -29,9 +30,51 @@
required=True,
help="Output generated videos"
)

# Parallel inference parameters
parser.add_argument(
"--use_sequence_parallel",
action="store_true",
help="Enable sequence parallelism for parallel inference"
)
parser.add_argument(
"--ulysses_degree",
type=int,
default=2,
help="The number of ulysses parallel"
)

args = parser.parse_args()

if args.use_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import init_distributed_environment, initialize_model_parallel

# Setup distributed environment
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
local_rank = int(os.getenv("LOCAL_RANK", 0))

assert world_size > 1, f"Sequence parallelism requires world_size > 1, got {world_size}"
assert args.ulysses_degree > 1, "ulysses_degree must be > 1 for sequence parallelism"
assert world_size == args.ulysses_degree, (
f"Currently only pure Ulysses parallelism is supported. "
f"world_size ({world_size}) must equal ulysses_degree ({args.ulysses_degree})"
)

# Initialize PyTorch distributed
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size)

# Initialize xFuser model parallelism
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=1,
ulysses_degree=args.ulysses_degree,
)

device = local_rank

seed_everything(args.seed)


Expand Down
12 changes: 12 additions & 0 deletions scripts/dist_cog_inference.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Feel free to change the prompt and the image!
prompt="A bright yellow water taxi glides smoothly across the choppy waters, creating gentle ripples in its wake. The iconic Brooklyn Bridge looms majestically in the background, its intricate web of cables and towering stone arches standing out against the city skyline. The boat, bustling with passengers, offers a lively contrast to the serene, expansive sky dotted with fluffy clouds. As it cruises forward, the vibrant cityscape of New York unfolds, with towering skyscrapers and historic buildings lining the waterfront, capturing the dynamic essence of urban life."
img_path="examples/cog/img/boat.jpg"

CUDA_VISIBLE_DEVICES=2,3 torchrun --nproc_per_node=2 \
cog_inference.py \
--prompt "$prompt" \
--image_path $img_path \
--output_path "output-cog-uly2.mp4" \
--ulysses_degree 2 \
--use_sequence_parallel

27 changes: 27 additions & 0 deletions scripts/dist_hyvideo_inference.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/bash
# Description: This script demonstrates multi-gpu video inference using the HunyuanVideo model

# TFP Values:
# Set the following values to control the percentage of timesteps using dense attention:
# 35% → 0.07, 30% → 0.055, 25% → 0.04, 20% → 0.033, 15% → 0.02, 10% → 0.015
first_times_fp=0.055
first_layers_fp=0.025

CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
hyvideo_inference.py \
--video-size 720 1280 \
--video-length 129 \
--infer-steps 10 \
--seed 0 \
--prompt "A cat walks on the grass, realistic style." \
--embedded-cfg-scale 6.0 \
--flow-shift 7.0 \
--flow-reverse \
--output_path ./hunyuan_output_svg_step10_sp4.mp4 \
--pattern "SVG" \
--num_sampled_rows 64 \
--sparsity 0.2 \
--first_times_fp $first_times_fp \
--first_layers_fp $first_layers_fp \
--ulysses-degree 4 \
--record_attention
48 changes: 46 additions & 2 deletions svg/models/cog/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from .placement import sparse_head_placement, hidden_states_placement, ref_sparse_head_placement, ref_hidden_states_placement
from .utils import generate_temporal_head_mask_mod, create_block_mask_cached
import torch.distributed as dist

try:
sys.path.append('svg/kernels/build/')
Expand Down Expand Up @@ -44,7 +45,11 @@ def rotary_emb(image_rotary_emb, query, key, text_seq_length):
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
return query, key


try:
from xfuser.core.distributed import get_ulysses_parallel_world_size
from xfuser.model_executor.layers.usp import _ft_c_input_all_to_all, _ft_c_output_all_to_all
except:
pass

flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs")
torch._dynamo.config.cache_size_limit = 192 * 3
Expand All @@ -71,6 +76,11 @@ class CogVideoX_SparseAttn_Processor2_0:

def __init__(self, layer_idx):
self.layer_idx = layer_idx

self.use_sp = False
if dist.is_initialized() and get_ulysses_parallel_world_size() > 1:
self.use_sp = True

if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

Expand Down Expand Up @@ -232,7 +242,41 @@ def __call__(
query, key = rotary_emb(image_rotary_emb, query, key, text_seq_length)

# ========================================================================
hidden_states = self.attention_core_logic(query, key, value, timestep)
if self.use_sp:
# input qkv ulysses all_to_all comm
query_text = query[:,:,:text_seq_length,:]
# Ugly but useful for MMDiT.
# TODO: handle layout inside all_to_all for cleaner code
# for sparse attention,the layout of sequence must be [text_1, text_2, ..., video_1, video_2, ...],
# [text_1, video_1, text_2, video_2, ...] will lead to different attention map
query_latents = query[:,:,text_seq_length:,:]
query_text = _ft_c_input_all_to_all(query_text)
query_latents = _ft_c_input_all_to_all(query_latents)
query = torch.cat([query_text, query_latents], dim=-2)

key_text = key[:,:,:text_seq_length,:]
key_latents = key[:,:,text_seq_length:,:]
key_text = _ft_c_input_all_to_all(key_text)
key_latents = _ft_c_input_all_to_all(key_latents)
key = torch.cat([key_text, key_latents], dim=-2)

value_text = value[:,:,:text_seq_length,:]
value_latents = value[:,:,text_seq_length:,:]
value_text = _ft_c_input_all_to_all(value_text)
value_latents = _ft_c_input_all_to_all(value_latents)
value = torch.cat([value_text, value_latents], dim=-2)

out = self.attention_core_logic(query, key, value, timestep)

# output o ulysses all_to_all comm
out_text = out[:,:,:get_ulysses_parallel_world_size()*text_seq_length,:]
out_latents = out[:,:,get_ulysses_parallel_world_size()*text_seq_length:,:]
out_text = _ft_c_output_all_to_all(out_text)
out_latents = _ft_c_output_all_to_all(out_latents)
hidden_states = torch.cat([out_text, out_latents], dim=-2)

else:
hidden_states = self.attention_core_logic(query, key, value, timestep)
# ========================================================================

hidden_states = self.get_o(attn, hidden_states, batch_size, head_dim)
Expand Down
26 changes: 26 additions & 0 deletions svg/models/cog/custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXBlock, CogVideoXTransformer3DModel
import torch.distributed as dist

try:
from xfuser.core.distributed import (
get_ulysses_parallel_world_size,
get_ulysses_parallel_rank,
get_sp_group
)
except:
pass

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -121,6 +131,16 @@ def forward(
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]

if dist.is_initialized() and get_ulysses_parallel_world_size() > 1:
encoder_hidden_states = torch.chunk(encoder_hidden_states, get_ulysses_parallel_world_size(), dim=-2)[get_ulysses_parallel_rank()]
# split video latents on dim TS
hidden_states = torch.chunk(hidden_states, get_ulysses_parallel_world_size(), dim=-2)[get_ulysses_parallel_rank()]
# image_rotary_emb should be splited in sequence dim
image_rotary_emb = (
torch.chunk(image_rotary_emb[0], get_ulysses_parallel_world_size(), dim=-2)[get_ulysses_parallel_rank()],
torch.chunk(image_rotary_emb[1], get_ulysses_parallel_world_size(), dim=-2)[get_ulysses_parallel_rank()],
)

# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
Expand Down Expand Up @@ -156,12 +176,18 @@ def custom_forward(*inputs):
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
if dist.is_initialized() and get_ulysses_parallel_world_size() > 1:
assert text_seq_length % get_ulysses_parallel_world_size() == 0
text_seq_length = text_seq_length // get_ulysses_parallel_world_size()
hidden_states = hidden_states[:, text_seq_length:]

# 4. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)

if dist.is_initialized() and get_ulysses_parallel_world_size() > 1:
hidden_states = get_sp_group().all_gather(hidden_states, dim=-2)

# 5. Unpatchify
p = self.config.patch_size
p_t = self.config.patch_size_t
Expand Down
2 changes: 1 addition & 1 deletion svg/models/hyvideo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def add_parallel_args(parser: argparse.ArgumentParser):
"--ring-degree",
type=int,
default=1,
help="Ulysses degree.",
help="Ring Attention degree.",
)

return parser
Expand Down
26 changes: 21 additions & 5 deletions svg/models/hyvideo/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,19 @@
initialize_model_parallel = None
init_distributed_environment = None


"""
NOTE: This parallelization function is not used in the current implementation.

This function implements sequence parallelism by splitting the input tensor along either
the height or width dimension. While this approach offers the advantage of parallelizing
the patch embedding process, it introduces significant complexity when adapting the data
layout for the SVG flex attention kernel.

Current approach: We implement parallelism directly in HYVideoDiffusionTransformer.forward()
(located in svg/models/hyvideo/modules/models.py) by splitting along the flattened
(temporal * height * width) dimension after patchify, which provides better compatibility
with the attention mechanism and simpler implementation.
"""
def parallelize_transformer(pipe):
transformer = pipe.transformer
original_forward = transformer.forward
Expand Down Expand Up @@ -173,7 +185,9 @@ def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs):
ring_degree=args.ring_degree,
ulysses_degree=args.ulysses_degree,
)
device = torch.device(f"cuda:{os.environ['LOCAL_RANK']}")
local_rank = int(os.getenv("LOCAL_RANK", 0))
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(local_rank)
else:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -405,8 +419,8 @@ def __init__(
)

self.default_negative_prompt = NEGATIVE_PROMPT
if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1:
parallelize_transformer(self.pipeline)
# if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1:
# parallelize_transformer(self.pipeline)

def load_diffusion_pipeline(
self,
Expand Down Expand Up @@ -722,7 +736,9 @@ def predict(
n_tokens: {n_tokens}
flow_shift: {flow_shift}
embedded_guidance_scale: {embedded_guidance_scale}"""
logger.debug(debug_str)
if (not torch.distributed.is_initialized()) \
or get_sequence_parallel_rank() == 0:
logger.debug(debug_str)

# ========================================================================
# Pipeline inference
Expand Down
Loading