diff --git a/TurboDiffusion_Studio.md b/TurboDiffusion_Studio.md
new file mode 100644
index 0000000..d0d7f4f
--- /dev/null
+++ b/TurboDiffusion_Studio.md
@@ -0,0 +1,123 @@
+
+## ๐ Scripts & Inference
+
+This repository contains optimized inference engines for the Wan2.1 and Wan2.2 models, specifically tuned for high-resolution output and robust memory management on consumer hardware.
+
+### ๐ฅ Inference Engines
+
+| Script | Function | Key Features |
+| --- | --- | --- |
+| **`wan2.2_i2v_infer.py`** | **Image-to-Video** | **Tiered Failover System**: Automatic recovery from OOM errors.
+
+
**Intelligent Model Switching**: Transitions between High and Low Noise models based on step boundaries.
+
+
**Tiled Processing**: Uses 4-chunk tiled encoding/decoding for 720p+ stability. |
+| **`wan2.1_t2v_infer.py`** | **Text-to-Video** | **Hardware Auto-Detection**: Automatically selects TF32, BF16, or FP16 based on GPU capabilities.
+
+
**Quantization Safety**: Force-disables `torch.compile` for quantized models to prevent graph-break OOMs.
+
+
**3-Tier Recovery**: Escalates from GPU โ Checkpointing โ Manual CPU Offloading if memory is exceeded. |
+
+### ๐ ๏ธ Utilities
+
+* **`cache_t5.py`**
+* **Purpose**: Pre-computes and saves T5 text embeddings to disk.
+* **VRAM Benefit**: Eliminates the need to load the **11GB T5 encoder** during the main inference run, allowing 14B models to fit on GPUs with lower VRAM.
+* **Usage**: Run this first to generate a `.pt` file, then pass it to the inference scripts using the `--cached_embedding` flag.
+
+
+---
+
+## ๐ Getting Started with TurboDiffusion
+
+To run the large 14B models on consumer GPUs, it is recommended to use the **T5 Caching** workflow. This offloads the 11GB text encoder from VRAM, leaving more space for the DiT model and high-resolution video decoding.
+
+### **Step 1: Environment Setup**
+
+Ensure your project structure is organized as follows:
+
+* **Root**: `/your/path/to/TurboDiffusion`
+* **Checkpoints**: Place your `.pth` models in the `checkpoints/` directory.
+* **Output**: Generated videos and metadata will be saved to `output/`.
+
+### **Step 2: The Two Ways to Cache T5**
+
+#### **Option A: Manual Pre-Caching (Recommended for Batching)**
+
+If you have a list of prompts you want to use frequently, use the standalone utility:
+
+```bash
+python turbodiffusion/inference/cache_t5.py --prompt "Your descriptive prompt here" --output cached_t5_embeddings.pt
+
+```
+
+This saves the processed text into a small `.pt` file, allowing the inference scripts to "skip" the heavy T5 model entirely.
+
+#### **Option B: Automatic Caching via Web UI**
+
+For a more streamlined experience, use the **TurboDiffusion Studio**:
+
+1. Launch the UI: `python turbo_diffusion_t5_cache_optimize_v6.py`.
+2. Open the **Precision & Advanced Settings** accordion.
+3. Check **Use Cached T5 Embeddings (Auto-Run)**.
+4. When you click generate, the UI will automatically run the caching script first, clear the T5 model from memory, and then start the video generation.
+
+### **Step 3: Running Inference**
+
+Once your UI is launched and caching is configured:
+
+1. **Select Mode**: Choose between **Text-to-Video** (Wan2.1) or **Image-to-Video** (Wan2.2).
+2. **Apply Quantization**: For 24GB VRAM GPUs (like the RTX 3090/4090/5090), ensure **Enable --quant_linear (8-bit)** is checked to avoid OOM errors.
+3. **Monitor Hardware**: Watch the **Live GPU Monitor** at the top of the UI to track real-time VRAM usage during the sampling process.
+4. **Retrieve Results**: Your video and its reproduction metadata (containing the exact CLI command used) will appear in the `output/` gallery.
+
+
+---
+
+## ๐ฅ๏ธ TurboDiffusion Studio (Web UI)
+
+The `turbo_diffusion_t5_cache_optimize_v6.py` script provides a high-performance, unified **Gradio-based Web interface** for both Text-to-Video and Image-to-Video generation. It serves as a centralized "Studio" dashboard that automates complex environment setups and memory optimizations.
+
+### **Key Features**
+
+| Feature | Description |
+| --- | --- |
+| **Unified Interface** | Toggle between **Text-to-Video (Wan2.1)** and **Image-to-Video (Wan2.2)** workflows within a single dashboard. |
+| **Real-time GPU Monitor** | Native PyTorch-based VRAM monitoring that displays current memory usage and hardware status directly in the UI. |
+| **Auto-Cache T5 Integration** | Automatically runs the `cache_t5.py` utility before inference to offload the 11GB text encoder, significantly reducing peak VRAM usage. |
+| **Frame Sanitization** | Automatically enforces the **4n + 1 rule** required by the Wan VAE to prevent kernel crashes during decoding. |
+| **Reproduction Metadata** | Every generated video automatically saves a matching `_metadata.txt` file containing the exact CLI command and environment variables needed to reproduce the result. |
+| **Live Console Output** | Pipes real-time CLI logs and progress bars directly into a "Live Console" window in the web browser. |
+
+### **Advanced Controls**
+
+The UI exposes granular controls for technical users:
+
+* **Precision & Quantization:** Toggle 8-bit `--quant_linear` mode for low-VRAM operation.
+* **Attention Tuning:** Switch between `sagesla`, `sla`, and `original` attention mechanisms.
+* **Adaptive I2V:** Enable adaptive resolution and ODE solvers for Image-to-Video workflows.
+* **Integrated Gallery:** Browse and view your output history directly within the `output/` directory.
+
+---
+
+## ๐ ๏ธ Usage
+
+To launch the studio:
+
+```bash
+python turbo_diffusion_t5_cache_optimize_v6.py
+
+```
+
+> **Note:** The script defaults to `/your/path/to/TurboDiffusion`as the project root. Ensure your local paths are configured accordingly in the **System Setup** section of the code.
+
+
+---
+
+## ๐ณ Credits & Acknowledgments
+
+If you utilize, share, or build upon these optimized scripts, please include the following acknowledgments:
+
+* **Optimization & Development**: Co-developed by **Waverly Edwards** and **Google Gemini**.
+* **T5 Caching Logic**: Original concept and utility implementation by **John D. Pope**.
+* **Base Framework**: Built upon the NVIDIA Imaginaire and Wan-Video research.
diff --git a/turbodiffusion/inference/cache_t5.py b/turbodiffusion/inference/cache_t5.py
new file mode 100644
index 0000000..a25e3cf
--- /dev/null
+++ b/turbodiffusion/inference/cache_t5.py
@@ -0,0 +1,117 @@
+#!/usr/bin/env python
+# -----------------------------------------------------------------------------------------
+# T5 EMBEDDING CACHE UTILITY
+#
+# Acknowledgments:
+# - Work and creativity of: John D. Pope
+#
+# Description:
+# Pre-computes text embeddings to allow running inference on GPUs with limited VRAM
+# by removing the need to keep the 11GB T5 encoder loaded in memory.
+# -----------------------------------------------------------------------------------------
+"""
+Pre-cache T5 text embeddings to avoid loading the 11GB model during inference.
+
+Usage:
+ # Cache a single prompt
+ python scripts/cache_t5.py --prompt "slow head turn, cinematic" --output cached_embeddings.pt
+
+ # Cache multiple prompts from file
+ python scripts/cache_t5.py --prompts_file prompts.txt --output cached_embeddings.pt
+
+Then use with inference:
+ python turbodiffusion/inference/wan2.2_i2v_infer.py \
+ --cached_embedding cached_embeddings.pt \
+ --skip_t5 \
+ ...
+"""
+import os
+import sys
+import argparse
+import torch
+
+# Add repo root to path for imports
+SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
+REPO_ROOT = os.path.dirname(SCRIPT_DIR)
+sys.path.insert(0, REPO_ROOT)
+
+def main():
+ parser = argparse.ArgumentParser(description="Pre-cache T5 text embeddings")
+ parser.add_argument("--prompt", type=str, default=None, help="Single prompt to cache")
+ parser.add_argument("--prompts_file", type=str, default=None, help="File with prompts (one per line)")
+ parser.add_argument("--text_encoder_path", type=str,
+ default="/media/2TB/ComfyUI/models/text_encoders/models_t5_umt5-xxl-enc-bf16.pth",
+ help="Path to the umT5 text encoder")
+ parser.add_argument("--output", type=str, default="cached_t5_embeddings.pt",
+ help="Output path for cached embeddings")
+ parser.add_argument("--device", type=str, default="cuda",
+ help="Device to use for encoding (cuda is faster, memory freed after)")
+ args = parser.parse_args()
+
+ # Collect prompts
+ prompts = []
+ if args.prompt:
+ prompts.append(args.prompt)
+ if args.prompts_file and os.path.exists(args.prompts_file):
+ with open(args.prompts_file, 'r') as f:
+ prompts.extend([line.strip() for line in f if line.strip()])
+
+ if not prompts:
+ print("Error: Provide --prompt or --prompts_file")
+ sys.exit(1)
+
+ print(f"Caching embeddings for {len(prompts)} prompt(s)")
+ print(f"Text encoder: {args.text_encoder_path}")
+ print(f"Device: {args.device}")
+ print()
+
+ # Import after path setup
+ from rcm.utils.umt5 import get_umt5_embedding, clear_umt5_memory
+
+ cache_data = {
+ 'prompts': prompts,
+ 'embeddings': [],
+ 'text_encoder_path': args.text_encoder_path,
+ }
+
+ with torch.no_grad():
+ for i, prompt in enumerate(prompts):
+ print(f"[{i+1}/{len(prompts)}] Encoding: '{prompt[:60]}...' " if len(prompt) > 60 else f"[{i+1}/{len(prompts)}] Encoding: '{prompt}'")
+
+ # Get embedding (loads T5 if not already loaded)
+ embedding = get_umt5_embedding(
+ checkpoint_path=args.text_encoder_path,
+ prompts=prompt
+ )
+
+ # Move to CPU for storage
+ cache_data['embeddings'].append({
+ 'prompt': prompt,
+ 'embedding': embedding.cpu(),
+ 'shape': list(embedding.shape),
+ })
+
+ print(f" Shape: {embedding.shape}, dtype: {embedding.dtype}")
+
+ # Clear T5 from memory
+ print("\nClearing T5 from memory...")
+ clear_umt5_memory()
+ torch.cuda.empty_cache()
+
+ # Save cache
+ print(f"\nSaving to: {args.output}")
+ torch.save(cache_data, args.output)
+
+ # Summary
+ file_size = os.path.getsize(args.output) / (1024 * 1024)
+ print(f"Done! Cache file size: {file_size:.2f} MB")
+ print()
+ print("Usage:")
+ print(f" python turbodiffusion/inference/wan2.2_i2v_infer.py \\")
+ print(f" --cached_embedding {args.output} \\")
+ print(f" --skip_t5 \\")
+ print(f" ... (other args)")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/turbodiffusion/inference/turbo_diffusion_t5_cache_optimize_v6.py b/turbodiffusion/inference/turbo_diffusion_t5_cache_optimize_v6.py
new file mode 100644
index 0000000..0525f4b
--- /dev/null
+++ b/turbodiffusion/inference/turbo_diffusion_t5_cache_optimize_v6.py
@@ -0,0 +1,343 @@
+import os
+import sys
+import subprocess
+import gradio as gr
+import glob
+import random
+import time
+import select
+import torch
+from datetime import datetime
+
+# --- 1. System Setup ---
+PROJECT_ROOT = "/home/wedwards/Documents/Programs/TurboDiffusion"
+os.chdir(PROJECT_ROOT)
+os.system('clear' if os.name == 'posix' else 'cls')
+
+CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, "checkpoints")
+OUTPUT_DIR = os.path.join(PROJECT_ROOT, "output")
+os.makedirs(OUTPUT_DIR, exist_ok=True)
+
+T2V_SCRIPT = "turbodiffusion/inference/wan2.1_t2v_infer.py"
+I2V_SCRIPT = "turbodiffusion/inference/wan2.2_i2v_infer.py"
+CACHE_SCRIPT = "turbodiffusion/inference/cache_t5.py"
+
+def get_gpu_status_original():
+ """System-level GPU check."""
+ try:
+ res = subprocess.check_output(
+ ["nvidia-smi", "--query-gpu=name,memory.used,memory.total", "--format=csv,nounits,noheader"],
+ encoding='utf-8'
+ ).strip().split(',')
+ return f"๐ฅ๏ธ {res[0]} | โก VRAM: {res[1]}MB / {res[2]}MB"
+ except:
+ return "๐ฅ๏ธ GPU Monitor Active"
+
+
+def get_gpu_status():
+ """
+ Check GPU status using PyTorch.
+ Returns system-wide VRAM usage without relying on nvidia-smi CLI.
+ """
+ try:
+ # 1. Check for CUDA (NVIDIA) or ROCm (AMD)
+ if torch.cuda.is_available():
+ # mem_get_info returns (free_bytes, total_bytes)
+ free_mem, total_mem = torch.cuda.mem_get_info()
+
+ used_mem = total_mem - free_mem
+
+ # Convert to MB for display
+ total_mb = int(total_mem / 1024**2)
+ used_mb = int(used_mem / 1024**2)
+ name = torch.cuda.get_device_name(0)
+
+ return f"๐ฅ๏ธ {name} | โก VRAM: {used_mb}MB / {total_mb}MB"
+
+ # 2. Check for Apple Silicon (MPS)
+ # Note: Apple uses Unified Memory, so 'VRAM' is shared with System RAM.
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
+ return "๐ฅ๏ธ Apple Silicon (MPS) | โก Unified Memory Active"
+
+ # 3. Fallback to CPU
+ else:
+ return "๐ฅ๏ธ Running on CPU"
+
+ except ImportError:
+ return "๐ฅ๏ธ GPU Monitor: PyTorch not installed"
+ except Exception as e:
+ return f"๐ฅ๏ธ GPU Monitor Error: {str(e)}"
+
+def save_debug_metadata(video_path, script_rel, cmd_list, cache_cmd_list=None):
+ """
+ Saves a fully executable reproduction script with env vars.
+ """
+ meta_path = video_path.replace(".mp4", "_metadata.txt")
+ with open(meta_path, "w") as f:
+ f.write(f"# Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
+ f.write("# Copy and paste the lines below to reproduce this video exactly:\n\n")
+
+ # Environment Variables
+ f.write("export PYTHONPATH=turbodiffusion\n")
+ f.write("export PYTORCH_ALLOC_CONF=expandable_segments:True\n")
+ f.write("export TOKENIZERS_PARALLELISM=false\n\n")
+
+ # Optional Cache Step
+ if cache_cmd_list:
+ f.write("# --- Step 1: Pre-Cache Embeddings ---\n")
+ f.write(f"python {CACHE_SCRIPT} \\\n")
+ c_args = cache_cmd_list[2:]
+ for i, arg in enumerate(c_args):
+ if arg.startswith("--"):
+ val = f'"{c_args[i+1]}"' if i+1 < len(c_args) and not c_args[i+1].startswith("--") else ""
+ f.write(f" {arg} {val} \\\n")
+ f.write("\n# --- Step 2: Run Inference ---\n")
+
+ # Main Inference Command
+ f.write(f"python {script_rel} \\\n")
+ args_only = cmd_list[2:]
+ for i, arg in enumerate(args_only):
+ if arg.startswith("--"):
+ val = f'"{args_only[i+1]}"' if i+1 < len(args_only) and not args_only[i+1].startswith("--") else ""
+ f.write(f" {arg} {val} \\\n")
+
+def sync_path(scale):
+ fname = "TurboWan2.1-T2V-1.3B-480P-quant.pth" if "1.3B" in scale else "TurboWan2.1-T2V-14B-720P-quant.pth"
+ return os.path.join(CHECKPOINT_DIR, fname)
+
+# --- 2. Unified Generation Logic (With Safety Checks) ---
+
+def run_gen(mode, prompt, model, dit_path, i2v_high, i2v_low, image, res, ratio, steps, seed, quant, attn, top_k, frames, sigma, norm, adapt, ode, use_cache, cache_path, pr=gr.Progress()):
+ # --- PRE-FLIGHT SAFETY CHECK ---
+ error_msg = ""
+ if mode == "T2V":
+ if "quant" in dit_path.lower() and not quant:
+ error_msg = "โ CONFIG ERROR: Quantized model selected but '8-bit' disabled."
+ if attn == "original" and ("turbo" in dit_path.lower() or "quant" in dit_path.lower()):
+ error_msg = "โ COMPATIBILITY ERROR: 'Original' attention with Turbo/Quantized checkpoint."
+ else:
+ if ("quant" in i2v_high.lower() or "quant" in i2v_low.lower()) and not quant:
+ error_msg = "โ CONFIG ERROR: Quantized I2V model selected but '8-bit' disabled."
+ if attn == "original" and (("turbo" in i2v_high.lower() or "quant" in i2v_high.lower()) or ("turbo" in i2v_low.lower() or "quant" in i2v_low.lower())):
+ error_msg = "โ COMPATIBILITY ERROR: 'Original' attention with Turbo/Quantized checkpoints."
+
+ if error_msg:
+ yield None, None, "โ Config Error", "๐ Aborted", error_msg
+ return
+ # -------------------------------
+
+ actual_seed = random.randint(1, 1000000) if seed <= 0 else int(seed)
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ start_time = time.time()
+
+ full_log = f"๐ Starting Job: {timestamp}\n"
+ pr(0, desc="๐ Starting...")
+
+ # --- FRAME SANITIZATION (4n+1 RULE) ---
+ # Wan2.1 VAE requires frames to be (4n + 1). If not, we sanitize.
+ target_frames = int(frames)
+ valid_frames = ((target_frames - 1) // 4) * 4 + 1
+
+ # If the user input (e.g., 32) became smaller (29) or changed, we log it.
+ # Note: We enforce a minimum of 1 frame just in case.
+ valid_frames = max(1, valid_frames)
+
+ if valid_frames != target_frames:
+ warning_msg = f"โ ๏ธ AUTO-ADJUST: Frame count {target_frames} is incompatible with VAE (requires 4n+1).\n"
+ warning_msg += f" Adjusted {target_frames} -> {valid_frames} frames to prevent kernel crash.\n"
+ full_log += warning_msg
+ print(warning_msg) # Print to console as well
+
+ # Use valid_frames for the rest of the logic
+ frames = valid_frames
+ # --------------------------------------
+
+ # --- AUTO-CACHE STEP ---
+ cache_cmd_list = None
+ if use_cache:
+ pr(0, desc="๐พ Auto-Caching T5 Embeddings...")
+ cache_script_full = os.path.join(PROJECT_ROOT, CACHE_SCRIPT)
+ encoder_path = os.path.join(CHECKPOINT_DIR, "models_t5_umt5-xxl-enc-bf16.pth")
+
+ cache_cmd = [
+ sys.executable,
+ cache_script_full,
+ "--prompt", prompt,
+ "--output", cache_path,
+ "--text_encoder_path", encoder_path
+ ]
+ cache_cmd_list = cache_cmd
+
+ full_log += f"\n[System] Running Cache Script: {' '.join(cache_cmd)}\n"
+ yield None, None, f"Seed: {actual_seed}", "๐พ Caching...", full_log
+
+ cache_process = subprocess.Popen(cache_cmd, cwd=PROJECT_ROOT, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
+
+ while True:
+ if cache_process.poll() is not None:
+ rest = cache_process.stdout.read()
+ if rest: full_log += rest
+ break
+ line = cache_process.stdout.readline()
+ if line:
+ full_log += line
+ yield None, None, f"Seed: {actual_seed}", "๐พ Caching...", full_log
+ time.sleep(0.02)
+
+ if cache_process.returncode != 0:
+ full_log += "\nโ CACHE FAILED. Aborting generation."
+ yield None, None, "โ Cache Failed", "๐ Aborted", full_log
+ return
+
+ full_log += "\nโ
Cache Complete. Starting Inference...\n"
+ # -----------------------------------------
+
+ # --- SETUP VIDEO GENERATION ---
+ if mode == "T2V":
+ save_path = os.path.join(OUTPUT_DIR, f"t2v_{timestamp}.mp4")
+ script_rel = T2V_SCRIPT
+ cmd = [sys.executable, os.path.join(PROJECT_ROOT, T2V_SCRIPT), "--model", model, "--dit_path", dit_path, "--prompt", prompt, "--resolution", res, "--aspect_ratio", ratio, "--num_steps", str(steps), "--seed", str(actual_seed), "--attention_type", attn, "--sla_topk", str(top_k), "--num_samples", "1", "--num_frames", str(frames), "--sigma_max", str(sigma)]
+ else:
+ save_path = os.path.join(OUTPUT_DIR, f"i2v_{timestamp}.mp4")
+ script_rel = I2V_SCRIPT
+ # Note: Added frames to I2V command in previous step, maintained here.
+ cmd = [sys.executable, os.path.join(PROJECT_ROOT, I2V_SCRIPT), "--prompt", prompt, "--image_path", image, "--high_noise_model_path", i2v_high, "--low_noise_model_path", i2v_low, "--resolution", res, "--aspect_ratio", ratio, "--num_steps", str(steps), "--seed", str(actual_seed), "--attention_type", attn, "--sla_topk", str(top_k), "--num_frames", str(frames)]
+ if adapt: cmd.append("--adaptive_resolution")
+ if ode: cmd.append("--ode")
+
+ if quant: cmd.append("--quant_linear")
+ if norm: cmd.append("--default_norm")
+
+ if use_cache:
+ cmd.extend(["--cached_embedding", cache_path, "--skip_t5"])
+
+ cmd.extend(["--save_path", save_path])
+
+ # Call the restored metadata saver
+ save_debug_metadata(save_path, script_rel, cmd, cache_cmd_list)
+
+ env = os.environ.copy()
+ env["PYTHONPATH"] = os.path.join(PROJECT_ROOT, "turbodiffusion")
+ env["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
+ env["TOKENIZERS_PARALLELISM"] = "false"
+ env["PYTHONUNBUFFERED"] = "1"
+
+ full_log += f"\n[System] Running Inference: {' '.join(cmd)}\n"
+ process = subprocess.Popen(cmd, cwd=PROJECT_ROOT, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
+
+ last_ui_update = 0
+
+ while True:
+ if process.poll() is not None:
+ rest = process.stdout.read()
+ if rest: full_log += rest
+ break
+
+ reads = [process.stdout.fileno()]
+ ret = select.select(reads, [], [], 0.1)
+
+ if ret[0]:
+ line = process.stdout.readline()
+ full_log += line
+
+ if "Loading DiT" in line: pr(0.1, desc="โก Loading weights...")
+ if "Encoding" in line: pr(0.05, desc="๐ผ๏ธ VAE Encoding...")
+ if "Switching to CPU" in line: pr(0.1, desc="โ ๏ธ CPU Fallback...")
+ if "Sampling:" in line:
+ try:
+ pct = int(line.split('%')[0].split('|')[-1].strip())
+ pr(0.2 + (pct/100 * 0.7), desc=f"๐ฌ Sampling: {pct}%")
+ except: pass
+ if "decoding" in line.lower(): pr(0.95, desc="๐ฅ Decoding VAE...")
+
+ current_time = time.time()
+ if current_time - last_ui_update > 0.25:
+ last_ui_update = current_time
+ elapsed = f"{int(current_time - start_time)}s"
+ yield None, None, f"Seed: {actual_seed}", f"โฑ๏ธ Time: {elapsed}", full_log
+
+ history = sorted(glob.glob(os.path.join(OUTPUT_DIR, "*.mp4")), key=os.path.getmtime, reverse=True)
+ total_time = f"{int(time.time() - start_time)}s"
+
+ yield save_path, history, f"โ
Done | Seed: {actual_seed}", f"๐ Finished in {total_time}", full_log
+
+# --- 3. UI Layout ---
+with gr.Blocks(title="TurboDiffusion Studio") as demo:
+ with gr.Row():
+ gr.HTML("
โก TurboDiffusion Studio
")
+ with gr.Column(scale=1):
+ gpu_display = gr.Markdown(get_gpu_status())
+
+ gr.Timer(2).tick(get_gpu_status, outputs=gpu_display)
+
+ with gr.Tabs():
+ with gr.Tab("Text-to-Video"):
+ with gr.Row():
+ with gr.Column(scale=4):
+ t2v_p = gr.Textbox(label="Prompt", lines=3, value="A stylish woman walks down a Tokyo street...")
+ with gr.Row():
+ t2v_m = gr.Radio(["Wan2.1-1.3B", "Wan2.1-14B"], label="Model", value="Wan2.1-1.3B")
+ t2v_res = gr.Dropdown(["480p", "720p"], label="Resolution", value="480p")
+ t2v_ratio = gr.Dropdown(["16:9", "4:3", "1:1", "9:16"], label="Aspect Ratio", value="16:9")
+ t2v_dit = gr.Textbox(label="DiT Path", value=sync_path("Wan2.1-1.3B"), interactive=False)
+ t2v_btn = gr.Button("Generate Video", variant="primary")
+ with gr.Column(scale=3):
+ t2v_out = gr.Video(label="Result", height=320)
+ with gr.Row():
+ t2v_stat = gr.Textbox(label="Status", interactive=False, scale=2)
+ t2v_time = gr.Textbox(label="Timer", value="โฑ๏ธ Ready", interactive=False, scale=1)
+
+ with gr.Tab("Image-to-Video"):
+ with gr.Row():
+ with gr.Column(scale=4):
+ with gr.Row():
+ i2v_img = gr.Image(label="Source", type="filepath", height=200)
+ i2v_p = gr.Textbox(label="Motion Prompt", lines=7)
+ with gr.Row():
+ i2v_res = gr.Dropdown(["480p", "720p"], label="Resolution", value="720p")
+ i2v_ratio = gr.Dropdown(["16:9", "4:3", "1:1", "9:16"], label="Aspect Ratio", value="16:9")
+ with gr.Row():
+ i2v_adapt = gr.Checkbox(label="Adaptive Resolution", value=True)
+ i2v_ode = gr.Checkbox(label="Use ODE", value=False)
+ with gr.Accordion("I2V Path Overrides", open=False):
+ i2v_high = gr.Textbox(label="High-Noise", value=os.path.join(CHECKPOINT_DIR, "TurboWan2.2-I2V-A14B-high-720P-quant.pth"))
+ i2v_low = gr.Textbox(label="Low-Noise", value=os.path.join(CHECKPOINT_DIR, "TurboWan2.2-I2V-A14B-low-720P-quant.pth"))
+ i2v_btn = gr.Button("Animate Image", variant="primary")
+ with gr.Column(scale=3):
+ i2v_out = gr.Video(label="Result", height=320)
+ with gr.Row():
+ i2v_stat_2 = gr.Textbox(label="Status", interactive=False, scale=2)
+ i2v_time_2 = gr.Textbox(label="Timer", value="โฑ๏ธ Ready", interactive=False, scale=1)
+
+ console_out = gr.Textbox(label="Live CLI Console Output", lines=8, max_lines=8, interactive=False)
+
+ with gr.Accordion("โ๏ธ Precision & Advanced Settings", open=False):
+ with gr.Row():
+ quant_opt = gr.Checkbox(label="Enable --quant_linear (8-bit)", value=True)
+ steps_opt = gr.Slider(1, 4, value=4, step=1, label="Steps")
+ seed_opt = gr.Number(label="Seed (0=Random)", value=0, precision=0)
+ with gr.Row():
+ top_k_opt = gr.Slider(0.01, 0.5, value=0.15, step=0.01, label="SLA Top-K")
+ attn_opt = gr.Radio(["sagesla", "sla", "original"], label="Attention", value="sagesla")
+ sigma_opt = gr.Number(label="Sigma Max", value=80)
+ norm_opt = gr.Checkbox(label="Original Norms", value=False)
+ frames_opt = gr.Slider(1, 120, value=77, step=4, label="Frames (Steps of 4)")
+ with gr.Row(variant="panel"):
+ # --- T5 CACHE UI ---
+ use_cache_opt = gr.Checkbox(label="Use Cached T5 Embeddings (Auto-Run)", value=True)
+ cache_path_opt = gr.Textbox(label="Cache File Path", value="cached_t5_embeddings.pt", scale=2)
+ # -------------------
+
+ history_gal = gr.Gallery(value=sorted(glob.glob(os.path.join(OUTPUT_DIR, "*.mp4")), reverse=True), columns=6, height="auto")
+
+ # --- 4. Logic Bindings ---
+ t2v_m.change(fn=sync_path, inputs=t2v_m, outputs=t2v_dit)
+
+ t2v_args = [gr.State("T2V"), t2v_p, t2v_m, t2v_dit, gr.State(""), gr.State(""), gr.State(""), t2v_res, t2v_ratio, steps_opt, seed_opt, quant_opt, attn_opt, top_k_opt, frames_opt, sigma_opt, norm_opt, gr.State(False), gr.State(False), use_cache_opt, cache_path_opt]
+ t2v_btn.click(run_gen, t2v_args, [t2v_out, history_gal, t2v_stat, t2v_time, console_out], show_progress="hidden")
+
+ i2v_args = [i2v_img, i2v_p, gr.State("Wan2.2-A14B"), gr.State(""), i2v_high, i2v_low, i2v_img, i2v_res, i2v_ratio, steps_opt, seed_opt, quant_opt, attn_opt, top_k_opt, frames_opt, gr.State(200), norm_opt, i2v_adapt, i2v_ode, use_cache_opt, cache_path_opt]
+ i2v_btn.click(run_gen, i2v_args, [i2v_out, history_gal, i2v_stat_2, i2v_time_2, console_out], show_progress="hidden")
+
+if __name__ == "__main__":
+ demo.launch(theme=gr.themes.Default(), allowed_paths=[OUTPUT_DIR])
\ No newline at end of file
diff --git a/turbodiffusion/inference/wan2.1_t2v_infer.py b/turbodiffusion/inference/wan2.1_t2v_infer.py
index c5581e4..a45220d 100644
--- a/turbodiffusion/inference/wan2.1_t2v_infer.py
+++ b/turbodiffusion/inference/wan2.1_t2v_infer.py
@@ -1,24 +1,58 @@
+# Blackwell Bridge
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
+# -----------------------------------------------------------------------------------------
+# TURBODIFFUSION OPTIMIZED INFERENCE SCRIPT (T2V)
#
-# http://www.apache.org/licenses/LICENSE-2.0
+# Co-developed by: Waverly Edwards & Google Gemini (2025)
#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+# Modifications:
+# - Implemented "Tiered Failover System" for robust OOM protection (GPU -> Checkpoint -> CPU).
+# - Added Intelligent Hardware Detection (TF32/BF16/FP16 auto-switching).
+# - Integrated Tiled Decoding for high-resolution VAE processing.
+# - Added Support for Pre-cached Text Embeddings to skip T5 loading.
+# - Optimized compilation logic for Quantized models (preventing graph breaks).
+#
+# Acknowledgments:
+# - Made possible by the work (cache_t5.py) and creativity of: John D. Pope
+#
+# Description:
+# cache_t5.py pre-computes text embeddings to allow running inference on GPUs with limited VRAM
+# by removing the need to keep the 11GB T5 encoder loaded in memory.
+#
+# CREDIT REQUEST:
+# If you utilize, share, or build upon this specific optimized script, please
+# acknowledge Waverly Edwards and Google Gemini in your documentation or credits.
+# -----------------------------------------------------------------------------------------
import argparse
import math
+import os
+import gc
+import time
+import sys
+
+# --- 1. Memory Tuning ---
+os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
+import torch.nn.functional as F
from einops import rearrange, repeat
from tqdm import tqdm
+import numpy as np
+
+# --- 2. Hardware Optimization ---
+if torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+ # 'high' allows TF32 but maintains reasonable precision
+ torch.set_float32_matmul_precision('high')
+
+try:
+ import psutil
+except ImportError:
+ psutil = None
from imaginaire.utils.io import save_image_or_video
from imaginaire.utils import log
@@ -29,123 +63,303 @@
from modify_model import tensor_kwargs, create_model
+# Suppress graph break warnings for cleaner output
torch._dynamo.config.suppress_errors = True
-
+torch._dynamo.config.verbose = False
def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="TurboDiffusion inference script for Wan2.1 T2V")
- parser.add_argument("--dit_path", type=str, required=True, help="Custom path to the DiT model checkpoint for distilled models")
+ parser.add_argument("--dit_path", type=str, required=True, help="Custom path to the DiT model checkpoint")
parser.add_argument("--model", choices=["Wan2.1-1.3B", "Wan2.1-14B"], default="Wan2.1-1.3B", help="Model to use")
parser.add_argument("--num_samples", type=int, default=1, help="Number of samples to generate")
parser.add_argument("--num_steps", type=int, choices=[1, 2, 3, 4], default=4, help="1~4 for timestep-distilled inference")
parser.add_argument("--sigma_max", type=float, default=80, help="Initial sigma for rCM")
parser.add_argument("--vae_path", type=str, default="checkpoints/Wan2.1_VAE.pth", help="Path to the Wan2.1 VAE")
parser.add_argument("--text_encoder_path", type=str, default="checkpoints/models_t5_umt5-xxl-enc-bf16.pth", help="Path to the umT5 text encoder")
+ parser.add_argument("--cached_embedding", type=str, default=None, help="Path to cached text embeddings (pt file)")
+ parser.add_argument("--skip_t5", action="store_true", help="Skip T5 loading (implied if cached_embedding is used)")
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames to generate")
- parser.add_argument("--prompt", type=str, default=None, help="Text prompt for video generation (required unless --serve)")
+ parser.add_argument("--prompt", type=str, required=True, help="Text prompt for video generation")
parser.add_argument("--resolution", default="480p", type=str, help="Resolution of the generated output")
parser.add_argument("--aspect_ratio", default="16:9", type=str, help="Aspect ratio of the generated output (width:height)")
parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility")
- parser.add_argument("--save_path", type=str, default="output/generated_video.mp4", help="Path to save the generated video (include file extension)")
- parser.add_argument("--attention_type", choices=["sla", "sagesla", "original"], default="sagesla", help="Type of attention mechanism to use")
+ parser.add_argument("--save_path", type=str, default="output/generated_video.mp4", help="Path to save the generated video")
+ parser.add_argument("--attention_type", choices=["sla", "sagesla", "original"], default="sagesla", help="Type of attention mechanism")
parser.add_argument("--sla_topk", type=float, default=0.1, help="Top-k ratio for SLA/SageSLA attention")
parser.add_argument("--quant_linear", action="store_true", help="Whether to replace Linear layers with quantized versions")
- parser.add_argument("--default_norm", action="store_true", help="Whether to replace LayerNorm/RMSNorm layers with faster versions")
- parser.add_argument("--serve", action="store_true", help="Launch interactive TUI server mode (keeps model loaded)")
+ parser.add_argument("--default_norm", action="store_true", help="Whether to replace LayerNorm/RMSNorm with faster versions")
+ parser.add_argument("--offload_dit", action="store_true", help="Offload DiT to CPU when not in use to save VRAM")
+ parser.add_argument("--compile", action="store_true", help="Use torch.compile (Inductor) for faster inference")
return parser.parse_args()
+def check_hardware_compatibility():
+ if not torch.cuda.is_available(): return
+ gpu_name = torch.cuda.get_device_name(0)
+ log.info(f"Hardware: {gpu_name}")
+
+ current_dtype = tensor_kwargs.get("dtype")
+ if current_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
+ log.warning(f"โ ๏ธ Device does not support BFloat16. Switching to Float16.")
+ tensor_kwargs["dtype"] = torch.float16
-if __name__ == "__main__":
- args = parse_arguments()
+def print_memory_status(step_name=""):
+ if not torch.cuda.is_available(): return
+ torch.cuda.synchronize()
+ allocated = torch.cuda.memory_allocated() / (1024**3)
+ free, total = torch.cuda.mem_get_info()
+ free_gb = free / (1024**3)
+ print(f"๐ [MEM] {step_name}: InUse={allocated:.2f}GB, Free={free_gb:.2f}GB")
- # Handle serve mode
- if args.serve:
- # Set mode to t2v for the TUI server
- args.mode = "t2v"
- from serve.tui import main as serve_main
- serve_main(args)
- exit(0)
+def cleanup_memory(step_info=""):
+ gc.collect()
+ torch.cuda.empty_cache()
- # Validate prompt is provided for one-shot mode
- if args.prompt is None:
- log.error("--prompt is required (unless using --serve mode)")
- exit(1)
+def load_dit_model(args, force_offload=False):
+ orig_offload = args.offload_dit
+ if force_offload: args.offload_dit = True
+ log.info(f"Loading DiT (Offload={args.offload_dit})...")
+ model = create_model(dit_path=args.dit_path, args=args).cpu()
+ args.offload_dit = orig_offload
+ return model
- log.info(f"Computing embedding for prompt: {args.prompt}")
- with torch.no_grad():
- text_emb = get_umt5_embedding(checkpoint_path=args.text_encoder_path, prompts=args.prompt).to(**tensor_kwargs)
- clear_umt5_memory()
+def tiled_decode_gpu(tokenizer, latents, overlap=12):
+ print(f"\n๐งฑ Starting Tiled GPU Decode (Overlap={overlap})...")
+ B, C, T, H, W = latents.shape
+ scale = tokenizer.spatial_compression_factor
+ h_mid = H // 2
+ w_mid = W // 2
+
+ def decode_tile(tile_latents):
+ cleanup_memory()
+ with torch.no_grad(): return tokenizer.decode(tile_latents).cpu()
- log.info(f"Loading DiT model from {args.dit_path}")
- net = create_model(dit_path=args.dit_path, args=args).cpu()
- torch.cuda.empty_cache()
- log.success("Successfully loaded DiT model.")
+ # 1. Top Tiles
+ l_tl = latents[..., :h_mid+overlap, :w_mid+overlap]
+ l_tr = latents[..., :h_mid+overlap, w_mid-overlap:]
+ v_tl = decode_tile(l_tl)
+ v_tr = decode_tile(l_tr)
- tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path)
+ B_dec, C_dec, T_dec, H_tile, W_tile = v_tl.shape
+ mid_pix = w_mid * scale
+ overlap_pix = overlap * scale
+
+ row_top = torch.zeros(B_dec, 3, T_dec, H_tile, W*scale, dtype=v_tl.dtype, device='cpu')
+ end_left = max(0, mid_pix - overlap_pix)
+ start_right = mid_pix + overlap_pix
+
+ row_top[..., :end_left] = v_tl[..., :end_left]
+ row_top[..., start_right:] = v_tr[..., 2*overlap_pix:]
+
+ x_linspace = torch.linspace(-6, 6, 2*overlap_pix, device='cpu')
+ alpha = torch.sigmoid(x_linspace).view(1, 1, 1, 1, -1)
+ row_top[..., end_left:start_right] = v_tl[..., mid_pix-overlap_pix:] * (1 - alpha) + v_tr[..., :2*overlap_pix] * alpha
+ del v_tl, v_tr
- w, h = VIDEO_RES_SIZE_INFO[args.resolution][args.aspect_ratio]
+ # 2. Bottom Tiles
+ l_bl = latents[..., h_mid-overlap:, :w_mid+overlap]
+ l_br = latents[..., h_mid-overlap:, w_mid-overlap:]
+ v_bl = decode_tile(l_bl)
+ v_br = decode_tile(l_br)
+
+ row_bot = torch.zeros(B_dec, 3, T_dec, H_tile, W*scale, dtype=v_bl.dtype, device='cpu')
+ row_bot[..., :end_left] = v_bl[..., :end_left]
+ row_bot[..., start_right:] = v_br[..., 2*overlap_pix:]
+ row_bot[..., end_left:start_right] = v_bl[..., mid_pix-overlap_pix:] * (1 - alpha) + v_br[..., :2*overlap_pix] * alpha
+ del v_bl, v_br
- log.info(f"Generating with prompt: {args.prompt}")
- condition = {"crossattn_emb": repeat(text_emb.to(**tensor_kwargs), "b l d -> (k b) l d", k=args.num_samples)}
+ # 3. Blend Vertically
+ h_mid_pix = h_mid * scale
+ video = torch.zeros(B_dec, 3, T_dec, H*scale, W*scale, dtype=row_top.dtype, device='cpu')
+ end_top = max(0, h_mid_pix - overlap_pix)
+ start_bot = h_mid_pix + overlap_pix
+
+ video[..., :end_top, :] = row_top[..., :end_top, :]
+ video[..., start_bot:, :] = row_bot[..., 2*overlap_pix:, :]
+
+ alpha_v = torch.sigmoid(x_linspace).view(1, 1, 1, -1, 1)
+ video[..., end_top:start_bot, :] = row_top[..., h_mid_pix-overlap_pix:, :] * (1 - alpha_v) + row_bot[..., :2*overlap_pix, :] * alpha_v
+
+ return video.to(latents.device)
- to_show = []
+def force_cpu_float32(target_obj):
+ for module in target_obj.modules():
+ module.cpu().float()
- state_shape = [
- tokenizer.latent_ch,
- tokenizer.get_latent_num_frames(args.num_frames),
- h // tokenizer.spatial_compression_factor,
- w // tokenizer.spatial_compression_factor,
- ]
+def apply_manual_offload(model, device="cuda"):
+ log.info("Applying Tier 3 Offload...")
+ block_list_name = None
+ max_len = 0
+ for name, child in model.named_children():
+ if isinstance(child, torch.nn.ModuleList):
+ if len(child) > max_len:
+ max_len = len(child)
+ block_list_name = name
+
+ if not block_list_name:
+ log.warning("Could not identify Block List! Offloading entire model to CPU.")
+ model.to("cpu")
+ return
- generator = torch.Generator(device=tensor_kwargs["device"])
- generator.manual_seed(args.seed)
+ print(f" ๐ Identified Transformer Blocks: '{block_list_name}' ({max_len} layers)")
+ try: model.to(device)
+ except RuntimeError: model.to("cpu")
+
+ blocks = getattr(model, block_list_name)
+ blocks.to("cpu")
+
+ def pre_hook(module, args):
+ module.to(device)
+ return args
+ def post_hook(module, args, output):
+ module.to("cpu")
+ return output
+
+ for i, block in enumerate(blocks):
+ block.register_forward_pre_hook(pre_hook)
+ block.register_forward_hook(post_hook)
- init_noise = torch.randn(
- args.num_samples,
- *state_shape,
- dtype=torch.float32,
- device=tensor_kwargs["device"],
- generator=generator,
- )
+if __name__ == "__main__":
+ print_memory_status("Script Start")
+
+ # --- CREDIT PRINT ---
+ log.info("----------------------------------------------------------------")
+ log.info("๐ TurboDiffusion Optimized Inference")
+ log.info(" Co-developed by Waverly Edwards & Google Gemini")
+ log.info("----------------------------------------------------------------")
- # mid_t = [1.3, 1.0, 0.6][: args.num_steps - 1]
- # For better visual quality
- mid_t = [1.5, 1.4, 1.0][: args.num_steps - 1]
+ check_hardware_compatibility()
+ args = parse_arguments()
- t_steps = torch.tensor(
- [math.atan(args.sigma_max), *mid_t, 0],
- dtype=torch.float64,
- device=init_noise.device,
- )
+ if (args.num_frames - 1) % 4 != 0:
+ new_f = ((args.num_frames - 1) // 4 + 1) * 4 + 1
+ print(f"โ ๏ธ Adjusting --num_frames to {new_f}")
+ args.num_frames = new_f
- # Convert TrigFlow timesteps to RectifiedFlow
- t_steps = torch.sin(t_steps) / (torch.cos(t_steps) + torch.sin(t_steps))
+ if args.num_frames > 90 and not args.offload_dit:
+ args.offload_dit = True
+
+ # --- CRITICAL FIX: Strictly Disable Compile for Quantized Models ---
+ if args.compile and args.quant_linear:
+ log.warning("๐ซ Quantized Model Detected: FORCE DISABLING `torch.compile` to avoid OOM.")
+ log.warning(" (Custom quantized kernels are not compatible with CUDA Graphs)")
+ args.compile = False
+
+ # 1. Text Embeddings
+ if args.cached_embedding and os.path.exists(args.cached_embedding):
+ log.info(f"Loading cache: {args.cached_embedding}")
+ c = torch.load(args.cached_embedding, map_location='cpu')
+ text_emb = c['embeddings'][0]['embedding'].to(**tensor_kwargs) if isinstance(c, dict) else c.to(**tensor_kwargs)
+ else:
+ log.info(f"Computing embedding...")
+ with torch.no_grad():
+ text_emb = get_umt5_embedding(args.text_encoder_path, args.prompt).to(**tensor_kwargs)
+ clear_umt5_memory()
+ cleanup_memory()
+
+ # 2. VAE Shape Calc & UNLOAD
+ log.info("VAE Setup (Temp)...")
+ tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path)
+ w, h = VIDEO_RES_SIZE_INFO[args.resolution][args.aspect_ratio]
+ state_shape = [tokenizer.latent_ch, tokenizer.get_latent_num_frames(args.num_frames), h // tokenizer.spatial_compression_factor, w // tokenizer.spatial_compression_factor]
+ del tokenizer
+ cleanup_memory("VAE Unloaded")
- # Sampling steps
+ # 3. Load DiT
+ net = load_dit_model(args)
+
+ # 4. Noise & Schedule
+ gen = torch.Generator(device=tensor_kwargs["device"]).manual_seed(args.seed)
+ cond = {"crossattn_emb": repeat(text_emb.to(**tensor_kwargs), "b l d -> (k b) l d", k=args.num_samples)}
+ init_noise = torch.randn(args.num_samples, *state_shape, dtype=torch.float32, device=tensor_kwargs["device"], generator=gen)
+
+ mid_t = [1.5, 1.4, 1.0][: args.num_steps - 1]
+ t_steps = torch.tensor([math.atan(args.sigma_max), *mid_t, 0], dtype=torch.float64, device=init_noise.device)
+ t_steps = torch.sin(t_steps) / (torch.cos(t_steps) + torch.sin(t_steps))
+
x = init_noise.to(torch.float64) * t_steps[0]
ones = torch.ones(x.size(0), 1, device=x.device, dtype=x.dtype)
- total_steps = t_steps.shape[0] - 1
+
+ # 5. Fast Sampling Loop
+ log.info("๐ฅ STARTING SAMPLING (INFERENCE MODE) ๐ฅ")
+ torch.cuda.empty_cache()
net.cuda()
- for i, (t_cur, t_next) in enumerate(tqdm(list(zip(t_steps[:-1], t_steps[1:])), desc="Sampling", total=total_steps)):
- with torch.no_grad():
- v_pred = net(x_B_C_T_H_W=x.to(**tensor_kwargs), timesteps_B_T=(t_cur.float() * ones * 1000).to(**tensor_kwargs), **condition).to(
- torch.float64
- )
- x = (1 - t_next) * (x - t_cur * v_pred) + t_next * torch.randn(
- *x.shape,
- dtype=torch.float32,
- device=tensor_kwargs["device"],
- generator=generator,
- )
+ print_memory_status("Tier 1: GPU Ready")
+
+ # Compile? (Only if NOT disabled above)
+ if args.compile:
+ log.info("๐ Compiling model...")
+ try:
+ net = torch.compile(net, mode="reduce-overhead")
+ except Exception as e:
+ log.warning(f"Compile failed: {e}. Running eager.")
+
+ failover = 0
+
+ with torch.inference_mode():
+ for i, (t_cur, t_next) in enumerate(tqdm(zip(t_steps[:-1], t_steps[1:]), total=len(t_steps)-1)):
+ retry = True
+ while retry:
+ try:
+ t_cur_scalar = t_cur.item()
+ t_next_scalar = t_next.item()
+
+ v_pred = net(
+ x_B_C_T_H_W=x.to(**tensor_kwargs),
+ timesteps_B_T=(t_cur * ones * 1000).to(**tensor_kwargs),
+ **cond
+ ).to(torch.float64)
+
+ if args.offload_dit and i == len(t_steps)-2 and failover == 0:
+ net.cpu()
+
+ noise = torch.randn(*x.shape, dtype=torch.float32, device=x.device, generator=gen).to(torch.float64)
+ term1 = x - (v_pred * t_cur_scalar)
+ x = term1 * (1.0 - t_next_scalar) + (noise * t_next_scalar)
+
+ retry = False
+
+ except torch.OutOfMemoryError:
+ log.warning(f"โ ๏ธ OOM at Step {i}. Recovering...")
+ try: net.cpu()
+ except: pass
+ del net
+ cleanup_memory()
+ failover += 1
+
+ if failover == 1:
+ print("โป๏ธ Tier 2: Checkpointing")
+ net = load_dit_model(args, force_offload=True)
+ net.cuda()
+ # Retry compile with safer mode if first attempt was aggressive
+ if args.compile:
+ try: net = torch.compile(net, mode="default")
+ except: pass
+ elif failover == 2:
+ print("โป๏ธ Tier 3: Manual Offload")
+ net = load_dit_model(args, force_offload=True)
+ apply_manual_offload(net)
+ else:
+ sys.exit("โ Critical OOM.")
+
samples = x.float()
- net.cpu()
- torch.cuda.empty_cache()
+ # 6. Decode
+ if 'net' in locals():
+ try: net.cpu()
+ except: pass
+ del net
+ cleanup_memory("Pre-VAE")
+
+ log.info("Decoding...")
+ tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path)
with torch.no_grad():
- video = tokenizer.decode(samples)
-
- to_show.append(video.float().cpu())
+ try:
+ video = tokenizer.decode(samples)
+ except torch.OutOfMemoryError:
+ log.warning("Falling back to Tiled Decode...")
+ video = tiled_decode_gpu(tokenizer, samples)
+ to_show = [video.float().cpu()]
to_show = (1.0 + torch.stack(to_show, dim=0).clamp(-1, 1)) / 2.0
-
save_image_or_video(rearrange(to_show, "n b c t h w -> c t (n h) (b w)"), args.save_path, fps=16)
+ log.success(f"Saved: {args.save_path}")
diff --git a/turbodiffusion/inference/wan2.2_i2v_infer.py b/turbodiffusion/inference/wan2.2_i2v_infer.py
index e57e509..f6fe32d 100644
--- a/turbodiffusion/inference/wan2.2_i2v_infer.py
+++ b/turbodiffusion/inference/wan2.2_i2v_infer.py
@@ -1,28 +1,55 @@
+# Blackwell Bridge
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
+# -----------------------------------------------------------------------------------------
+# TURBODIFFUSION OPTIMIZED INFERENCE SCRIPT (I2V)
#
-# http://www.apache.org/licenses/LICENSE-2.0
+# Co-developed by: Waverly Edwards & Google Gemini (2025)
#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+# Modifications:
+# - Implemented "Tiered Failover System" for robust OOM protection.
+# - Added Intelligent Model Switching (High/Low Noise) with memory optimization.
+# - Integrated Tiled Encoding & Decoding for high-resolution processing.
+# - Added Support for Pre-cached Text Embeddings to skip T5 loading.
+# - Optimized memory management (VAE unload/reload, aggressive GC).
+#
+# Acknowledgments:
+# - Made possible by the work (cache_t5.py) and creativity of: John D. Pope
+#
+# Description:
+# cache_t5.py pre-computes text embeddings to allow running inference on GPUs with limited VRAM
+# by removing the need to keep the 11GB T5 encoder loaded in memory.
+#
+# CREDIT REQUEST:
+# If you utilize, share, or build upon this specific optimized script, please
+# acknowledge Waverly Edwards and Google Gemini in your documentation or credits.
+# -----------------------------------------------------------------------------------------
import argparse
import math
+import os
+import gc
+import time
+import sys
+
+# --- 1. Memory Tuning (Must be before torch imports) ---
+os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
+import torch.nn.functional as F
from einops import rearrange, repeat
from tqdm import tqdm
from PIL import Image
import torchvision.transforms.v2 as T
import numpy as np
+# Safe import for system memory checks
+try:
+ import psutil
+except ImportError:
+ psutil = None
+
from imaginaire.utils.io import save_image_or_video
from imaginaire.utils import log
@@ -34,189 +61,514 @@
torch._dynamo.config.suppress_errors = True
-
def parse_arguments() -> argparse.Namespace:
- parser = argparse.ArgumentParser(description="TurboDiffusion inference script for Wan2.2 I2V with High/Low Noise models")
- parser.add_argument("--image_path", type=str, default=None, help="Path to the input image (required unless --serve)")
- parser.add_argument("--high_noise_model_path", type=str, required=True, help="Path to the high-noise model")
- parser.add_argument("--low_noise_model_path", type=str, required=True, help="Path to the low-noise model")
- parser.add_argument("--boundary", type=float, default=0.9, help="Timestep boundary for switching from high to low noise model")
- parser.add_argument("--model", choices=["Wan2.2-A14B"], default="Wan2.2-A14B", help="Model to use")
- parser.add_argument("--num_samples", type=int, default=1, help="Number of samples to generate")
- parser.add_argument("--num_steps", type=int, choices=[1, 2, 3, 4], default=4, help="1~4 for timestep-distilled inference")
- parser.add_argument("--sigma_max", type=float, default=200, help="Initial sigma for rCM")
- parser.add_argument("--vae_path", type=str, default="checkpoints/Wan2.1_VAE.pth", help="Path to the Wan2.1 VAE")
- parser.add_argument("--text_encoder_path", type=str, default="checkpoints/models_t5_umt5-xxl-enc-bf16.pth", help="Path to the umT5 text encoder")
- parser.add_argument("--num_frames", type=int, default=81, help="Number of frames to generate")
- parser.add_argument("--prompt", type=str, default=None, help="Text prompt for video generation (required unless --serve)")
- parser.add_argument("--resolution", default="720p", type=str, help="Resolution of the generated output")
- parser.add_argument("--aspect_ratio", default="16:9", type=str, help="Aspect ratio of the generated output (width:height)")
- parser.add_argument("--adaptive_resolution", action="store_true", help="If set, adapts the output resolution to the input image's aspect ratio, using the area defined by --resolution and --aspect_ratio as a target.")
- parser.add_argument("--ode", action="store_true", help="Use ODE for sampling (sharper but less robust than SDE)")
- parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility")
- parser.add_argument("--save_path", type=str, default="output/generated_video.mp4", help="Path to save the generated video (include file extension)")
- parser.add_argument("--attention_type", choices=["sla", "sagesla", "original"], default="sagesla", help="Type of attention mechanism to use")
- parser.add_argument("--sla_topk", type=float, default=0.1, help="Top-k ratio for SLA/SageSLA attention")
- parser.add_argument("--quant_linear", action="store_true", help="Whether to replace Linear layers with quantized versions")
- parser.add_argument("--default_norm", action="store_true", help="Whether to replace LayerNorm/RMSNorm layers with faster versions")
- parser.add_argument("--serve", action="store_true", help="Launch interactive TUI server mode (keeps model loaded)")
+ parser = argparse.ArgumentParser(description="TurboDiffusion inference script for Wan2.2 I2V")
+ parser.add_argument("--image_path", type=str, default=None, help="Path to input image")
+ parser.add_argument("--high_noise_model_path", type=str, required=True, help="Path to high-noise model")
+ parser.add_argument("--low_noise_model_path", type=str, required=True, help="Path to low-noise model")
+ parser.add_argument("--boundary", type=float, default=0.9, help="Switch boundary")
+ parser.add_argument("--model", choices=["Wan2.2-A14B"], default="Wan2.2-A14B")
+ parser.add_argument("--num_samples", type=int, default=1)
+ parser.add_argument("--num_steps", type=int, default=4)
+ parser.add_argument("--sigma_max", type=float, default=200)
+ parser.add_argument("--vae_path", type=str, default="checkpoints/Wan2.1_VAE.pth")
+ parser.add_argument("--text_encoder_path", type=str, default="checkpoints/models_t5_umt5-xxl-enc-bf16.pth")
+ parser.add_argument("--cached_embedding", type=str, default=None)
+ parser.add_argument("--skip_t5", action="store_true")
+ parser.add_argument("--num_frames", type=int, default=81)
+ parser.add_argument("--prompt", type=str, default=None)
+ parser.add_argument("--resolution", default="720p", type=str)
+ parser.add_argument("--aspect_ratio", default="16:9", type=str)
+ parser.add_argument("--adaptive_resolution", action="store_true")
+ parser.add_argument("--ode", action="store_true")
+ parser.add_argument("--seed", type=int, default=0)
+ parser.add_argument("--save_path", type=str, default="output/generated_video.mp4")
+ parser.add_argument("--attention_type", choices=["sla", "sagesla", "original"], default="sagesla")
+ parser.add_argument("--sla_topk", type=float, default=0.1)
+ parser.add_argument("--quant_linear", action="store_true")
+ parser.add_argument("--default_norm", action="store_true")
+ parser.add_argument("--serve", action="store_true")
+ parser.add_argument("--offload_dit", action="store_true")
return parser.parse_args()
+def print_memory_status(step_name=""):
+ """
+ Prints a detailed breakdown of GPU memory usage.
+ """
+ if not torch.cuda.is_available():
+ return
+
+ torch.cuda.synchronize()
+ allocated_gb = torch.cuda.memory_allocated() / (1024**3)
+ reserved_gb = torch.cuda.memory_reserved() / (1024**3)
+ free_mem, total_mem = torch.cuda.mem_get_info()
+ free_gb = free_mem / (1024**3)
+
+ print(f"\n๐ [MEMORY REPORT] {step_name}")
+ print(f" โโโ ๐พ VRAM In Use: {allocated_gb:.2f} GB")
+ print(f" โโโ ๐ฆ VRAM Reserved: {reserved_gb:.2f} GB")
+ print(f" โโโ ๐ VRAM Free: {free_gb:.2f} GB")
+ print("-" * 60)
+
+def cleanup_memory(step_info=""):
+ """Aggressively clears VRAM."""
+ if step_info:
+ print(f"๐งน Cleaning memory: {step_info}...")
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+ if step_info:
+ print_memory_status(f"After Cleanup ({step_info})")
+
+def get_tensor_size_mb(tensor):
+ return tensor.element_size() * tensor.nelement() / (1024 * 1024)
+
+def force_cpu_float32(target_obj):
+ """
+ Recursively forces a model or wrapper object to CPU Float32.
+ """
+ def recursive_cast_to_cpu(obj):
+ if isinstance(obj, torch.Tensor):
+ return obj.cpu().float()
+ elif isinstance(obj, (list, tuple)):
+ return type(obj)(recursive_cast_to_cpu(x) for x in obj)
+ elif isinstance(obj, dict):
+ return {k: recursive_cast_to_cpu(v) for k, v in obj.items()}
+ return obj
+
+ targets = [target_obj]
+ if hasattr(target_obj, "model"):
+ targets.append(target_obj.model)
+
+ for obj in targets:
+ if isinstance(obj, torch.nn.Module):
+ try: obj.cpu().float()
+ except: pass
+
+ for attr_name in dir(obj):
+ if attr_name.startswith("__"): continue
+ try:
+ val = getattr(obj, attr_name)
+ if isinstance(val, torch.nn.Module):
+ val.cpu().float()
+ elif isinstance(val, (torch.Tensor, list, tuple)):
+ setattr(obj, attr_name, recursive_cast_to_cpu(val))
+ except Exception: pass
+
+ if isinstance(obj, torch.nn.Module):
+ try:
+ for module in obj.modules():
+ for param in module.parameters(recurse=False):
+ if param is not None:
+ param.data = param.data.cpu().float()
+ if param.grad is not None: param.grad = None
+ for buf in module.buffers(recurse=False):
+ if buf is not None:
+ buf.data = buf.data.cpu().float()
+ except: pass
+ else:
+ for attr_name in dir(obj):
+ if attr_name.startswith("__"): continue
+ try:
+ val = getattr(obj, attr_name)
+ if isinstance(val, torch.nn.Module):
+ for module in val.modules():
+ for param in module.parameters(recurse=False):
+ if param is not None:
+ param.data = param.data.cpu().float()
+ for buf in module.buffers(recurse=False):
+ if buf is not None:
+ buf.data = buf.data.cpu().float()
+ except: pass
+
+def tiled_encode_4x(tokenizer, frames, target_dtype):
+ B, C, T, H, W = frames.shape
+ h_mid = H // 2
+ w_mid = W // 2
+ print(f"\n๐งฉ Starting 4-Chunk Tiled Encoding (Input: {W}x{H})")
+ latents_list = [[None, None], [None, None]]
+
+ try:
+ print(" ๐ Encoding Chunk 1/4 (Top-Left)...")
+ with torch.amp.autocast("cuda", dtype=target_dtype):
+ l_tl = tokenizer.encode(frames[:, :, :, :h_mid, :w_mid])
+ latents_list[0][0] = l_tl.cpu()
+ del l_tl; cleanup_memory("After Chunk 1")
+ except Exception as e:
+ print(f"โ Chunk 1 Failed: {e}")
+ raise e
+
+ try:
+ print(" ๐ Encoding Chunk 2/4 (Top-Right)...")
+ with torch.amp.autocast("cuda", dtype=target_dtype):
+ l_tr = tokenizer.encode(frames[:, :, :, :h_mid, w_mid:])
+ latents_list[0][1] = l_tr.cpu()
+ del l_tr; cleanup_memory("After Chunk 2")
+ except Exception as e:
+ print(f"โ Chunk 2 Failed: {e}")
+ raise e
+
+ try:
+ print(" ๐ Encoding Chunk 3/4 (Bottom-Left)...")
+ with torch.amp.autocast("cuda", dtype=target_dtype):
+ l_bl = tokenizer.encode(frames[:, :, :, h_mid:, :w_mid])
+ latents_list[1][0] = l_bl.cpu()
+ del l_bl; cleanup_memory("After Chunk 3")
+ except Exception as e:
+ print(f"โ Chunk 3 Failed: {e}")
+ raise e
+
+ try:
+ print(" ๐ Encoding Chunk 4/4 (Bottom-Right)...")
+ with torch.amp.autocast("cuda", dtype=target_dtype):
+ l_br = tokenizer.encode(frames[:, :, :, h_mid:, w_mid:])
+ latents_list[1][1] = l_br.cpu()
+ del l_br; cleanup_memory("After Chunk 4")
+ except Exception as e:
+ print(f"โ Chunk 4 Failed: {e}")
+ raise e
+
+ print(" ๐งต Stitching Latents...")
+ row1 = torch.cat([latents_list[0][0], latents_list[0][1]], dim=4)
+ row2 = torch.cat([latents_list[1][0], latents_list[1][1]], dim=4)
+ full_latents = torch.cat([row1, row2], dim=3)
+ return full_latents.to(device=tensor_kwargs["device"], dtype=target_dtype)
+
+def safe_cpu_fallback_encode(tokenizer, frames, target_dtype):
+ log.warning("๐ Switching to CPU for VAE Encode (Slow but reliable)...")
+ cleanup_memory("Pre-CPU Encode")
+ frames_cpu = frames.cpu().to(dtype=torch.float32)
+ force_cpu_float32(tokenizer)
+ t0 = time.time()
+ with torch.autocast("cpu", enabled=False):
+ with torch.autocast("cuda", enabled=False):
+ latents = tokenizer.encode(frames_cpu)
+ print(f" โฑ๏ธ CPU Encode took: {time.time() - t0:.2f}s")
+ return latents.to(device=tensor_kwargs["device"], dtype=target_dtype)
+
+def tiled_decode_gpu(tokenizer, latents, overlap=12):
+ """
+ Decodes latents in 4 spatial quadrants with OVERLAP and SIGMOID BLENDING.
+ Overlap=12 latents (96 pixels). Safe for 720p.
+ Removing Global Color Matching to prevent exposure shifts.
+ """
+ print(f"\n๐งฑ Starting Tiled GPU Decode (4 Quadrants, Overlap={overlap}, Blended)...")
+ B, C, T, H, W = latents.shape
+ scale = tokenizer.spatial_compression_factor
+ h_mid = H // 2
+ w_mid = W // 2
+
+ def decode_tile(tile_latents, name):
+ cleanup_memory(f"Tile {name}")
+ with torch.no_grad():
+ return tokenizer.decode(tile_latents).cpu()
+
+ try:
+ # 1. Decode Top-Left and Top-Right
+ l_tl = latents[..., :h_mid+overlap, :w_mid+overlap]
+ l_tr = latents[..., :h_mid+overlap, w_mid-overlap:]
+ v_tl = decode_tile(l_tl, "1/4 (TL)")
+ v_tr = decode_tile(l_tr, "2/4 (TR)")
+ B_dec, C_dec, T_dec, H_tile, W_tile = v_tl.shape
+
+ print(f" ๐งต Blending Top Row (Decoded Frames: {T_dec})...")
+ mid_pix = w_mid * scale
+ overlap_pix = overlap * scale
+
+ # Slices for overlap
+ tl_blend_slice = v_tl[..., mid_pix-overlap_pix:]
+ tr_blend_slice = v_tr[..., :2*overlap_pix]
+
+ row_top = torch.zeros(B_dec, 3, T_dec, H_tile, W*scale, dtype=v_tl.dtype, device='cpu')
+
+ # Place non-overlapping parts (Clamped indices)
+ end_left = max(0, mid_pix - overlap_pix)
+ start_right = mid_pix + overlap_pix
+
+ row_top[..., :end_left] = v_tl[..., :end_left]
+ row_top[..., start_right:] = v_tr[..., 2*overlap_pix:]
+
+ x = torch.linspace(-6, 6, 2*overlap_pix, device='cpu')
+ alpha = torch.sigmoid(x).view(1, 1, 1, 1, -1)
+ blended_h = tl_blend_slice * (1 - alpha) + tr_blend_slice * alpha
+
+ row_top[..., end_left:start_right] = blended_h
+ del v_tl, v_tr, l_tl, l_tr
+
+ # 3. Decode Bottom-Left and Bottom-Right
+ l_bl = latents[..., h_mid-overlap:, :w_mid+overlap]
+ l_br = latents[..., h_mid-overlap:, w_mid-overlap:]
+ v_bl = decode_tile(l_bl, "3/4 (BL)")
+ v_br = decode_tile(l_br, "4/4 (BR)")
+
+ print(" ๐งต Blending Bottom Row...")
+ bl_blend_slice = v_bl[..., mid_pix-overlap_pix:]
+ br_blend_slice = v_br[..., :2*overlap_pix]
+
+ row_bot = torch.zeros(B_dec, 3, T_dec, H_tile, W*scale, dtype=v_bl.dtype, device='cpu')
+ row_bot[..., :end_left] = v_bl[..., :end_left]
+ row_bot[..., start_right:] = v_br[..., 2*overlap_pix:]
+ row_bot[..., end_left:start_right] = bl_blend_slice * (1 - alpha) + br_blend_slice * alpha
+ del v_bl, v_br, l_bl, l_br
+
+ # 5. Blend Top and Bottom Vertically
+ print(" ๐งต Blending Rows Vertically...")
+ h_mid_pix = h_mid * scale
+
+ # Slices
+ top_blend_slice = row_top[..., h_mid_pix-overlap_pix:, :]
+ bot_blend_slice = row_bot[..., :2*overlap_pix, :]
+
+ video = torch.zeros(B_dec, 3, T_dec, H*scale, W*scale, dtype=row_top.dtype, device='cpu')
+
+ end_top = max(0, h_mid_pix - overlap_pix)
+ start_bot = h_mid_pix + overlap_pix
+
+ video[..., :end_top, :] = row_top[..., :end_top, :]
+ video[..., start_bot:, :] = row_bot[..., 2*overlap_pix:, :]
+
+ alpha_v = torch.sigmoid(x).view(1, 1, 1, -1, 1)
+ blended_v = top_blend_slice * (1 - alpha_v) + bot_blend_slice * alpha_v
+
+ video[..., end_top:start_bot, :] = blended_v
+
+ except Exception as e:
+ print(f"โ Tiled GPU Decode Failed: {e}")
+ raise e
+ return video.to(latents.device)
+
+def load_dit_model(args, is_high_noise=True, force_offload=False):
+ """Helper to load the model, respecting overrides."""
+ original_offload = args.offload_dit
+ if force_offload:
+ args.offload_dit = True
+
+ path = args.high_noise_model_path if is_high_noise else args.low_noise_model_path
+ log.info(f"Loading {'High' if is_high_noise else 'Low'} Noise DiT (Offload={args.offload_dit})...")
+
+ try:
+ model = create_model(dit_path=path, args=args).cpu()
+ finally:
+ args.offload_dit = original_offload
+
+ return model
if __name__ == "__main__":
+ print_memory_status("Script Start")
args = parse_arguments()
- # Handle serve mode
if args.serve:
- # Set mode to i2v for the TUI server
args.mode = "i2v"
from serve.tui import main as serve_main
serve_main(args)
exit(0)
-
- # Validate required args for one-shot mode
- if args.prompt is None:
- log.error("--prompt is required (unless using --serve mode)")
- exit(1)
- if args.image_path is None:
- log.error("--image_path is required (unless using --serve mode)")
- exit(1)
-
- log.info(f"Computing embedding for prompt: {args.prompt}")
- with torch.no_grad():
+
+ # --- AUTO-ADJUST FRAME COUNT ---
+ if (args.num_frames - 1) % 4 != 0:
+ old_f = args.num_frames
+ new_f = ((old_f - 1) // 4 + 1) * 4 + 1
+ print(f"โ ๏ธ Adjusting --num_frames from {old_f} to {new_f} to satisfy VAE temporal stride (4n+1).")
+ args.num_frames = new_f
+
+ # --- AUTO-ENABLE OFFLOAD FOR HIGH FRAMES ---
+ if args.num_frames > 90 and not args.offload_dit:
+ print(f"โ ๏ธ High frame count ({args.num_frames}) detected. Enabling --offload_dit to prevent OOM.")
+ args.offload_dit = True
+
+ # 1. Text Embeddings
+ if args.cached_embedding and os.path.exists(args.cached_embedding):
+ log.info(f"Loading cached embedding from: {args.cached_embedding}")
+ cache_data = torch.load(args.cached_embedding, map_location='cpu')
+ text_emb = cache_data['embeddings'][0]['embedding'].to(**tensor_kwargs)
+ else:
+ log.info(f"Computing embedding...")
text_emb = get_umt5_embedding(checkpoint_path=args.text_encoder_path, prompts=args.prompt).to(**tensor_kwargs)
- clear_umt5_memory()
-
- log.info(f"Loading DiT models.")
- high_noise_model = create_model(dit_path=args.high_noise_model_path, args=args).cpu()
- torch.cuda.empty_cache()
- low_noise_model = create_model(dit_path=args.low_noise_model_path, args=args).cpu()
- torch.cuda.empty_cache()
- log.success(f"Successfully loaded DiT model.")
+ clear_umt5_memory()
+ # 2. VAE Encoding
+ print("-" * 20 + " VAE SETUP " + "-" * 20)
tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path)
-
- log.info(f"Loading and preprocessing image from: {args.image_path}")
+ target_dtype = tensor_kwargs.get("dtype", torch.bfloat16)
input_image = Image.open(args.image_path).convert("RGB")
+
if args.adaptive_resolution:
- log.info("Adaptive resolution mode enabled.")
base_w, base_h = VIDEO_RES_SIZE_INFO[args.resolution][args.aspect_ratio]
max_resolution_area = base_w * base_h
- log.info(f"Target area is based on {args.resolution} {args.aspect_ratio} (~{max_resolution_area} pixels).")
-
orig_w, orig_h = input_image.size
- image_aspect_ratio = orig_h / orig_w
-
- ideal_w = np.sqrt(max_resolution_area / image_aspect_ratio)
- ideal_h = np.sqrt(max_resolution_area * image_aspect_ratio)
-
+ aspect = orig_h / orig_w
+ ideal_w = np.sqrt(max_resolution_area / aspect)
+ ideal_h = np.sqrt(max_resolution_area * aspect)
stride = tokenizer.spatial_compression_factor * 2
- lat_h = round(ideal_h / stride)
- lat_w = round(ideal_w / stride)
- h = lat_h * stride
- w = lat_w * stride
-
- log.info(f"Input image aspect ratio: {image_aspect_ratio:.4f}. Adaptive resolution set to: {w}x{h}")
+ h = round(ideal_h / stride) * stride
+ w = round(ideal_w / stride) * stride
+ log.info(f"Adaptive Res: {w}x{h}")
else:
- log.info("Fixed resolution mode.")
w, h = VIDEO_RES_SIZE_INFO[args.resolution][args.aspect_ratio]
- log.info(f"Resolution set to: {w}x{h}")
+
F = args.num_frames
lat_h = h // tokenizer.spatial_compression_factor
lat_w = w // tokenizer.spatial_compression_factor
lat_t = tokenizer.get_latent_num_frames(F)
- log.info(f"Preprocessing image to {w}x{h}...")
- image_transforms = T.Compose(
- [
- T.ToImage(),
- T.Resize(size=(h, w), antialias=True),
- T.ToDtype(torch.float32, scale=True),
- T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
- ]
- )
- image_tensor = image_transforms(input_image).unsqueeze(0).to(device=tensor_kwargs["device"], dtype=torch.float32)
-
- with torch.no_grad():
- frames_to_encode = torch.cat(
- [image_tensor.unsqueeze(2), torch.zeros(1, 3, F - 1, h, w, device=image_tensor.device)], dim=2
- ) # -> B, C, T, H, W
- encoded_latents = tokenizer.encode(frames_to_encode) # -> B, C_lat, T_lat, H_lat, W_lat
-
- del frames_to_encode
- torch.cuda.empty_cache()
-
+ image_transforms = T.Compose([
+ T.ToImage(),
+ T.Resize(size=(h, w), antialias=True),
+ T.ToDtype(torch.float32, scale=True),
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ ])
+
+ image_tensor = image_transforms(input_image).unsqueeze(0).to(device=tensor_kwargs["device"], dtype=target_dtype)
+ frames_to_encode = torch.cat([image_tensor.unsqueeze(2), torch.zeros(1, 3, F - 1, h, w, device=image_tensor.device, dtype=target_dtype)], dim=2)
+
+ log.info(f"Encoding {F} frames...")
+
+ try:
+ free_mem, _ = torch.cuda.mem_get_info()
+ if free_mem < 24 * (1024**3):
+ raise torch.OutOfMemoryError("Pre-emptive tiling")
+ with torch.amp.autocast("cuda", dtype=target_dtype):
+ encoded_latents = tokenizer.encode(frames_to_encode)
+ except torch.OutOfMemoryError:
+ try:
+ cleanup_memory("Switching to Tiled Encode")
+ encoded_latents = tiled_encode_4x(tokenizer, frames_to_encode, target_dtype)
+ except Exception as e:
+ log.warning(f"Tiling failed ({e}). Fallback to CPU.")
+ encoded_latents = safe_cpu_fallback_encode(tokenizer, frames_to_encode, target_dtype)
+
+ print(f"โ
VAE Encode Complete.")
+ del frames_to_encode
+ cleanup_memory("After VAE Encode")
+
+ # Prepare for Diffusion
msk = torch.zeros(1, 4, lat_t, lat_h, lat_w, device=tensor_kwargs["device"], dtype=tensor_kwargs["dtype"])
msk[:, :, 0, :, :] = 1.0
-
y = torch.cat([msk, encoded_latents.to(**tensor_kwargs)], dim=1)
y = y.repeat(args.num_samples, 1, 1, 1, 1)
+ saved_latent_ch = tokenizer.latent_ch
+
+ del tokenizer
+ cleanup_memory("Unloaded VAE Model")
+
+ # 3. Diffusion Sampling
+ print("-" * 20 + " DIT LOADING " + "-" * 20)
+
+ current_model = load_dit_model(args, is_high_noise=True)
+ is_high_noise_active = True
+ fallback_triggered = args.offload_dit
- log.info(f"Generating with prompt: {args.prompt}")
condition = {"crossattn_emb": repeat(text_emb.to(**tensor_kwargs), "b l d -> (k b) l d", k=args.num_samples), "y_B_C_T_H_W": y}
-
- to_show = []
-
- state_shape = [tokenizer.latent_ch, lat_t, lat_h, lat_w]
-
- generator = torch.Generator(device=tensor_kwargs["device"])
- generator.manual_seed(args.seed)
-
- init_noise = torch.randn(
- args.num_samples,
- *state_shape,
- dtype=torch.float32,
- device=tensor_kwargs["device"],
- generator=generator,
- )
-
+
+ generator = torch.Generator(device=tensor_kwargs["device"]).manual_seed(args.seed)
+ init_noise = torch.randn(args.num_samples, saved_latent_ch, lat_t, lat_h, lat_w, dtype=torch.float32, device=tensor_kwargs["device"], generator=generator)
+
mid_t = [1.5, 1.4, 1.0][: args.num_steps - 1]
-
- t_steps = torch.tensor(
- [math.atan(args.sigma_max), *mid_t, 0],
- dtype=torch.float64,
- device=init_noise.device,
- )
-
- # Convert TrigFlow timesteps to RectifiedFlow
+ t_steps = torch.tensor([math.atan(args.sigma_max), *mid_t, 0], dtype=torch.float64, device=init_noise.device)
t_steps = torch.sin(t_steps) / (torch.cos(t_steps) + torch.sin(t_steps))
-
x = init_noise.to(torch.float64) * t_steps[0]
ones = torch.ones(x.size(0), 1, device=x.device, dtype=x.dtype)
- total_steps = t_steps.shape[0] - 1
- high_noise_model.cuda()
- net = high_noise_model
- switched = False
- for i, (t_cur, t_next) in enumerate(tqdm(list(zip(t_steps[:-1], t_steps[1:])), desc="Sampling", total=total_steps)):
- if t_cur.item() < args.boundary and not switched:
- high_noise_model.cpu()
- torch.cuda.empty_cache()
- low_noise_model.cuda()
- net = low_noise_model
- switched = True
- log.info("Switched to low noise model.")
- with torch.no_grad():
- v_pred = net(x_B_C_T_H_W=x.to(**tensor_kwargs), timesteps_B_T=(t_cur.float() * ones * 1000).to(**tensor_kwargs), **condition).to(
- torch.float64
- )
- if args.ode:
- x = x - (t_cur - t_next) * v_pred
- else:
- x = (1 - t_next) * (x - t_cur * v_pred) + t_next * torch.randn(
- *x.shape,
- dtype=torch.float32,
- device=tensor_kwargs["device"],
- generator=generator,
- )
+
+ # Always ensure CUDA initially
+ current_model.cuda()
+
+ print("-" * 20 + " SAMPLING START " + "-" * 20)
+ print_memory_status("High Noise Model to GPU")
+
+ # Sampling Loop
+ for i, (t_cur, t_next) in enumerate(tqdm(list(zip(t_steps[:-1], t_steps[1:])), total=len(t_steps)-1)):
+ if t_cur.item() < args.boundary and is_high_noise_active:
+ print(f"\n๐ Switching DiT Models (Step {i})...")
+ current_model.cpu()
+ del current_model
+ cleanup_memory("Unloaded High Noise")
+
+ current_model = load_dit_model(args, is_high_noise=False, force_offload=fallback_triggered)
+ current_model.cuda() # Force CUDA
+ is_high_noise_active = False
+ print_memory_status("Loaded Low Noise to GPU")
+
+ step_success = False
+ while not step_success:
+ try:
+ gc.collect()
+ torch.cuda.empty_cache()
+ with torch.no_grad():
+ v_pred = current_model(
+ x_B_C_T_H_W=x.to(**tensor_kwargs),
+ timesteps_B_T=(t_cur.float() * ones * 1000).to(**tensor_kwargs),
+ **condition
+ ).to(torch.float64)
+ step_success = True
+ except torch.OutOfMemoryError:
+ if fallback_triggered:
+ log.error("โ OOM occurred even after reload. Physical Memory Limit Reached.")
+ sys.exit(1)
+
+ print(f"\nโ ๏ธ OOM in DiT Sampling Step {i}. Reloading model to clear fragmentation...")
+ cleanup_memory("Pre-Reload")
+
+ # Unload and Reload to Defrag
+ was_high = is_high_noise_active
+ current_model.cpu()
+ del current_model
+ cleanup_memory("Unload for Reload")
+
+ fallback_triggered = True
+ current_model = load_dit_model(args, is_high_noise=was_high, force_offload=True)
+ current_model.cuda() # Move back to GPU
+
+ print("โป๏ธ Model Reloaded. Retrying step...")
+
+ if args.ode:
+ x = x - (t_cur - t_next) * v_pred
+ else:
+ x = (1 - t_next) * (x - t_cur * v_pred) + t_next * torch.randn(*x.shape, dtype=torch.float32, device=tensor_kwargs["device"], generator=generator)
+
samples = x.float()
- low_noise_model.cpu()
- torch.cuda.empty_cache()
-
+
+ print("-" * 20 + " DECODE SETUP (DEFRAG) " + "-" * 20)
+ samples_cpu_backup = samples.cpu()
+ del samples
+ del x
+ current_model.cpu()
+ del current_model
+ cleanup_memory("FULL WIPE before VAE Load")
+
+ log.info("Reloading VAE for decoding...")
+ tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path)
+ print_memory_status("Reloaded VAE (Clean Slate)")
+
+ samples = samples_cpu_backup.to(device=tensor_kwargs["device"])
+
with torch.no_grad():
- video = tokenizer.decode(samples)
-
- to_show.append(video.float().cpu())
-
+ success = False
+ video = None
+
+ try:
+ log.info("Attempting Standard GPU Decode...")
+ video = tokenizer.decode(samples)
+ success = True
+ except torch.OutOfMemoryError:
+ log.warning("โ ๏ธ GPU OOM (Standard). Switching to Tiled GPU Decode...")
+ cleanup_memory("Pre-Tile Fallback")
+
+ try:
+ # 12 Latents overlap = 96 Image pixels
+ video = tiled_decode_gpu(tokenizer, samples, overlap=12)
+ success = True
+ except (torch.OutOfMemoryError, RuntimeError) as e:
+ log.warning(f"โ ๏ธ GPU Tiled Decode Failed ({e}). Switching to CPU Decode (Slow)...")
+ cleanup_memory("Pre-CPU Fallback")
+
+ if not success:
+ log.info("Performing Hard Cast of VAE to CPU Float32...")
+ samples_cpu = samples.cpu().float()
+ force_cpu_float32(tokenizer)
+ with torch.autocast("cpu", enabled=False):
+ with torch.autocast("cuda", enabled=False):
+ video = tokenizer.decode(samples_cpu)
+
+ to_show = [video.float().cpu()]
to_show = (1.0 + torch.stack(to_show, dim=0).clamp(-1, 1)) / 2.0
-
save_image_or_video(rearrange(to_show, "n b c t h w -> c t (n h) (b w)"), args.save_path, fps=16)
+ log.success("Done.")