diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index c8ced1b747..a52ee783ea 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -30,6 +30,8 @@ th { |`ZImagePipeline` | Z-Image | `Tongyi-MAI/Z-Image-Turbo` | | `WanPipeline` | Wan2.2-T2V, Wan2.2-TI2V | `Wan-AI/Wan2.2-T2V-A14B-Diffusers`, `Wan-AI/Wan2.2-TI2V-5B-Diffusers` | | `WanImageToVideoPipeline` | Wan2.2-I2V | `Wan-AI/Wan2.2-I2V-A14B-Diffusers` | +| `LTX2Pipeline` | LTX-2-T2V | `Lightricks/LTX-2` | +| `LTX2ImageToVideoPipeline` | LTX-2-I2V | `Lightricks/LTX-2` | | `OvisImagePipeline` | Ovis-Image | `OvisAI/Ovis-Image` | |`LongcatImagePipeline` | LongCat-Image | `meituan-longcat/LongCat-Image` | |`LongCatImageEditPipeline` | LongCat-Image-Edit | `meituan-longcat/LongCat-Image-Edit` | diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index b0225e2ffb..5a74714c7e 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -54,6 +54,7 @@ The following table shows which models are currently supported by parallelism me | Model | Model Identifier | Ulysses-SP | Ring-Attention | Tensor-Parallel | HSDP | |-------|------------------|:----------:|:--------------:|:---------------:|:----:| | **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ✅ | ✅ | ✅ | ✅ | +| **LTX-2** | `Lightricks/LTX-2` | ✅ | ✅ | ✅ | ❌ | ### Tensor Parallelism diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 9fd14a2b07..beb8dba1cd 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -74,6 +74,7 @@ The following table shows which models are currently supported by each accelerat | Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention | CFG-Parallel | HSDP | |-------|------------------|:--------:|:---------:|:----------:|:--------------:|:------------:|:----:| | **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | +| **LTX-2** | `Lightricks/LTX-2` | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ### Quantization diff --git a/examples/offline_inference/image_to_video/README.md b/examples/offline_inference/image_to_video/README.md index 2643084f1a..2692c76df2 100644 --- a/examples/offline_inference/image_to_video/README.md +++ b/examples/offline_inference/image_to_video/README.md @@ -65,6 +65,7 @@ Key arguments: - `--vae-use-slicing`: Enable VAE slicing for memory optimization. - `--vae-use-tiling`: Enable VAE tiling for memory optimization. - `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel). +- `--tensor-parallel-size`: tensor parallel size (effective for models that support TP, e.g. LTX2). - `--enable-cpu-offload`: enable CPU offloading for diffusion models. - `--use-hsdp`: Enable Hybrid Sharded Data Parallel to shard model weights across GPUs. - `--hsdp-shard-size`: Number of GPUs to shard model weights across within each replica group. -1 (default) auto-calculates as world_size / replicate_size. diff --git a/examples/offline_inference/image_to_video/image_to_video.py b/examples/offline_inference/image_to_video/image_to_video.py index a491ddfc33..abbe7c717b 100644 --- a/examples/offline_inference/image_to_video/image_to_video.py +++ b/examples/offline_inference/image_to_video/image_to_video.py @@ -2,11 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -Image-to-Video generation example using Wan2.2 I2V/TI2V models. +Image-to-Video generation example using Wan2.2 I2V/TI2V models or LTX2. Supports: - Wan2.2-I2V-A14B-Diffusers: MoE model with CLIP image encoder - Wan2.2-TI2V-5B-Diffusers: Unified T2V+I2V model (dense 5B) +- LTX2 image-to-video pipeline Usage: # I2V-A14B (MoE) @@ -16,6 +17,13 @@ # TI2V-5B (unified) python image_to_video.py --model Wan-AI/Wan2.2-TI2V-5B-Diffusers \ --image input.jpg --prompt "A cat playing with yarn" + + # LTX2 image-to-video + python image_to_video.py --model /path/to/LTX-2 \ + --model_class_name LTX2ImageToVideoPipeline \ + --image input.jpg --prompt "A cinematic dolly shot of a boat" \ + --num_frames 121 --num_inference_steps 40 --guidance_scale 4.0 \ + --frame_rate 24 --fps 24 --output ltx2_i2v.mp4 """ import argparse @@ -34,12 +42,17 @@ def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Generate a video from an image with Wan2.2 I2V/TI2V.") + parser = argparse.ArgumentParser(description="Generate a video from an image with Wan2.2 or LTX2.") parser.add_argument( "--model", default="Wan-AI/Wan2.2-I2V-A14B-Diffusers", help="Diffusers Wan2.2 I2V model ID or local path.", ) + parser.add_argument( + "--model_class_name", + default=None, + help="Override model class name (e.g., LTX2ImageToVideoPipeline).", + ) parser.add_argument("--image", required=True, help="Path to input image.") parser.add_argument("--prompt", default="", help="Text prompt describing the desired motion.") parser.add_argument("--negative-prompt", default="", help="Negative prompt.") @@ -55,11 +68,17 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--num-frames", type=int, default=81, help="Number of frames.") parser.add_argument("--num-inference-steps", type=int, default=50, help="Sampling steps.") parser.add_argument("--boundary-ratio", type=float, default=0.875, help="Boundary split ratio for MoE models.") + parser.add_argument( + "--frame-rate", + type=float, + default=None, + help="Optional generation frame rate (used by models like LTX2). Defaults to --fps.", + ) parser.add_argument( "--flow-shift", type=float, default=5.0, help="Scheduler flow_shift (5.0 for 720p, 12.0 for 480p)." ) parser.add_argument("--output", type=str, default="i2v_output.mp4", help="Path to save the video (mp4).") - parser.add_argument("--fps", type=int, default=16, help="Frames per second for the output video.") + parser.add_argument("--fps", type=int, default=None, help="Frames per second for the output video.") parser.add_argument( "--vae-use-slicing", action="store_true", @@ -110,6 +129,23 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Disable torch.compile and force eager execution.", ) + parser.add_argument( + "--audio-sample-rate", + type=int, + default=24000, + help="Sample rate for audio output when saved (default: 24000 for LTX2).", + ) + parser.add_argument( + "--cache-backend", + type=str, + default=None, + choices=["cache_dit", "tea_cache"], + help=( + "Cache backend to use for acceleration. " + "Options: 'cache_dit' (DBCache + SCM + TaylorSeer), 'tea_cache' (Timestep Embedding Aware Cache). " + "Default: None (no cache acceleration)." + ), + ) parser.add_argument( "--use-hsdp", action="store_true", @@ -136,10 +172,13 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() -def calculate_dimensions(image: PIL.Image.Image, max_area: int = 480 * 832) -> tuple[int, int]: +def calculate_dimensions( + image: PIL.Image.Image, + max_area: int = 480 * 832, + mod_value: int = 16, +) -> tuple[int, int]: """Calculate output dimensions maintaining aspect ratio.""" aspect_ratio = image.height / image.width - mod_value = 16 # Must be divisible by 16 height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value @@ -150,22 +189,66 @@ def calculate_dimensions(image: PIL.Image.Image, max_area: int = 480 * 832) -> t def main(): args = parse_args() generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed) + model_name = str(args.model).lower() if args.model is not None else "" + model_class_name = args.model_class_name + is_ltx2 = "ltx2" in model_name or (model_class_name and "ltx2" in model_class_name.lower()) + if model_class_name is None and is_ltx2: + model_class_name = "LTX2ImageToVideoPipeline" # Load input image image = PIL.Image.open(args.image).convert("RGB") + fps = args.fps if args.fps is not None else (24 if is_ltx2 else 16) + frame_rate = args.frame_rate if args.frame_rate is not None else float(fps) + guidance_scale = args.guidance_scale if args.guidance_scale is not None else (4.0 if is_ltx2 else 5.0) + num_frames = args.num_frames if args.num_frames is not None else (121 if is_ltx2 else 81) + num_inference_steps = args.num_inference_steps if args.num_inference_steps is not None else (40 if is_ltx2 else 50) + # Calculate dimensions if not provided height = args.height width = args.width if height is None or width is None: - # Default to 480P area for I2V - calc_height, calc_width = calculate_dimensions(image, max_area=480 * 832) + # Default to 480P area for Wan2.2 I2V, 512x768 area for LTX2 + max_area = 512 * 768 if is_ltx2 else 480 * 832 + mod_value = 32 if is_ltx2 else 16 + calc_height, calc_width = calculate_dimensions(image, max_area=max_area, mod_value=mod_value) height = height or calc_height width = width or calc_width # Resize image to target dimensions image = image.resize((width, height), PIL.Image.Resampling.LANCZOS) + # Configure cache based on backend type + cache_config = None + if args.cache_backend == "cache_dit": + if is_ltx2: + cache_config = { + "Fn_compute_blocks": 2, + "Bn_compute_blocks": 0, + "max_warmup_steps": 8, + "residual_diff_threshold": 0.12, + "max_continuous_cached_steps": 1, + "max_cached_steps": 20, + "enable_taylorseer": False, + "scm_steps_mask_policy": None, + } + else: + cache_config = { + "Fn_compute_blocks": 1, + "Bn_compute_blocks": 0, + "max_warmup_steps": 4, + "residual_diff_threshold": 0.24, + "max_continuous_cached_steps": 3, + "enable_taylorseer": False, + "taylorseer_order": 1, + "scm_steps_mask_policy": None, + "scm_steps_policy": "dynamic", + } + elif args.cache_backend == "tea_cache": + cache_config = { + "rel_l1_thresh": 0.2, + } + # Check if profiling is requested via environment variable profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) parallel_config = DiffusionParallelConfig( @@ -187,6 +270,9 @@ def main(): enable_cpu_offload=args.enable_cpu_offload, parallel_config=parallel_config, enforce_eager=args.enforce_eager, + model_class_name=model_class_name, + cache_backend=args.cache_backend, + cache_config=cache_config, ) if profiler_enabled: @@ -199,7 +285,13 @@ def main(): print(f" Model: {args.model}") print(f" Inference steps: {args.num_inference_steps}") print(f" Frames: {args.num_frames}") - print(f" Parallel configuration: cfg_parallel_size={args.cfg_parallel_size}") + print( + " Parallel configuration: " + f"cfg_parallel_size={args.cfg_parallel_size}, " + f"ulysses_degree={args.ulysses_degree}, " + f"ring_degree={args.ring_degree}, " + f"tensor_parallel_size={args.tensor_parallel_size}" + ) print(f" Video size: {args.width}x{args.height}") print(f"{'=' * 60}\n") @@ -214,38 +306,64 @@ def main(): height=height, width=width, generator=generator, - guidance_scale=args.guidance_scale, + guidance_scale=guidance_scale, guidance_scale_2=args.guidance_scale_high, - num_inference_steps=args.num_inference_steps, - num_frames=args.num_frames, + num_inference_steps=num_inference_steps, + num_frames=num_frames, + frame_rate=frame_rate, ), ) - # Extract video frames from OmniRequestOutput - if isinstance(frames, list) and len(frames) > 0: - first_item = frames[0] - - # Check if it's an OmniRequestOutput - if hasattr(first_item, "final_output_type"): - if first_item.final_output_type != "image": - raise ValueError( - f"Unexpected output type '{first_item.final_output_type}', expected 'image' for video generation." - ) - - # Pipeline mode: extract from nested request_output - if hasattr(first_item, "is_pipeline_output") and first_item.is_pipeline_output: - if isinstance(first_item.request_output, list) and len(first_item.request_output) > 0: - inner_output = first_item.request_output[0] - if isinstance(inner_output, OmniRequestOutput) and hasattr(inner_output, "images"): - frames = inner_output.images[0] if inner_output.images else None - if frames is None: - raise ValueError("No video frames found in output.") - # Diffusion mode: use direct images field - elif hasattr(first_item, "images") and first_item.images: - frames = first_item.images + audio = None + if isinstance(frames, list): + frames = frames[0] if frames else None + + if isinstance(frames, OmniRequestOutput): + if frames.final_output_type != "image": + raise ValueError( + f"Unexpected output type '{frames.final_output_type}', expected 'image' for video generation." + ) + if frames.multimodal_output and "audio" in frames.multimodal_output: + audio = frames.multimodal_output["audio"] + if frames.is_pipeline_output and frames.request_output is not None: + inner_output = frames.request_output + if isinstance(inner_output, list): + inner_output = inner_output[0] if inner_output else None + if isinstance(inner_output, OmniRequestOutput): + if inner_output.multimodal_output and "audio" in inner_output.multimodal_output: + audio = inner_output.multimodal_output["audio"] + frames = inner_output + if isinstance(frames, OmniRequestOutput): + if frames.images: + if len(frames.images) == 1 and isinstance(frames.images[0], tuple) and len(frames.images[0]) == 2: + frames, audio = frames.images[0] + elif len(frames.images) == 1 and isinstance(frames.images[0], dict): + audio = frames.images[0].get("audio") + frames = frames.images[0].get("frames") or frames.images[0].get("video") + else: + frames = frames.images else: raise ValueError("No video frames found in OmniRequestOutput.") + if isinstance(frames, list) and frames: + first_item = frames[0] + if isinstance(first_item, tuple) and len(first_item) == 2: + frames, audio = first_item + elif isinstance(first_item, dict): + audio = first_item.get("audio") + frames = first_item.get("frames") or first_item.get("video") + elif isinstance(first_item, list): + frames = first_item + + if isinstance(frames, tuple) and len(frames) == 2: + frames, audio = frames + elif isinstance(frames, dict): + audio = frames.get("audio") + frames = frames.get("frames") or frames.get("video") + + if frames is None: + raise ValueError("No video frames found in output.") + output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) @@ -254,32 +372,128 @@ def main(): except ImportError: raise ImportError("diffusers is required for export_to_video.") - # frames may be np.ndarray (preferred) or torch.Tensor + def _normalize_frame(frame): + if isinstance(frame, torch.Tensor): + frame_tensor = frame.detach().cpu() + if frame_tensor.dim() == 4 and frame_tensor.shape[0] == 1: + frame_tensor = frame_tensor[0] + if frame_tensor.dim() == 3 and frame_tensor.shape[0] in (3, 4): + frame_tensor = frame_tensor.permute(1, 2, 0) + if frame_tensor.is_floating_point(): + frame_tensor = frame_tensor.clamp(-1, 1) * 0.5 + 0.5 + return frame_tensor.float().numpy() + if isinstance(frame, np.ndarray): + frame_array = frame + if frame_array.ndim == 4 and frame_array.shape[0] == 1: + frame_array = frame_array[0] + if np.issubdtype(frame_array.dtype, np.integer): + frame_array = frame_array.astype(np.float32) / 255.0 + return frame_array + try: + from PIL import Image + except ImportError: + Image = None + if Image is not None and isinstance(frame, Image.Image): + return np.asarray(frame).astype(np.float32) / 255.0 + return frame + + def _ensure_frame_list(video_array): + if isinstance(video_array, list): + if len(video_array) == 0: + return video_array + first_item = video_array[0] + if isinstance(first_item, np.ndarray): + if first_item.ndim == 5: + return list(first_item[0]) + if first_item.ndim == 4: + if len(video_array) == 1: + return list(first_item) + return list(first_item) + if first_item.ndim == 3: + return video_array + return video_array + if isinstance(video_array, np.ndarray): + if video_array.ndim == 5: + return list(video_array[0]) + if video_array.ndim == 4: + return list(video_array) + if video_array.ndim == 3: + return [video_array] + return video_array + + # frames may be np.ndarray, torch.Tensor, or list of tensors/arrays/images # export_to_video expects a list of frames with values in [0, 1] if isinstance(frames, torch.Tensor): video_tensor = frames.detach().cpu() if video_tensor.dim() == 5: - # [B, C, F, H, W] or [B, F, H, W, C] if video_tensor.shape[1] in (3, 4): video_tensor = video_tensor[0].permute(1, 2, 3, 0) else: video_tensor = video_tensor[0] elif video_tensor.dim() == 4 and video_tensor.shape[0] in (3, 4): video_tensor = video_tensor.permute(1, 2, 3, 0) - # If float, assume [-1,1] and normalize to [0,1] if video_tensor.is_floating_point(): video_tensor = video_tensor.clamp(-1, 1) * 0.5 + 0.5 video_array = video_tensor.float().numpy() - else: + elif isinstance(frames, np.ndarray): video_array = frames - if hasattr(video_array, "shape") and video_array.ndim == 5: + if video_array.ndim == 5: video_array = video_array[0] + if np.issubdtype(video_array.dtype, np.integer): + video_array = video_array.astype(np.float32) / 255.0 + elif isinstance(frames, list): + if len(frames) == 0: + raise ValueError("No video frames found in output.") + video_array = [_normalize_frame(frame) for frame in frames] + else: + video_array = frames - # Convert 4D array (frames, H, W, C) to list of frames for export_to_video - if isinstance(video_array, np.ndarray) and video_array.ndim == 4: - video_array = list(video_array) - - export_to_video(video_array, str(output_path), fps=args.fps) + video_array = _ensure_frame_list(video_array) + + use_ltx2_export = is_ltx2 + encode_video = None + if use_ltx2_export: + try: + from diffusers.pipelines.ltx2.export_utils import encode_video + except ImportError: + encode_video = None + + if use_ltx2_export and encode_video is not None: + if isinstance(video_array, list): + frames_np = np.stack(video_array, axis=0) + elif isinstance(video_array, np.ndarray): + frames_np = video_array + else: + frames_np = np.asarray(video_array) + + if frames_np.ndim == 4 and frames_np.shape[-1] == 4: + frames_np = frames_np[..., :3] + + frames_np = np.clip(frames_np, 0.0, 1.0) + frames_u8 = (frames_np * 255).round().clip(0, 255).astype("uint8") + video_tensor = torch.from_numpy(frames_u8) + + audio_out = None + if audio is not None: + if isinstance(audio, list): + audio = audio[0] if audio else None + if isinstance(audio, np.ndarray): + audio = torch.from_numpy(audio) + if isinstance(audio, torch.Tensor): + audio_out = audio + if audio_out.dim() > 1: + audio_out = audio_out[0] + audio_out = audio_out.float().cpu() + + encode_video( + video_tensor, + fps=fps, + audio=audio_out, + audio_sample_rate=args.audio_sample_rate if audio_out is not None else None, + output_path=str(output_path), + ) + else: + export_to_video(video_array, str(output_path), fps=fps) print(f"Saved generated video to {output_path}") if profiler_enabled: diff --git a/examples/offline_inference/text_to_video/text_to_video.md b/examples/offline_inference/text_to_video/text_to_video.md index 4e34ff8bab..815622e174 100644 --- a/examples/offline_inference/text_to_video/text_to_video.md +++ b/examples/offline_inference/text_to_video/text_to_video.md @@ -1,6 +1,7 @@ # Text-To-Video -The `Wan-AI/Wan2.2-T2V-A14B-Diffusers` pipeline generates short videos from text prompts. +The `Wan-AI/Wan2.2-T2V-A14B-Diffusers` pipeline generates short videos from text prompts. This script can also be used +for `Lightricks/LTX-2` to generate video+audio. ## Local CLI Usage @@ -19,6 +20,23 @@ python text_to_video.py \ --output t2v_out.mp4 ``` +LTX2 example: + +```bash +python text_to_video.py \ + --model "Lightricks/LTX-2" \ + --prompt "A cinematic close-up of ocean waves at golden hour." \ + --negative_prompt "worst quality, inconsistent motion, blurry, jittery, distorted" \ + --height 512 \ + --width 768 \ + --num_frames 121 \ + --num_inference_steps 40 \ + --guidance_scale 4.0 \ + --frame_rate 24 \ + --fps 24 \ + --output ltx2_out.mp4 +``` + Key arguments: - `--prompt`: text description (string). @@ -32,6 +50,9 @@ Key arguments: - `--vae-use-slicing`: enable VAE slicing for memory optimization. - `--vae-use-tiling`: enable VAE tiling for memory optimization. - `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel). +- `--tensor-parallel-size`: tensor parallel size (effective for models that support TP, e.g. LTX2). - `--enable-cpu-offload`: enable CPU offloading for diffusion models. +- `--frame_rate`: generation FPS for pipelines that require it (e.g., LTX2). +- `--audio_sample_rate`: audio sample rate for embedded audio (when the pipeline returns audio). > ℹ️ If you encounter OOM errors, try using `--vae-use-slicing` and `--vae-use-tiling` to reduce memory usage. diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py index 40fafa1009..68c1256467 100644 --- a/examples/offline_inference/text_to_video/text_to_video.py +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -31,6 +31,12 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--height", type=int, default=720, help="Video height.") parser.add_argument("--width", type=int, default=1280, help="Video width.") parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (Wan default is 81).") + parser.add_argument( + "--frame-rate", + type=float, + default=None, + help="Optional generation frame rate (used by models like LTX2). Defaults to --fps.", + ) parser.add_argument("--num-inference-steps", type=int, default=40, help="Sampling steps.") parser.add_argument( "--boundary-ratio", @@ -109,12 +115,31 @@ def parse_args() -> argparse.Namespace: default=1, help="Number of GPUs used for tensor parallelism (TP) inside the DiT.", ) + parser.add_argument( + "--audio_sample_rate", + type=int, + default=24000, + help="Sample rate for audio output when saved (default: 24000 for LTX2).", + ) + parser.add_argument( + "--cache_backend", + type=str, + default=None, + choices=["cache_dit", "tea_cache"], + help=( + "Cache backend to use for acceleration. " + "Options: 'cache_dit' (DBCache + SCM + TaylorSeer), 'tea_cache' (Timestep Embedding Aware Cache). " + "Default: None (no cache acceleration)." + ), + ) + return parser.parse_args() def main(): args = parse_args() generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed) + frame_rate = args.frame_rate if args.frame_rate is not None else float(args.fps) # Wan2.2 cache-dit tuning (from cache-dit examples and cache_alignment). cache_config = None @@ -134,8 +159,7 @@ def main(): "scm_steps_mask_policy": None, # SCM mask policy: None (disabled), "slow", "medium", "fast", "ultra" "scm_steps_policy": "dynamic", # SCM steps policy: "dynamic" or "static" } - # Configure parallel settings (only SP is supported for Wan) - # Note: cfg_parallel and tensor_parallel are not implemented for Wan models + # Configure parallel settings parallel_config = DiffusionParallelConfig( ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree, @@ -153,12 +177,12 @@ def main(): vae_use_tiling=args.vae_use_tiling, boundary_ratio=args.boundary_ratio, flow_shift=args.flow_shift, - cache_backend=args.cache_backend, - cache_config=cache_config, enable_cache_dit_summary=args.enable_cache_dit_summary, enable_cpu_offload=args.enable_cpu_offload, parallel_config=parallel_config, enforce_eager=args.enforce_eager, + cache_backend=args.cache_backend, + cache_config=cache_config, ) if profiler_enabled: @@ -172,7 +196,11 @@ def main(): print(f" Inference steps: {args.num_inference_steps}") print(f" Frames: {args.num_frames}") print( - f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}, tensor_parallel_size={args.tensor_parallel_size}" + " Parallel configuration: " + f"ulysses_degree={args.ulysses_degree}, " + f"ring_degree={args.ring_degree}, " + f"cfg_parallel_size={args.cfg_parallel_size}, " + f"tensor_parallel_size={args.tensor_parallel_size}" ) print(f" Video size: {args.width}x{args.height}") print(f"{'=' * 60}\n") @@ -191,6 +219,7 @@ def main(): guidance_scale_2=args.guidance_scale_high, num_inference_steps=args.num_inference_steps, num_frames=args.num_frames, + frame_rate=frame_rate, ), ) generation_end = time.perf_counter() @@ -199,31 +228,56 @@ def main(): # Print profiling results print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)") - # Extract video frames from OmniRequestOutput - if isinstance(frames, list) and len(frames) > 0: - first_item = frames[0] + audio = None + if isinstance(frames, list): + frames = frames[0] if frames else None - # Check if it's an OmniRequestOutput - if hasattr(first_item, "final_output_type"): - if first_item.final_output_type != "image": - raise ValueError( - f"Unexpected output type '{first_item.final_output_type}', expected 'image' for video generation." - ) - - # Pipeline mode: extract from nested request_output - if hasattr(first_item, "is_pipeline_output") and first_item.is_pipeline_output: - if isinstance(first_item.request_output, list) and len(first_item.request_output) > 0: - inner_output = first_item.request_output[0] - if isinstance(inner_output, OmniRequestOutput) and hasattr(inner_output, "images"): - frames = inner_output.images[0] if inner_output.images else None - if frames is None: - raise ValueError("No video frames found in output.") - # Diffusion mode: use direct images field - elif hasattr(first_item, "images") and first_item.images: - frames = first_item.images + if isinstance(frames, OmniRequestOutput): + if frames.final_output_type != "image": + raise ValueError( + f"Unexpected output type '{frames.final_output_type}', expected 'image' for video generation." + ) + if frames.multimodal_output and "audio" in frames.multimodal_output: + audio = frames.multimodal_output["audio"] + if frames.is_pipeline_output and frames.request_output is not None: + inner_output = frames.request_output + if isinstance(inner_output, list): + inner_output = inner_output[0] if inner_output else None + if isinstance(inner_output, OmniRequestOutput): + if inner_output.multimodal_output and "audio" in inner_output.multimodal_output: + audio = inner_output.multimodal_output["audio"] + frames = inner_output + if isinstance(frames, OmniRequestOutput): + if frames.images: + if len(frames.images) == 1 and isinstance(frames.images[0], tuple) and len(frames.images[0]) == 2: + frames, audio = frames.images[0] + elif len(frames.images) == 1 and isinstance(frames.images[0], dict): + audio = frames.images[0].get("audio") + frames = frames.images[0].get("frames") or frames.images[0].get("video") + else: + frames = frames.images else: raise ValueError("No video frames found in OmniRequestOutput.") + if isinstance(frames, list) and frames: + first_item = frames[0] + if isinstance(first_item, tuple) and len(first_item) == 2: + frames, audio = first_item + elif isinstance(first_item, dict): + audio = first_item.get("audio") + frames = first_item.get("frames") or first_item.get("video") + elif isinstance(first_item, list): + frames = first_item + + if isinstance(frames, tuple) and len(frames) == 2: + frames, audio = frames + elif isinstance(frames, dict): + audio = frames.get("audio") + frames = frames.get("frames") or frames.get("video") + + if frames is None: + raise ValueError("No video frames found in output.") + output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) try: @@ -231,32 +285,127 @@ def main(): except ImportError: raise ImportError("diffusers is required for export_to_video.") - # frames may be np.ndarray (preferred) or torch.Tensor + def _normalize_frame(frame): + if isinstance(frame, torch.Tensor): + frame_tensor = frame.detach().cpu() + if frame_tensor.dim() == 4 and frame_tensor.shape[0] == 1: + frame_tensor = frame_tensor[0] + if frame_tensor.dim() == 3 and frame_tensor.shape[0] in (3, 4): + frame_tensor = frame_tensor.permute(1, 2, 0) + if frame_tensor.is_floating_point(): + frame_tensor = frame_tensor.clamp(-1, 1) * 0.5 + 0.5 + return frame_tensor.float().numpy() + if isinstance(frame, np.ndarray): + frame_array = frame + if frame_array.ndim == 4 and frame_array.shape[0] == 1: + frame_array = frame_array[0] + if np.issubdtype(frame_array.dtype, np.integer): + frame_array = frame_array.astype(np.float32) / 255.0 + return frame_array + try: + from PIL import Image + except ImportError: + Image = None + if Image is not None and isinstance(frame, Image.Image): + return np.asarray(frame).astype(np.float32) / 255.0 + return frame + + def _ensure_frame_list(video_array): + if isinstance(video_array, list): + if len(video_array) == 0: + return video_array + first_item = video_array[0] + if isinstance(first_item, np.ndarray): + if first_item.ndim == 5: + return list(first_item[0]) + if first_item.ndim == 4: + if len(video_array) == 1: + return list(first_item) + return list(first_item) + if first_item.ndim == 3: + return video_array + return video_array + if isinstance(video_array, np.ndarray): + if video_array.ndim == 5: + return list(video_array[0]) + if video_array.ndim == 4: + return list(video_array) + if video_array.ndim == 3: + return [video_array] + return video_array + + # frames may be np.ndarray, torch.Tensor, or list of tensors/arrays/images # export_to_video expects a list of frames with values in [0, 1] if isinstance(frames, torch.Tensor): video_tensor = frames.detach().cpu() if video_tensor.dim() == 5: - # [B, C, F, H, W] or [B, F, H, W, C] if video_tensor.shape[1] in (3, 4): video_tensor = video_tensor[0].permute(1, 2, 3, 0) else: video_tensor = video_tensor[0] elif video_tensor.dim() == 4 and video_tensor.shape[0] in (3, 4): video_tensor = video_tensor.permute(1, 2, 3, 0) - # If float, assume [-1,1] and normalize to [0,1] if video_tensor.is_floating_point(): video_tensor = video_tensor.clamp(-1, 1) * 0.5 + 0.5 video_array = video_tensor.float().numpy() - else: + elif isinstance(frames, np.ndarray): video_array = frames - if hasattr(video_array, "shape") and video_array.ndim == 5: + if video_array.ndim == 5: video_array = video_array[0] + if np.issubdtype(video_array.dtype, np.integer): + video_array = video_array.astype(np.float32) / 255.0 + elif isinstance(frames, list): + if len(frames) == 0: + raise ValueError("No video frames found in output.") + video_array = [_normalize_frame(frame) for frame in frames] + else: + video_array = frames + + video_array = _ensure_frame_list(video_array) - # Convert 4D array (frames, H, W, C) to list of frames for export_to_video - if isinstance(video_array, np.ndarray) and video_array.ndim == 4: - video_array = list(video_array) + use_ltx2_export = False + if args.model and "ltx" in str(args.model).lower(): + use_ltx2_export = True + if audio is not None: + use_ltx2_export = True - export_to_video(video_array, str(output_path), fps=args.fps) + if use_ltx2_export: + try: + from diffusers.pipelines.ltx2.export_utils import encode_video + except ImportError: + raise ImportError("diffusers is required for LTX2 encode_video.") + + if isinstance(video_array, list): + frames_np = np.stack(video_array, axis=0) + elif isinstance(video_array, np.ndarray): + frames_np = video_array + else: + frames_np = np.asarray(video_array) + + frames_u8 = (frames_np * 255).round().clip(0, 255).astype("uint8") + video_tensor = torch.from_numpy(frames_u8) + + audio_out = None + if audio is not None: + if isinstance(audio, list): + audio = audio[0] if audio else None + if isinstance(audio, np.ndarray): + audio = torch.from_numpy(audio) + if isinstance(audio, torch.Tensor): + audio_out = audio + if audio_out.dim() > 1: + audio_out = audio_out[0] + audio_out = audio_out.float().cpu() + + encode_video( + video_tensor, + fps=args.fps, + audio=audio_out, + audio_sample_rate=args.audio_sample_rate if audio_out is not None else None, + output_path=str(output_path), + ) + else: + export_to_video(video_array, str(output_path), fps=args.fps) print(f"Saved generated video to {output_path}") if profiler_enabled: diff --git a/tests/entrypoints/openai_api/test_video_server.py b/tests/entrypoints/openai_api/test_video_server.py index dc9414a67d..2cc12f31cf 100644 --- a/tests/entrypoints/openai_api/test_video_server.py +++ b/tests/entrypoints/openai_api/test_video_server.py @@ -91,6 +91,7 @@ def _fake_encode(video, fps): assert captured.height == 360 assert captured.num_frames == 24 assert captured.fps == 12 + assert captured.frame_rate == 12.0 assert fps_values == [12, 12] @@ -121,6 +122,32 @@ def test_i2v_video_generation_form(test_client, mocker: MockerFixture): assert input_image.size == (48, 32) +def test_i2v_video_generation_resizes_input_to_requested_dimensions(test_client, mocker: MockerFixture): + image_bytes = _make_test_image_bytes((48, 32)) + + mocker.patch( + "vllm_omni.entrypoints.openai.serving_video.encode_video_base64", + return_value="Zg==", + ) + response = test_client.post( + "/v1/videos", + data={ + "prompt": "A bear playing with yarn.", + "width": "96", + "height": "64", + }, + files={"input_reference": ("input.png", image_bytes, "image/png")}, + ) + + assert response.status_code == 200 + + engine = test_client.app.state.openai_serving_video._engine_client + prompt = engine.captured_prompt + input_image = prompt["multi_modal_data"]["image"] + assert isinstance(input_image, Image.Image) + assert input_image.size == (96, 64) + + def test_seconds_defaults_fps_and_frames(test_client, mocker: MockerFixture): fps_values = [] diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index f014d073e8..5008a8fa33 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -414,6 +414,60 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool return refresh_cache_context +def enable_cache_for_ltx2(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for LTX2 pipelines (audio-video transformer blocks).""" + transformer = pipeline.transformer + + db_cache_config = _build_db_cache_config(cache_config) + + calibrator_config = None + if cache_config.enable_taylorseer: + taylorseer_order = cache_config.taylorseer_order + calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) + logger.info(f"TaylorSeer enabled with order={taylorseer_order}") + + blocks = transformer.transformer_blocks + + logger.info( + f"Enabling cache-dit on LTX2 transformer: " + f"Fn={db_cache_config.Fn_compute_blocks}, " + f"Bn={db_cache_config.Bn_compute_blocks}, " + f"W={db_cache_config.max_warmup_steps}, " + ) + + cache_dit.enable_cache( + BlockAdapter( + transformer=transformer, + blocks=blocks, + # LTX2 blocks return (hidden_states, audio_hidden_states) + forward_pattern=ForwardPattern.Pattern_0, + # Treat audio_hidden_states as encoder_hidden_states in Pattern_0 + check_forward_pattern=False, + ), + cache_config=db_cache_config, + calibrator_config=calibrator_config, + ) + + def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + if cache_config.scm_steps_mask_policy is None: + cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose) + else: + cache_dit.refresh_context( + pipeline.transformer, + cache_config=DBCacheConfig().reset( + num_inference_steps=num_inference_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=cache_config.scm_steps_mask_policy, + total_steps=num_inference_steps, + ), + steps_computation_policy=cache_config.scm_steps_policy, + ), + verbose=verbose, + ) + + return refresh_cache_context + + def enable_cache_for_dit(pipeline: Any, cache_config: Any) -> Callable[[int], None]: """Enable cache-dit for regular single-transformer DiT models. @@ -862,6 +916,8 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool "LongCatImagePipeline": enable_cache_for_longcat_image, "LongCatImageEditPipeline": enable_cache_for_longcat_image, "StableDiffusion3Pipeline": enable_cache_for_sd3, + "LTX2Pipeline": enable_cache_for_ltx2, + "LTX2ImageToVideoPipeline": enable_cache_for_ltx2, "BagelPipeline": enable_cache_for_bagel, } ) diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 8e4a9f7a20..11c3dc220f 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -94,6 +94,10 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: postprocess_start_time = time.time() outputs = self.post_process_func(output.output) if self.post_process_func is not None else output.output + audio_payload = None + if isinstance(outputs, dict): + audio_payload = outputs.get("audio") + outputs = outputs.get("video", outputs) postprocess_time = time.time() - postprocess_start_time logger.info(f"Post-processing completed in {postprocess_time:.4f} seconds") @@ -117,7 +121,7 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: request_id = request.request_ids[0] if request.request_ids else "" if supports_audio_output(self.od_config.model_class_name): - audio_payload = outputs[0] if len(outputs) == 1 else outputs + request_audio_payload = outputs[0] if len(outputs) == 1 else outputs return [ OmniRequestOutput.from_diffusion( request_id=request_id, @@ -125,11 +129,14 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: prompt=prompt, metrics=metrics, latents=output.trajectory_latents, - multimodal_output={"audio": audio_payload}, + multimodal_output={"audio": request_audio_payload}, final_output_type="audio", ), ] else: + mm_output = {} + if audio_payload is not None: + mm_output["audio"] = audio_payload return [ OmniRequestOutput.from_diffusion( request_id=request_id, @@ -137,6 +144,7 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: prompt=prompt, metrics=metrics, latents=output.trajectory_latents, + multimodal_output=mm_output, ), ] else: @@ -150,11 +158,13 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: # Get images for this request num_outputs = request.sampling_params.num_outputs_per_prompt - request_outputs = outputs[output_idx : output_idx + num_outputs] if output_idx < len(outputs) else [] - output_idx += num_outputs + start_idx = output_idx + end_idx = start_idx + num_outputs + request_outputs = outputs[start_idx:end_idx] if output_idx < len(outputs) else [] + output_idx = end_idx if supports_audio_output(self.od_config.model_class_name): - audio_payload = request_outputs[0] if len(request_outputs) == 1 else request_outputs + request_audio_payload = request_outputs[0] if len(request_outputs) == 1 else request_outputs results.append( OmniRequestOutput.from_diffusion( request_id=request_id, @@ -162,11 +172,24 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: prompt=prompt, metrics=metrics, latents=output.trajectory_latents, - multimodal_output={"audio": audio_payload}, + multimodal_output={"audio": request_audio_payload}, final_output_type="audio", - ) + ), ) else: + mm_output = {} + if audio_payload is not None: + sliced_audio = audio_payload + if isinstance(audio_payload, (list, tuple)): + sliced_audio = audio_payload[start_idx:end_idx] + if len(sliced_audio) == 1: + sliced_audio = sliced_audio[0] + elif hasattr(audio_payload, "shape") and getattr(audio_payload, "shape", None) is not None: + if len(audio_payload.shape) > 0 and audio_payload.shape[0] >= end_idx: + sliced_audio = audio_payload[start_idx:end_idx] + if num_outputs == 1: + sliced_audio = sliced_audio[0] + mm_output["audio"] = sliced_audio results.append( OmniRequestOutput.from_diffusion( request_id=request_id, @@ -174,7 +197,8 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: prompt=prompt, metrics=metrics, latents=output.trajectory_latents, - ) + multimodal_output=mm_output, + ), ) return results diff --git a/vllm_omni/diffusion/models/ltx2/__init__.py b/vllm_omni/diffusion/models/ltx2/__init__.py new file mode 100644 index 0000000000..0a92d4f24f --- /dev/null +++ b/vllm_omni/diffusion/models/ltx2/__init__.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.diffusion.models.ltx2.ltx2_transformer import LTX2VideoTransformer3DModel +from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import ( + LTX2Pipeline, + create_transformer_from_config, + get_ltx2_post_process_func, + load_transformer_config, +) +from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_image2video import LTX2ImageToVideoPipeline + +__all__ = [ + "LTX2Pipeline", + "LTX2ImageToVideoPipeline", + "get_ltx2_post_process_func", + "load_transformer_config", + "create_transformer_from_config", + "LTX2VideoTransformer3DModel", +] diff --git a/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py b/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py new file mode 100644 index 0000000000..b64448d22e --- /dev/null +++ b/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py @@ -0,0 +1,1835 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from collections.abc import Iterable +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Any + +import torch +import torch.distributed +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings, PixArtAlphaTextProjection +from diffusers.utils import ( + BaseOutput, + is_torch_version, +) +from torch.utils.checkpoint import checkpoint +from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelInput, SequenceParallelOutput +from vllm_omni.diffusion.forward_context import get_forward_context, is_forward_context_available + +logger = init_logger(__name__) + +_RMSNORM_INIT_PARAMS = inspect.signature(RMSNorm.__init__).parameters + + +def _make_rms_norm(hidden_size: int, *, eps: float, elementwise_affine: bool) -> nn.Module: + """Bridge diffusers' RMSNorm API onto vLLM's `has_weight` variant.""" + kwargs: dict[str, Any] = {"eps": eps} + if "elementwise_affine" in _RMSNORM_INIT_PARAMS: + kwargs["elementwise_affine"] = elementwise_affine + elif "has_weight" in _RMSNORM_INIT_PARAMS: + kwargs["has_weight"] = elementwise_affine + elif not elementwise_affine: + raise TypeError("RMSNorm backend does not support disabling affine weights.") + return RMSNorm(hidden_size, **kwargs) + + +def apply_interleaved_rotary_emb(x: torch.Tensor, freqs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + cos, sin = freqs + x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + + +def apply_split_rotary_emb(x: torch.Tensor, freqs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + cos, sin = freqs + + x_dtype = x.dtype + needs_reshape = False + if x.ndim != 4 and cos.ndim == 4: + # cos is (#b, h, t, r) -> reshape x to (b, h, t, dim_per_head) + # The cos/sin batch dim may only be broadcastable, so take batch size from x + b = x.shape[0] + _, h, t, _ = cos.shape + x = x.reshape(b, t, h, -1).swapaxes(1, 2) + needs_reshape = True + + # Split last dim (2*r) into (d=2, r) + last = x.shape[-1] + if last % 2 != 0: + raise ValueError(f"Expected x.shape[-1] to be even for split rotary, got {last}.") + r = last // 2 + + # (..., 2, r) + split_x = x.reshape(*x.shape[:-1], 2, r).float() # Explicitly upcast to float + first_x = split_x[..., :1, :] # (..., 1, r) + second_x = split_x[..., 1:, :] # (..., 1, r) + + cos_u = cos.unsqueeze(-2) # broadcast to (..., 1, r) against (..., 2, r) + sin_u = sin.unsqueeze(-2) + + out = split_x * cos_u + first_out = out[..., :1, :] + second_out = out[..., 1:, :] + + first_out.addcmul_(-sin_u, second_x) + second_out.addcmul_(sin_u, first_x) + + out = out.reshape(*out.shape[:-2], last) + + if needs_reshape: + out = out.swapaxes(1, 2).reshape(b, t, -1) + + out = out.to(dtype=x_dtype) + return out + + +@dataclass +class AudioVisualModelOutput(BaseOutput): + r""" + Holds the output of an audiovisual model which produces both visual (e.g. video) and audio outputs. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_video_tokens, out_channels)`): + The patchified visual output conditioned on the `encoder_hidden_states` input. This is the transformer + output before the pipeline unpacks it back into video latent dimensions. + audio_sample (`torch.Tensor` of shape `(batch_size, num_audio_tokens, audio_out_channels)`): + The patchified audio output of the audiovisual model before the pipeline unpacks it back into audio latent + dimensions. + """ + + sample: "torch.Tensor" # noqa: F821 + audio_sample: "torch.Tensor" # noqa: F821 + + +class LTX2AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://huggingface.co/papers/2310.00426; Section 2.3) and adapted by the LTX-2.0 + model. In particular, the number of modulation parameters to be calculated is now configurable. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_mod_params (`int`, *optional*, defaults to `6`): + The number of modulation parameters which will be calculated in the first return argument. The default of 6 + is standard, but sometimes we may want to have a different (usually smaller) number of modulation + parameters. + use_additional_conditions (`bool`, *optional*, defaults to `False`): + Whether to use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, num_mod_params: int = 6, use_additional_conditions: bool = False): + super().__init__() + self.num_mod_params = num_mod_params + + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, self.num_mod_params * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: dict[str, torch.Tensor] | None = None, + batch_size: int | None = None, + hidden_dtype: torch.dtype | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None} + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +class ColumnParallelApproxGELU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, *, approximate: str, bias: bool = True): + super().__init__() + self.proj = ColumnParallelLinear( + dim_in, + dim_out, + bias=bias, + gather_output=False, + return_bias=False, + ) + self.approximate = approximate + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return F.gelu(x, approximate=self.approximate) + + +class LTX2FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: int = 4, + activation_fn: str = "gelu-approximate", + inner_dim: int | None = None, + bias: bool = True, + dropout: float = 0.0, + final_dropout: bool = False, + ) -> None: + super().__init__() + + assert activation_fn == "gelu-approximate", "Only gelu-approximate is supported." + + inner_dim = inner_dim or int(dim * mult) + dim_out = dim_out or dim + + dropout_layer: nn.Module = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + + layers: list[nn.Module] = [ + ColumnParallelApproxGELU(dim, inner_dim, approximate="tanh", bias=bias), + dropout_layer, + RowParallelLinear( + inner_dim, + dim_out, + input_is_parallel=True, + return_bias=False, + ), + ] + if final_dropout: + layers.append(nn.Dropout(dropout) if dropout > 0 else nn.Identity()) + + self.net = nn.ModuleList(layers) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class TensorParallelRMSNorm(nn.Module): + """RMSNorm that computes stats across TP shards for q/k norm. + + LTX2 uses qk_norm="rms_norm_across_heads" while Q/K are tensor-parallel + sharded. A local RMSNorm would compute statistics on only the local shard, + which changes the normalization when TP > 1. We all-reduce the squared + sum to match the global RMS across all heads. + """ + + def __init__(self, hidden_size: int, eps: float = 1e-6, elementwise_affine: bool = True, tp_size: int = 1): + super().__init__() + self.hidden_size = hidden_size + self.global_hidden_size = hidden_size * max(tp_size, 1) + self.eps = eps + self.tp_size = tp_size + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + else: + self.register_parameter("weight", None) + + def _all_reduce(self, tensor: torch.Tensor) -> None: + if not torch.distributed.is_initialized(): + return + try: + import vllm.distributed.parallel_state as vllm_parallel_state + + tp_group = getattr(vllm_parallel_state, "_TP", None) + except Exception: + tp_group = None + + if tp_group is not None and hasattr(tp_group, "all_reduce"): + tp_group.all_reduce(tensor) + return + + if tp_group is not None: + torch.distributed.all_reduce(tensor, group=tp_group) + else: + torch.distributed.all_reduce(tensor) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dtype = x.dtype + x_float = x.float() + local_sum = x_float.pow(2).sum(dim=-1, keepdim=True) + if self.tp_size > 1: + self._all_reduce(local_sum) + inv_rms = torch.rsqrt(local_sum / self.global_hidden_size + self.eps) + out = x_float * inv_rms + if self.weight is not None: + out = out * self.weight.float() + return out.to(dtype=x_dtype) + + +class LTX2AudioVideoAttnProcessor: + r""" + Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0) for the LTX-2.0 model. + Compared to the LTX-1.0 model, we allow the RoPE embeddings for the queries and keys to be separate so that we can + support audio-to-video (a2v) and video-to-audio (v2a) cross attention. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if is_torch_version("<", "2.0"): + raise ValueError( + "LTX attention processors require a minimum PyTorch version of 2.0. " + "Please upgrade your PyTorch installation." + ) + + @staticmethod + def _to_padding_mask(attention_mask: torch.Tensor) -> torch.Tensor: + # Convert additive/expanded masks into a 2D padding mask for flash-attn. + if attention_mask.ndim > 2: + if attention_mask.is_floating_point(): + valid = attention_mask >= 0 + else: + valid = attention_mask.to(torch.bool) + b = valid.shape[0] + key_len = valid.shape[-1] + valid = valid.reshape(b, -1, key_len).all(dim=1) + attention_mask = valid + if attention_mask.is_floating_point(): + attention_mask = attention_mask >= 0 + if attention_mask.dtype != torch.bool: + attention_mask = attention_mask.to(torch.bool) + return attention_mask + + @staticmethod + def _is_sp_enabled() -> bool: + if not is_forward_context_available(): + return False + try: + od_config = get_forward_context().omni_diffusion_config + parallel_config = getattr(od_config, "parallel_config", None) if od_config is not None else None + return getattr(parallel_config, "sequence_parallel_size", 1) > 1 + except Exception: + return False + + def _prepare_attention_mask( + self, + attn: "LTX2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None, + attention_mask: torch.Tensor | None, + batch_size: int, + sequence_length: int, + ) -> torch.Tensor | None: + if attention_mask is None: + return None + + if self._is_sp_enabled(): + # In SP, Ulysses expects a 2D padding mask that matches query length. + # For cross-attention, encoder sequence length != query length, so drop the mask. + if encoder_hidden_states is not None and encoder_hidden_states.shape[1] != hidden_states.shape[1]: + return None + return self._to_padding_mask(attention_mask) + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + if attn.attn.attn_backend.get_name().upper() == "FLASH_ATTN": + attention_mask = self._to_padding_mask(attention_mask) + return attention_mask + + @staticmethod + def _project_qkv( + attn: "LTX2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None, + is_self_attention: bool, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if is_self_attention and attn.to_qkv is not None: + qkv, _ = attn.to_qkv(hidden_states) + q_heads = getattr(attn, "query_num_heads", attn.heads) + kv_heads = getattr(attn, "kv_num_heads", attn.heads) + q_size = q_heads * attn.head_dim + kv_size = kv_heads * attn.head_dim + query, key, value = qkv.split([q_size, kv_size, kv_size], dim=-1) + return query, key, value + + query = attn.to_q(hidden_states) + if isinstance(query, tuple): + query = query[0] + if encoder_hidden_states is None: + raise ValueError("encoder_hidden_states is required for cross-attention projection.") + key = attn.to_k(encoder_hidden_states) + if isinstance(key, tuple): + key = key[0] + value = attn.to_v(encoder_hidden_states) + if isinstance(value, tuple): + value = value[0] + return query, key, value + + @staticmethod + def _slice_rope_for_tp( + rope: tuple[torch.Tensor, torch.Tensor] | None, + attn_module: "LTX2Attention", + ) -> tuple[torch.Tensor, torch.Tensor] | None: + if rope is None: + return None + cos, sin = rope + tp_size = get_tensor_model_parallel_world_size() + if tp_size <= 1: + return rope + tp_rank = get_tensor_model_parallel_rank() + + if cos.ndim == 4: + if cos.shape[1] != attn_module.heads: + local_heads = cos.shape[1] // tp_size + if local_heads == attn_module.heads: + start = tp_rank * local_heads + end = start + local_heads + cos = cos[:, start:end, :, :] + sin = sin[:, start:end, :, :] + elif cos.ndim == 3: + local_dim = attn_module.heads * attn_module.head_dim + if cos.shape[-1] != local_dim: + if cos.shape[-1] == local_dim * tp_size: + start = tp_rank * local_dim + end = start + local_dim + cos = cos[..., start:end] + sin = sin[..., start:end] + + return cos, sin + + def __call__( + self, + attn: "LTX2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + is_self_attention = encoder_hidden_states is None + batch_size, sequence_length, _ = hidden_states.shape if is_self_attention else encoder_hidden_states.shape + + attention_mask = self._prepare_attention_mask( + attn=attn, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + batch_size=batch_size, + sequence_length=sequence_length, + ) + + if is_self_attention: + encoder_hidden_states = hidden_states + + query, key, value = self._project_qkv( + attn=attn, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + is_self_attention=is_self_attention, + ) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if query_rotary_emb is not None: + query_rotary_emb = self._slice_rope_for_tp(query_rotary_emb, attn) + if key_rotary_emb is not None: + key_rotary_emb = self._slice_rope_for_tp(key_rotary_emb, attn) + if attn.rope_type == "interleaved": + query = apply_interleaved_rotary_emb(query, query_rotary_emb) + key = apply_interleaved_rotary_emb( + key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb + ) + elif attn.rope_type == "split": + query = apply_split_rotary_emb(query, query_rotary_emb) + key = apply_split_rotary_emb(key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + attn_metadata = AttentionMetadata(attn_mask=attention_mask) if attention_mask is not None else None + hidden_states = attn.attn(query, key, value, attn_metadata) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + hidden_states = attn.to_out[0](hidden_states) + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class LTX2Attention(torch.nn.Module): + r""" + Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key + RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention. + """ + + _default_processor_cls = LTX2AudioVideoAttnProcessor + _available_processors = [LTX2AudioVideoAttnProcessor] + + def __init__( + self, + query_dim: int, + heads: int = 8, + kv_heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = True, + cross_attention_dim: int | None = None, + out_bias: bool = True, + qk_norm: str = "rms_norm_across_heads", + norm_eps: float = 1e-6, + norm_elementwise_affine: bool = True, + rope_type: str = "interleaved", + processor=None, + ): + super().__init__() + if qk_norm != "rms_norm_across_heads": + raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.") + + kv_heads = heads if kv_heads is None else kv_heads + + self.head_dim = dim_head + self.inner_dim = dim_head * heads + self.inner_kv_dim = dim_head * kv_heads + self.query_dim = query_dim + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = query_dim + self.total_num_heads = heads + self.total_num_kv_heads = kv_heads + self.rope_type = rope_type + + self.to_qkv = None + self.to_q = None + self.to_k = None + self.to_v = None + if cross_attention_dim is None: + self.to_qkv = QKVParallelLinear( + hidden_size=query_dim, + head_size=self.head_dim, + total_num_heads=heads, + bias=bias, + ) + self.query_num_heads = self.to_qkv.num_heads + self.kv_num_heads = self.to_qkv.num_kv_heads + else: + tp_size = get_tensor_model_parallel_world_size() + self.query_num_heads = heads // tp_size + self.kv_num_heads = kv_heads // tp_size + + self.to_q = ColumnParallelLinear( + query_dim, + self.inner_dim, + bias=bias, + gather_output=False, + return_bias=False, + ) + self.to_k = ColumnParallelLinear( + self.cross_attention_dim, + self.inner_kv_dim, + bias=bias, + gather_output=False, + return_bias=False, + ) + self.to_v = ColumnParallelLinear( + self.cross_attention_dim, + self.inner_kv_dim, + bias=bias, + gather_output=False, + return_bias=False, + ) + + self.heads = self.query_num_heads + tp_size = get_tensor_model_parallel_world_size() + self.norm_q = TensorParallelRMSNorm( + dim_head * self.query_num_heads, + eps=norm_eps, + elementwise_affine=norm_elementwise_affine, + tp_size=tp_size, + ) + self.norm_k = TensorParallelRMSNorm( + dim_head * self.kv_num_heads, + eps=norm_eps, + elementwise_affine=norm_elementwise_affine, + tp_size=tp_size, + ) + + self.to_out = torch.nn.ModuleList( + [ + RowParallelLinear( + self.inner_dim, + self.out_dim, + bias=out_bias, + input_is_parallel=True, + return_bias=False, + ), + torch.nn.Dropout(dropout) if dropout > 0 else torch.nn.Identity(), + ] + ) + self.attn = Attention( + num_heads=self.query_num_heads, + head_size=dim_head, + num_kv_heads=self.kv_num_heads, + softmax_scale=1.0 / (dim_head**0.5), + causal=False, + ) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def set_processor(self, processor: Any) -> None: + if processor is None: + raise ValueError("processor must not be None.") + self.processor = processor + + def get_processor(self) -> Any: + return self.processor + + def prepare_attention_mask( + self, + attention_mask: torch.Tensor | None, + target_length: int, + batch_size: int, + out_dim: int = 3, + ) -> torch.Tensor | None: + if attention_mask is None: + return None + + current_length = attention_mask.shape[-1] + if current_length != target_length: + pad_length = target_length - current_length + if pad_length > 0: + attention_mask = F.pad(attention_mask, (0, pad_length), value=0.0) + else: + attention_mask = attention_mask[..., :target_length] + + if out_dim == 3: + expected_batch = batch_size * self.heads + if attention_mask.shape[0] != expected_batch: + repeat_factor = expected_batch // attention_mask.shape[0] + if repeat_factor * attention_mask.shape[0] != expected_batch: + raise ValueError( + "attention_mask batch dimension is incompatible with the requested batch/head expansion: " + f"got {attention_mask.shape[0]}, expected a divisor of {expected_batch}." + ) + attention_mask = attention_mask.repeat_interleave(repeat_factor, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=1) + else: + raise ValueError(f"Unsupported out_dim={out_dim}; expected 3 or 4.") + + return attention_mask + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + "attention_kwargs %s are not expected by %s and will be ignored.", + unused_kwargs, + self.processor.__class__.__name__, + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + hidden_states = self.processor( + self, hidden_states, encoder_hidden_states, attention_mask, query_rotary_emb, key_rotary_emb, **kwargs + ) + return hidden_states + + +class LTX2VideoTransformerBlock(nn.Module): + r""" + Transformer block used in [LTX-2.0](https://huggingface.co/Lightricks/LTX-Video). + + Args: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + qk_norm (`str`, defaults to `"rms_norm"`): + The normalization layer to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: int, + audio_dim: int, + audio_num_attention_heads: int, + audio_attention_head_dim, + audio_cross_attention_dim: int, + qk_norm: str = "rms_norm_across_heads", + activation_fn: str = "gelu-approximate", + attention_bias: bool = True, + attention_out_bias: bool = True, + eps: float = 1e-6, + elementwise_affine: bool = False, + rope_type: str = "interleaved", + ): + super().__init__() + + # 1. Self-Attention (video and audio) + self.norm1 = _make_rms_norm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.attn1 = LTX2Attention( + query_dim=dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + cross_attention_dim=None, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + self.audio_norm1 = _make_rms_norm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_attn1 = LTX2Attention( + query_dim=audio_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + cross_attention_dim=None, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # 2. Prompt Cross-Attention + self.norm2 = _make_rms_norm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.attn2 = LTX2Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + self.audio_norm2 = _make_rms_norm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_attn2 = LTX2Attention( + query_dim=audio_dim, + cross_attention_dim=audio_cross_attention_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention + # Audio-to-Video (a2v) Attention --> Q: Video; K,V: Audio + self.audio_to_video_norm = _make_rms_norm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_to_video_attn = LTX2Attention( + query_dim=dim, + cross_attention_dim=audio_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video + self.video_to_audio_norm = _make_rms_norm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.video_to_audio_attn = LTX2Attention( + query_dim=audio_dim, + cross_attention_dim=dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # 4. Feedforward layers + self.norm3 = _make_rms_norm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.ff = LTX2FeedForward(dim, activation_fn=activation_fn) + + self.audio_norm3 = _make_rms_norm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_ff = LTX2FeedForward(audio_dim, activation_fn=activation_fn) + + # 5. Per-Layer Modulation Parameters + # Self-Attention / Feedforward AdaLayerNorm-Zero mod params + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5) + + # Per-layer a2v, v2a Cross-Attention mod params + self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim)) + self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim)) + + def forward( + self, + hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + audio_encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + temb_audio: torch.Tensor, + temb_ca_scale_shift: torch.Tensor, + temb_ca_audio_scale_shift: torch.Tensor, + temb_ca_gate: torch.Tensor, + temb_ca_audio_gate: torch.Tensor, + video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + ca_video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + ca_audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + audio_encoder_attention_mask: torch.Tensor | None = None, + a2v_cross_attention_mask: torch.Tensor | None = None, + v2a_cross_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + batch_size = hidden_states.size(0) + + # 1. Video and Audio Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape( + batch_size, temb.size(1), num_ada_params, -1 + ) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + + attn_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + query_rotary_emb=video_rotary_emb, + ) + hidden_states = hidden_states + attn_hidden_states * gate_msa + + norm_audio_hidden_states = self.audio_norm1(audio_hidden_states) + + num_audio_ada_params = self.audio_scale_shift_table.shape[0] + audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape( + batch_size, temb_audio.size(1), num_audio_ada_params, -1 + ) + audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = ( + audio_ada_values.unbind(dim=2) + ) + norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa + + attn_audio_hidden_states = self.audio_attn1( + hidden_states=norm_audio_hidden_states, + encoder_hidden_states=None, + query_rotary_emb=audio_rotary_emb, + ) + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa + + # 2. Video and Audio Cross-Attention with the text embeddings + norm_hidden_states = self.norm2(hidden_states) + attn_hidden_states = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + query_rotary_emb=None, + attention_mask=encoder_attention_mask, + ) + hidden_states = hidden_states + attn_hidden_states + + norm_audio_hidden_states = self.audio_norm2(audio_hidden_states) + attn_audio_hidden_states = self.audio_attn2( + norm_audio_hidden_states, + encoder_hidden_states=audio_encoder_hidden_states, + query_rotary_emb=None, + attention_mask=audio_encoder_attention_mask, + ) + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states + + # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention + norm_hidden_states = self.audio_to_video_norm(hidden_states) + norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states) + + # Combine global and per-layer cross attention modulation parameters + # Video + video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :] + video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :] + + video_ca_scale_shift_table = ( + video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype) + + temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1) + ).unbind(dim=2) + video_ca_gate = ( + video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype) + + temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1) + ).unbind(dim=2) + + video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table + a2v_gate = video_ca_gate[0].squeeze(2) + + # Audio + audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :] + audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :] + + audio_ca_scale_shift_table = ( + audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype) + + temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1) + ).unbind(dim=2) + audio_ca_gate = ( + audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype) + + temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1) + ).unbind(dim=2) + + audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table + v2a_gate = audio_ca_gate[0].squeeze(2) + + # Audio-to-Video Cross Attention: Q: Video; K,V: Audio + mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze( + 2 + ) + mod_norm_audio_hidden_states = norm_audio_hidden_states * ( + 1 + audio_a2v_ca_scale.squeeze(2) + ) + audio_a2v_ca_shift.squeeze(2) + + a2v_attn_hidden_states = self.audio_to_video_attn( + mod_norm_hidden_states, + encoder_hidden_states=mod_norm_audio_hidden_states, + query_rotary_emb=ca_video_rotary_emb, + key_rotary_emb=ca_audio_rotary_emb, + attention_mask=a2v_cross_attention_mask, + ) + + hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states + + # Video-to-Audio Cross Attention: Q: Audio; K,V: Video + mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze( + 2 + ) + mod_norm_audio_hidden_states = norm_audio_hidden_states * ( + 1 + audio_v2a_ca_scale.squeeze(2) + ) + audio_v2a_ca_shift.squeeze(2) + + v2a_attn_hidden_states = self.video_to_audio_attn( + mod_norm_audio_hidden_states, + encoder_hidden_states=mod_norm_hidden_states, + query_rotary_emb=ca_audio_rotary_emb, + key_rotary_emb=ca_video_rotary_emb, + attention_mask=v2a_cross_attention_mask, + ) + + audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states + + # 4. Feedforward + norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_output * gate_mlp + + norm_audio_hidden_states = self.audio_norm3(audio_hidden_states) * (1 + audio_scale_mlp) + audio_shift_mlp + audio_ff_output = self.audio_ff(norm_audio_hidden_states) + audio_hidden_states = audio_hidden_states + audio_ff_output * audio_gate_mlp + + return hidden_states, audio_hidden_states + + +class LTX2AudioVideoRotaryPosEmbed(nn.Module): + """ + Video and audio rotary positional embeddings (RoPE) for the LTX-2.0 model. + + Args: + causal_offset (`int`, *optional*, defaults to `1`): + Offset in the temporal axis for causal VAE modeling. This is typically 1 (for causal modeling where the VAE + treats the very first frame differently), but could also be 0 (for non-causal modeling). + """ + + def __init__( + self, + dim: int, + patch_size: int = 1, + patch_size_t: int = 1, + base_num_frames: int = 20, + base_height: int = 2048, + base_width: int = 2048, + sampling_rate: int = 16000, + hop_length: int = 160, + scale_factors: tuple[int, ...] = (8, 32, 32), + theta: float = 10000.0, + causal_offset: int = 1, + modality: str = "video", + double_precision: bool = True, + rope_type: str = "interleaved", + num_attention_heads: int = 32, + ) -> None: + super().__init__() + + self.dim = dim + self.patch_size = patch_size + self.patch_size_t = patch_size_t + + if rope_type not in ["interleaved", "split"]: + raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.") + self.rope_type = rope_type + + self.base_num_frames = base_num_frames + self.num_attention_heads = num_attention_heads + + # Video-specific + self.base_height = base_height + self.base_width = base_width + + # Audio-specific + self.sampling_rate = sampling_rate + self.hop_length = hop_length + self.audio_latents_per_second = float(sampling_rate) / float(hop_length) / float(scale_factors[0]) + + self.scale_factors = scale_factors + self.theta = theta + self.causal_offset = causal_offset + + self.modality = modality + if self.modality not in ["video", "audio"]: + raise ValueError(f"Modality {modality} is not supported. Supported modalities are `video` and `audio`.") + self.double_precision = double_precision + + def prepare_video_coords( + self, + batch_size: int, + num_frames: int, + height: int, + width: int, + device: torch.device, + fps: float = 24.0, + ) -> torch.Tensor: + """ + Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original pixel + space video grid (num_frames, height, width). This will ultimately have shape (batch_size, 3, num_patches, 2) + where + - axis 1 (size 3) enumerates (frame, height, width) dimensions (e.g. idx 0 corresponds to frames) + - axis 3 (size 2) stores `[start, end)` indices within each dimension + + Args: + batch_size (`int`): + Batch size of the video latents. + num_frames (`int`): + Number of latent frames in the video latents. + height (`int`): + Latent height of the video latents. + width (`int`): + Latent width of the video latents. + device (`torch.device`): + Device on which to create the video grid. + + Returns: + `torch.Tensor`: + Per-dimension patch boundaries tensor of shape [batch_size, 3, num_patches, 2]. + """ + + # 1. Generate grid coordinates for each spatiotemporal dimension (frames, height, width) + # Always compute rope in fp32 + grid_f = torch.arange(start=0, end=num_frames, step=self.patch_size_t, dtype=torch.float32, device=device) + grid_h = torch.arange(start=0, end=height, step=self.patch_size, dtype=torch.float32, device=device) + grid_w = torch.arange(start=0, end=width, step=self.patch_size, dtype=torch.float32, device=device) + # indexing='ij' ensures that the dimensions are kept in order as (frames, height, width) + grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") + grid = torch.stack(grid, dim=0) # [3, N_F, N_H, N_W], where e.g. N_F is the number of temporal patches + + # 2. Get the patch boundaries with respect to the latent video grid + patch_size = (self.patch_size_t, self.patch_size, self.patch_size) + patch_size_delta = torch.tensor(patch_size, dtype=grid.dtype, device=grid.device) + patch_ends = grid + patch_size_delta.view(3, 1, 1, 1) + + # Combine the start (grid) and end (patch_ends) coordinates along new trailing dimension + latent_coords = torch.stack([grid, patch_ends], dim=-1) # [3, N_F, N_H, N_W, 2] + # Reshape to (batch_size, 3, num_patches, 2) + latent_coords = latent_coords.flatten(1, 3) + latent_coords = latent_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1) + + # 3. Calculate the pixel space patch boundaries from the latent boundaries. + scale_tensor = torch.tensor(self.scale_factors, device=latent_coords.device) + # Broadcast the VAE scale factors such that they are compatible with latent_coords's shape + broadcast_shape = [1] * latent_coords.ndim + broadcast_shape[1] = -1 # This is the (frame, height, width) dim + # Apply per-axis scaling to convert latent coordinates to pixel space coordinates + pixel_coords = latent_coords * scale_tensor.view(*broadcast_shape) + + # As the VAE temporal stride for the first frame is 1 instead of self.vae_scale_factors[0], we need to shift + # and clamp to keep the first-frame timestamps causal and non-negative. + pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + self.causal_offset - self.scale_factors[0]).clamp(min=0) + + # Scale the temporal coordinates by the video FPS + pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / fps + + return pixel_coords + + def prepare_audio_coords( + self, + batch_size: int, + num_frames: int, + device: torch.device, + shift: int = 0, + ) -> torch.Tensor: + """ + Create per-dimension bounds [inclusive start, exclusive end) of start and end timestamps for each latent frame. + This will ultimately have shape (batch_size, 3, num_patches, 2) where + - axis 1 (size 1) represents the temporal dimension + - axis 3 (size 2) stores `[start, end)` indices within each dimension + + Args: + batch_size (`int`): + Batch size of the audio latents. + num_frames (`int`): + Number of latent frames in the audio latents. + device (`torch.device`): + Device on which to create the audio grid. + shift (`int`, *optional*, defaults to `0`): + Offset on the latent indices. Different shift values correspond to different overlapping windows with + respect to the same underlying latent grid. + + Returns: + `torch.Tensor`: + Per-dimension patch boundaries tensor of shape [batch_size, 1, num_patches, 2]. + """ + + # 1. Generate coordinates in the frame (time) dimension. + # Always compute rope in fp32 + grid_f = torch.arange( + start=shift, end=num_frames + shift, step=self.patch_size_t, dtype=torch.float32, device=device + ) + + # 2. Calculate start timestamps in seconds with respect to the original spectrogram grid + audio_scale_factor = self.scale_factors[0] + # Scale back to mel spectrogram space + grid_start_mel = grid_f * audio_scale_factor + # Handle first frame causal offset, ensuring non-negative timestamps + grid_start_mel = (grid_start_mel + self.causal_offset - audio_scale_factor).clip(min=0) + # Convert mel bins back into seconds + grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate + + # 3. Calculate start timestamps in seconds with respect to the original spectrogram grid + grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor + grid_end_mel = (grid_end_mel + self.causal_offset - audio_scale_factor).clip(min=0) + grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate + + audio_coords = torch.stack([grid_start_s, grid_end_s], dim=-1) # [num_patches, 2] + audio_coords = audio_coords.unsqueeze(0).expand(batch_size, -1, -1) # [batch_size, num_patches, 2] + audio_coords = audio_coords.unsqueeze(1) # [batch_size, 1, num_patches, 2] + return audio_coords + + def prepare_coords(self, *args, **kwargs): + if self.modality == "video": + return self.prepare_video_coords(*args, **kwargs) + elif self.modality == "audio": + return self.prepare_audio_coords(*args, **kwargs) + + def forward( + self, coords: torch.Tensor, device: str | torch.device | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + device = device or coords.device + + # Number of spatiotemporal dimensions (3 for video, 1 (temporal) for audio and cross attn) + num_pos_dims = coords.shape[1] + + # 1. If the coords are patch boundaries [start, end), use the midpoint of these boundaries as the patch + # position index + if coords.ndim == 4: + coords_start, coords_end = coords.chunk(2, dim=-1) + coords = (coords_start + coords_end) / 2.0 + coords = coords.squeeze(-1) # [B, num_pos_dims, num_patches] + + # 2. Get coordinates as a fraction of the base data shape + if self.modality == "video": + max_positions = (self.base_num_frames, self.base_height, self.base_width) + elif self.modality == "audio": + max_positions = (self.base_num_frames,) + # [B, num_pos_dims, num_patches] --> [B, num_patches, num_pos_dims] + grid = torch.stack([coords[:, i] / max_positions[i] for i in range(num_pos_dims)], dim=-1).to(device) + # Number of spatiotemporal dimensions (3 for video, 1 for audio and cross attn) times 2 for cos, sin + num_rope_elems = num_pos_dims * 2 + + # 3. Create a 1D grid of frequencies for RoPE + freqs_dtype = torch.float64 if self.double_precision else torch.float32 + pow_indices = torch.pow( + self.theta, + torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device), + ) + freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32) + + # 4. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape + # (self.dim // num_elems,) + freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, num_patches, num_pos_dims, self.dim // num_elems] + freqs = freqs.transpose(-1, -2).flatten(2) # [B, num_patches, self.dim // 2] + + # 5. Get real, interleaved (cos, sin) frequencies, padded to self.dim + # TODO: consider implementing this as a utility and reuse in `connectors.py`. + # src/diffusers/pipelines/ltx2/connectors.py + if self.rope_type == "interleaved": + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % num_rope_elems != 0: + cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + elif self.rope_type == "split": + expected_freqs = self.dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq = freqs.cos() + sin_freq = freqs.sin() + + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) + + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + + # Reshape freqs to be compatible with multi-head attention + b = cos_freq.shape[0] + t = cos_freq.shape[1] + + cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1) + sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1) + + cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + + return cos_freqs, sin_freqs + + +class LTX2VideoTransformer3DModel(nn.Module): + r""" + A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video). + + Args: + in_channels (`int`, defaults to `128`): + The number of channels in the input. + out_channels (`int`, defaults to `128`): + The number of channels in the output. + patch_size (`int`, defaults to `1`): + The size of the spatial patches to use in the patch embedding layer. + patch_size_t (`int`, defaults to `1`): + The size of the tmeporal patches to use in the patch embedding layer. + num_attention_heads (`int`, defaults to `32`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `64`): + The number of channels in each head. + cross_attention_dim (`int`, defaults to `2048 `): + The number of channels for cross attention heads. + num_layers (`int`, defaults to `28`): + The number of layers of Transformer blocks to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + qk_norm (`str`, defaults to `"rms_norm_across_heads"`): + The normalization layer to use. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["norm"] + _repeated_blocks = ["LTX2VideoTransformerBlock"] + _sp_plan: dict[str, Any] | None = None + + @staticmethod + def _build_sp_plan(rope_type: str) -> dict[str, Any]: + if rope_type == "split": + # split RoPE returns (B, H, T, D/2) -> shard along T dim + rope_expected_dims = 4 + rope_split_dim = 2 + else: + # interleaved RoPE returns (B, T, D) -> shard along T dim + rope_expected_dims = 3 + rope_split_dim = 1 + + return { + "": { + # Shard video/audio latents across sequence + "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3, split_output=False), + "audio_hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3, split_output=False), + # Shard prompt embeds across sequence + "encoder_hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3, split_output=False), + "audio_encoder_hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3, split_output=False), + # Shard video timestep when provided as (B, seq_len) + "timestep": SequenceParallelInput(split_dim=1, expected_dims=2, split_output=False), + }, + "rope": { + 0: SequenceParallelInput(split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True), + 1: SequenceParallelInput(split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True), + }, + "audio_rope": { + 0: SequenceParallelInput(split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True), + 1: SequenceParallelInput(split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True), + }, + "cross_attn_rope": { + 0: SequenceParallelInput(split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True), + 1: SequenceParallelInput(split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True), + }, + "cross_attn_audio_rope": { + 0: SequenceParallelInput(split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True), + 1: SequenceParallelInput(split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True), + }, + # Gather outputs before returning + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + "audio_proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } + + def __init__( + self, + in_channels: int = 128, # Video Arguments + out_channels: int | None = 128, + patch_size: int = 1, + patch_size_t: int = 1, + num_attention_heads: int = 32, + attention_head_dim: int = 128, + cross_attention_dim: int = 4096, + vae_scale_factors: tuple[int, int, int] = (8, 32, 32), + pos_embed_max_pos: int = 20, + base_height: int = 2048, + base_width: int = 2048, + audio_in_channels: int = 128, # Audio Arguments + audio_out_channels: int | None = 128, + audio_patch_size: int = 1, + audio_patch_size_t: int = 1, + audio_num_attention_heads: int = 32, + audio_attention_head_dim: int = 64, + audio_cross_attention_dim: int = 2048, + audio_scale_factor: int = 4, + audio_pos_embed_max_pos: int = 20, + audio_sampling_rate: int = 16000, + audio_hop_length: int = 160, + num_layers: int = 48, # Shared arguments + activation_fn: str = "gelu-approximate", + qk_norm: str = "rms_norm_across_heads", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + caption_channels: int = 3840, + attention_bias: bool = True, + attention_out_bias: bool = True, + rope_theta: float = 10000.0, + rope_double_precision: bool = True, + causal_offset: int = 1, + timestep_scale_multiplier: int = 1000, + cross_attn_timestep_scale_multiplier: int = 1000, + rope_type: str = "interleaved", + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + audio_out_channels = audio_out_channels or audio_in_channels + inner_dim = num_attention_heads * attention_head_dim + audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim + self.config = SimpleNamespace( + in_channels=in_channels, + out_channels=out_channels, + patch_size=patch_size, + patch_size_t=patch_size_t, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + vae_scale_factors=vae_scale_factors, + pos_embed_max_pos=pos_embed_max_pos, + base_height=base_height, + base_width=base_width, + audio_in_channels=audio_in_channels, + audio_out_channels=audio_out_channels, + audio_patch_size=audio_patch_size, + audio_patch_size_t=audio_patch_size_t, + audio_num_attention_heads=audio_num_attention_heads, + audio_attention_head_dim=audio_attention_head_dim, + audio_cross_attention_dim=audio_cross_attention_dim, + audio_scale_factor=audio_scale_factor, + audio_pos_embed_max_pos=audio_pos_embed_max_pos, + audio_sampling_rate=audio_sampling_rate, + audio_hop_length=audio_hop_length, + num_layers=num_layers, + activation_fn=activation_fn, + qk_norm=qk_norm, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + caption_channels=caption_channels, + attention_bias=attention_bias, + attention_out_bias=attention_out_bias, + rope_theta=rope_theta, + rope_double_precision=rope_double_precision, + causal_offset=causal_offset, + timestep_scale_multiplier=timestep_scale_multiplier, + cross_attn_timestep_scale_multiplier=cross_attn_timestep_scale_multiplier, + rope_type=rope_type, + ) + + # 1. Patchification input projections + self.proj_in = nn.Linear(in_channels, inner_dim) + self.audio_proj_in = nn.Linear(audio_in_channels, audio_inner_dim) + + # 2. Prompt embeddings + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=audio_inner_dim + ) + + # 3. Timestep Modulation Params and Embedding + # 3.1. Global Timestep Modulation Parameters (except for cross-attention) and timestep + size embedding + # time_embed and audio_time_embed calculate both the timestep embedding and (global) modulation parameters + self.time_embed = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=6, use_additional_conditions=False) + self.audio_time_embed = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=6, use_additional_conditions=False + ) + + # 3.2. Global Cross Attention Modulation Parameters + # Used in the audio-to-video and video-to-audio cross attention layers as a global set of modulation params, + # which are then further modified by per-block modulaton params in each transformer block. + # There are 2 sets of scale/shift parameters for each modality, 1 each for audio-to-video (a2v) and + # video-to-audio (v2a) cross attention + self.av_cross_attn_video_scale_shift = LTX2AdaLayerNormSingle( + inner_dim, num_mod_params=4, use_additional_conditions=False + ) + self.av_cross_attn_audio_scale_shift = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=4, use_additional_conditions=False + ) + # Gate param for audio-to-video (a2v) cross attn (where the video is the queries (Q) and the audio is the keys + # and values (KV)) + self.av_cross_attn_video_a2v_gate = LTX2AdaLayerNormSingle( + inner_dim, num_mod_params=1, use_additional_conditions=False + ) + # Gate param for video-to-audio (v2a) cross attn (where the audio is the queries (Q) and the video is the keys + # and values (KV)) + self.av_cross_attn_audio_v2a_gate = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=1, use_additional_conditions=False + ) + + # 3.3. Output Layer Scale/Shift Modulation parameters + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.audio_scale_shift_table = nn.Parameter(torch.randn(2, audio_inner_dim) / audio_inner_dim**0.5) + + # 4. Rotary Positional Embeddings (RoPE) + # Self-Attention + self.rope = LTX2AudioVideoRotaryPosEmbed( + dim=inner_dim, + patch_size=patch_size, + patch_size_t=patch_size_t, + base_num_frames=pos_embed_max_pos, + base_height=base_height, + base_width=base_width, + scale_factors=vae_scale_factors, + theta=rope_theta, + causal_offset=causal_offset, + modality="video", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=num_attention_heads, + ) + self.audio_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_inner_dim, + patch_size=audio_patch_size, + patch_size_t=audio_patch_size_t, + base_num_frames=audio_pos_embed_max_pos, + sampling_rate=audio_sampling_rate, + hop_length=audio_hop_length, + scale_factors=[audio_scale_factor], + theta=rope_theta, + causal_offset=causal_offset, + modality="audio", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=audio_num_attention_heads, + ) + + # Audio-to-Video, Video-to-Audio Cross-Attention + cross_attn_pos_embed_max_pos = max(pos_embed_max_pos, audio_pos_embed_max_pos) + self.cross_attn_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_cross_attention_dim, + patch_size=patch_size, + patch_size_t=patch_size_t, + base_num_frames=cross_attn_pos_embed_max_pos, + base_height=base_height, + base_width=base_width, + theta=rope_theta, + causal_offset=causal_offset, + modality="video", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=num_attention_heads, + ) + self.cross_attn_audio_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_cross_attention_dim, + patch_size=audio_patch_size, + patch_size_t=audio_patch_size_t, + base_num_frames=cross_attn_pos_embed_max_pos, + sampling_rate=audio_sampling_rate, + hop_length=audio_hop_length, + theta=rope_theta, + causal_offset=causal_offset, + modality="audio", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=audio_num_attention_heads, + ) + + # 5. Transformer Blocks + self.transformer_blocks = nn.ModuleList( + [ + LTX2VideoTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + audio_dim=audio_inner_dim, + audio_num_attention_heads=audio_num_attention_heads, + audio_attention_head_dim=audio_attention_head_dim, + audio_cross_attention_dim=audio_cross_attention_dim, + qk_norm=qk_norm, + activation_fn=activation_fn, + attention_bias=attention_bias, + attention_out_bias=attention_out_bias, + eps=norm_eps, + elementwise_affine=norm_elementwise_affine, + rope_type=rope_type, + ) + for _ in range(num_layers) + ] + ) + + # 6. Output layers + self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels) + + self.audio_norm_out = nn.LayerNorm(audio_inner_dim, eps=1e-6, elementwise_affine=False) + self.audio_proj_out = nn.Linear(audio_inner_dim, audio_out_channels) + + self.gradient_checkpointing = False + self._sp_plan = self._build_sp_plan(rope_type) + + def _gradient_checkpointing_func(self, module: nn.Module, *inputs: Any): + def custom_forward(*checkpoint_inputs: Any): + return module(*checkpoint_inputs) + + return checkpoint(custom_forward, *inputs, use_reentrant=False) + + def enable_gradient_checkpointing(self) -> None: + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self) -> None: + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + audio_encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + audio_timestep: torch.LongTensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + audio_encoder_attention_mask: torch.Tensor | None = None, + num_frames: int | None = None, + height: int | None = None, + width: int | None = None, + fps: float = 24.0, + audio_num_frames: int | None = None, + video_coords: torch.Tensor | None = None, + audio_coords: torch.Tensor | None = None, + attention_kwargs: dict[str, Any] | None = None, + return_dict: bool = True, + ) -> torch.Tensor: + """ + Forward pass for LTX-2.0 audiovisual video transformer. + + Args: + hidden_states (`torch.Tensor`): + Input patchified video latents of shape `(batch_size, num_video_tokens, in_channels)`. + audio_hidden_states (`torch.Tensor`): + Input patchified audio latents of shape `(batch_size, num_audio_tokens, audio_in_channels)`. + encoder_hidden_states (`torch.Tensor`): + Input video text embeddings of shape `(batch_size, text_seq_len, self.config.caption_channels)`. + audio_encoder_hidden_states (`torch.Tensor`): + Input audio text embeddings of shape `(batch_size, text_seq_len, self.config.caption_channels)`. + timestep (`torch.Tensor`): + Input timestep of shape `(batch_size, num_video_tokens)`. These should already be scaled by + `self.config.timestep_scale_multiplier`. + audio_timestep (`torch.Tensor`, *optional*): + Input timestep of shape `(batch_size,)` or `(batch_size, num_audio_tokens)` for audio modulation + params. This is only used by certain pipelines such as the I2V pipeline. + encoder_attention_mask (`torch.Tensor`, *optional*): + Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`. + audio_encoder_attention_mask (`torch.Tensor`, *optional*): + Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)` for audio modeling. + num_frames (`int`, *optional*): + The number of latent video frames. Used if calculating the video coordinates for RoPE. + height (`int`, *optional*): + The latent video height. Used if calculating the video coordinates for RoPE. + width (`int`, *optional*): + The latent video width. Used if calculating the video coordinates for RoPE. + fps: (`float`, *optional*, defaults to `24.0`): + The desired frames per second of the generated video. Used if calculating the video coordinates for + RoPE. + audio_num_frames: (`int`, *optional*): + The number of latent audio frames. Used if calculating the audio coordinates for RoPE. + video_coords (`torch.Tensor`, *optional*): + The video coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape + `(batch_size, 3, num_video_tokens, 2)`. If not supplied, this will be calculated inside `forward`. + audio_coords (`torch.Tensor`, *optional*): + The audio coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape + `(batch_size, 1, num_audio_tokens, 2)`. If not supplied, this will be calculated inside `forward`. + attention_kwargs (`Dict[str, Any]`, *optional*): + Optional dict of keyword args to be passed to the attention processor. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a dict-like structured output of type `AudioVisualModelOutput` or a tuple. + + Returns: + `AudioVisualModelOutput` or `tuple`: + If `return_dict` is `True`, returns a structured output of type `AudioVisualModelOutput`, otherwise a + `tuple` is returned where the first element is the denoised video latent patch sequence and the second + element is the denoised audio latent patch sequence. + """ + # Determine timestep for audio. + audio_timestep = audio_timestep if audio_timestep is not None else timestep + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + if audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2: + audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0 + audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1) + + batch_size = hidden_states.size(0) + + # 1. Prepare RoPE positional embeddings + if video_coords is None: + video_coords = self.rope.prepare_video_coords( + batch_size, num_frames, height, width, hidden_states.device, fps=fps + ) + if audio_coords is None: + audio_coords = self.audio_rope.prepare_audio_coords( + batch_size, audio_num_frames, audio_hidden_states.device + ) + + video_rotary_emb = self.rope(video_coords, device=hidden_states.device) + audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) + + video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device) + audio_cross_attn_rotary_emb = self.cross_attn_audio_rope( + audio_coords[:, 0:1, :], device=audio_hidden_states.device + ) + + # 2. Patchify input projections + hidden_states = self.proj_in(hidden_states) + audio_hidden_states = self.audio_proj_in(audio_hidden_states) + + # 3. Prepare timestep embeddings and modulation parameters + timestep_cross_attn_gate_scale_factor = ( + self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier + ) + + # 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters + # temb is used in the transformer blocks (as expected), while embedded_timestep is used for the output layer + # modulation with scale_shift_table (and similarly for audio) + temb, embedded_timestep = self.time_embed( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(batch_size, -1, temb.size(-1)) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) + + temb_audio, audio_embedded_timestep = self.audio_time_embed( + audio_timestep.flatten(), + batch_size=batch_size, + hidden_dtype=audio_hidden_states.dtype, + ) + temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1)) + audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1)) + + # 3.2. Prepare global modality cross attention modulation parameters + video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate( + timestep.flatten() * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_scale_shift = video_cross_attn_scale_shift.view( + batch_size, -1, video_cross_attn_scale_shift.shape[-1] + ) + video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1]) + + audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( + audio_timestep.flatten(), + batch_size=batch_size, + hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( + audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, + hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view( + batch_size, -1, audio_cross_attn_scale_shift.shape[-1] + ) + audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1]) + + # 4. Prepare prompt embeddings + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + + audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) + audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1)) + + # 5. Run transformer blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, audio_hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + audio_hidden_states, + encoder_hidden_states, + audio_encoder_hidden_states, + temb, + temb_audio, + video_cross_attn_scale_shift, + audio_cross_attn_scale_shift, + video_cross_attn_a2v_gate, + audio_cross_attn_v2a_gate, + video_rotary_emb, + audio_rotary_emb, + video_cross_attn_rotary_emb, + audio_cross_attn_rotary_emb, + encoder_attention_mask, + audio_encoder_attention_mask, + ) + else: + hidden_states, audio_hidden_states = block( + hidden_states, + audio_hidden_states, + encoder_hidden_states, + audio_encoder_hidden_states, + temb, + temb_audio, + video_cross_attn_scale_shift, + audio_cross_attn_scale_shift, + video_cross_attn_a2v_gate, + audio_cross_attn_v2a_gate, + video_rotary_emb, + audio_rotary_emb, + video_cross_attn_rotary_emb, + audio_cross_attn_rotary_emb, + encoder_attention_mask, + audio_encoder_attention_mask, + ) + + # 6. Output layers (including unpatchification) + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + output = self.proj_out(hidden_states) + + audio_scale_shift_values = self.audio_scale_shift_table[None, None] + audio_embedded_timestep[:, :, None] + audio_shift, audio_scale = audio_scale_shift_values[:, :, 0], audio_scale_shift_values[:, :, 1] + + audio_hidden_states = self.audio_norm_out(audio_hidden_states) + audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift + audio_output = self.audio_proj_out(audio_hidden_states) + + if not return_dict: + return (output, audio_output) + return AudioVisualModelOutput(sample=output, audio_sample=audio_output) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """ + Load weights from a pretrained model, mapping separate Q/K/V projections + into fused QKV projections for self-attention blocks. + + Returns: + Set of parameter names that were successfully loaded. + """ + stacked_params_mapping = [ + (".attn1.to_qkv", ".attn1.to_q", "q"), + (".attn1.to_qkv", ".attn1.to_k", "k"), + (".attn1.to_qkv", ".attn1.to_v", "v"), + (".audio_attn1.to_qkv", ".audio_attn1.to_q", "q"), + (".audio_attn1.to_qkv", ".audio_attn1.to_k", "k"), + (".audio_attn1.to_qkv", ".audio_attn1.to_v", "v"), + ] + + params_dict = dict(self.named_parameters()) + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() if tp_size > 1 else 0 + loaded_params: set[str] = set() + + def _maybe_shard_weight(weight: torch.Tensor, param: torch.Tensor) -> torch.Tensor: + if tp_size <= 1 or weight.shape == param.shape: + return weight + + if weight.ndim == 1 and weight.numel() == param.numel() * tp_size: + return weight.chunk(tp_size, dim=0)[tp_rank] + + if weight.ndim == 2: + if weight.shape[0] == param.shape[0] * tp_size: + return weight.chunk(tp_size, dim=0)[tp_rank] + if weight.shape[1] == param.shape[1] * tp_size: + return weight.chunk(tp_size, dim=1)[tp_rank] + + return weight + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", None) + if weight_loader is not None: + weight_loader(param, loaded_weight) + else: + loaded_weight = _maybe_shard_weight(loaded_weight, param) + default_weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params diff --git a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py new file mode 100644 index 0000000000..a08e75bb2e --- /dev/null +++ b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py @@ -0,0 +1,1181 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import copy +import inspect +import json +import os +from collections.abc import Iterable +from contextlib import nullcontext +from typing import Any + +import numpy as np +import torch +from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, FlowMatchEulerDiscreteScheduler +from diffusers.pipelines.ltx2 import LTX2TextConnectors +from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg, retrieve_timesteps +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from torch import nn +from transformers import AutoTokenizer, Gemma3ForConditionalGeneration +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, +) +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.request import OmniDiffusionRequest + +from .ltx2_transformer import LTX2VideoTransformer3DModel + + +def load_transformer_config(model_path: str, subfolder: str = "transformer", local_files_only: bool = True) -> dict: + """Load transformer config from model directory or HF Hub.""" + if local_files_only: + config_path = os.path.join(model_path, subfolder, "config.json") + if os.path.exists(config_path): + with open(config_path) as f: + return json.load(f) + else: + try: + from huggingface_hub import hf_hub_download + + config_path = hf_hub_download( + repo_id=model_path, + filename=f"{subfolder}/config.json", + ) + with open(config_path) as f: + return json.load(f) + except Exception: + pass + return {} + + +def create_transformer_from_config(config: dict) -> LTX2VideoTransformer3DModel: + """Create LTX2VideoTransformer3DModel from config dict.""" + if not config: + return LTX2VideoTransformer3DModel() + + signature = inspect.signature(LTX2VideoTransformer3DModel.__init__) + allowed_keys = set(signature.parameters.keys()) + kwargs = {k: v for k, v in config.items() if k in allowed_keys} + return LTX2VideoTransformer3DModel(**kwargs) + + +def get_ltx2_post_process_func( + od_config: OmniDiffusionConfig, +): + def post_process_func(output: tuple[torch.Tensor, torch.Tensor] | torch.Tensor): + if isinstance(output, tuple) and len(output) == 2: + video, audio = output + if isinstance(audio, torch.Tensor): + audio = audio.detach().cpu() + return {"video": video, "audio": audio} + return output + + return post_process_func + + +def _unwrap_request_tensor(value: Any) -> Any: + if isinstance(value, list): + return value[0] if value else None + return value + + +def _get_prompt_field(prompt: Any, key: str) -> Any: + if isinstance(prompt, str): + return None + value = prompt.get(key) + if value is None: + additional = prompt.get("additional_information") + if isinstance(additional, dict): + value = additional.get(key) + return _unwrap_request_tensor(value) + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +class LTX2Pipeline(nn.Module, CFGParallelMixin): + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + + self.device = get_local_device() + dtype = getattr(od_config, "dtype", torch.bfloat16) + + model = od_config.model + local_files_only = os.path.exists(model) + + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ), + ] + + self.tokenizer = AutoTokenizer.from_pretrained( + model, + subfolder="tokenizer", + local_files_only=local_files_only, + ) + self.text_encoder = Gemma3ForConditionalGeneration.from_pretrained( + model, + subfolder="text_encoder", + torch_dtype=dtype, + local_files_only=local_files_only, + ).to(self.device) + self.connectors = LTX2TextConnectors.from_pretrained( + model, + subfolder="connectors", + torch_dtype=dtype, + local_files_only=local_files_only, + ).to(self.device) + + self.vae = AutoencoderKLLTX2Video.from_pretrained( + model, + subfolder="vae", + torch_dtype=dtype, + local_files_only=local_files_only, + ).to(self.device) + self.audio_vae = AutoencoderKLLTX2Audio.from_pretrained( + model, + subfolder="audio_vae", + torch_dtype=dtype, + local_files_only=local_files_only, + ).to(self.device) + self.vocoder = LTX2Vocoder.from_pretrained( + model, + subfolder="vocoder", + torch_dtype=dtype, + local_files_only=local_files_only, + ).to(self.device) + + transformer_config = load_transformer_config(model, "transformer", local_files_only) + self.transformer = create_transformer_from_config(transformer_config) + + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model, subfolder="scheduler", local_files_only=local_files_only + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.audio_vae_mel_compression_ratio = ( + self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.audio_vae_temporal_compression_ratio = ( + self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer", None) is not None else 1 + ) + + self.audio_sampling_rate = ( + self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000 + ) + self.audio_hop_length = ( + self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + tokenizer_max_length = 1024 + if getattr(self, "tokenizer", None) is not None: + tokenizer_max_length = self.tokenizer.model_max_length + if tokenizer_max_length is None or tokenizer_max_length > 100000: + encoder_config = getattr(self.text_encoder, "config", None) + config_max_len = getattr(encoder_config, "max_position_embeddings", None) + if config_max_len is None: + config_max_len = getattr(encoder_config, "max_seq_len", None) + tokenizer_max_length = config_max_len or 1024 + self.tokenizer_max_length = int(tokenizer_max_length) + + self._guidance_scale = None + self._guidance_rescale = None + self._attention_kwargs = None + self._interrupt = False + self._num_timesteps = None + self._current_timestep = None + + @staticmethod + def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: str | torch.device, + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, + ) -> torch.Tensor: + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + mask = token_indices < sequence_lengths[:, None] + elif padding_side == "left": + start_indices = seq_len - sequence_lengths[:, None] + mask = token_indices >= start_indices + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] + + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self.device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + prompt = [p.strip() for p in prompt] + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + text_input_ids = text_input_ids.to(device) + prompt_attention_mask = prompt_attention_mask.to(device) + + text_encoder_outputs = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + sequence_lengths = prompt_attention_mask.sum(dim=-1) + + prompt_embeds = self._pack_text_embeds( + text_encoder_hidden_states, + sequence_lengths, + device=device, + padding_side=self.tokenizer.padding_side, + scale_factor=scale_factor, + ) + prompt_embeds = prompt_embeds.to(dtype=dtype) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self.device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when " + "passed directly, but got: `prompt_attention_mask` " + f"{prompt_attention_mask.shape} != `negative_prompt_attention_mask` " + f"{negative_prompt_attention_mask.shape}." + ) + + @staticmethod + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents * latents_std) + latents_mean + + @staticmethod + def _pack_audio_latents( + latents: torch.Tensor, patch_size: int | None = None, patch_size_t: int | None = None + ) -> torch.Tensor: + if patch_size is not None and patch_size_t is not None: + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape + post_patch_latent_length = latent_length / patch_size_t + post_patch_mel_bins = latent_mel_bins / patch_size + latents = latents.reshape( + batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size + ) + latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + else: + latents = latents.transpose(1, 2).flatten(2, 3) + return latents + + @staticmethod + def _unpack_audio_latents( + latents: torch.Tensor, + latent_length: int, + num_mel_bins: int, + patch_size: int | None = None, + patch_size_t: int | None = None, + ) -> torch.Tensor: + if patch_size is not None and patch_size_t is not None: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) + latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + else: + latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + return latents + + def prepare_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 768, + num_frames: int = 121, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size) + + return latents + + def prepare_audio_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 8, + num_mel_bins: int = 64, + num_frames: int = 121, + frame_rate: float = 25.0, + sampling_rate: int = 16000, + hop_length: int = 160, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, int]: + duration_s = num_frames / frame_rate + latents_per_second = float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio) + latent_length = round(duration_s * latents_per_second) + + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + sp_size = getattr(self.od_config.parallel_config, "sequence_parallel_size", 1) + if sp_size > 1: + pad_len = (sp_size - (latent_length % sp_size)) % sp_size + if pad_len > 0: + if latents is not None: + pad_shape = list(latents.shape) + pad_shape[2] = pad_len + padding = torch.zeros(pad_shape, dtype=latents.dtype, device=latents.device) + latents = torch.cat([latents, padding], dim=2) + latent_length += pad_len + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_length + + shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_audio_latents(latents) + return latents, latent_length + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + def _is_cfg_parallel_enabled(self, do_true_cfg: bool) -> bool: + return do_true_cfg and get_classifier_free_guidance_world_size() > 1 + + def _transformer_cache_context(self, context_name: str): + cache_context = getattr(self.transformer, "cache_context", None) + if callable(cache_context): + return cache_context(context_name) + return nullcontext() + + def _predict_noise_av(self, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: + with self._transformer_cache_context("cond_uncond"): + noise_pred_video, noise_pred_audio = self.transformer(**kwargs) + return noise_pred_video, noise_pred_audio + + def predict_noise_av_maybe_with_cfg( + self, + do_true_cfg: bool, + true_cfg_scale: float, + positive_kwargs: dict[str, Any], + negative_kwargs: dict[str, Any] | None, + guidance_rescale: float = 0.0, + cfg_normalize: bool = False, + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + if do_true_cfg: + cfg_parallel_ready = get_classifier_free_guidance_world_size() > 1 + + if cfg_parallel_ready: + cfg_group = get_cfg_group() + cfg_rank = get_classifier_free_guidance_rank() + + if cfg_rank == 0: + noise_pred_video, noise_pred_audio = self._predict_noise_av(**positive_kwargs) + else: + noise_pred_video, noise_pred_audio = self._predict_noise_av(**negative_kwargs) + + noise_pred_video = noise_pred_video.float() + noise_pred_audio = noise_pred_audio.float() + + gathered_video = cfg_group.all_gather(noise_pred_video, separate_tensors=True) + gathered_audio = cfg_group.all_gather(noise_pred_audio, separate_tensors=True) + + if cfg_rank == 0: + noise_pred_video_text = gathered_video[0] + noise_pred_video_uncond = gathered_video[1] + noise_pred_audio_text = gathered_audio[0] + noise_pred_audio_uncond = gathered_audio[1] + + noise_pred_video = self.combine_cfg_noise( + noise_pred_video_text, + noise_pred_video_uncond, + true_cfg_scale, + cfg_normalize, + ) + noise_pred_audio = self.combine_cfg_noise( + noise_pred_audio_text, + noise_pred_audio_uncond, + true_cfg_scale, + cfg_normalize, + ) + + if guidance_rescale > 0: + noise_pred_video = rescale_noise_cfg( + noise_pred_video, + noise_pred_video_text, + guidance_rescale=guidance_rescale, + ) + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio, + noise_pred_audio_text, + guidance_rescale=guidance_rescale, + ) + return noise_pred_video, noise_pred_audio + return None, None + + noise_pred_video_text, noise_pred_audio_text = self._predict_noise_av(**positive_kwargs) + noise_pred_video_uncond, noise_pred_audio_uncond = self._predict_noise_av(**negative_kwargs) + + noise_pred_video_text = noise_pred_video_text.float() + noise_pred_audio_text = noise_pred_audio_text.float() + noise_pred_video_uncond = noise_pred_video_uncond.float() + noise_pred_audio_uncond = noise_pred_audio_uncond.float() + + noise_pred_video = self.combine_cfg_noise( + noise_pred_video_text, + noise_pred_video_uncond, + true_cfg_scale, + cfg_normalize, + ) + noise_pred_audio = self.combine_cfg_noise( + noise_pred_audio_text, + noise_pred_audio_uncond, + true_cfg_scale, + cfg_normalize, + ) + + if guidance_rescale > 0: + noise_pred_video = rescale_noise_cfg( + noise_pred_video, + noise_pred_video_text, + guidance_rescale=guidance_rescale, + ) + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio, + noise_pred_audio_text, + guidance_rescale=guidance_rescale, + ) + + return noise_pred_video, noise_pred_audio + + noise_pred_video, noise_pred_audio = self._predict_noise_av(**positive_kwargs) + return noise_pred_video.float(), noise_pred_audio.float() + + def _scheduler_step_video_audio_maybe_with_cfg( + self, + noise_pred_video: torch.Tensor | None, + noise_pred_audio: torch.Tensor | None, + t: torch.Tensor, + latents: torch.Tensor, + audio_latents: torch.Tensor, + audio_scheduler: FlowMatchEulerDiscreteScheduler, + do_true_cfg: bool, + ) -> tuple[torch.Tensor, torch.Tensor]: + cfg_parallel_ready = self._is_cfg_parallel_enabled(do_true_cfg) + + if cfg_parallel_ready: + cfg_group = get_cfg_group() + cfg_rank = get_classifier_free_guidance_rank() + + if cfg_rank == 0: + latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0] + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + latents = latents.contiguous() + audio_latents = audio_latents.contiguous() + cfg_group.broadcast(latents, src=0) + cfg_group.broadcast(audio_latents, src=0) + return latents, audio_latents + + latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0] + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + return latents, audio_latents + + @torch.no_grad() + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_frames: int | None = None, + frame_rate: float | None = None, + num_inference_steps: int | None = None, + timesteps: list[int] | None = None, + guidance_scale: float = 4.0, + guidance_rescale: float = 0.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + audio_latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + decode_timestep: float | list[float] = 0.0, + decode_noise_scale: float | list[float] | None = None, + output_type: str = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + max_sequence_length: int | None = None, + ) -> DiffusionOutput: + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): + negative_prompt = None + elif req.prompts: + negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + + height = req.sampling_params.height or height or 512 + width = req.sampling_params.width or width or 768 + num_frames = req.sampling_params.num_frames or num_frames or 121 + req_fps = req.sampling_params.fps + if isinstance(req_fps, list): + req_fps = req_fps[0] if req_fps else None + frame_rate = ( + req.sampling_params.frame_rate or (float(req_fps) if req_fps is not None else None) or frame_rate or 24.0 + ) + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps or 40 + if timesteps is None: + num_inference_steps = max(int(num_inference_steps), 2) + elif len(timesteps) < 2: + raise ValueError("`timesteps` must contain at least 2 values for FlowMatchEulerDiscreteScheduler.") + num_videos_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_videos_per_prompt or 1 + ) + max_sequence_length = ( + req.sampling_params.max_sequence_length or max_sequence_length or self.tokenizer_max_length + ) + + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + if req.sampling_params.guidance_rescale is not None: + guidance_rescale = req.sampling_params.guidance_rescale + + if generator is None: + generator = req.sampling_params.generator + if generator is None and req.sampling_params.seed is not None: + generator = torch.Generator(device=self.device).manual_seed(req.sampling_params.seed) + + latents = req.sampling_params.latents if req.sampling_params.latents is not None else latents + audio_latents = ( + req.sampling_params.audio_latents + if req.sampling_params.audio_latents is not None + else req.sampling_params.extra_args.get("audio_latents", audio_latents) + ) + + req_prompt_embeds = [_get_prompt_field(p, "prompt_embeds") for p in req.prompts] + if any(p is not None for p in req_prompt_embeds): + prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore[arg-type] + + req_negative_prompt_embeds = [_get_prompt_field(p, "negative_prompt_embeds") for p in req.prompts] + if any(p is not None for p in req_negative_prompt_embeds): + negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore[arg-type] + + req_prompt_attention_masks = [ + _get_prompt_field(p, "prompt_attention_mask") or _get_prompt_field(p, "attention_mask") for p in req.prompts + ] + if any(m is not None for m in req_prompt_attention_masks): + prompt_attention_mask = torch.stack(req_prompt_attention_masks) # type: ignore[arg-type] + + req_negative_attention_masks = [ + _get_prompt_field(p, "negative_prompt_attention_mask") or _get_prompt_field(p, "negative_attention_mask") + for p in req.prompts + ] + if any(m is not None for m in req_negative_attention_masks): + negative_prompt_attention_mask = torch.stack(req_negative_attention_masks) # type: ignore[arg-type] + + if req.sampling_params.decode_timestep is not None: + decode_timestep = req.sampling_params.decode_timestep + if req.sampling_params.decode_noise_scale is not None: + decode_noise_scale = req.sampling_params.decode_noise_scale + if req.sampling_params.output_type is not None: + output_type = req.sampling_params.output_type + + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self.device + + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + cfg_parallel_ready = self._is_cfg_parallel_enabled(self.do_classifier_free_guidance) + if self.do_classifier_free_guidance and not cfg_parallel_ready: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, additive_attention_mask, additive_mask=True + ) + + negative_connector_prompt_embeds = None + negative_connector_audio_prompt_embeds = None + negative_connector_attention_mask = None + if cfg_parallel_ready: + negative_additive_attention_mask = ( + 1 - negative_prompt_attention_mask.to(negative_prompt_embeds.dtype) + ) * -1000000.0 + ( + negative_connector_prompt_embeds, + negative_connector_audio_prompt_embeds, + negative_connector_attention_mask, + ) = self.connectors( + negative_prompt_embeds, + negative_additive_attention_mask, + additive_mask=True, + ) + + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + video_sequence_length = latent_num_frames * latent_height * latent_width + + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + num_channels_latents_audio = ( + self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 + ) + audio_latents, audio_num_frames = self.prepare_audio_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents_audio, + num_mel_bins=num_mel_bins, + num_frames=num_frames, + frame_rate=frame_rate, + sampling_rate=self.audio_sampling_rate, + hop_length=self.audio_hop_length, + dtype=torch.float32, + device=device, + generator=generator, + latents=audio_latents, + ) + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + audio_scheduler = copy.deepcopy(self.scheduler) + _ = retrieve_timesteps( + audio_scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + video_coords = self.transformer.rope.prepare_video_coords( + latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device + ) + + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + if cfg_parallel_ready: + latent_model_input = latents.to(prompt_embeds.dtype) + audio_latent_model_input = audio_latents.to(prompt_embeds.dtype) + timestep = t.expand(latent_model_input.shape[0]) + + positive_kwargs = { + "hidden_states": latent_model_input, + "audio_hidden_states": audio_latent_model_input, + "encoder_hidden_states": connector_prompt_embeds, + "audio_encoder_hidden_states": connector_audio_prompt_embeds, + "timestep": timestep, + "encoder_attention_mask": connector_attention_mask, + "audio_encoder_attention_mask": connector_attention_mask, + "num_frames": latent_num_frames, + "height": latent_height, + "width": latent_width, + "fps": frame_rate, + "audio_num_frames": audio_num_frames, + "video_coords": video_coords, + "audio_coords": audio_coords, + "attention_kwargs": attention_kwargs, + "return_dict": False, + } + negative_kwargs = { + "hidden_states": latent_model_input, + "audio_hidden_states": audio_latent_model_input, + "encoder_hidden_states": negative_connector_prompt_embeds, + "audio_encoder_hidden_states": negative_connector_audio_prompt_embeds, + "timestep": timestep, + "encoder_attention_mask": negative_connector_attention_mask, + "audio_encoder_attention_mask": negative_connector_attention_mask, + "num_frames": latent_num_frames, + "height": latent_height, + "width": latent_width, + "fps": frame_rate, + "audio_num_frames": audio_num_frames, + "video_coords": video_coords, + "audio_coords": audio_coords, + "attention_kwargs": attention_kwargs, + "return_dict": False, + } + + noise_pred_video, noise_pred_audio = self.predict_noise_av_maybe_with_cfg( + do_true_cfg=True, + true_cfg_scale=guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + guidance_rescale=guidance_rescale, + cfg_normalize=False, + ) + + latents, audio_latents = self._scheduler_step_video_audio_maybe_with_cfg( + noise_pred_video, + noise_pred_audio, + t, + latents, + audio_latents, + audio_scheduler, + do_true_cfg=True, + ) + else: + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + audio_latent_model_input = ( + torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + ) + audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) + + timestep = t.expand(latent_model_input.shape[0]) + + with self._transformer_cache_context("cond_uncond"): + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=timestep, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + noise_pred_video = noise_pred_video.float() + noise_pred_audio = noise_pred_audio.float() + + if self.do_classifier_free_guidance: + noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) + noise_pred_video = noise_pred_video_uncond + guidance_scale * ( + noise_pred_video_text - noise_pred_video_uncond + ) + + noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) + noise_pred_audio = noise_pred_audio_uncond + guidance_scale * ( + noise_pred_audio_text - noise_pred_audio_uncond + ) + + if guidance_rescale > 0: + noise_pred_video = rescale_noise_cfg( + noise_pred_video, noise_pred_video_text, guidance_rescale=guidance_rescale + ) + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio, noise_pred_audio_text, guidance_rescale=guidance_rescale + ) + + latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0] + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + pass + + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + + if output_type == "latent": + video = latents + audio = audio_latents + else: + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.to(self.vae.dtype) + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + audio_latents = audio_latents.to(self.audio_vae.dtype) + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + audio = self.vocoder(generated_mel_spectrograms) + + if not return_dict: + return DiffusionOutput(output=(video, audio)) + + return DiffusionOutput(output=(video, audio)) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py new file mode 100644 index 0000000000..8419a206e8 --- /dev/null +++ b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py @@ -0,0 +1,710 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import copy +from typing import Any + +import numpy as np +import PIL.Image +import torch +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg, retrieve_timesteps +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_latents +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.parallel_state import get_cfg_group, get_classifier_free_guidance_rank +from vllm_omni.diffusion.request import OmniDiffusionRequest + +from .pipeline_ltx2 import ( + LTX2Pipeline, + _get_prompt_field, + calculate_shift, +) +from .pipeline_ltx2 import ( + get_ltx2_post_process_func as _get_ltx2_post_process_func, +) + + +def get_ltx2_post_process_func(od_config: OmniDiffusionConfig): + return _get_ltx2_post_process_func(od_config) + + +class LTX2ImageToVideoPipeline(LTX2Pipeline): + support_image_input = True + + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__(od_config=od_config, prefix=prefix) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio, resample="bilinear") + + @staticmethod + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + def prepare_latents( + self, + image: torch.Tensor | None = None, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 768, + num_frames: int = 121, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + mask_shape = (batch_size, 1, num_frames, height, width) + + if latents is not None: + conditioning_mask = latents.new_zeros(mask_shape) + conditioning_mask[:, :, 0] = 1.0 + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ).squeeze(-1) + if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape: + raise ValueError( + "Provided `latents` tensor has shape" + f" {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}." + ) + return latents.to(device=device, dtype=dtype), conditioning_mask + + if image is None: + raise ValueError("`image` must be provided when `latents` is None.") + + image_batch_size = image.shape[0] + if image_batch_size == 0: + raise ValueError("`image` batch is empty.") + if batch_size % image_batch_size != 0: + raise ValueError( + f"`batch_size` ({batch_size}) must be divisible by image batch size ({image_batch_size}) " + "for image-to-video outputs." + ) + num_videos_per_prompt = batch_size // image_batch_size + + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective" + f" batch size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + image_generators = [generator[i * num_videos_per_prompt] for i in range(image_batch_size)] + init_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), image_generators[i], "argmax") + for i in range(image_batch_size) + ] + else: + init_latents = [ + retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator, "argmax") for img in image + ] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + if num_videos_per_prompt > 1: + init_latents = init_latents.repeat_interleave(num_videos_per_prompt, dim=0) + init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) + init_latents = init_latents.repeat(1, 1, num_frames, 1, 1) + + conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype) + conditioning_mask[:, :, 0] = 1.0 + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask) + + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ).squeeze(-1) + latents = self._pack_latents(latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size) + + return latents, conditioning_mask + + def check_inputs( + self, + image, + height, + width, + prompt, + latents=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if image is None and latents is None: + raise ValueError("Provide either `image` or `latents`. Cannot leave both undefined.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + if prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when" + " passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + def _step_video_latents_i2v( + self, + noise_pred_video: torch.Tensor, + latents: torch.Tensor, + t: torch.Tensor, + latent_num_frames: int, + latent_height: int, + latent_width: int, + ) -> torch.Tensor: + noise_pred_video = self._unpack_latents( + noise_pred_video, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents_unpacked = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + + noise_pred_video = noise_pred_video[:, :, 1:] + noise_latents = latents_unpacked[:, :, 1:] + pred_latents = self.scheduler.step(noise_pred_video, t, noise_latents, return_dict=False)[0] + + latents_unpacked = torch.cat([latents_unpacked[:, :, :1], pred_latents], dim=2) + latents = self._pack_latents( + latents_unpacked, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + return latents + + @torch.no_grad() + def forward( + self, + req: OmniDiffusionRequest, + image: PIL.Image.Image | torch.Tensor | None = None, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_frames: int | None = None, + frame_rate: float | None = None, + num_inference_steps: int | None = None, + timesteps: list[int] | None = None, + guidance_scale: float = 4.0, + guidance_rescale: float = 0.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + audio_latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + decode_timestep: float | list[float] = 0.0, + decode_noise_scale: float | list[float] | None = None, + output_type: str = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + max_sequence_length: int | None = None, + ) -> DiffusionOutput: + prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt + if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts): + negative_prompt = None + elif req.prompts: + negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] + + height = req.sampling_params.height or height or 512 + width = req.sampling_params.width or width or 768 + num_frames = req.sampling_params.num_frames or num_frames or 121 + req_fps = req.sampling_params.fps + if isinstance(req_fps, list): + req_fps = req_fps[0] if req_fps else None + frame_rate = ( + req.sampling_params.frame_rate or (float(req_fps) if req_fps is not None else None) or frame_rate or 24.0 + ) + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps or 40 + if timesteps is None: + num_inference_steps = max(int(num_inference_steps), 2) + elif len(timesteps) < 2: + raise ValueError("`timesteps` must contain at least 2 values for FlowMatchEulerDiscreteScheduler.") + num_videos_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_videos_per_prompt or 1 + ) + max_sequence_length = ( + req.sampling_params.max_sequence_length or max_sequence_length or self.tokenizer_max_length + ) + + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + if req.sampling_params.guidance_rescale is not None: + guidance_rescale = req.sampling_params.guidance_rescale + + if generator is None: + generator = req.sampling_params.generator + if generator is None and req.sampling_params.seed is not None: + generator = torch.Generator(device=self.device).manual_seed(req.sampling_params.seed) + + latents = req.sampling_params.latents if req.sampling_params.latents is not None else latents + audio_latents = ( + req.sampling_params.audio_latents + if req.sampling_params.audio_latents is not None + else req.sampling_params.extra_args.get("audio_latents", audio_latents) + ) + + req_prompt_embeds = [_get_prompt_field(p, "prompt_embeds") for p in req.prompts] + if any(p is not None for p in req_prompt_embeds): + prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore[arg-type] + + req_negative_prompt_embeds = [_get_prompt_field(p, "negative_prompt_embeds") for p in req.prompts] + if any(p is not None for p in req_negative_prompt_embeds): + negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore[arg-type] + + req_prompt_attention_masks = [ + _get_prompt_field(p, "prompt_attention_mask") or _get_prompt_field(p, "attention_mask") for p in req.prompts + ] + if any(m is not None for m in req_prompt_attention_masks): + prompt_attention_mask = torch.stack(req_prompt_attention_masks) # type: ignore[arg-type] + + req_negative_attention_masks = [ + _get_prompt_field(p, "negative_prompt_attention_mask") or _get_prompt_field(p, "negative_attention_mask") + for p in req.prompts + ] + if any(m is not None for m in req_negative_attention_masks): + negative_prompt_attention_mask = torch.stack(req_negative_attention_masks) # type: ignore[arg-type] + + if req.sampling_params.decode_timestep is not None: + decode_timestep = req.sampling_params.decode_timestep + if req.sampling_params.decode_noise_scale is not None: + decode_noise_scale = req.sampling_params.decode_noise_scale + if req.sampling_params.output_type is not None: + output_type = req.sampling_params.output_type + + if image is None and req.prompts: + raw_images = [] + for prompt_item in req.prompts: + if isinstance(prompt_item, str): + raw_image = None + else: + multi_modal_data = prompt_item.get("multi_modal_data") or {} + raw_image = multi_modal_data.get("image") + if raw_image is None: + additional = prompt_item.get("additional_information") or {} + raw_image = ( + additional.get("preprocessed_image") + or additional.get("pixel_values") + or additional.get("image") + ) + if isinstance(raw_image, list): + raw_image = raw_image[0] if raw_image else None + if isinstance(raw_image, str): + raw_image = PIL.Image.open(raw_image).convert("RGB") + raw_images.append(raw_image) + + if any(img is None for img in raw_images): + if latents is None: + raise ValueError("Image is required for LTX2 I2V generation.") + if len(raw_images) == 1: + image = raw_images[0] + elif raw_images: + image = raw_images + + self.check_inputs( + image=image, + height=height, + width=width, + prompt=prompt, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self.device + + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + cfg_parallel_ready = self._is_cfg_parallel_enabled(self.do_classifier_free_guidance) + if self.do_classifier_free_guidance and not cfg_parallel_ready: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, additive_attention_mask, additive_mask=True + ) + + negative_connector_prompt_embeds = None + negative_connector_audio_prompt_embeds = None + negative_connector_attention_mask = None + if cfg_parallel_ready: + negative_additive_attention_mask = ( + 1 - negative_prompt_attention_mask.to(negative_prompt_embeds.dtype) + ) * -1000000.0 + ( + negative_connector_prompt_embeds, + negative_connector_audio_prompt_embeds, + negative_connector_attention_mask, + ) = self.connectors( + negative_prompt_embeds, + negative_additive_attention_mask, + additive_mask=True, + ) + + if latents is None: + if isinstance(image, torch.Tensor): + if image.ndim == 3: + image = image.unsqueeze(0) + elif isinstance(image, list) and image and isinstance(image[0], torch.Tensor): + image = torch.stack(image, dim=0) + else: + image = self.video_processor.preprocess(image, height=height, width=width) + image = image.to(device=device, dtype=prompt_embeds.dtype) + + num_channels_latents = self.transformer.config.in_channels + latents, conditioning_mask = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + if self.do_classifier_free_guidance and not cfg_parallel_ready: + conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + num_channels_latents_audio = ( + self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 + ) + audio_latents, audio_num_frames = self.prepare_audio_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents_audio, + num_mel_bins=num_mel_bins, + num_frames=num_frames, + frame_rate=frame_rate, + sampling_rate=self.audio_sampling_rate, + hop_length=self.audio_hop_length, + dtype=torch.float32, + device=device, + generator=generator, + latents=audio_latents, + ) + + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + video_sequence_length = latent_num_frames * latent_height * latent_width + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + audio_scheduler = copy.deepcopy(self.scheduler) + _ = retrieve_timesteps( + audio_scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + video_coords = self.transformer.rope.prepare_video_coords( + latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device + ) + + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + if cfg_parallel_ready: + latent_model_input = latents.to(prompt_embeds.dtype) + audio_latent_model_input = audio_latents.to(prompt_embeds.dtype) + + timestep = t.expand(latent_model_input.shape[0]) + video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) + + positive_kwargs = { + "hidden_states": latent_model_input, + "audio_hidden_states": audio_latent_model_input, + "encoder_hidden_states": connector_prompt_embeds, + "audio_encoder_hidden_states": connector_audio_prompt_embeds, + "timestep": video_timestep, + "audio_timestep": timestep, + "encoder_attention_mask": connector_attention_mask, + "audio_encoder_attention_mask": connector_attention_mask, + "num_frames": latent_num_frames, + "height": latent_height, + "width": latent_width, + "fps": frame_rate, + "audio_num_frames": audio_num_frames, + "video_coords": video_coords, + "audio_coords": audio_coords, + "attention_kwargs": attention_kwargs, + "return_dict": False, + } + negative_kwargs = { + "hidden_states": latent_model_input, + "audio_hidden_states": audio_latent_model_input, + "encoder_hidden_states": negative_connector_prompt_embeds, + "audio_encoder_hidden_states": negative_connector_audio_prompt_embeds, + "timestep": video_timestep, + "audio_timestep": timestep, + "encoder_attention_mask": negative_connector_attention_mask, + "audio_encoder_attention_mask": negative_connector_attention_mask, + "num_frames": latent_num_frames, + "height": latent_height, + "width": latent_width, + "fps": frame_rate, + "audio_num_frames": audio_num_frames, + "video_coords": video_coords, + "audio_coords": audio_coords, + "attention_kwargs": attention_kwargs, + "return_dict": False, + } + + noise_pred_video, noise_pred_audio = self.predict_noise_av_maybe_with_cfg( + do_true_cfg=True, + true_cfg_scale=guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + guidance_rescale=guidance_rescale, + cfg_normalize=False, + ) + + if get_classifier_free_guidance_rank() == 0: + latents = self._step_video_latents_i2v( + noise_pred_video, + latents, + t, + latent_num_frames, + latent_height, + latent_width, + ) + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + cfg_group = get_cfg_group() + latents = latents.contiguous() + audio_latents = audio_latents.contiguous() + cfg_group.broadcast(latents, src=0) + cfg_group.broadcast(audio_latents, src=0) + else: + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + audio_latent_model_input = ( + torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + ) + audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) + + timestep = t.expand(latent_model_input.shape[0]) + video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) + + with self._transformer_cache_context("cond_uncond"): + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + noise_pred_video = noise_pred_video.float() + noise_pred_audio = noise_pred_audio.float() + + if self.do_classifier_free_guidance: + noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) + noise_pred_video = noise_pred_video_uncond + guidance_scale * ( + noise_pred_video_text - noise_pred_video_uncond + ) + + noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) + noise_pred_audio = noise_pred_audio_uncond + guidance_scale * ( + noise_pred_audio_text - noise_pred_audio_uncond + ) + + if guidance_rescale > 0: + noise_pred_video = rescale_noise_cfg( + noise_pred_video, noise_pred_video_text, guidance_rescale=guidance_rescale + ) + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio, noise_pred_audio_text, guidance_rescale=guidance_rescale + ) + + latents = self._step_video_latents_i2v( + noise_pred_video, + latents, + t, + latent_num_frames, + latent_height, + latent_width, + ) + + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + pass + + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + + if output_type == "latent": + video = latents + audio = audio_latents + else: + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.to(self.vae.dtype) + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + audio_latents = audio_latents.to(self.audio_vae.dtype) + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + audio = self.vocoder(generated_mel_spectrograms) + + if not return_dict: + return DiffusionOutput(output=(video, audio)) + + return DiffusionOutput(output=(video, audio)) diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index ed99b1473b..7ca1afa167 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -55,6 +55,16 @@ "pipeline_wan2_2", "Wan22Pipeline", ), + "LTX2Pipeline": ( + "ltx2", + "pipeline_ltx2", + "LTX2Pipeline", + ), + "LTX2ImageToVideoPipeline": ( + "ltx2", + "pipeline_ltx2_image2video", + "LTX2ImageToVideoPipeline", + ), "StableAudioPipeline": ( "stable_audio", "pipeline_stable_audio", @@ -286,6 +296,8 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "ZImagePipeline": "get_post_process_func", "OvisImagePipeline": "get_ovis_image_post_process_func", "WanPipeline": "get_wan22_post_process_func", + "LTX2Pipeline": "get_ltx2_post_process_func", + "LTX2ImageToVideoPipeline": "get_ltx2_post_process_func", "StableAudioPipeline": "get_stable_audio_post_process_func", "WanImageToVideoPipeline": "get_wan22_i2v_post_process_func", "LongCatImagePipeline": "get_longcat_image_post_process_func", diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 19c8b07ace..ea5732691a 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -180,6 +180,7 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[st }, "engine_args": { "parallel_config": parallel_config, + "model_class_name": kwargs.get("model_class_name", None), "vae_use_slicing": kwargs.get("vae_use_slicing", False), "vae_use_tiling": kwargs.get("vae_use_tiling", False), "cache_backend": cache_backend, diff --git a/vllm_omni/entrypoints/async_omni_diffusion.py b/vllm_omni/entrypoints/async_omni_diffusion.py index 34187ec141..7ff7ecff52 100644 --- a/vllm_omni/entrypoints/async_omni_diffusion.py +++ b/vllm_omni/entrypoints/async_omni_diffusion.py @@ -86,7 +86,8 @@ def __init__( try: config_dict = get_hf_file_to_dict("model_index.json", od_config.model) if config_dict is not None: - od_config.model_class_name = config_dict.get("_class_name", None) + if od_config.model_class_name is None: + od_config.model_class_name = config_dict.get("_class_name", None) od_config.update_multimodal_support() tf_config_dict = get_hf_file_to_dict("transformer/config.json", od_config.model) diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index c9ed7150f7..5f5e337c43 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -193,6 +193,13 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu default=None, help="Number of GPUs to use for diffusion model inference.", ) + omni_config_group.add_argument( + "--model-class-name", + dest="model_class_name", + type=str, + default=None, + help="Override the diffusion pipeline class name (e.g. LTX2ImageToVideoPipeline).", + ) omni_config_group.add_argument( "--usp", "--ulysses-degree", diff --git a/vllm_omni/entrypoints/omni_diffusion.py b/vllm_omni/entrypoints/omni_diffusion.py index 880847533c..419640d4b5 100644 --- a/vllm_omni/entrypoints/omni_diffusion.py +++ b/vllm_omni/entrypoints/omni_diffusion.py @@ -57,7 +57,8 @@ def __init__(self, od_config: OmniDiffusionConfig | None = None, **kwargs): od_config.model, ) if config_dict is not None: - od_config.model_class_name = config_dict.get("_class_name", None) + if od_config.model_class_name is None: + od_config.model_class_name = config_dict.get("_class_name", None) od_config.update_multimodal_support() tf_config_dict = get_hf_file_to_dict( @@ -93,7 +94,8 @@ def __init__(self, od_config: OmniDiffusionConfig | None = None, **kwargs): if pipeline_class is None: raise ValueError(f"Unknown model type: {model_type}, architectures: {architectures}") - od_config.model_class_name = pipeline_class + if od_config.model_class_name is None: + od_config.model_class_name = pipeline_class od_config.tf_model_config = TransformerConfig() od_config.update_multimodal_support() diff --git a/vllm_omni/entrypoints/openai/serving_video.py b/vllm_omni/entrypoints/openai/serving_video.py index c5140f24ab..0f96865cf2 100644 --- a/vllm_omni/entrypoints/openai/serving_video.py +++ b/vllm_omni/entrypoints/openai/serving_video.py @@ -89,12 +89,17 @@ async def generate_videos( status_code=HTTPStatus.BAD_REQUEST.value, detail=str(exc), ) from exc - if input_image is not None: - prompt["multi_modal_data"] = {"image": input_image} gen_params = OmniDiffusionSamplingParams(num_outputs_per_prompt=request.n) width, height, num_frames, fps = self._resolve_video_params(request) + if input_image is not None and width is not None and height is not None: + target_size = (width, height) + if input_image.size != target_size: + input_image = input_image.resize(target_size, Image.Resampling.LANCZOS) + if input_image is not None: + prompt["multi_modal_data"] = {"image": input_image} + if width is not None and height is not None: gen_params.width = width gen_params.height = height @@ -102,6 +107,7 @@ async def generate_videos( gen_params.num_frames = num_frames if fps is not None: gen_params.fps = fps + gen_params.frame_rate = float(fps) if request.num_inference_steps is not None: gen_params.num_inference_steps = request.num_inference_steps @@ -136,9 +142,21 @@ async def generate_videos( result = await self._run_generation(prompt, gen_params, request_id, raw_request) videos = self._extract_video_outputs(result) + audios = self._extract_audio_outputs(result, expected_count=len(videos)) output_fps = fps or 24 - - video_data = [VideoData(b64_json=encode_video_base64(video, fps=output_fps)) for video in videos] + audio_sample_rate = 24000 + + video_data = [ + VideoData( + b64_json=encode_video_base64( + video, + fps=output_fps, + audio=audios[idx], + audio_sample_rate=audio_sample_rate if audios[idx] is not None else None, + ) + ) + for idx, video in enumerate(videos) + ] return VideoGenerationResponse(created=int(time.time()), data=video_data) def _resolve_model_name(self, raw_request: Request | None) -> str | None: @@ -323,3 +341,35 @@ def _extract_video_outputs(self, result: Any) -> list[Any]: detail="No video outputs found in generation result.", ) return normalized + + @staticmethod + def _extract_audio_outputs(result: Any, expected_count: int) -> list[Any | None]: + audio = None + if hasattr(result, "multimodal_output") and result.multimodal_output: + audio = result.multimodal_output.get("audio") + elif hasattr(result, "request_output"): + request_output = result.request_output + if isinstance(request_output, dict) and request_output.get("multimodal_output"): + mm_output = request_output.get("multimodal_output") or {} + audio = mm_output.get("audio") + elif hasattr(request_output, "multimodal_output") and request_output.multimodal_output: + audio = request_output.multimodal_output.get("audio") + + if audio is None: + return [None] * expected_count + + if isinstance(audio, (list, tuple)): + if len(audio) == expected_count and any(hasattr(item, "shape") or hasattr(item, "ndim") for item in audio): + return list(audio) + if expected_count == 1: + return [audio] + + if hasattr(audio, "ndim") and getattr(audio, "ndim", None) is not None and audio.ndim > 1: + first_dim = getattr(audio, "shape", [0])[0] + if first_dim == expected_count: + return [audio[i] for i in range(expected_count)] + + if expected_count == 1: + return [audio] + + return [audio] + [None] * max(expected_count - 1, 0) diff --git a/vllm_omni/entrypoints/openai/video_api_utils.py b/vllm_omni/entrypoints/openai/video_api_utils.py index 865b1a5283..871ed6ca8c 100644 --- a/vllm_omni/entrypoints/openai/video_api_utils.py +++ b/vllm_omni/entrypoints/openai/video_api_utils.py @@ -136,8 +136,44 @@ def _coerce_video_to_frames(video: Any) -> list[np.ndarray]: raise ValueError(f"Unsupported video payload type: {type(video)}") -def encode_video_base64(video: Any, fps: int) -> str: - """Encode a video (frames/array/tensor) to base64 MP4.""" +def _coerce_audio_to_waveform(audio: Any) -> torch.Tensor: + """Convert an audio payload into a 2-channel CPU float tensor for LTX2 export.""" + if isinstance(audio, torch.Tensor): + waveform = audio.detach().cpu() + elif isinstance(audio, np.ndarray): + waveform = torch.from_numpy(audio) + elif isinstance(audio, list): + waveform = torch.tensor(audio) + else: + raise ValueError(f"Unsupported audio payload type: {type(audio)}") + + waveform = waveform.squeeze() + + if waveform.ndim == 0: + raise ValueError("Audio payload must contain at least one sample.") + + if waveform.ndim == 1: + waveform = waveform.unsqueeze(0) + elif waveform.ndim == 2: + if waveform.shape[0] in (1, 2): + pass + elif waveform.shape[1] in (1, 2): + waveform = waveform.transpose(0, 1) + else: + raise ValueError(f"Unsupported audio payload shape: {tuple(waveform.shape)}") + else: + raise ValueError(f"Unsupported audio payload rank: {waveform.ndim}") + + if waveform.shape[0] == 1: + waveform = waveform.repeat(2, 1) + elif waveform.shape[0] != 2: + raise ValueError(f"Expected mono or stereo audio, got shape {tuple(waveform.shape)}") + + return waveform.float().contiguous() + + +def _encode_video_bytes(video: Any, fps: int, audio: Any | None = None, audio_sample_rate: int | None = None) -> bytes: + """Encode a video payload into MP4 bytes, optionally muxing audio.""" try: from diffusers.utils import export_to_video except ImportError as exc: # pragma: no cover - optional dependency @@ -150,12 +186,37 @@ def encode_video_base64(video: Any, fps: int) -> str: tmp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) tmp_file.close() try: - export_to_video(frames, tmp_file.name, fps=fps) + if audio is not None: + from diffusers.pipelines.ltx2.export_utils import encode_video as encode_ltx2_video + + if encode_ltx2_video is not None: + frames_np = np.stack(frames, axis=0) + if frames_np.ndim == 4 and frames_np.shape[-1] == 4: + frames_np = frames_np[..., :3] + frames_np = np.clip(frames_np, 0.0, 1.0) + frames_u8 = (frames_np * 255).round().clip(0, 255).astype("uint8") + video_tensor = torch.from_numpy(frames_u8) + encode_ltx2_video( + video_tensor, + fps=fps, + audio=_coerce_audio_to_waveform(audio), + audio_sample_rate=audio_sample_rate, + output_path=tmp_file.name, + ) + else: + export_to_video(frames, tmp_file.name, fps=fps) + else: + export_to_video(frames, tmp_file.name, fps=fps) with open(tmp_file.name, "rb") as f: - video_bytes = f.read() - return base64.b64encode(video_bytes).decode("utf-8") + return f.read() finally: try: os.remove(tmp_file.name) except OSError: pass + + +def encode_video_base64(video: Any, fps: int, audio: Any | None = None, audio_sample_rate: int | None = None) -> str: + """Encode a video (frames/array/tensor) to base64 MP4.""" + video_bytes = _encode_video_bytes(video, fps=fps, audio=audio, audio_sample_rate=audio_sample_rate) + return base64.b64encode(video_bytes).decode("utf-8") diff --git a/vllm_omni/inputs/data.py b/vllm_omni/inputs/data.py index bf8563bd29..ae3b0552fa 100644 --- a/vllm_omni/inputs/data.py +++ b/vllm_omni/inputs/data.py @@ -163,6 +163,7 @@ class OmniDiffusionSamplingParams: max_sequence_length: int | None = None prompt_template: dict[str, Any] | None = None do_classifier_free_guidance: bool = False + output_type: str | None = None # Batch info num_outputs_per_prompt: int = 1 @@ -187,6 +188,7 @@ class OmniDiffusionSamplingParams: # Latent tensors latents: torch.Tensor | None = None + audio_latents: torch.Tensor | None = None raw_latent_shape: torch.Tensor | None = None noise_pred: torch.Tensor | None = None image_latent: torch.Tensor | None = None @@ -201,6 +203,7 @@ class OmniDiffusionSamplingParams: height: int | None = None width: int | None = None fps: int | None = None + frame_rate: float | None = None height_not_provided: bool = False width_not_provided: bool = False @@ -216,6 +219,8 @@ class OmniDiffusionSamplingParams: guidance_scale_provided: bool = False guidance_scale_2: float | None = None guidance_rescale: float = 0.0 + decode_timestep: float | list[float] | None = None + decode_noise_scale: float | list[float] | None = None eta: float = 0.0 sigmas: list[float] | None = None