Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
f2ee36b
add ltx2
david6666666 Feb 2, 2026
ca6820f
adapt #797
david6666666 Feb 2, 2026
e697209
fix example
david6666666 Feb 2, 2026
eedefb8
support cfg
david6666666 Feb 2, 2026
9ac7ed9
support tp
david6666666 Feb 2, 2026
5668d4c
support tp2
david6666666 Feb 2, 2026
d571942
support tp3
david6666666 Feb 2, 2026
31244db
support tp4
david6666666 Feb 3, 2026
3393702
modify cache-dit config
david6666666 Feb 3, 2026
32c1d52
fix pre-commit and comment
david6666666 Feb 3, 2026
1bbe93e
fix codex
david6666666 Feb 3, 2026
30ddb16
fix comment 1
david6666666 Feb 4, 2026
4df4ac2
fix comment 2
david6666666 Feb 4, 2026
4f233ae
fix pre-commit
david6666666 Feb 4, 2026
7480bed
fix comment ZJY
david6666666 Feb 6, 2026
d25b07e
fix bug1
david6666666 Feb 6, 2026
2fba252
fix comment
david6666666 Feb 6, 2026
41e3f88
Merge branch 'main' into ltx2
david6666666 Feb 27, 2026
443da6c
fix pre-commit
david6666666 Feb 27, 2026
6527027
remove redun code
david6666666 Feb 27, 2026
781b70d
remove redun code 2
david6666666 Feb 27, 2026
55091e9
fix 1
david6666666 Feb 27, 2026
7f78347
fix comment 1
david6666666 Feb 27, 2026
e9cd92d
fix comment 2
david6666666 Feb 28, 2026
c7dccdd
fix comment 3
david6666666 Feb 28, 2026
3ec6b0f
fix comment 3
david6666666 Feb 28, 2026
fcf0cdc
fix comment 3
david6666666 Feb 28, 2026
9f6239a
fix comment 3
david6666666 Feb 28, 2026
83c546c
fix pre-commit
david6666666 Feb 28, 2026
c0bfa07
fix comment 4
david6666666 Feb 28, 2026
e0ee0d9
fix comment 5
david6666666 Mar 2, 2026
4ed6ec7
support online serving
david6666666 Mar 2, 2026
4518808
support online serving 1
david6666666 Mar 2, 2026
7d88144
support online serving 2
david6666666 Mar 2, 2026
f69a5c3
support online serving 3
david6666666 Mar 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ th {
|`ZImagePipeline` | Z-Image | `Tongyi-MAI/Z-Image-Turbo` |
| `WanPipeline` | Wan2.2-T2V, Wan2.2-TI2V | `Wan-AI/Wan2.2-T2V-A14B-Diffusers`, `Wan-AI/Wan2.2-TI2V-5B-Diffusers` |
| `WanImageToVideoPipeline` | Wan2.2-I2V | `Wan-AI/Wan2.2-I2V-A14B-Diffusers` |
| `LTX2Pipeline` | LTX-2-T2V | `Lightricks/LTX-2` |
| `LTX2ImageToVideoPipeline` | LTX-2-I2V | `Lightricks/LTX-2` |
| `OvisImagePipeline` | Ovis-Image | `OvisAI/Ovis-Image` |
|`LongcatImagePipeline` | LongCat-Image | `meituan-longcat/LongCat-Image` |
|`LongCatImageEditPipeline` | LongCat-Image-Edit | `meituan-longcat/LongCat-Image-Edit` |
Expand Down
1 change: 1 addition & 0 deletions docs/user_guide/diffusion/parallelism_acceleration.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ The following table shows which models are currently supported by parallelism me
| Model | Model Identifier | Ulysses-SP | Ring-Attention | Tensor-Parallel | HSDP |
|-------|------------------|:----------:|:--------------:|:---------------:|:----:|
| **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ✅ | ✅ | ✅ | ✅ |
| **LTX-2** | `Lightricks/LTX-2` | ✅ | ✅ | ✅ | ❌ |

### Tensor Parallelism

Expand Down
1 change: 1 addition & 0 deletions docs/user_guide/diffusion_acceleration.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ The following table shows which models are currently supported by each accelerat
| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention | CFG-Parallel | HSDP |
|-------|------------------|:--------:|:---------:|:----------:|:--------------:|:------------:|:----:|
| **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ |
| **LTX-2** | `Lightricks/LTX-2` | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ |

### Quantization

Expand Down
1 change: 1 addition & 0 deletions examples/offline_inference/image_to_video/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ Key arguments:
- `--vae-use-slicing`: Enable VAE slicing for memory optimization.
- `--vae-use-tiling`: Enable VAE tiling for memory optimization.
- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel).
- `--tensor-parallel-size`: tensor parallel size (effective for models that support TP, e.g. LTX2).
- `--enable-cpu-offload`: enable CPU offloading for diffusion models.
- `--use-hsdp`: Enable Hybrid Sharded Data Parallel to shard model weights across GPUs.
- `--hsdp-shard-size`: Number of GPUs to shard model weights across within each replica group. -1 (default) auto-calculates as world_size / replicate_size.
Expand Down
300 changes: 257 additions & 43 deletions examples/offline_inference/image_to_video/image_to_video.py

