diff --git a/TurboDiffusion_Studio.md b/TurboDiffusion_Studio.md new file mode 100644 index 0000000..d0d7f4f --- /dev/null +++ b/TurboDiffusion_Studio.md @@ -0,0 +1,123 @@ + +## ๐Ÿš€ Scripts & Inference + +This repository contains optimized inference engines for the Wan2.1 and Wan2.2 models, specifically tuned for high-resolution output and robust memory management on consumer hardware. + +### ๐ŸŽฅ Inference Engines + +| Script | Function | Key Features | +| --- | --- | --- | +| **`wan2.2_i2v_infer.py`** | **Image-to-Video** | **Tiered Failover System**: Automatic recovery from OOM errors.
+ +
**Intelligent Model Switching**: Transitions between High and Low Noise models based on step boundaries.
+ +
**Tiled Processing**: Uses 4-chunk tiled encoding/decoding for 720p+ stability. | +| **`wan2.1_t2v_infer.py`** | **Text-to-Video** | **Hardware Auto-Detection**: Automatically selects TF32, BF16, or FP16 based on GPU capabilities.
+ +
**Quantization Safety**: Force-disables `torch.compile` for quantized models to prevent graph-break OOMs.
+ +
**3-Tier Recovery**: Escalates from GPU โž” Checkpointing โž” Manual CPU Offloading if memory is exceeded. | + +### ๐Ÿ› ๏ธ Utilities + +* **`cache_t5.py`** +* **Purpose**: Pre-computes and saves T5 text embeddings to disk. +* **VRAM Benefit**: Eliminates the need to load the **11GB T5 encoder** during the main inference run, allowing 14B models to fit on GPUs with lower VRAM. +* **Usage**: Run this first to generate a `.pt` file, then pass it to the inference scripts using the `--cached_embedding` flag. + + +--- + +## ๐Ÿš€ Getting Started with TurboDiffusion + +To run the large 14B models on consumer GPUs, it is recommended to use the **T5 Caching** workflow. This offloads the 11GB text encoder from VRAM, leaving more space for the DiT model and high-resolution video decoding. + +### **Step 1: Environment Setup** + +Ensure your project structure is organized as follows: + +* **Root**: `/your/path/to/TurboDiffusion` +* **Checkpoints**: Place your `.pth` models in the `checkpoints/` directory. +* **Output**: Generated videos and metadata will be saved to `output/`. + +### **Step 2: The Two Ways to Cache T5** + +#### **Option A: Manual Pre-Caching (Recommended for Batching)** + +If you have a list of prompts you want to use frequently, use the standalone utility: + +```bash +python turbodiffusion/inference/cache_t5.py --prompt "Your descriptive prompt here" --output cached_t5_embeddings.pt + +``` + +This saves the processed text into a small `.pt` file, allowing the inference scripts to "skip" the heavy T5 model entirely. + +#### **Option B: Automatic Caching via Web UI** + +For a more streamlined experience, use the **TurboDiffusion Studio**: + +1. Launch the UI: `python turbo_diffusion_t5_cache_optimize_v6.py`. +2. Open the **Precision & Advanced Settings** accordion. +3. Check **Use Cached T5 Embeddings (Auto-Run)**. +4. When you click generate, the UI will automatically run the caching script first, clear the T5 model from memory, and then start the video generation. + +### **Step 3: Running Inference** + +Once your UI is launched and caching is configured: + +1. **Select Mode**: Choose between **Text-to-Video** (Wan2.1) or **Image-to-Video** (Wan2.2). +2. **Apply Quantization**: For 24GB VRAM GPUs (like the RTX 3090/4090/5090), ensure **Enable --quant_linear (8-bit)** is checked to avoid OOM errors. +3. **Monitor Hardware**: Watch the **Live GPU Monitor** at the top of the UI to track real-time VRAM usage during the sampling process. +4. **Retrieve Results**: Your video and its reproduction metadata (containing the exact CLI command used) will appear in the `output/` gallery. + + +--- + +## ๐Ÿ–ฅ๏ธ TurboDiffusion Studio (Web UI) + +The `turbo_diffusion_t5_cache_optimize_v6.py` script provides a high-performance, unified **Gradio-based Web interface** for both Text-to-Video and Image-to-Video generation. It serves as a centralized "Studio" dashboard that automates complex environment setups and memory optimizations. + +### **Key Features** + +| Feature | Description | +| --- | --- | +| **Unified Interface** | Toggle between **Text-to-Video (Wan2.1)** and **Image-to-Video (Wan2.2)** workflows within a single dashboard. | +| **Real-time GPU Monitor** | Native PyTorch-based VRAM monitoring that displays current memory usage and hardware status directly in the UI. | +| **Auto-Cache T5 Integration** | Automatically runs the `cache_t5.py` utility before inference to offload the 11GB text encoder, significantly reducing peak VRAM usage. | +| **Frame Sanitization** | Automatically enforces the **4n + 1 rule** required by the Wan VAE to prevent kernel crashes during decoding. | +| **Reproduction Metadata** | Every generated video automatically saves a matching `_metadata.txt` file containing the exact CLI command and environment variables needed to reproduce the result. | +| **Live Console Output** | Pipes real-time CLI logs and progress bars directly into a "Live Console" window in the web browser. | + +### **Advanced Controls** + +The UI exposes granular controls for technical users: + +* **Precision & Quantization:** Toggle 8-bit `--quant_linear` mode for low-VRAM operation. +* **Attention Tuning:** Switch between `sagesla`, `sla`, and `original` attention mechanisms. +* **Adaptive I2V:** Enable adaptive resolution and ODE solvers for Image-to-Video workflows. +* **Integrated Gallery:** Browse and view your output history directly within the `output/` directory. + +--- + +## ๐Ÿ› ๏ธ Usage + +To launch the studio: + +```bash +python turbo_diffusion_t5_cache_optimize_v6.py + +``` + +> **Note:** The script defaults to `/your/path/to/TurboDiffusion`as the project root. Ensure your local paths are configured accordingly in the **System Setup** section of the code. + + +--- + +## ๐Ÿ’ณ Credits & Acknowledgments + +If you utilize, share, or build upon these optimized scripts, please include the following acknowledgments: + +* **Optimization & Development**: Co-developed by **Waverly Edwards** and **Google Gemini**. +* **T5 Caching Logic**: Original concept and utility implementation by **John D. Pope**. +* **Base Framework**: Built upon the NVIDIA Imaginaire and Wan-Video research. diff --git a/turbodiffusion/inference/cache_t5.py b/turbodiffusion/inference/cache_t5.py new file mode 100644 index 0000000..a25e3cf --- /dev/null +++ b/turbodiffusion/inference/cache_t5.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python +# ----------------------------------------------------------------------------------------- +# T5 EMBEDDING CACHE UTILITY +# +# Acknowledgments: +# - Work and creativity of: John D. Pope +# +# Description: +# Pre-computes text embeddings to allow running inference on GPUs with limited VRAM +# by removing the need to keep the 11GB T5 encoder loaded in memory. +# ----------------------------------------------------------------------------------------- +""" +Pre-cache T5 text embeddings to avoid loading the 11GB model during inference. + +Usage: + # Cache a single prompt + python scripts/cache_t5.py --prompt "slow head turn, cinematic" --output cached_embeddings.pt + + # Cache multiple prompts from file + python scripts/cache_t5.py --prompts_file prompts.txt --output cached_embeddings.pt + +Then use with inference: + python turbodiffusion/inference/wan2.2_i2v_infer.py \ + --cached_embedding cached_embeddings.pt \ + --skip_t5 \ + ... +""" +import os +import sys +import argparse +import torch + +# Add repo root to path for imports +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +REPO_ROOT = os.path.dirname(SCRIPT_DIR) +sys.path.insert(0, REPO_ROOT) + +def main(): + parser = argparse.ArgumentParser(description="Pre-cache T5 text embeddings") + parser.add_argument("--prompt", type=str, default=None, help="Single prompt to cache") + parser.add_argument("--prompts_file", type=str, default=None, help="File with prompts (one per line)") + parser.add_argument("--text_encoder_path", type=str, + default="/media/2TB/ComfyUI/models/text_encoders/models_t5_umt5-xxl-enc-bf16.pth", + help="Path to the umT5 text encoder") + parser.add_argument("--output", type=str, default="cached_t5_embeddings.pt", + help="Output path for cached embeddings") + parser.add_argument("--device", type=str, default="cuda", + help="Device to use for encoding (cuda is faster, memory freed after)") + args = parser.parse_args() + + # Collect prompts + prompts = [] + if args.prompt: + prompts.append(args.prompt) + if args.prompts_file and os.path.exists(args.prompts_file): + with open(args.prompts_file, 'r') as f: + prompts.extend([line.strip() for line in f if line.strip()]) + + if not prompts: + print("Error: Provide --prompt or --prompts_file") + sys.exit(1) + + print(f"Caching embeddings for {len(prompts)} prompt(s)") + print(f"Text encoder: {args.text_encoder_path}") + print(f"Device: {args.device}") + print() + + # Import after path setup + from rcm.utils.umt5 import get_umt5_embedding, clear_umt5_memory + + cache_data = { + 'prompts': prompts, + 'embeddings': [], + 'text_encoder_path': args.text_encoder_path, + } + + with torch.no_grad(): + for i, prompt in enumerate(prompts): + print(f"[{i+1}/{len(prompts)}] Encoding: '{prompt[:60]}...' " if len(prompt) > 60 else f"[{i+1}/{len(prompts)}] Encoding: '{prompt}'") + + # Get embedding (loads T5 if not already loaded) + embedding = get_umt5_embedding( + checkpoint_path=args.text_encoder_path, + prompts=prompt + ) + + # Move to CPU for storage + cache_data['embeddings'].append({ + 'prompt': prompt, + 'embedding': embedding.cpu(), + 'shape': list(embedding.shape), + }) + + print(f" Shape: {embedding.shape}, dtype: {embedding.dtype}") + + # Clear T5 from memory + print("\nClearing T5 from memory...") + clear_umt5_memory() + torch.cuda.empty_cache() + + # Save cache + print(f"\nSaving to: {args.output}") + torch.save(cache_data, args.output) + + # Summary + file_size = os.path.getsize(args.output) / (1024 * 1024) + print(f"Done! Cache file size: {file_size:.2f} MB") + print() + print("Usage:") + print(f" python turbodiffusion/inference/wan2.2_i2v_infer.py \\") + print(f" --cached_embedding {args.output} \\") + print(f" --skip_t5 \\") + print(f" ... (other args)") + + +if __name__ == "__main__": + main() diff --git a/turbodiffusion/inference/turbo_diffusion_t5_cache_optimize_v6.py b/turbodiffusion/inference/turbo_diffusion_t5_cache_optimize_v6.py new file mode 100644 index 0000000..0525f4b --- /dev/null +++ b/turbodiffusion/inference/turbo_diffusion_t5_cache_optimize_v6.py @@ -0,0 +1,343 @@ +import os +import sys +import subprocess +import gradio as gr +import glob +import random +import time +import select +import torch +from datetime import datetime + +# --- 1. System Setup --- +PROJECT_ROOT = "/home/wedwards/Documents/Programs/TurboDiffusion" +os.chdir(PROJECT_ROOT) +os.system('clear' if os.name == 'posix' else 'cls') + +CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, "checkpoints") +OUTPUT_DIR = os.path.join(PROJECT_ROOT, "output") +os.makedirs(OUTPUT_DIR, exist_ok=True) + +T2V_SCRIPT = "turbodiffusion/inference/wan2.1_t2v_infer.py" +I2V_SCRIPT = "turbodiffusion/inference/wan2.2_i2v_infer.py" +CACHE_SCRIPT = "turbodiffusion/inference/cache_t5.py" + +def get_gpu_status_original(): + """System-level GPU check.""" + try: + res = subprocess.check_output( + ["nvidia-smi", "--query-gpu=name,memory.used,memory.total", "--format=csv,nounits,noheader"], + encoding='utf-8' + ).strip().split(',') + return f"๐Ÿ–ฅ๏ธ {res[0]} | โšก VRAM: {res[1]}MB / {res[2]}MB" + except: + return "๐Ÿ–ฅ๏ธ GPU Monitor Active" + + +def get_gpu_status(): + """ + Check GPU status using PyTorch. + Returns system-wide VRAM usage without relying on nvidia-smi CLI. + """ + try: + # 1. Check for CUDA (NVIDIA) or ROCm (AMD) + if torch.cuda.is_available(): + # mem_get_info returns (free_bytes, total_bytes) + free_mem, total_mem = torch.cuda.mem_get_info() + + used_mem = total_mem - free_mem + + # Convert to MB for display + total_mb = int(total_mem / 1024**2) + used_mb = int(used_mem / 1024**2) + name = torch.cuda.get_device_name(0) + + return f"๐Ÿ–ฅ๏ธ {name} | โšก VRAM: {used_mb}MB / {total_mb}MB" + + # 2. Check for Apple Silicon (MPS) + # Note: Apple uses Unified Memory, so 'VRAM' is shared with System RAM. + elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + return "๐Ÿ–ฅ๏ธ Apple Silicon (MPS) | โšก Unified Memory Active" + + # 3. Fallback to CPU + else: + return "๐Ÿ–ฅ๏ธ Running on CPU" + + except ImportError: + return "๐Ÿ–ฅ๏ธ GPU Monitor: PyTorch not installed" + except Exception as e: + return f"๐Ÿ–ฅ๏ธ GPU Monitor Error: {str(e)}" + +def save_debug_metadata(video_path, script_rel, cmd_list, cache_cmd_list=None): + """ + Saves a fully executable reproduction script with env vars. + """ + meta_path = video_path.replace(".mp4", "_metadata.txt") + with open(meta_path, "w") as f: + f.write(f"# Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write("# Copy and paste the lines below to reproduce this video exactly:\n\n") + + # Environment Variables + f.write("export PYTHONPATH=turbodiffusion\n") + f.write("export PYTORCH_ALLOC_CONF=expandable_segments:True\n") + f.write("export TOKENIZERS_PARALLELISM=false\n\n") + + # Optional Cache Step + if cache_cmd_list: + f.write("# --- Step 1: Pre-Cache Embeddings ---\n") + f.write(f"python {CACHE_SCRIPT} \\\n") + c_args = cache_cmd_list[2:] + for i, arg in enumerate(c_args): + if arg.startswith("--"): + val = f'"{c_args[i+1]}"' if i+1 < len(c_args) and not c_args[i+1].startswith("--") else "" + f.write(f" {arg} {val} \\\n") + f.write("\n# --- Step 2: Run Inference ---\n") + + # Main Inference Command + f.write(f"python {script_rel} \\\n") + args_only = cmd_list[2:] + for i, arg in enumerate(args_only): + if arg.startswith("--"): + val = f'"{args_only[i+1]}"' if i+1 < len(args_only) and not args_only[i+1].startswith("--") else "" + f.write(f" {arg} {val} \\\n") + +def sync_path(scale): + fname = "TurboWan2.1-T2V-1.3B-480P-quant.pth" if "1.3B" in scale else "TurboWan2.1-T2V-14B-720P-quant.pth" + return os.path.join(CHECKPOINT_DIR, fname) + +# --- 2. Unified Generation Logic (With Safety Checks) --- + +def run_gen(mode, prompt, model, dit_path, i2v_high, i2v_low, image, res, ratio, steps, seed, quant, attn, top_k, frames, sigma, norm, adapt, ode, use_cache, cache_path, pr=gr.Progress()): + # --- PRE-FLIGHT SAFETY CHECK --- + error_msg = "" + if mode == "T2V": + if "quant" in dit_path.lower() and not quant: + error_msg = "โŒ CONFIG ERROR: Quantized model selected but '8-bit' disabled." + if attn == "original" and ("turbo" in dit_path.lower() or "quant" in dit_path.lower()): + error_msg = "โŒ COMPATIBILITY ERROR: 'Original' attention with Turbo/Quantized checkpoint." + else: + if ("quant" in i2v_high.lower() or "quant" in i2v_low.lower()) and not quant: + error_msg = "โŒ CONFIG ERROR: Quantized I2V model selected but '8-bit' disabled." + if attn == "original" and (("turbo" in i2v_high.lower() or "quant" in i2v_high.lower()) or ("turbo" in i2v_low.lower() or "quant" in i2v_low.lower())): + error_msg = "โŒ COMPATIBILITY ERROR: 'Original' attention with Turbo/Quantized checkpoints." + + if error_msg: + yield None, None, "โŒ Config Error", "๐Ÿ›‘ Aborted", error_msg + return + # ------------------------------- + + actual_seed = random.randint(1, 1000000) if seed <= 0 else int(seed) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + start_time = time.time() + + full_log = f"๐Ÿš€ Starting Job: {timestamp}\n" + pr(0, desc="๐Ÿš€ Starting...") + + # --- FRAME SANITIZATION (4n+1 RULE) --- + # Wan2.1 VAE requires frames to be (4n + 1). If not, we sanitize. + target_frames = int(frames) + valid_frames = ((target_frames - 1) // 4) * 4 + 1 + + # If the user input (e.g., 32) became smaller (29) or changed, we log it. + # Note: We enforce a minimum of 1 frame just in case. + valid_frames = max(1, valid_frames) + + if valid_frames != target_frames: + warning_msg = f"โš ๏ธ AUTO-ADJUST: Frame count {target_frames} is incompatible with VAE (requires 4n+1).\n" + warning_msg += f" Adjusted {target_frames} -> {valid_frames} frames to prevent kernel crash.\n" + full_log += warning_msg + print(warning_msg) # Print to console as well + + # Use valid_frames for the rest of the logic + frames = valid_frames + # -------------------------------------- + + # --- AUTO-CACHE STEP --- + cache_cmd_list = None + if use_cache: + pr(0, desc="๐Ÿ’พ Auto-Caching T5 Embeddings...") + cache_script_full = os.path.join(PROJECT_ROOT, CACHE_SCRIPT) + encoder_path = os.path.join(CHECKPOINT_DIR, "models_t5_umt5-xxl-enc-bf16.pth") + + cache_cmd = [ + sys.executable, + cache_script_full, + "--prompt", prompt, + "--output", cache_path, + "--text_encoder_path", encoder_path + ] + cache_cmd_list = cache_cmd + + full_log += f"\n[System] Running Cache Script: {' '.join(cache_cmd)}\n" + yield None, None, f"Seed: {actual_seed}", "๐Ÿ’พ Caching...", full_log + + cache_process = subprocess.Popen(cache_cmd, cwd=PROJECT_ROOT, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1) + + while True: + if cache_process.poll() is not None: + rest = cache_process.stdout.read() + if rest: full_log += rest + break + line = cache_process.stdout.readline() + if line: + full_log += line + yield None, None, f"Seed: {actual_seed}", "๐Ÿ’พ Caching...", full_log + time.sleep(0.02) + + if cache_process.returncode != 0: + full_log += "\nโŒ CACHE FAILED. Aborting generation." + yield None, None, "โŒ Cache Failed", "๐Ÿ›‘ Aborted", full_log + return + + full_log += "\nโœ… Cache Complete. Starting Inference...\n" + # ----------------------------------------- + + # --- SETUP VIDEO GENERATION --- + if mode == "T2V": + save_path = os.path.join(OUTPUT_DIR, f"t2v_{timestamp}.mp4") + script_rel = T2V_SCRIPT + cmd = [sys.executable, os.path.join(PROJECT_ROOT, T2V_SCRIPT), "--model", model, "--dit_path", dit_path, "--prompt", prompt, "--resolution", res, "--aspect_ratio", ratio, "--num_steps", str(steps), "--seed", str(actual_seed), "--attention_type", attn, "--sla_topk", str(top_k), "--num_samples", "1", "--num_frames", str(frames), "--sigma_max", str(sigma)] + else: + save_path = os.path.join(OUTPUT_DIR, f"i2v_{timestamp}.mp4") + script_rel = I2V_SCRIPT + # Note: Added frames to I2V command in previous step, maintained here. + cmd = [sys.executable, os.path.join(PROJECT_ROOT, I2V_SCRIPT), "--prompt", prompt, "--image_path", image, "--high_noise_model_path", i2v_high, "--low_noise_model_path", i2v_low, "--resolution", res, "--aspect_ratio", ratio, "--num_steps", str(steps), "--seed", str(actual_seed), "--attention_type", attn, "--sla_topk", str(top_k), "--num_frames", str(frames)] + if adapt: cmd.append("--adaptive_resolution") + if ode: cmd.append("--ode") + + if quant: cmd.append("--quant_linear") + if norm: cmd.append("--default_norm") + + if use_cache: + cmd.extend(["--cached_embedding", cache_path, "--skip_t5"]) + + cmd.extend(["--save_path", save_path]) + + # Call the restored metadata saver + save_debug_metadata(save_path, script_rel, cmd, cache_cmd_list) + + env = os.environ.copy() + env["PYTHONPATH"] = os.path.join(PROJECT_ROOT, "turbodiffusion") + env["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" + env["TOKENIZERS_PARALLELISM"] = "false" + env["PYTHONUNBUFFERED"] = "1" + + full_log += f"\n[System] Running Inference: {' '.join(cmd)}\n" + process = subprocess.Popen(cmd, cwd=PROJECT_ROOT, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1) + + last_ui_update = 0 + + while True: + if process.poll() is not None: + rest = process.stdout.read() + if rest: full_log += rest + break + + reads = [process.stdout.fileno()] + ret = select.select(reads, [], [], 0.1) + + if ret[0]: + line = process.stdout.readline() + full_log += line + + if "Loading DiT" in line: pr(0.1, desc="โšก Loading weights...") + if "Encoding" in line: pr(0.05, desc="๐Ÿ–ผ๏ธ VAE Encoding...") + if "Switching to CPU" in line: pr(0.1, desc="โš ๏ธ CPU Fallback...") + if "Sampling:" in line: + try: + pct = int(line.split('%')[0].split('|')[-1].strip()) + pr(0.2 + (pct/100 * 0.7), desc=f"๐ŸŽฌ Sampling: {pct}%") + except: pass + if "decoding" in line.lower(): pr(0.95, desc="๐ŸŽฅ Decoding VAE...") + + current_time = time.time() + if current_time - last_ui_update > 0.25: + last_ui_update = current_time + elapsed = f"{int(current_time - start_time)}s" + yield None, None, f"Seed: {actual_seed}", f"โฑ๏ธ Time: {elapsed}", full_log + + history = sorted(glob.glob(os.path.join(OUTPUT_DIR, "*.mp4")), key=os.path.getmtime, reverse=True) + total_time = f"{int(time.time() - start_time)}s" + + yield save_path, history, f"โœ… Done | Seed: {actual_seed}", f"๐Ÿ Finished in {total_time}", full_log + +# --- 3. UI Layout --- +with gr.Blocks(title="TurboDiffusion Studio") as demo: + with gr.Row(): + gr.HTML("

