Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions QEfficient/diffusers/pipelines/wan/pipeline_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def from_pretrained(

Example:
>>> # Load from HuggingFace Hub
>>> pipeline = QEffWanPipeline.from_pretrained("path/to/wan/model")
>>> pipeline = QEffWanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers")
>>>
>>> # Load from local path
>>> pipeline = QEffWanPipeline.from_pretrained("/local/path/to/wan")
Expand Down Expand Up @@ -219,7 +219,7 @@ def export(
ValueError: If module configurations are invalid

Example:
>>> pipeline = QEffWanPipeline.from_pretrained("path/to/wan/model")
>>> pipeline = QEffWanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers")
>>> export_path = pipeline.export(
... export_dir="/path/to/export",
... use_onnx_subfunctions=True
Expand Down Expand Up @@ -291,7 +291,7 @@ def compile(
OSError: If there are issues with file I/O during compilation

Example:
>>> pipeline = QEffWanPipeline.from_pretrained("path/to/wan/model")
>>> pipeline = QEffWanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers")
>>> # Sequential compilation with default config
>>> pipeline.compile(height=480, width=832, num_frames=81)
>>>
Expand Down Expand Up @@ -356,7 +356,6 @@ def compile(
}

# Use generic utility functions for compilation
logger.warning('For VAE compilation use QAIC_COMPILER_OPTS_UNSUPPORTED="-aic-hmx-conv3d" ')
if parallel:
compile_modules_parallel(self.modules, self.custom_config, specialization_updates)
else:
Expand Down Expand Up @@ -453,7 +452,7 @@ def __call__(
>>> # Save generated video
>>> result.images[0].save("cat_garden.mp4")
"""
device = "cpu"
device = self.model._execution_device

# Compile models with custom configuration if needed
self.compile(
Expand Down Expand Up @@ -616,11 +615,11 @@ def __call__(
timestep = t.expand(latents.shape[0])

# Extract dimensions for patch processing
batch_size, num_channels, num_frames, height, width = latents.shape
batch_size, num_channels, latent_frames, latent_height, latent_width = latents.shape
p_t, p_h, p_w = current_model.config.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_width = width // p_w
post_patch_num_frames = latent_frames // p_t
post_patch_height = latent_height // p_h
post_patch_width = latent_width // p_w

# Generate rotary position embeddings
rotary_emb = current_model.rope(latent_model_input)
Expand Down Expand Up @@ -757,7 +756,7 @@ def __call__(

# Allocate output buffer for VAE decoder
output_buffer = {"sample": np.random.rand(batch_size, 3, num_frames, height, width).astype(np.int32)}

self.vae_decoder.qpc_session.set_buffers(output_buffer)
inputs = {"latent_sample": latents.numpy()}

start_decode_time = time.perf_counter()
Expand All @@ -773,7 +772,7 @@ def __call__(

# Step 10: Collect performance metrics
perf_data = {
"transformer": transformer_perf, # Unified transformer (QAIC)
"transformer": transformer_perf,
"vae_decoder": vae_decoder_perf,
}

Expand Down
6 changes: 3 additions & 3 deletions examples/diffusers/wan/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ WAN 2.2 is a text-to-video diffusion model that uses dual-stage processing for h
## Files

- **`wan_lightning.py`** - Complete example with Lightning LoRA for fast video generation
- **`wan_config.json`** - Configuration file for transformer module compilation
- **`wan_config.json`** - Contains default compilation config for transformer, vae modules.

## Quick Start

Expand Down Expand Up @@ -102,7 +102,7 @@ pipeline.transformer.model.transformer_high.blocks = torch.nn.ModuleList(
[original_blocks[i] for i in range(0, pipeline.transformer.model.transformer_high.config['num_layers'])]
)
pipeline.transformer.model.transformer_low.blocks = torch.nn.ModuleList(
[org_blocks[i] for i in range(0, pipeline.transformer.model.transformer_low.config.config['num_layers'])]
[org_blocks[i] for i in range(0, pipeline.transformer.model.transformer_low.config['num_layers'])]
)
```

Expand Down Expand Up @@ -160,7 +160,7 @@ Head blocking is common in all modes

## Configuration File

The `wan_config.json` file controls compilation settings for the transformer module:
The `wan_config.json` file controls compilation settings for the transformer, vae modules:

### Module Structure

Expand Down
3 changes: 1 addition & 2 deletions examples/diffusers/wan/wan_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@
{
"batch_size": 1,
"num_channels": 16
}
,
},
"compilation":
{
"onnx_path": null,
Expand Down
2 changes: 1 addition & 1 deletion scripts/Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ pipeline {
export TOKENIZERS_PARALLELISM=false &&
export QEFF_HOME=$PWD/Non_cli_qaic_diffusion &&
export HF_HUB_CACHE=/huggingface_hub &&
pytest tests -m '(not cli) and (on_qaic) and (diffusion_models) and (not wan) and (not qnn) and (not finetune)' --ignore tests/vllm --ignore tests/unit_test --junitxml=tests/tests_log_diffusion.xml --durations=10 &&
pytest tests -m '(not cli) and (on_qaic) and (diffusion_models) and (not qnn) and (not finetune)' --ignore tests/vllm --ignore tests/unit_test --junitxml=tests/tests_log_diffusion.xml --durations=20 &&
junitparser merge tests/tests_log_diffusion.xml tests/tests_log.xml &&
deactivate"
'''
Expand Down
12 changes: 6 additions & 6 deletions tests/diffusers/flux_test_config.json
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
{
"model_setup": {
"height": 256,
"width": 256,
"height": 64,
"width": 64,
"num_transformer_layers": 2,
"num_single_layers": 2
},
"mad_validation": {
"tolerances": {
"clip_text_encoder": 0.1,
"t5_text_encoder": 5.5,
"transformer": 2.0,
"vae_decoder": 1.0
"clip_text_encoder": 0.01,
"t5_text_encoder": 5,
"transformer": 0.1,
"vae_decoder": 0.01
}
},
"pipeline_params": {
Expand Down
69 changes: 42 additions & 27 deletions tests/diffusers/test_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import numpy as np
import pytest
import torch
from diffusers import FluxPipeline
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast

from QEfficient import QEffFluxPipeline
from QEfficient.diffusers.pipelines.pipeline_utils import (
Expand Down Expand Up @@ -311,34 +312,48 @@ def flux_pipeline_call_with_mad_validation(

@pytest.fixture(scope="session")
def flux_pipeline():
"""Setup compiled Flux pipeline for testing"""
"""Setup Flux test pipelines with random-initialized (dummy) weights."""
config = INITIAL_TEST_CONFIG["model_setup"]

pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")

# Reduce to 2 layers for testing
original_blocks = pipeline.transformer.model.transformer_blocks
org_single_blocks = pipeline.transformer.model.single_transformer_blocks

pipeline.transformer.model.config["num_layers"] = config["num_transformer_layers"]
pipeline.transformer.model.config["num_single_layers"] = config["num_single_layers"]
pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList(
[original_blocks[i] for i in range(0, pipeline.transformer.model.config["num_layers"])]
)
pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList(
[org_single_blocks[i] for i in range(0, pipeline.transformer.model.config["num_single_layers"])]
model_id = "black-forest-labs/FLUX.1-schnell"

# Build random-init components from model configs (no pretrained weights).
vae_config = AutoencoderKL.load_config(model_id, subfolder="vae")
transformer_config = FluxTransformer2DModel.load_config(model_id, subfolder="transformer")
scheduler_cfg = FlowMatchEulerDiscreteScheduler.load_config(model_id, subfolder="scheduler")

transformer_config["num_layers"] = config["num_transformer_layers"]
transformer_config["num_single_layers"] = config["num_single_layers"]

vae = AutoencoderKL.from_config(vae_config)
transformer = FluxTransformer2DModel.from_config(transformer_config)
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_cfg)

clip_text_encoder_cfg = CLIPTextModel.config_class.from_pretrained(model_id, subfolder="text_encoder")
t5_text_encoder_cfg = T5EncoderModel.config_class.from_pretrained(model_id, subfolder="text_encoder_2")

# Reduce text-encoder depth for faster export/compile in this test.
clip_text_encoder_cfg.num_hidden_layers = 1
t5_text_encoder_cfg.num_layers = 1

text_encoder = CLIPTextModel(clip_text_encoder_cfg)
text_encoder_2 = T5EncoderModel(t5_text_encoder_cfg)
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
tokenizer_2 = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer_2")

pytorch_pipeline = FluxPipeline(
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
transformer=transformer,
)

### Pytorch pipeline
pytorch_pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
original_blocks_pt = pytorch_pipeline.transformer.transformer_blocks
org_single_blocks_pt = pytorch_pipeline.transformer.single_transformer_blocks
pytorch_pipeline.transformer.transformer_blocks = torch.nn.ModuleList(
[original_blocks_pt[i] for i in range(0, pipeline.transformer.model.config["num_layers"])]
)
pytorch_pipeline.transformer.single_transformer_blocks = torch.nn.ModuleList(
[org_single_blocks_pt[i] for i in range(0, pipeline.transformer.model.config["num_single_layers"])]
)
# Use QEff wrapper on a copy of the random-init reference model.
import copy

pipeline = QEffFluxPipeline(copy.deepcopy(pytorch_pipeline))
return pipeline, pytorch_pipeline


Expand Down Expand Up @@ -411,7 +426,7 @@ def test_flux_pipeline(flux_pipeline):
print(f" - Mode: {image_validation['mode']}")
print(f" - Variance: {image_validation['variance']:.2f}")
print(f" - Mean pixel value: {image_validation['mean_pixel_value']:.2f}")
file_path = "test_flux_256x256_2layers.png"
file_path = "test_flux_64x64_2layers.png"
# Save test image
generated_image.save(file_path)

Expand Down
Loading
Loading