Large diffs are not rendered by default.

23 changes: 22 additions & 1 deletion examples/offline_inference/text_to_video/text_to_video.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Text-To-Video

The `Wan-AI/Wan2.2-T2V-A14B-Diffusers` pipeline generates short videos from text prompts.
The `Wan-AI/Wan2.2-T2V-A14B-Diffusers` pipeline generates short videos from text prompts. This script can also be used
for `Lightricks/LTX-2` to generate video+audio.

## Local CLI Usage

Expand All @@ -19,6 +20,23 @@ python text_to_video.py \
--output t2v_out.mp4
```

LTX2 example:

```bash
python text_to_video.py \
--model "Lightricks/LTX-2" \
--prompt "A cinematic close-up of ocean waves at golden hour." \
--negative_prompt "worst quality, inconsistent motion, blurry, jittery, distorted" \
--height 512 \
--width 768 \
--num_frames 121 \
--num_inference_steps 40 \
--guidance_scale 4.0 \
--frame_rate 24 \
--fps 24 \
--output ltx2_out.mp4
```

Key arguments:

- `--prompt`: text description (string).
Expand All @@ -32,6 +50,9 @@ Key arguments:
- `--vae-use-slicing`: enable VAE slicing for memory optimization.
- `--vae-use-tiling`: enable VAE tiling for memory optimization.
- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel).
- `--tensor-parallel-size`: tensor parallel size (effective for models that support TP, e.g. LTX2).
- `--enable-cpu-offload`: enable CPU offloading for diffusion models.
- `--frame_rate`: generation FPS for pipelines that require it (e.g., LTX2).
- `--audio_sample_rate`: audio sample rate for embedded audio (when the pipeline returns audio).

> ℹ️ If you encounter OOM errors, try using `--vae-use-slicing` and `--vae-use-tiling` to reduce memory usage.
219 changes: 184 additions & 35 deletions examples/offline_inference/text_to_video/text_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--height", type=int, default=720, help="Video height.")
parser.add_argument("--width", type=int, default=1280, help="Video width.")
parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (Wan default is 81).")
parser.add_argument(
"--frame-rate",
type=float,
default=None,
help="Optional generation frame rate (used by models like LTX2). Defaults to --fps.",
)
parser.add_argument("--num-inference-steps", type=int, default=40, help="Sampling steps.")
parser.add_argument(
"--boundary-ratio",
Expand Down Expand Up @@ -109,12 +115,31 @@ def parse_args() -> argparse.Namespace:
default=1,
help="Number of GPUs used for tensor parallelism (TP) inside the DiT.",
)
parser.add_argument(
"--audio_sample_rate",
type=int,
default=24000,
help="Sample rate for audio output when saved (default: 24000 for LTX2).",
)
parser.add_argument(
"--cache_backend",
type=str,
default=None,
choices=["cache_dit", "tea_cache"],
help=(
"Cache backend to use for acceleration. "
"Options: 'cache_dit' (DBCache + SCM + TaylorSeer), 'tea_cache' (Timestep Embedding Aware Cache). "
"Default: None (no cache acceleration)."
),
)

return parser.parse_args()


def main():
args = parse_args()
generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed)
frame_rate = args.frame_rate if args.frame_rate is not None else float(args.fps)

# Wan2.2 cache-dit tuning (from cache-dit examples and cache_alignment).
cache_config = None
Expand All @@ -134,8 +159,7 @@ def main():
"scm_steps_mask_policy": None, # SCM mask policy: None (disabled), "slow", "medium", "fast", "ultra"
"scm_steps_policy": "dynamic", # SCM steps policy: "dynamic" or "static"
}
# Configure parallel settings (only SP is supported for Wan)
# Note: cfg_parallel and tensor_parallel are not implemented for Wan models
# Configure parallel settings
parallel_config = DiffusionParallelConfig(
ulysses_degree=args.ulysses_degree,
ring_degree=args.ring_degree,
Expand All @@ -153,12 +177,12 @@ def main():
vae_use_tiling=args.vae_use_tiling,
boundary_ratio=args.boundary_ratio,
flow_shift=args.flow_shift,
cache_backend=args.cache_backend,
cache_config=cache_config,
enable_cache_dit_summary=args.enable_cache_dit_summary,
enable_cpu_offload=args.enable_cpu_offload,
parallel_config=parallel_config,
enforce_eager=args.enforce_eager,
cache_backend=args.cache_backend,
cache_config=cache_config,
)

if profiler_enabled:
Expand All @@ -172,7 +196,11 @@ def main():
print(f" Inference steps: {args.num_inference_steps}")
print(f" Frames: {args.num_frames}")
print(
f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}, tensor_parallel_size={args.tensor_parallel_size}"
" Parallel configuration: "
f"ulysses_degree={args.ulysses_degree}, "
f"ring_degree={args.ring_degree}, "
f"cfg_parallel_size={args.cfg_parallel_size}, "
f"tensor_parallel_size={args.tensor_parallel_size}"
)
print(f" Video size: {args.width}x{args.height}")
print(f"{'=' * 60}\n")
Expand All @@ -191,6 +219,7 @@ def main():
guidance_scale_2=args.guidance_scale_high,
num_inference_steps=args.num_inference_steps,
num_frames=args.num_frames,
frame_rate=frame_rate,
),
)
generation_end = time.perf_counter()
Expand All @@ -199,64 +228,184 @@ def main():
# Print profiling results
print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)")

# Extract video frames from OmniRequestOutput
if isinstance(frames, list) and len(frames) > 0:
first_item = frames[0]
audio = None
if isinstance(frames, list):
frames = frames[0] if frames else None

# Check if it's an OmniRequestOutput
if hasattr(first_item, "final_output_type"):
if first_item.final_output_type != "image":
raise ValueError(
f"Unexpected output type '{first_item.final_output_type}', expected 'image' for video generation."
)

# Pipeline mode: extract from nested request_output
if hasattr(first_item, "is_pipeline_output") and first_item.is_pipeline_output:
if isinstance(first_item.request_output, list) and len(first_item.request_output) > 0:
inner_output = first_item.request_output[0]
if isinstance(inner_output, OmniRequestOutput) and hasattr(inner_output, "images"):
frames = inner_output.images[0] if inner_output.images else None
if frames is None:
raise ValueError("No video frames found in output.")
# Diffusion mode: use direct images field
elif hasattr(first_item, "images") and first_item.images:
frames = first_item.images
if isinstance(frames, OmniRequestOutput):
if frames.final_output_type != "image":
raise ValueError(
f"Unexpected output type '{frames.final_output_type}', expected 'image' for video generation."
)
if frames.multimodal_output and "audio" in frames.multimodal_output:
audio = frames.multimodal_output["audio"]
if frames.is_pipeline_output and frames.request_output is not None:
inner_output = frames.request_output
if isinstance(inner_output, list):
inner_output = inner_output[0] if inner_output else None
if isinstance(inner_output, OmniRequestOutput):
if inner_output.multimodal_output and "audio" in inner_output.multimodal_output:
audio = inner_output.multimodal_output["audio"]
frames = inner_output
if isinstance(frames, OmniRequestOutput):
if frames.images:
if len(frames.images) == 1 and isinstance(frames.images[0], tuple) and len(frames.images[0]) == 2:
frames, audio = frames.images[0]
elif len(frames.images) == 1 and isinstance(frames.images[0], dict):
audio = frames.images[0].get("audio")
frames = frames.images[0].get("frames") or frames.images[0].get("video")
else:
frames = frames.images
else:
raise ValueError("No video frames found in OmniRequestOutput.")

if isinstance(frames, list) and frames:
first_item = frames[0]
if isinstance(first_item, tuple) and len(first_item) == 2:
frames, audio = first_item
elif isinstance(first_item, dict):
audio = first_item.get("audio")
frames = first_item.get("frames") or first_item.get("video")
elif isinstance(first_item, list):
frames = first_item

if isinstance(frames, tuple) and len(frames) == 2:
frames, audio = frames
elif isinstance(frames, dict):
audio = frames.get("audio")
frames = frames.get("frames") or frames.get("video")

if frames is None:
raise ValueError("No video frames found in output.")
Comment on lines +231 to +279
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Complex and fragile output unpacking logic. Lines 227-275 contain deeply nested conditionals to extract frames and audio from various possible output formats. This is brittle and hard to maintain. Consider creating a dedicated helper function or class to standardize output format handling, possibly in a shared utility module. The same complex logic is also duplicated in image_to_video.py lines 303-351.

Copilot uses AI. Check for mistakes.

output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
try:
from diffusers.utils import export_to_video
except ImportError:
raise ImportError("diffusers is required for export_to_video.")

# frames may be np.ndarray (preferred) or torch.Tensor
def _normalize_frame(frame):
if isinstance(frame, torch.Tensor):
frame_tensor = frame.detach().cpu()
if frame_tensor.dim() == 4 and frame_tensor.shape[0] == 1:
frame_tensor = frame_tensor[0]
if frame_tensor.dim() == 3 and frame_tensor.shape[0] in (3, 4):
frame_tensor = frame_tensor.permute(1, 2, 0)
if frame_tensor.is_floating_point():
frame_tensor = frame_tensor.clamp(-1, 1) * 0.5 + 0.5
return frame_tensor.float().numpy()
if isinstance(frame, np.ndarray):
frame_array = frame
if frame_array.ndim == 4 and frame_array.shape[0] == 1:
frame_array = frame_array[0]
if np.issubdtype(frame_array.dtype, np.integer):
frame_array = frame_array.astype(np.float32) / 255.0
return frame_array
try:
from PIL import Image
except ImportError:
Image = None
if Image is not None and isinstance(frame, Image.Image):
return np.asarray(frame).astype(np.float32) / 255.0
return frame

def _ensure_frame_list(video_array):
if isinstance(video_array, list):
if len(video_array) == 0:
return video_array
first_item = video_array[0]
if isinstance(first_item, np.ndarray):
if first_item.ndim == 5:
return list(first_item[0])
if first_item.ndim == 4:
if len(video_array) == 1:
return list(first_item)
return list(first_item)
if first_item.ndim == 3:
return video_array
return video_array
if isinstance(video_array, np.ndarray):
if video_array.ndim == 5:
return list(video_array[0])
if video_array.ndim == 4:
return list(video_array)
if video_array.ndim == 3:
return [video_array]
return video_array

# frames may be np.ndarray, torch.Tensor, or list of tensors/arrays/images
# export_to_video expects a list of frames with values in [0, 1]
if isinstance(frames, torch.Tensor):
video_tensor = frames.detach().cpu()
if video_tensor.dim() == 5:
# [B, C, F, H, W] or [B, F, H, W, C]
if video_tensor.shape[1] in (3, 4):
video_tensor = video_tensor[0].permute(1, 2, 3, 0)
else:
video_tensor = video_tensor[0]
elif video_tensor.dim() == 4 and video_tensor.shape[0] in (3, 4):
video_tensor = video_tensor.permute(1, 2, 3, 0)
# If float, assume [-1,1] and normalize to [0,1]
if video_tensor.is_floating_point():
video_tensor = video_tensor.clamp(-1, 1) * 0.5 + 0.5
video_array = video_tensor.float().numpy()
else:
elif isinstance(frames, np.ndarray):
video_array = frames
if hasattr(video_array, "shape") and video_array.ndim == 5:
if video_array.ndim == 5:
video_array = video_array[0]
if np.issubdtype(video_array.dtype, np.integer):
video_array = video_array.astype(np.float32) / 255.0
elif isinstance(frames, list):
if len(frames) == 0:
raise ValueError("No video frames found in output.")
video_array = [_normalize_frame(frame) for frame in frames]
else:
video_array = frames

video_array = _ensure_frame_list(video_array)

# Convert 4D array (frames, H, W, C) to list of frames for export_to_video
if isinstance(video_array, np.ndarray) and video_array.ndim == 4:
video_array = list(video_array)
use_ltx2_export = False
if args.model and "ltx" in str(args.model).lower():
use_ltx2_export = True
if audio is not None:
use_ltx2_export = True

export_to_video(video_array, str(output_path), fps=args.fps)
if use_ltx2_export:
try:
from diffusers.pipelines.ltx2.export_utils import encode_video
except ImportError:
raise ImportError("diffusers is required for LTX2 encode_video.")

if isinstance(video_array, list):
frames_np = np.stack(video_array, axis=0)
elif isinstance(video_array, np.ndarray):
frames_np = video_array
else:
frames_np = np.asarray(video_array)

frames_u8 = (frames_np * 255).round().clip(0, 255).astype("uint8")
video_tensor = torch.from_numpy(frames_u8)

audio_out = None
if audio is not None:
if isinstance(audio, list):
audio = audio[0] if audio else None
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)
if isinstance(audio, torch.Tensor):
audio_out = audio
if audio_out.dim() > 1:
audio_out = audio_out[0]
audio_out = audio_out.float().cpu()

encode_video(
video_tensor,
fps=args.fps,
audio=audio_out,
audio_sample_rate=args.audio_sample_rate if audio_out is not None else None,
output_path=str(output_path),
)
else:
export_to_video(video_array, str(output_path), fps=args.fps)
print(f"Saved generated video to {output_path}")

if profiler_enabled:
Expand Down
Loading