โšก TurboDiffusion Studio

") + with gr.Column(scale=1): + gpu_display = gr.Markdown(get_gpu_status()) + + gr.Timer(2).tick(get_gpu_status, outputs=gpu_display) + + with gr.Tabs(): + with gr.Tab("Text-to-Video"): + with gr.Row(): + with gr.Column(scale=4): + t2v_p = gr.Textbox(label="Prompt", lines=3, value="A stylish woman walks down a Tokyo street...") + with gr.Row(): + t2v_m = gr.Radio(["Wan2.1-1.3B", "Wan2.1-14B"], label="Model", value="Wan2.1-1.3B") + t2v_res = gr.Dropdown(["480p", "720p"], label="Resolution", value="480p") + t2v_ratio = gr.Dropdown(["16:9", "4:3", "1:1", "9:16"], label="Aspect Ratio", value="16:9") + t2v_dit = gr.Textbox(label="DiT Path", value=sync_path("Wan2.1-1.3B"), interactive=False) + t2v_btn = gr.Button("Generate Video", variant="primary") + with gr.Column(scale=3): + t2v_out = gr.Video(label="Result", height=320) + with gr.Row(): + t2v_stat = gr.Textbox(label="Status", interactive=False, scale=2) + t2v_time = gr.Textbox(label="Timer", value="โฑ๏ธ Ready", interactive=False, scale=1) + + with gr.Tab("Image-to-Video"): + with gr.Row(): + with gr.Column(scale=4): + with gr.Row(): + i2v_img = gr.Image(label="Source", type="filepath", height=200) + i2v_p = gr.Textbox(label="Motion Prompt", lines=7) + with gr.Row(): + i2v_res = gr.Dropdown(["480p", "720p"], label="Resolution", value="720p") + i2v_ratio = gr.Dropdown(["16:9", "4:3", "1:1", "9:16"], label="Aspect Ratio", value="16:9") + with gr.Row(): + i2v_adapt = gr.Checkbox(label="Adaptive Resolution", value=True) + i2v_ode = gr.Checkbox(label="Use ODE", value=False) + with gr.Accordion("I2V Path Overrides", open=False): + i2v_high = gr.Textbox(label="High-Noise", value=os.path.join(CHECKPOINT_DIR, "TurboWan2.2-I2V-A14B-high-720P-quant.pth")) + i2v_low = gr.Textbox(label="Low-Noise", value=os.path.join(CHECKPOINT_DIR, "TurboWan2.2-I2V-A14B-low-720P-quant.pth")) + i2v_btn = gr.Button("Animate Image", variant="primary") + with gr.Column(scale=3): + i2v_out = gr.Video(label="Result", height=320) + with gr.Row(): + i2v_stat_2 = gr.Textbox(label="Status", interactive=False, scale=2) + i2v_time_2 = gr.Textbox(label="Timer", value="โฑ๏ธ Ready", interactive=False, scale=1) + + console_out = gr.Textbox(label="Live CLI Console Output", lines=8, max_lines=8, interactive=False) + + with gr.Accordion("โš™๏ธ Precision & Advanced Settings", open=False): + with gr.Row(): + quant_opt = gr.Checkbox(label="Enable --quant_linear (8-bit)", value=True) + steps_opt = gr.Slider(1, 4, value=4, step=1, label="Steps") + seed_opt = gr.Number(label="Seed (0=Random)", value=0, precision=0) + with gr.Row(): + top_k_opt = gr.Slider(0.01, 0.5, value=0.15, step=0.01, label="SLA Top-K") + attn_opt = gr.Radio(["sagesla", "sla", "original"], label="Attention", value="sagesla") + sigma_opt = gr.Number(label="Sigma Max", value=80) + norm_opt = gr.Checkbox(label="Original Norms", value=False) + frames_opt = gr.Slider(1, 120, value=77, step=4, label="Frames (Steps of 4)") + with gr.Row(variant="panel"): + # --- T5 CACHE UI --- + use_cache_opt = gr.Checkbox(label="Use Cached T5 Embeddings (Auto-Run)", value=True) + cache_path_opt = gr.Textbox(label="Cache File Path", value="cached_t5_embeddings.pt", scale=2) + # ------------------- + + history_gal = gr.Gallery(value=sorted(glob.glob(os.path.join(OUTPUT_DIR, "*.mp4")), reverse=True), columns=6, height="auto") + + # --- 4. Logic Bindings --- + t2v_m.change(fn=sync_path, inputs=t2v_m, outputs=t2v_dit) + + t2v_args = [gr.State("T2V"), t2v_p, t2v_m, t2v_dit, gr.State(""), gr.State(""), gr.State(""), t2v_res, t2v_ratio, steps_opt, seed_opt, quant_opt, attn_opt, top_k_opt, frames_opt, sigma_opt, norm_opt, gr.State(False), gr.State(False), use_cache_opt, cache_path_opt] + t2v_btn.click(run_gen, t2v_args, [t2v_out, history_gal, t2v_stat, t2v_time, console_out], show_progress="hidden") + + i2v_args = [i2v_img, i2v_p, gr.State("Wan2.2-A14B"), gr.State(""), i2v_high, i2v_low, i2v_img, i2v_res, i2v_ratio, steps_opt, seed_opt, quant_opt, attn_opt, top_k_opt, frames_opt, gr.State(200), norm_opt, i2v_adapt, i2v_ode, use_cache_opt, cache_path_opt] + i2v_btn.click(run_gen, i2v_args, [i2v_out, history_gal, i2v_stat_2, i2v_time_2, console_out], show_progress="hidden") + +if __name__ == "__main__": + demo.launch(theme=gr.themes.Default(), allowed_paths=[OUTPUT_DIR]) \ No newline at end of file diff --git a/turbodiffusion/inference/wan2.1_t2v_infer.py b/turbodiffusion/inference/wan2.1_t2v_infer.py index c5581e4..a45220d 100644 --- a/turbodiffusion/inference/wan2.1_t2v_infer.py +++ b/turbodiffusion/inference/wan2.1_t2v_infer.py @@ -1,24 +1,58 @@ +# Blackwell Bridge # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# ----------------------------------------------------------------------------------------- +# TURBODIFFUSION OPTIMIZED INFERENCE SCRIPT (T2V) # -# http://www.apache.org/licenses/LICENSE-2.0 +# Co-developed by: Waverly Edwards & Google Gemini (2025) # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Modifications: +# - Implemented "Tiered Failover System" for robust OOM protection (GPU -> Checkpoint -> CPU). +# - Added Intelligent Hardware Detection (TF32/BF16/FP16 auto-switching). +# - Integrated Tiled Decoding for high-resolution VAE processing. +# - Added Support for Pre-cached Text Embeddings to skip T5 loading. +# - Optimized compilation logic for Quantized models (preventing graph breaks). +# +# Acknowledgments: +# - Made possible by the work (cache_t5.py) and creativity of: John D. Pope +# +# Description: +# cache_t5.py pre-computes text embeddings to allow running inference on GPUs with limited VRAM +# by removing the need to keep the 11GB T5 encoder loaded in memory. +# +# CREDIT REQUEST: +# If you utilize, share, or build upon this specific optimized script, please +# acknowledge Waverly Edwards and Google Gemini in your documentation or credits. +# ----------------------------------------------------------------------------------------- import argparse import math +import os +import gc +import time +import sys + +# --- 1. Memory Tuning --- +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import torch +import torch.nn.functional as F from einops import rearrange, repeat from tqdm import tqdm +import numpy as np + +# --- 2. Hardware Optimization --- +if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + # 'high' allows TF32 but maintains reasonable precision + torch.set_float32_matmul_precision('high') + +try: + import psutil +except ImportError: + psutil = None from imaginaire.utils.io import save_image_or_video from imaginaire.utils import log @@ -29,123 +63,303 @@ from modify_model import tensor_kwargs, create_model +# Suppress graph break warnings for cleaner output torch._dynamo.config.suppress_errors = True - +torch._dynamo.config.verbose = False def parse_arguments() -> argparse.Namespace: parser = argparse.ArgumentParser(description="TurboDiffusion inference script for Wan2.1 T2V") - parser.add_argument("--dit_path", type=str, required=True, help="Custom path to the DiT model checkpoint for distilled models") + parser.add_argument("--dit_path", type=str, required=True, help="Custom path to the DiT model checkpoint") parser.add_argument("--model", choices=["Wan2.1-1.3B", "Wan2.1-14B"], default="Wan2.1-1.3B", help="Model to use") parser.add_argument("--num_samples", type=int, default=1, help="Number of samples to generate") parser.add_argument("--num_steps", type=int, choices=[1, 2, 3, 4], default=4, help="1~4 for timestep-distilled inference") parser.add_argument("--sigma_max", type=float, default=80, help="Initial sigma for rCM") parser.add_argument("--vae_path", type=str, default="checkpoints/Wan2.1_VAE.pth", help="Path to the Wan2.1 VAE") parser.add_argument("--text_encoder_path", type=str, default="checkpoints/models_t5_umt5-xxl-enc-bf16.pth", help="Path to the umT5 text encoder") + parser.add_argument("--cached_embedding", type=str, default=None, help="Path to cached text embeddings (pt file)") + parser.add_argument("--skip_t5", action="store_true", help="Skip T5 loading (implied if cached_embedding is used)") parser.add_argument("--num_frames", type=int, default=81, help="Number of frames to generate") - parser.add_argument("--prompt", type=str, default=None, help="Text prompt for video generation (required unless --serve)") + parser.add_argument("--prompt", type=str, required=True, help="Text prompt for video generation") parser.add_argument("--resolution", default="480p", type=str, help="Resolution of the generated output") parser.add_argument("--aspect_ratio", default="16:9", type=str, help="Aspect ratio of the generated output (width:height)") parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility") - parser.add_argument("--save_path", type=str, default="output/generated_video.mp4", help="Path to save the generated video (include file extension)") - parser.add_argument("--attention_type", choices=["sla", "sagesla", "original"], default="sagesla", help="Type of attention mechanism to use") + parser.add_argument("--save_path", type=str, default="output/generated_video.mp4", help="Path to save the generated video") + parser.add_argument("--attention_type", choices=["sla", "sagesla", "original"], default="sagesla", help="Type of attention mechanism") parser.add_argument("--sla_topk", type=float, default=0.1, help="Top-k ratio for SLA/SageSLA attention") parser.add_argument("--quant_linear", action="store_true", help="Whether to replace Linear layers with quantized versions") - parser.add_argument("--default_norm", action="store_true", help="Whether to replace LayerNorm/RMSNorm layers with faster versions") - parser.add_argument("--serve", action="store_true", help="Launch interactive TUI server mode (keeps model loaded)") + parser.add_argument("--default_norm", action="store_true", help="Whether to replace LayerNorm/RMSNorm with faster versions") + parser.add_argument("--offload_dit", action="store_true", help="Offload DiT to CPU when not in use to save VRAM") + parser.add_argument("--compile", action="store_true", help="Use torch.compile (Inductor) for faster inference") return parser.parse_args() +def check_hardware_compatibility(): + if not torch.cuda.is_available(): return + gpu_name = torch.cuda.get_device_name(0) + log.info(f"Hardware: {gpu_name}") + + current_dtype = tensor_kwargs.get("dtype") + if current_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): + log.warning(f"โš ๏ธ Device does not support BFloat16. Switching to Float16.") + tensor_kwargs["dtype"] = torch.float16 -if __name__ == "__main__": - args = parse_arguments() +def print_memory_status(step_name=""): + if not torch.cuda.is_available(): return + torch.cuda.synchronize() + allocated = torch.cuda.memory_allocated() / (1024**3) + free, total = torch.cuda.mem_get_info() + free_gb = free / (1024**3) + print(f"๐Ÿ“Š [MEM] {step_name}: InUse={allocated:.2f}GB, Free={free_gb:.2f}GB") - # Handle serve mode - if args.serve: - # Set mode to t2v for the TUI server - args.mode = "t2v" - from serve.tui import main as serve_main - serve_main(args) - exit(0) +def cleanup_memory(step_info=""): + gc.collect() + torch.cuda.empty_cache() - # Validate prompt is provided for one-shot mode - if args.prompt is None: - log.error("--prompt is required (unless using --serve mode)") - exit(1) +def load_dit_model(args, force_offload=False): + orig_offload = args.offload_dit + if force_offload: args.offload_dit = True + log.info(f"Loading DiT (Offload={args.offload_dit})...") + model = create_model(dit_path=args.dit_path, args=args).cpu() + args.offload_dit = orig_offload + return model - log.info(f"Computing embedding for prompt: {args.prompt}") - with torch.no_grad(): - text_emb = get_umt5_embedding(checkpoint_path=args.text_encoder_path, prompts=args.prompt).to(**tensor_kwargs) - clear_umt5_memory() +def tiled_decode_gpu(tokenizer, latents, overlap=12): + print(f"\n๐Ÿงฑ Starting Tiled GPU Decode (Overlap={overlap})...") + B, C, T, H, W = latents.shape + scale = tokenizer.spatial_compression_factor + h_mid = H // 2 + w_mid = W // 2 + + def decode_tile(tile_latents): + cleanup_memory() + with torch.no_grad(): return tokenizer.decode(tile_latents).cpu() - log.info(f"Loading DiT model from {args.dit_path}") - net = create_model(dit_path=args.dit_path, args=args).cpu() - torch.cuda.empty_cache() - log.success("Successfully loaded DiT model.") + # 1. Top Tiles + l_tl = latents[..., :h_mid+overlap, :w_mid+overlap] + l_tr = latents[..., :h_mid+overlap, w_mid-overlap:] + v_tl = decode_tile(l_tl) + v_tr = decode_tile(l_tr) - tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path) + B_dec, C_dec, T_dec, H_tile, W_tile = v_tl.shape + mid_pix = w_mid * scale + overlap_pix = overlap * scale + + row_top = torch.zeros(B_dec, 3, T_dec, H_tile, W*scale, dtype=v_tl.dtype, device='cpu') + end_left = max(0, mid_pix - overlap_pix) + start_right = mid_pix + overlap_pix + + row_top[..., :end_left] = v_tl[..., :end_left] + row_top[..., start_right:] = v_tr[..., 2*overlap_pix:] + + x_linspace = torch.linspace(-6, 6, 2*overlap_pix, device='cpu') + alpha = torch.sigmoid(x_linspace).view(1, 1, 1, 1, -1) + row_top[..., end_left:start_right] = v_tl[..., mid_pix-overlap_pix:] * (1 - alpha) + v_tr[..., :2*overlap_pix] * alpha + del v_tl, v_tr - w, h = VIDEO_RES_SIZE_INFO[args.resolution][args.aspect_ratio] + # 2. Bottom Tiles + l_bl = latents[..., h_mid-overlap:, :w_mid+overlap] + l_br = latents[..., h_mid-overlap:, w_mid-overlap:] + v_bl = decode_tile(l_bl) + v_br = decode_tile(l_br) + + row_bot = torch.zeros(B_dec, 3, T_dec, H_tile, W*scale, dtype=v_bl.dtype, device='cpu') + row_bot[..., :end_left] = v_bl[..., :end_left] + row_bot[..., start_right:] = v_br[..., 2*overlap_pix:] + row_bot[..., end_left:start_right] = v_bl[..., mid_pix-overlap_pix:] * (1 - alpha) + v_br[..., :2*overlap_pix] * alpha + del v_bl, v_br - log.info(f"Generating with prompt: {args.prompt}") - condition = {"crossattn_emb": repeat(text_emb.to(**tensor_kwargs), "b l d -> (k b) l d", k=args.num_samples)} + # 3. Blend Vertically + h_mid_pix = h_mid * scale + video = torch.zeros(B_dec, 3, T_dec, H*scale, W*scale, dtype=row_top.dtype, device='cpu') + end_top = max(0, h_mid_pix - overlap_pix) + start_bot = h_mid_pix + overlap_pix + + video[..., :end_top, :] = row_top[..., :end_top, :] + video[..., start_bot:, :] = row_bot[..., 2*overlap_pix:, :] + + alpha_v = torch.sigmoid(x_linspace).view(1, 1, 1, -1, 1) + video[..., end_top:start_bot, :] = row_top[..., h_mid_pix-overlap_pix:, :] * (1 - alpha_v) + row_bot[..., :2*overlap_pix, :] * alpha_v + + return video.to(latents.device) - to_show = [] +def force_cpu_float32(target_obj): + for module in target_obj.modules(): + module.cpu().float() - state_shape = [ - tokenizer.latent_ch, - tokenizer.get_latent_num_frames(args.num_frames), - h // tokenizer.spatial_compression_factor, - w // tokenizer.spatial_compression_factor, - ] +def apply_manual_offload(model, device="cuda"): + log.info("Applying Tier 3 Offload...") + block_list_name = None + max_len = 0 + for name, child in model.named_children(): + if isinstance(child, torch.nn.ModuleList): + if len(child) > max_len: + max_len = len(child) + block_list_name = name + + if not block_list_name: + log.warning("Could not identify Block List! Offloading entire model to CPU.") + model.to("cpu") + return - generator = torch.Generator(device=tensor_kwargs["device"]) - generator.manual_seed(args.seed) + print(f" ๐Ÿ‘‰ Identified Transformer Blocks: '{block_list_name}' ({max_len} layers)") + try: model.to(device) + except RuntimeError: model.to("cpu") + + blocks = getattr(model, block_list_name) + blocks.to("cpu") + + def pre_hook(module, args): + module.to(device) + return args + def post_hook(module, args, output): + module.to("cpu") + return output + + for i, block in enumerate(blocks): + block.register_forward_pre_hook(pre_hook) + block.register_forward_hook(post_hook) - init_noise = torch.randn( - args.num_samples, - *state_shape, - dtype=torch.float32, - device=tensor_kwargs["device"], - generator=generator, - ) +if __name__ == "__main__": + print_memory_status("Script Start") + + # --- CREDIT PRINT --- + log.info("----------------------------------------------------------------") + log.info("๐Ÿš€ TurboDiffusion Optimized Inference") + log.info(" Co-developed by Waverly Edwards & Google Gemini") + log.info("----------------------------------------------------------------") - # mid_t = [1.3, 1.0, 0.6][: args.num_steps - 1] - # For better visual quality - mid_t = [1.5, 1.4, 1.0][: args.num_steps - 1] + check_hardware_compatibility() + args = parse_arguments() - t_steps = torch.tensor( - [math.atan(args.sigma_max), *mid_t, 0], - dtype=torch.float64, - device=init_noise.device, - ) + if (args.num_frames - 1) % 4 != 0: + new_f = ((args.num_frames - 1) // 4 + 1) * 4 + 1 + print(f"โš ๏ธ Adjusting --num_frames to {new_f}") + args.num_frames = new_f - # Convert TrigFlow timesteps to RectifiedFlow - t_steps = torch.sin(t_steps) / (torch.cos(t_steps) + torch.sin(t_steps)) + if args.num_frames > 90 and not args.offload_dit: + args.offload_dit = True + + # --- CRITICAL FIX: Strictly Disable Compile for Quantized Models --- + if args.compile and args.quant_linear: + log.warning("๐Ÿšซ Quantized Model Detected: FORCE DISABLING `torch.compile` to avoid OOM.") + log.warning(" (Custom quantized kernels are not compatible with CUDA Graphs)") + args.compile = False + + # 1. Text Embeddings + if args.cached_embedding and os.path.exists(args.cached_embedding): + log.info(f"Loading cache: {args.cached_embedding}") + c = torch.load(args.cached_embedding, map_location='cpu') + text_emb = c['embeddings'][0]['embedding'].to(**tensor_kwargs) if isinstance(c, dict) else c.to(**tensor_kwargs) + else: + log.info(f"Computing embedding...") + with torch.no_grad(): + text_emb = get_umt5_embedding(args.text_encoder_path, args.prompt).to(**tensor_kwargs) + clear_umt5_memory() + cleanup_memory() + + # 2. VAE Shape Calc & UNLOAD + log.info("VAE Setup (Temp)...") + tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path) + w, h = VIDEO_RES_SIZE_INFO[args.resolution][args.aspect_ratio] + state_shape = [tokenizer.latent_ch, tokenizer.get_latent_num_frames(args.num_frames), h // tokenizer.spatial_compression_factor, w // tokenizer.spatial_compression_factor] + del tokenizer + cleanup_memory("VAE Unloaded") - # Sampling steps + # 3. Load DiT + net = load_dit_model(args) + + # 4. Noise & Schedule + gen = torch.Generator(device=tensor_kwargs["device"]).manual_seed(args.seed) + cond = {"crossattn_emb": repeat(text_emb.to(**tensor_kwargs), "b l d -> (k b) l d", k=args.num_samples)} + init_noise = torch.randn(args.num_samples, *state_shape, dtype=torch.float32, device=tensor_kwargs["device"], generator=gen) + + mid_t = [1.5, 1.4, 1.0][: args.num_steps - 1] + t_steps = torch.tensor([math.atan(args.sigma_max), *mid_t, 0], dtype=torch.float64, device=init_noise.device) + t_steps = torch.sin(t_steps) / (torch.cos(t_steps) + torch.sin(t_steps)) + x = init_noise.to(torch.float64) * t_steps[0] ones = torch.ones(x.size(0), 1, device=x.device, dtype=x.dtype) - total_steps = t_steps.shape[0] - 1 + + # 5. Fast Sampling Loop + log.info("๐Ÿ”ฅ STARTING SAMPLING (INFERENCE MODE) ๐Ÿ”ฅ") + torch.cuda.empty_cache() net.cuda() - for i, (t_cur, t_next) in enumerate(tqdm(list(zip(t_steps[:-1], t_steps[1:])), desc="Sampling", total=total_steps)): - with torch.no_grad(): - v_pred = net(x_B_C_T_H_W=x.to(**tensor_kwargs), timesteps_B_T=(t_cur.float() * ones * 1000).to(**tensor_kwargs), **condition).to( - torch.float64 - ) - x = (1 - t_next) * (x - t_cur * v_pred) + t_next * torch.randn( - *x.shape, - dtype=torch.float32, - device=tensor_kwargs["device"], - generator=generator, - ) + print_memory_status("Tier 1: GPU Ready") + + # Compile? (Only if NOT disabled above) + if args.compile: + log.info("๐Ÿš€ Compiling model...") + try: + net = torch.compile(net, mode="reduce-overhead") + except Exception as e: + log.warning(f"Compile failed: {e}. Running eager.") + + failover = 0 + + with torch.inference_mode(): + for i, (t_cur, t_next) in enumerate(tqdm(zip(t_steps[:-1], t_steps[1:]), total=len(t_steps)-1)): + retry = True + while retry: + try: + t_cur_scalar = t_cur.item() + t_next_scalar = t_next.item() + + v_pred = net( + x_B_C_T_H_W=x.to(**tensor_kwargs), + timesteps_B_T=(t_cur * ones * 1000).to(**tensor_kwargs), + **cond + ).to(torch.float64) + + if args.offload_dit and i == len(t_steps)-2 and failover == 0: + net.cpu() + + noise = torch.randn(*x.shape, dtype=torch.float32, device=x.device, generator=gen).to(torch.float64) + term1 = x - (v_pred * t_cur_scalar) + x = term1 * (1.0 - t_next_scalar) + (noise * t_next_scalar) + + retry = False + + except torch.OutOfMemoryError: + log.warning(f"โš ๏ธ OOM at Step {i}. Recovering...") + try: net.cpu() + except: pass + del net + cleanup_memory() + failover += 1 + + if failover == 1: + print("โ™ป๏ธ Tier 2: Checkpointing") + net = load_dit_model(args, force_offload=True) + net.cuda() + # Retry compile with safer mode if first attempt was aggressive + if args.compile: + try: net = torch.compile(net, mode="default") + except: pass + elif failover == 2: + print("โ™ป๏ธ Tier 3: Manual Offload") + net = load_dit_model(args, force_offload=True) + apply_manual_offload(net) + else: + sys.exit("โŒ Critical OOM.") + samples = x.float() - net.cpu() - torch.cuda.empty_cache() + # 6. Decode + if 'net' in locals(): + try: net.cpu() + except: pass + del net + cleanup_memory("Pre-VAE") + + log.info("Decoding...") + tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path) with torch.no_grad(): - video = tokenizer.decode(samples) - - to_show.append(video.float().cpu()) + try: + video = tokenizer.decode(samples) + except torch.OutOfMemoryError: + log.warning("Falling back to Tiled Decode...") + video = tiled_decode_gpu(tokenizer, samples) + to_show = [video.float().cpu()] to_show = (1.0 + torch.stack(to_show, dim=0).clamp(-1, 1)) / 2.0 - save_image_or_video(rearrange(to_show, "n b c t h w -> c t (n h) (b w)"), args.save_path, fps=16) + log.success(f"Saved: {args.save_path}") diff --git a/turbodiffusion/inference/wan2.2_i2v_infer.py b/turbodiffusion/inference/wan2.2_i2v_infer.py index e57e509..f6fe32d 100644 --- a/turbodiffusion/inference/wan2.2_i2v_infer.py +++ b/turbodiffusion/inference/wan2.2_i2v_infer.py @@ -1,28 +1,55 @@ +# Blackwell Bridge # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# ----------------------------------------------------------------------------------------- +# TURBODIFFUSION OPTIMIZED INFERENCE SCRIPT (I2V) # -# http://www.apache.org/licenses/LICENSE-2.0 +# Co-developed by: Waverly Edwards & Google Gemini (2025) # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Modifications: +# - Implemented "Tiered Failover System" for robust OOM protection. +# - Added Intelligent Model Switching (High/Low Noise) with memory optimization. +# - Integrated Tiled Encoding & Decoding for high-resolution processing. +# - Added Support for Pre-cached Text Embeddings to skip T5 loading. +# - Optimized memory management (VAE unload/reload, aggressive GC). +# +# Acknowledgments: +# - Made possible by the work (cache_t5.py) and creativity of: John D. Pope +# +# Description: +# cache_t5.py pre-computes text embeddings to allow running inference on GPUs with limited VRAM +# by removing the need to keep the 11GB T5 encoder loaded in memory. +# +# CREDIT REQUEST: +# If you utilize, share, or build upon this specific optimized script, please +# acknowledge Waverly Edwards and Google Gemini in your documentation or credits. +# ----------------------------------------------------------------------------------------- import argparse import math +import os +import gc +import time +import sys + +# --- 1. Memory Tuning (Must be before torch imports) --- +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import torch +import torch.nn.functional as F from einops import rearrange, repeat from tqdm import tqdm from PIL import Image import torchvision.transforms.v2 as T import numpy as np +# Safe import for system memory checks +try: + import psutil +except ImportError: + psutil = None + from imaginaire.utils.io import save_image_or_video from imaginaire.utils import log @@ -34,189 +61,514 @@ torch._dynamo.config.suppress_errors = True - def parse_arguments() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="TurboDiffusion inference script for Wan2.2 I2V with High/Low Noise models") - parser.add_argument("--image_path", type=str, default=None, help="Path to the input image (required unless --serve)") - parser.add_argument("--high_noise_model_path", type=str, required=True, help="Path to the high-noise model") - parser.add_argument("--low_noise_model_path", type=str, required=True, help="Path to the low-noise model") - parser.add_argument("--boundary", type=float, default=0.9, help="Timestep boundary for switching from high to low noise model") - parser.add_argument("--model", choices=["Wan2.2-A14B"], default="Wan2.2-A14B", help="Model to use") - parser.add_argument("--num_samples", type=int, default=1, help="Number of samples to generate") - parser.add_argument("--num_steps", type=int, choices=[1, 2, 3, 4], default=4, help="1~4 for timestep-distilled inference") - parser.add_argument("--sigma_max", type=float, default=200, help="Initial sigma for rCM") - parser.add_argument("--vae_path", type=str, default="checkpoints/Wan2.1_VAE.pth", help="Path to the Wan2.1 VAE") - parser.add_argument("--text_encoder_path", type=str, default="checkpoints/models_t5_umt5-xxl-enc-bf16.pth", help="Path to the umT5 text encoder") - parser.add_argument("--num_frames", type=int, default=81, help="Number of frames to generate") - parser.add_argument("--prompt", type=str, default=None, help="Text prompt for video generation (required unless --serve)") - parser.add_argument("--resolution", default="720p", type=str, help="Resolution of the generated output") - parser.add_argument("--aspect_ratio", default="16:9", type=str, help="Aspect ratio of the generated output (width:height)") - parser.add_argument("--adaptive_resolution", action="store_true", help="If set, adapts the output resolution to the input image's aspect ratio, using the area defined by --resolution and --aspect_ratio as a target.") - parser.add_argument("--ode", action="store_true", help="Use ODE for sampling (sharper but less robust than SDE)") - parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility") - parser.add_argument("--save_path", type=str, default="output/generated_video.mp4", help="Path to save the generated video (include file extension)") - parser.add_argument("--attention_type", choices=["sla", "sagesla", "original"], default="sagesla", help="Type of attention mechanism to use") - parser.add_argument("--sla_topk", type=float, default=0.1, help="Top-k ratio for SLA/SageSLA attention") - parser.add_argument("--quant_linear", action="store_true", help="Whether to replace Linear layers with quantized versions") - parser.add_argument("--default_norm", action="store_true", help="Whether to replace LayerNorm/RMSNorm layers with faster versions") - parser.add_argument("--serve", action="store_true", help="Launch interactive TUI server mode (keeps model loaded)") + parser = argparse.ArgumentParser(description="TurboDiffusion inference script for Wan2.2 I2V") + parser.add_argument("--image_path", type=str, default=None, help="Path to input image") + parser.add_argument("--high_noise_model_path", type=str, required=True, help="Path to high-noise model") + parser.add_argument("--low_noise_model_path", type=str, required=True, help="Path to low-noise model") + parser.add_argument("--boundary", type=float, default=0.9, help="Switch boundary") + parser.add_argument("--model", choices=["Wan2.2-A14B"], default="Wan2.2-A14B") + parser.add_argument("--num_samples", type=int, default=1) + parser.add_argument("--num_steps", type=int, default=4) + parser.add_argument("--sigma_max", type=float, default=200) + parser.add_argument("--vae_path", type=str, default="checkpoints/Wan2.1_VAE.pth") + parser.add_argument("--text_encoder_path", type=str, default="checkpoints/models_t5_umt5-xxl-enc-bf16.pth") + parser.add_argument("--cached_embedding", type=str, default=None) + parser.add_argument("--skip_t5", action="store_true") + parser.add_argument("--num_frames", type=int, default=81) + parser.add_argument("--prompt", type=str, default=None) + parser.add_argument("--resolution", default="720p", type=str) + parser.add_argument("--aspect_ratio", default="16:9", type=str) + parser.add_argument("--adaptive_resolution", action="store_true") + parser.add_argument("--ode", action="store_true") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--save_path", type=str, default="output/generated_video.mp4") + parser.add_argument("--attention_type", choices=["sla", "sagesla", "original"], default="sagesla") + parser.add_argument("--sla_topk", type=float, default=0.1) + parser.add_argument("--quant_linear", action="store_true") + parser.add_argument("--default_norm", action="store_true") + parser.add_argument("--serve", action="store_true") + parser.add_argument("--offload_dit", action="store_true") return parser.parse_args() +def print_memory_status(step_name=""): + """ + Prints a detailed breakdown of GPU memory usage. + """ + if not torch.cuda.is_available(): + return + + torch.cuda.synchronize() + allocated_gb = torch.cuda.memory_allocated() / (1024**3) + reserved_gb = torch.cuda.memory_reserved() / (1024**3) + free_mem, total_mem = torch.cuda.mem_get_info() + free_gb = free_mem / (1024**3) + + print(f"\n๐Ÿ“Š [MEMORY REPORT] {step_name}") + print(f" โ”œโ”€โ”€ ๐Ÿ’พ VRAM In Use: {allocated_gb:.2f} GB") + print(f" โ”œโ”€โ”€ ๐Ÿ“ฆ VRAM Reserved: {reserved_gb:.2f} GB") + print(f" โ”œโ”€โ”€ ๐Ÿ†“ VRAM Free: {free_gb:.2f} GB") + print("-" * 60) + +def cleanup_memory(step_info=""): + """Aggressively clears VRAM.""" + if step_info: + print(f"๐Ÿงน Cleaning memory: {step_info}...") + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + if step_info: + print_memory_status(f"After Cleanup ({step_info})") + +def get_tensor_size_mb(tensor): + return tensor.element_size() * tensor.nelement() / (1024 * 1024) + +def force_cpu_float32(target_obj): + """ + Recursively forces a model or wrapper object to CPU Float32. + """ + def recursive_cast_to_cpu(obj): + if isinstance(obj, torch.Tensor): + return obj.cpu().float() + elif isinstance(obj, (list, tuple)): + return type(obj)(recursive_cast_to_cpu(x) for x in obj) + elif isinstance(obj, dict): + return {k: recursive_cast_to_cpu(v) for k, v in obj.items()} + return obj + + targets = [target_obj] + if hasattr(target_obj, "model"): + targets.append(target_obj.model) + + for obj in targets: + if isinstance(obj, torch.nn.Module): + try: obj.cpu().float() + except: pass + + for attr_name in dir(obj): + if attr_name.startswith("__"): continue + try: + val = getattr(obj, attr_name) + if isinstance(val, torch.nn.Module): + val.cpu().float() + elif isinstance(val, (torch.Tensor, list, tuple)): + setattr(obj, attr_name, recursive_cast_to_cpu(val)) + except Exception: pass + + if isinstance(obj, torch.nn.Module): + try: + for module in obj.modules(): + for param in module.parameters(recurse=False): + if param is not None: + param.data = param.data.cpu().float() + if param.grad is not None: param.grad = None + for buf in module.buffers(recurse=False): + if buf is not None: + buf.data = buf.data.cpu().float() + except: pass + else: + for attr_name in dir(obj): + if attr_name.startswith("__"): continue + try: + val = getattr(obj, attr_name) + if isinstance(val, torch.nn.Module): + for module in val.modules(): + for param in module.parameters(recurse=False): + if param is not None: + param.data = param.data.cpu().float() + for buf in module.buffers(recurse=False): + if buf is not None: + buf.data = buf.data.cpu().float() + except: pass + +def tiled_encode_4x(tokenizer, frames, target_dtype): + B, C, T, H, W = frames.shape + h_mid = H // 2 + w_mid = W // 2 + print(f"\n๐Ÿงฉ Starting 4-Chunk Tiled Encoding (Input: {W}x{H})") + latents_list = [[None, None], [None, None]] + + try: + print(" ๐Ÿ‘‰ Encoding Chunk 1/4 (Top-Left)...") + with torch.amp.autocast("cuda", dtype=target_dtype): + l_tl = tokenizer.encode(frames[:, :, :, :h_mid, :w_mid]) + latents_list[0][0] = l_tl.cpu() + del l_tl; cleanup_memory("After Chunk 1") + except Exception as e: + print(f"โŒ Chunk 1 Failed: {e}") + raise e + + try: + print(" ๐Ÿ‘‰ Encoding Chunk 2/4 (Top-Right)...") + with torch.amp.autocast("cuda", dtype=target_dtype): + l_tr = tokenizer.encode(frames[:, :, :, :h_mid, w_mid:]) + latents_list[0][1] = l_tr.cpu() + del l_tr; cleanup_memory("After Chunk 2") + except Exception as e: + print(f"โŒ Chunk 2 Failed: {e}") + raise e + + try: + print(" ๐Ÿ‘‰ Encoding Chunk 3/4 (Bottom-Left)...") + with torch.amp.autocast("cuda", dtype=target_dtype): + l_bl = tokenizer.encode(frames[:, :, :, h_mid:, :w_mid]) + latents_list[1][0] = l_bl.cpu() + del l_bl; cleanup_memory("After Chunk 3") + except Exception as e: + print(f"โŒ Chunk 3 Failed: {e}") + raise e + + try: + print(" ๐Ÿ‘‰ Encoding Chunk 4/4 (Bottom-Right)...") + with torch.amp.autocast("cuda", dtype=target_dtype): + l_br = tokenizer.encode(frames[:, :, :, h_mid:, w_mid:]) + latents_list[1][1] = l_br.cpu() + del l_br; cleanup_memory("After Chunk 4") + except Exception as e: + print(f"โŒ Chunk 4 Failed: {e}") + raise e + + print(" ๐Ÿงต Stitching Latents...") + row1 = torch.cat([latents_list[0][0], latents_list[0][1]], dim=4) + row2 = torch.cat([latents_list[1][0], latents_list[1][1]], dim=4) + full_latents = torch.cat([row1, row2], dim=3) + return full_latents.to(device=tensor_kwargs["device"], dtype=target_dtype) + +def safe_cpu_fallback_encode(tokenizer, frames, target_dtype): + log.warning("๐Ÿ”„ Switching to CPU for VAE Encode (Slow but reliable)...") + cleanup_memory("Pre-CPU Encode") + frames_cpu = frames.cpu().to(dtype=torch.float32) + force_cpu_float32(tokenizer) + t0 = time.time() + with torch.autocast("cpu", enabled=False): + with torch.autocast("cuda", enabled=False): + latents = tokenizer.encode(frames_cpu) + print(f" โฑ๏ธ CPU Encode took: {time.time() - t0:.2f}s") + return latents.to(device=tensor_kwargs["device"], dtype=target_dtype) + +def tiled_decode_gpu(tokenizer, latents, overlap=12): + """ + Decodes latents in 4 spatial quadrants with OVERLAP and SIGMOID BLENDING. + Overlap=12 latents (96 pixels). Safe for 720p. + Removing Global Color Matching to prevent exposure shifts. + """ + print(f"\n๐Ÿงฑ Starting Tiled GPU Decode (4 Quadrants, Overlap={overlap}, Blended)...") + B, C, T, H, W = latents.shape + scale = tokenizer.spatial_compression_factor + h_mid = H // 2 + w_mid = W // 2 + + def decode_tile(tile_latents, name): + cleanup_memory(f"Tile {name}") + with torch.no_grad(): + return tokenizer.decode(tile_latents).cpu() + + try: + # 1. Decode Top-Left and Top-Right + l_tl = latents[..., :h_mid+overlap, :w_mid+overlap] + l_tr = latents[..., :h_mid+overlap, w_mid-overlap:] + v_tl = decode_tile(l_tl, "1/4 (TL)") + v_tr = decode_tile(l_tr, "2/4 (TR)") + B_dec, C_dec, T_dec, H_tile, W_tile = v_tl.shape + + print(f" ๐Ÿงต Blending Top Row (Decoded Frames: {T_dec})...") + mid_pix = w_mid * scale + overlap_pix = overlap * scale + + # Slices for overlap + tl_blend_slice = v_tl[..., mid_pix-overlap_pix:] + tr_blend_slice = v_tr[..., :2*overlap_pix] + + row_top = torch.zeros(B_dec, 3, T_dec, H_tile, W*scale, dtype=v_tl.dtype, device='cpu') + + # Place non-overlapping parts (Clamped indices) + end_left = max(0, mid_pix - overlap_pix) + start_right = mid_pix + overlap_pix + + row_top[..., :end_left] = v_tl[..., :end_left] + row_top[..., start_right:] = v_tr[..., 2*overlap_pix:] + + x = torch.linspace(-6, 6, 2*overlap_pix, device='cpu') + alpha = torch.sigmoid(x).view(1, 1, 1, 1, -1) + blended_h = tl_blend_slice * (1 - alpha) + tr_blend_slice * alpha + + row_top[..., end_left:start_right] = blended_h + del v_tl, v_tr, l_tl, l_tr + + # 3. Decode Bottom-Left and Bottom-Right + l_bl = latents[..., h_mid-overlap:, :w_mid+overlap] + l_br = latents[..., h_mid-overlap:, w_mid-overlap:] + v_bl = decode_tile(l_bl, "3/4 (BL)") + v_br = decode_tile(l_br, "4/4 (BR)") + + print(" ๐Ÿงต Blending Bottom Row...") + bl_blend_slice = v_bl[..., mid_pix-overlap_pix:] + br_blend_slice = v_br[..., :2*overlap_pix] + + row_bot = torch.zeros(B_dec, 3, T_dec, H_tile, W*scale, dtype=v_bl.dtype, device='cpu') + row_bot[..., :end_left] = v_bl[..., :end_left] + row_bot[..., start_right:] = v_br[..., 2*overlap_pix:] + row_bot[..., end_left:start_right] = bl_blend_slice * (1 - alpha) + br_blend_slice * alpha + del v_bl, v_br, l_bl, l_br + + # 5. Blend Top and Bottom Vertically + print(" ๐Ÿงต Blending Rows Vertically...") + h_mid_pix = h_mid * scale + + # Slices + top_blend_slice = row_top[..., h_mid_pix-overlap_pix:, :] + bot_blend_slice = row_bot[..., :2*overlap_pix, :] + + video = torch.zeros(B_dec, 3, T_dec, H*scale, W*scale, dtype=row_top.dtype, device='cpu') + + end_top = max(0, h_mid_pix - overlap_pix) + start_bot = h_mid_pix + overlap_pix + + video[..., :end_top, :] = row_top[..., :end_top, :] + video[..., start_bot:, :] = row_bot[..., 2*overlap_pix:, :] + + alpha_v = torch.sigmoid(x).view(1, 1, 1, -1, 1) + blended_v = top_blend_slice * (1 - alpha_v) + bot_blend_slice * alpha_v + + video[..., end_top:start_bot, :] = blended_v + + except Exception as e: + print(f"โŒ Tiled GPU Decode Failed: {e}") + raise e + return video.to(latents.device) + +def load_dit_model(args, is_high_noise=True, force_offload=False): + """Helper to load the model, respecting overrides.""" + original_offload = args.offload_dit + if force_offload: + args.offload_dit = True + + path = args.high_noise_model_path if is_high_noise else args.low_noise_model_path + log.info(f"Loading {'High' if is_high_noise else 'Low'} Noise DiT (Offload={args.offload_dit})...") + + try: + model = create_model(dit_path=path, args=args).cpu() + finally: + args.offload_dit = original_offload + + return model if __name__ == "__main__": + print_memory_status("Script Start") args = parse_arguments() - # Handle serve mode if args.serve: - # Set mode to i2v for the TUI server args.mode = "i2v" from serve.tui import main as serve_main serve_main(args) exit(0) - - # Validate required args for one-shot mode - if args.prompt is None: - log.error("--prompt is required (unless using --serve mode)") - exit(1) - if args.image_path is None: - log.error("--image_path is required (unless using --serve mode)") - exit(1) - - log.info(f"Computing embedding for prompt: {args.prompt}") - with torch.no_grad(): + + # --- AUTO-ADJUST FRAME COUNT --- + if (args.num_frames - 1) % 4 != 0: + old_f = args.num_frames + new_f = ((old_f - 1) // 4 + 1) * 4 + 1 + print(f"โš ๏ธ Adjusting --num_frames from {old_f} to {new_f} to satisfy VAE temporal stride (4n+1).") + args.num_frames = new_f + + # --- AUTO-ENABLE OFFLOAD FOR HIGH FRAMES --- + if args.num_frames > 90 and not args.offload_dit: + print(f"โš ๏ธ High frame count ({args.num_frames}) detected. Enabling --offload_dit to prevent OOM.") + args.offload_dit = True + + # 1. Text Embeddings + if args.cached_embedding and os.path.exists(args.cached_embedding): + log.info(f"Loading cached embedding from: {args.cached_embedding}") + cache_data = torch.load(args.cached_embedding, map_location='cpu') + text_emb = cache_data['embeddings'][0]['embedding'].to(**tensor_kwargs) + else: + log.info(f"Computing embedding...") text_emb = get_umt5_embedding(checkpoint_path=args.text_encoder_path, prompts=args.prompt).to(**tensor_kwargs) - clear_umt5_memory() - - log.info(f"Loading DiT models.") - high_noise_model = create_model(dit_path=args.high_noise_model_path, args=args).cpu() - torch.cuda.empty_cache() - low_noise_model = create_model(dit_path=args.low_noise_model_path, args=args).cpu() - torch.cuda.empty_cache() - log.success(f"Successfully loaded DiT model.") + clear_umt5_memory() + # 2. VAE Encoding + print("-" * 20 + " VAE SETUP " + "-" * 20) tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path) - - log.info(f"Loading and preprocessing image from: {args.image_path}") + target_dtype = tensor_kwargs.get("dtype", torch.bfloat16) input_image = Image.open(args.image_path).convert("RGB") + if args.adaptive_resolution: - log.info("Adaptive resolution mode enabled.") base_w, base_h = VIDEO_RES_SIZE_INFO[args.resolution][args.aspect_ratio] max_resolution_area = base_w * base_h - log.info(f"Target area is based on {args.resolution} {args.aspect_ratio} (~{max_resolution_area} pixels).") - orig_w, orig_h = input_image.size - image_aspect_ratio = orig_h / orig_w - - ideal_w = np.sqrt(max_resolution_area / image_aspect_ratio) - ideal_h = np.sqrt(max_resolution_area * image_aspect_ratio) - + aspect = orig_h / orig_w + ideal_w = np.sqrt(max_resolution_area / aspect) + ideal_h = np.sqrt(max_resolution_area * aspect) stride = tokenizer.spatial_compression_factor * 2 - lat_h = round(ideal_h / stride) - lat_w = round(ideal_w / stride) - h = lat_h * stride - w = lat_w * stride - - log.info(f"Input image aspect ratio: {image_aspect_ratio:.4f}. Adaptive resolution set to: {w}x{h}") + h = round(ideal_h / stride) * stride + w = round(ideal_w / stride) * stride + log.info(f"Adaptive Res: {w}x{h}") else: - log.info("Fixed resolution mode.") w, h = VIDEO_RES_SIZE_INFO[args.resolution][args.aspect_ratio] - log.info(f"Resolution set to: {w}x{h}") + F = args.num_frames lat_h = h // tokenizer.spatial_compression_factor lat_w = w // tokenizer.spatial_compression_factor lat_t = tokenizer.get_latent_num_frames(F) - log.info(f"Preprocessing image to {w}x{h}...") - image_transforms = T.Compose( - [ - T.ToImage(), - T.Resize(size=(h, w), antialias=True), - T.ToDtype(torch.float32, scale=True), - T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - ] - ) - image_tensor = image_transforms(input_image).unsqueeze(0).to(device=tensor_kwargs["device"], dtype=torch.float32) - - with torch.no_grad(): - frames_to_encode = torch.cat( - [image_tensor.unsqueeze(2), torch.zeros(1, 3, F - 1, h, w, device=image_tensor.device)], dim=2 - ) # -> B, C, T, H, W - encoded_latents = tokenizer.encode(frames_to_encode) # -> B, C_lat, T_lat, H_lat, W_lat - - del frames_to_encode - torch.cuda.empty_cache() - + image_transforms = T.Compose([ + T.ToImage(), + T.Resize(size=(h, w), antialias=True), + T.ToDtype(torch.float32, scale=True), + T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + image_tensor = image_transforms(input_image).unsqueeze(0).to(device=tensor_kwargs["device"], dtype=target_dtype) + frames_to_encode = torch.cat([image_tensor.unsqueeze(2), torch.zeros(1, 3, F - 1, h, w, device=image_tensor.device, dtype=target_dtype)], dim=2) + + log.info(f"Encoding {F} frames...") + + try: + free_mem, _ = torch.cuda.mem_get_info() + if free_mem < 24 * (1024**3): + raise torch.OutOfMemoryError("Pre-emptive tiling") + with torch.amp.autocast("cuda", dtype=target_dtype): + encoded_latents = tokenizer.encode(frames_to_encode) + except torch.OutOfMemoryError: + try: + cleanup_memory("Switching to Tiled Encode") + encoded_latents = tiled_encode_4x(tokenizer, frames_to_encode, target_dtype) + except Exception as e: + log.warning(f"Tiling failed ({e}). Fallback to CPU.") + encoded_latents = safe_cpu_fallback_encode(tokenizer, frames_to_encode, target_dtype) + + print(f"โœ… VAE Encode Complete.") + del frames_to_encode + cleanup_memory("After VAE Encode") + + # Prepare for Diffusion msk = torch.zeros(1, 4, lat_t, lat_h, lat_w, device=tensor_kwargs["device"], dtype=tensor_kwargs["dtype"]) msk[:, :, 0, :, :] = 1.0 - y = torch.cat([msk, encoded_latents.to(**tensor_kwargs)], dim=1) y = y.repeat(args.num_samples, 1, 1, 1, 1) + saved_latent_ch = tokenizer.latent_ch + + del tokenizer + cleanup_memory("Unloaded VAE Model") + + # 3. Diffusion Sampling + print("-" * 20 + " DIT LOADING " + "-" * 20) + + current_model = load_dit_model(args, is_high_noise=True) + is_high_noise_active = True + fallback_triggered = args.offload_dit - log.info(f"Generating with prompt: {args.prompt}") condition = {"crossattn_emb": repeat(text_emb.to(**tensor_kwargs), "b l d -> (k b) l d", k=args.num_samples), "y_B_C_T_H_W": y} - - to_show = [] - - state_shape = [tokenizer.latent_ch, lat_t, lat_h, lat_w] - - generator = torch.Generator(device=tensor_kwargs["device"]) - generator.manual_seed(args.seed) - - init_noise = torch.randn( - args.num_samples, - *state_shape, - dtype=torch.float32, - device=tensor_kwargs["device"], - generator=generator, - ) - + + generator = torch.Generator(device=tensor_kwargs["device"]).manual_seed(args.seed) + init_noise = torch.randn(args.num_samples, saved_latent_ch, lat_t, lat_h, lat_w, dtype=torch.float32, device=tensor_kwargs["device"], generator=generator) + mid_t = [1.5, 1.4, 1.0][: args.num_steps - 1] - - t_steps = torch.tensor( - [math.atan(args.sigma_max), *mid_t, 0], - dtype=torch.float64, - device=init_noise.device, - ) - - # Convert TrigFlow timesteps to RectifiedFlow + t_steps = torch.tensor([math.atan(args.sigma_max), *mid_t, 0], dtype=torch.float64, device=init_noise.device) t_steps = torch.sin(t_steps) / (torch.cos(t_steps) + torch.sin(t_steps)) - x = init_noise.to(torch.float64) * t_steps[0] ones = torch.ones(x.size(0), 1, device=x.device, dtype=x.dtype) - total_steps = t_steps.shape[0] - 1 - high_noise_model.cuda() - net = high_noise_model - switched = False - for i, (t_cur, t_next) in enumerate(tqdm(list(zip(t_steps[:-1], t_steps[1:])), desc="Sampling", total=total_steps)): - if t_cur.item() < args.boundary and not switched: - high_noise_model.cpu() - torch.cuda.empty_cache() - low_noise_model.cuda() - net = low_noise_model - switched = True - log.info("Switched to low noise model.") - with torch.no_grad(): - v_pred = net(x_B_C_T_H_W=x.to(**tensor_kwargs), timesteps_B_T=(t_cur.float() * ones * 1000).to(**tensor_kwargs), **condition).to( - torch.float64 - ) - if args.ode: - x = x - (t_cur - t_next) * v_pred - else: - x = (1 - t_next) * (x - t_cur * v_pred) + t_next * torch.randn( - *x.shape, - dtype=torch.float32, - device=tensor_kwargs["device"], - generator=generator, - ) + + # Always ensure CUDA initially + current_model.cuda() + + print("-" * 20 + " SAMPLING START " + "-" * 20) + print_memory_status("High Noise Model to GPU") + + # Sampling Loop + for i, (t_cur, t_next) in enumerate(tqdm(list(zip(t_steps[:-1], t_steps[1:])), total=len(t_steps)-1)): + if t_cur.item() < args.boundary and is_high_noise_active: + print(f"\n๐Ÿ”„ Switching DiT Models (Step {i})...") + current_model.cpu() + del current_model + cleanup_memory("Unloaded High Noise") + + current_model = load_dit_model(args, is_high_noise=False, force_offload=fallback_triggered) + current_model.cuda() # Force CUDA + is_high_noise_active = False + print_memory_status("Loaded Low Noise to GPU") + + step_success = False + while not step_success: + try: + gc.collect() + torch.cuda.empty_cache() + with torch.no_grad(): + v_pred = current_model( + x_B_C_T_H_W=x.to(**tensor_kwargs), + timesteps_B_T=(t_cur.float() * ones * 1000).to(**tensor_kwargs), + **condition + ).to(torch.float64) + step_success = True + except torch.OutOfMemoryError: + if fallback_triggered: + log.error("โŒ OOM occurred even after reload. Physical Memory Limit Reached.") + sys.exit(1) + + print(f"\nโš ๏ธ OOM in DiT Sampling Step {i}. Reloading model to clear fragmentation...") + cleanup_memory("Pre-Reload") + + # Unload and Reload to Defrag + was_high = is_high_noise_active + current_model.cpu() + del current_model + cleanup_memory("Unload for Reload") + + fallback_triggered = True + current_model = load_dit_model(args, is_high_noise=was_high, force_offload=True) + current_model.cuda() # Move back to GPU + + print("โ™ป๏ธ Model Reloaded. Retrying step...") + + if args.ode: + x = x - (t_cur - t_next) * v_pred + else: + x = (1 - t_next) * (x - t_cur * v_pred) + t_next * torch.randn(*x.shape, dtype=torch.float32, device=tensor_kwargs["device"], generator=generator) + samples = x.float() - low_noise_model.cpu() - torch.cuda.empty_cache() - + + print("-" * 20 + " DECODE SETUP (DEFRAG) " + "-" * 20) + samples_cpu_backup = samples.cpu() + del samples + del x + current_model.cpu() + del current_model + cleanup_memory("FULL WIPE before VAE Load") + + log.info("Reloading VAE for decoding...") + tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path) + print_memory_status("Reloaded VAE (Clean Slate)") + + samples = samples_cpu_backup.to(device=tensor_kwargs["device"]) + with torch.no_grad(): - video = tokenizer.decode(samples) - - to_show.append(video.float().cpu()) - + success = False + video = None + + try: + log.info("Attempting Standard GPU Decode...") + video = tokenizer.decode(samples) + success = True + except torch.OutOfMemoryError: + log.warning("โš ๏ธ GPU OOM (Standard). Switching to Tiled GPU Decode...") + cleanup_memory("Pre-Tile Fallback") + + try: + # 12 Latents overlap = 96 Image pixels + video = tiled_decode_gpu(tokenizer, samples, overlap=12) + success = True + except (torch.OutOfMemoryError, RuntimeError) as e: + log.warning(f"โš ๏ธ GPU Tiled Decode Failed ({e}). Switching to CPU Decode (Slow)...") + cleanup_memory("Pre-CPU Fallback") + + if not success: + log.info("Performing Hard Cast of VAE to CPU Float32...") + samples_cpu = samples.cpu().float() + force_cpu_float32(tokenizer) + with torch.autocast("cpu", enabled=False): + with torch.autocast("cuda", enabled=False): + video = tokenizer.decode(samples_cpu) + + to_show = [video.float().cpu()] to_show = (1.0 + torch.stack(to_show, dim=0).clamp(-1, 1)) / 2.0 - save_image_or_video(rearrange(to_show, "n b c t h w -> c t (n h) (b w)"), args.save_path, fps=16) + log.success("Done.")