diff --git a/api/test/api/test_diffusion.py b/api/test/api/test_diffusion.py index 81ac282fb..aa702d3bc 100644 --- a/api/test/api/test_diffusion.py +++ b/api/test/api/test_diffusion.py @@ -1045,3 +1045,104 @@ def test_get_pipeline_key_whitespace_adaptor(): # Should treat whitespace-only adaptor as no adaptor assert key == "test-model:: ::txt2img" + + +def test_resolve_diffusion_model_reference_non_directory(): + """Non-directory model refs should pass through unchanged.""" + main = pytest.importorskip("transformerlab.plugins.image_diffusion.main") + + with patch("transformerlab.plugins.image_diffusion.main.os.path.isdir", return_value=False): + resolved = main.resolve_diffusion_model_reference("Tongyi-MAI/Z-Image-Turbo") + + assert resolved == "Tongyi-MAI/Z-Image-Turbo" + + +def test_resolve_diffusion_model_reference_prefers_local_complete_dir(): + """Local directory with model_index.json should stay as local path.""" + main = pytest.importorskip("transformerlab.plugins.image_diffusion.main") + local_dir = "/tmp/models/Tongyi-MAI_Z-Image-Turbo" + + with ( + patch("transformerlab.plugins.image_diffusion.main.os.path.isdir", return_value=True), + patch("transformerlab.plugins.image_diffusion.main.os.path.isfile", return_value=True), + patch("transformerlab.plugins.image_diffusion.main._extract_hf_repo_from_model_metadata") as mock_extract, + ): + resolved = main.resolve_diffusion_model_reference(local_dir) + + mock_extract.assert_not_called() + assert resolved == local_dir + + +def test_resolve_diffusion_model_reference_falls_back_to_hf_repo(): + """Incomplete local directory should fall back to Hugging Face repo id from metadata.""" + main = pytest.importorskip("transformerlab.plugins.image_diffusion.main") + local_dir = "/tmp/models/Tongyi-MAI_Z-Image-Turbo" + + with ( + patch("transformerlab.plugins.image_diffusion.main.os.path.isdir", return_value=True), + patch("transformerlab.plugins.image_diffusion.main.os.path.isfile", return_value=False), + patch( + "transformerlab.plugins.image_diffusion.main._extract_hf_repo_from_model_metadata", + return_value="Tongyi-MAI/Z-Image-Turbo", + ), + ): + resolved = main.resolve_diffusion_model_reference(local_dir) + + assert resolved == "Tongyi-MAI/Z-Image-Turbo" + + +def test_filter_generation_kwargs_for_pipeline_drops_unsupported(): + """Unsupported kwargs should be removed when pipeline call signature is strict.""" + main = pytest.importorskip("transformerlab.plugins.image_diffusion.main") + + class StrictPipeline: + def __call__(self, prompt, guidance_scale): + return {"prompt": prompt, "guidance_scale": guidance_scale} + + pipe = StrictPipeline() + kwargs = { + "prompt": "test", + "guidance_scale": 7.5, + "cross_attention_kwargs": {"scale": 1.0}, + "callback_on_step_end": lambda *_: None, + } + + filtered = main.filter_generation_kwargs_for_pipeline(pipe, kwargs) + + assert filtered == {"prompt": "test", "guidance_scale": 7.5} + + +def test_invoke_pipeline_with_safe_kwargs_retries_on_unexpected_keyword(): + """Retry logic should remove unsupported kwargs when strict call signatures are wrapped.""" + main = pytest.importorskip("transformerlab.plugins.image_diffusion.main") + + class WrappedStrictPipeline: + def __call__(self, *args, **kwargs): + if "cross_attention_kwargs" in kwargs: + raise TypeError("ZImagePipeline.__call__() got an unexpected keyword argument 'cross_attention_kwargs'") + return kwargs + + pipe = WrappedStrictPipeline() + kwargs = { + "prompt": "test", + "guidance_scale": 7.5, + "cross_attention_kwargs": {"scale": 1.0}, + } + + result = main.invoke_pipeline_with_safe_kwargs(pipe, kwargs) + + assert result["prompt"] == "test" + assert result["guidance_scale"] == 7.5 + assert "cross_attention_kwargs" not in result + + +def test_latents_to_rgb_supports_non_sdxl_channel_counts(): + """Intermediate preview conversion should not fail for non-4-channel latents.""" + main = pytest.importorskip("transformerlab.plugins.image_diffusion.main") + torch = pytest.importorskip("torch") + + latents = torch.randn(16, 8, 8) + preview = main.latents_to_rgb(latents) + + assert preview.mode == "RGB" + assert preview.size == (8, 8) diff --git a/api/transformerlab/plugin_sdk/plugin_harness.py b/api/transformerlab/plugin_sdk/plugin_harness.py index 1697eb66b..456a65b63 100644 --- a/api/transformerlab/plugin_sdk/plugin_harness.py +++ b/api/transformerlab/plugin_sdk/plugin_harness.py @@ -12,7 +12,48 @@ import sys import argparse import traceback -import asyncio +import sqlite3 +from typing import Optional + + +def get_db_config_value(key: str, team_id: Optional[str] = None, user_id: Optional[str] = None) -> Optional[str]: + """ + Read config values directly from sqlite without importing transformerlab.plugin. + This keeps harness startup independent from heavy ML dependencies. + """ + from lab import HOME_DIR + + db_path = f"{HOME_DIR}/llmlab.sqlite3" + db = sqlite3.connect(db_path, isolation_level=None) + db.execute("PRAGMA busy_timeout=30000") + try: + # Priority 1: user-specific config (requires both user_id and team_id) + if user_id and team_id: + cursor = db.execute( + "SELECT value FROM config WHERE key = ? AND user_id = ? AND team_id = ?", (key, user_id, team_id) + ) + row = cursor.fetchone() + cursor.close() + if row is not None: + return row[0] + + # Priority 2: team-wide config + if team_id: + cursor = db.execute( + "SELECT value FROM config WHERE key = ? AND user_id IS NULL AND team_id = ?", (key, team_id) + ) + row = cursor.fetchone() + cursor.close() + if row is not None: + return row[0] + + # Priority 3: global config + cursor = db.execute("SELECT value FROM config WHERE key = ? AND user_id IS NULL AND team_id IS NULL", (key,)) + row = cursor.fetchone() + cursor.close() + return row[0] if row is not None else None + finally: + db.close() parser = argparse.ArgumentParser() @@ -20,16 +61,66 @@ args, unknown = parser.parse_known_args() -def set_config_env_vars(env_var: str, target_env_var: str = None, user_id: str = None, team_id: str = None): - try: - from transformerlab.plugin import get_db_config_value +def configure_plugin_runtime_library_paths(plugin_dir: str) -> None: + """ + Prefer CUDA/NCCL libraries from the plugin venv over system-wide libraries. + This reduces CUDA symbol mismatches caused by stale host NCCL installs. + """ + if os.name == "nt": + return + + venv_path = os.path.join(plugin_dir, "venv") + if not os.path.isdir(venv_path): + return + + pyver = f"python{sys.version_info.major}.{sys.version_info.minor}" + site_packages = os.path.join(venv_path, "lib", pyver, "site-packages") + + candidate_paths: list[str] = [] + torch_lib = os.path.join(site_packages, "torch", "lib") + if os.path.isdir(torch_lib): + candidate_paths.append(torch_lib) + + nvidia_root = os.path.join(site_packages, "nvidia") + if os.path.isdir(nvidia_root): + for pkg_name in os.listdir(nvidia_root): + lib_dir = os.path.join(nvidia_root, pkg_name, "lib") + if os.path.isdir(lib_dir): + candidate_paths.append(lib_dir) - value = asyncio.run(get_db_config_value(env_var, user_id=user_id, team_id=team_id)) + if not candidate_paths: + return + + existing_paths = [p for p in os.environ.get("LD_LIBRARY_PATH", "").split(os.pathsep) if p] + candidate_norm = {os.path.normpath(c) for c in candidate_paths} + + merged = list(candidate_paths) + for path in existing_paths: + if os.path.normpath(path) not in candidate_norm: + merged.append(path) + + if merged != existing_paths: + os.environ["LD_LIBRARY_PATH"] = os.pathsep.join(merged) + print("Configured LD_LIBRARY_PATH for plugin runtime libraries") + + +configure_plugin_runtime_library_paths(args.plugin_dir) + + +def set_config_env_vars( + env_var: str, + target_env_var: Optional[str] = None, + user_id: Optional[str] = None, + team_id: Optional[str] = None, +) -> None: + target_key = target_env_var or env_var + try: + value = get_db_config_value(env_var, user_id=user_id, team_id=team_id) if value: - os.environ[target_env_var] = value - print(f"Set {target_env_var} from {'user' if user_id else 'team'} config: {value}") + os.environ[target_key] = value + print(f"Set {target_key} from {'user' if user_id else 'team'} config") except Exception as e: - print(f"Warning: Could not set {target_env_var} from {'user' if user_id else 'team'} config: {e}") + print(f"Warning: Could not set {target_key} from {'user' if user_id else 'team'} config: {e}") # Set organization context from environment variable if provided @@ -69,6 +160,11 @@ def set_config_env_vars(env_var: str, target_env_var: str = None, user_id: str = except ImportError as e: print(f"Error executing plugin: {e}") traceback.print_exc() + if "ncclCommShrink" in str(e): + print( + "Detected CUDA/NCCL mismatch while importing torch. " + "Reinstall the plugin venv with a torch build matching this machine's CUDA runtime." + ) # if e is a ModuleNotFoundError, the plugin is missing a required package if isinstance(e, ModuleNotFoundError): diff --git a/api/transformerlab/plugins/diffusion_trainer/index.json b/api/transformerlab/plugins/diffusion_trainer/index.json index 0d59be36d..20728c454 100644 --- a/api/transformerlab/plugins/diffusion_trainer/index.json +++ b/api/transformerlab/plugins/diffusion_trainer/index.json @@ -4,14 +4,15 @@ "description": "A plugin for fine-tuning Stable Diffusion using LoRA adapters.", "plugin-format": "python", "type": "trainer", - "version": "0.1.10", + "version": "0.1.11", "git": "", "url": "", "model_architectures": [ "StableDiffusionPipeline", "StableDiffusionXLPipeline", "StableDiffusion3Pipeline", - "FluxPipeline" + "FluxPipeline", + "ZImagePipeline" ], "files": ["main.py", "setup.sh"], "supported_hardware_architectures": ["cuda", "amd"], diff --git a/api/transformerlab/plugins/diffusion_trainer/main.py b/api/transformerlab/plugins/diffusion_trainer/main.py index 3510e1851..e81f2f517 100644 --- a/api/transformerlab/plugins/diffusion_trainer/main.py +++ b/api/transformerlab/plugins/diffusion_trainer/main.py @@ -3,32 +3,52 @@ import json import gc import asyncio +from typing import Any import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint -from peft import LoraConfig +from peft import LoraConfig, get_peft_model from peft.utils import get_peft_model_state_dict from torchvision import transforms +import os +import glob from diffusers import AutoPipelineForText2Image, StableDiffusionPipeline from diffusers.optimization import get_scheduler from diffusers.training_utils import cast_training_params, compute_snr from diffusers.utils import convert_state_dict_to_diffusers +from transformerlab.sdk.v1.train import tlab_trainer +from lab.dirs import get_workspace_dir +from lab import storage + +# diffsynth is only required for Z-Image training. Keep it optional so non-ZImage +# workflows still run when optional low-level deps (for example xformers) are broken. +ModelConfig = None +ZImagePipeline = None +FlowMatchSFTLoss = None +diffsynth_available = False +diffsynth_import_error = None +try: + from diffsynth import ModelConfig + from diffsynth.pipelines.z_image import ZImagePipeline + from diffsynth.diffusion.loss import FlowMatchSFTLoss + + diffsynth_available = True +except Exception as e: + diffsynth_import_error = e + print(f"Warning: diffsynth is unavailable; Z-Image training will be disabled. Error: {e}") # Try to import xformers for memory optimization try: import xformers # noqa: F401 xformers_available = True -except ImportError: +except Exception as e: xformers_available = False - -from transformerlab.sdk.v1.train import tlab_trainer -from lab.dirs import get_workspace_dir -from lab import storage + print(f"Warning: xFormers is unavailable; continuing without it. Error: {e}") workspace_dir = asyncio.run(get_workspace_dir()) @@ -54,6 +74,44 @@ def cleanup_pipeline(): cleanup_pipeline() +def build_zimage_model_configs(model_id_or_path: str) -> tuple[list[Any], Any]: + """Build ModelConfig list + tokenizer config for Z-Image Turbo.""" + if ModelConfig is None: + raise RuntimeError("diffsynth ModelConfig is unavailable.") + + transformer_pattern = os.path.join("transformer", "*.safetensors") + text_encoder_pattern = os.path.join("text_encoder", "*.safetensors") + vae_pattern = os.path.join("vae", "diffusion_pytorch_model.safetensors") + + tokenizer_pattern = "tokenizer/" + + if model_id_or_path and os.path.isdir(model_id_or_path): + transformer_paths = glob.glob(os.path.join(model_id_or_path, transformer_pattern)) + text_encoder_paths = glob.glob(os.path.join(model_id_or_path, text_encoder_pattern)) + vae_paths = glob.glob(os.path.join(model_id_or_path, vae_pattern)) + + model_configs = [ + ModelConfig(path=transformer_paths), + ModelConfig(path=text_encoder_paths), + ModelConfig(path=vae_paths), + ] + + tokenizer_config = ModelConfig(path=os.path.join(model_id_or_path, tokenizer_pattern)) + else: + model_configs = [ + ModelConfig(model_id=model_id_or_path, origin_file_pattern=transformer_pattern), + ModelConfig(model_id=model_id_or_path, origin_file_pattern=text_encoder_pattern), + # Fix: Use the corrected pattern for remote loading + ModelConfig(model_id=model_id_or_path, origin_file_pattern=vae_pattern), + ] + + tokenizer_config = ModelConfig( + model_id=model_id_or_path, origin_file_pattern=os.path.join(tokenizer_pattern, "*") + ) + + return model_configs, tokenizer_config + + def compute_loss_weighting(args, timesteps, noise_scheduler): """ Compute loss weighting for improved training stability. @@ -304,6 +362,59 @@ def encode_prompt_sdxl( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds +def encode_prompt_zimage(pipe, prompts, device, max_sequence_length: int = 512): + """Encode prompts using Z-Image tokenizer/chat template.""" + if isinstance(prompts, str): + prompts = [prompts] + + chat_prompts = [] + for prompt_item in prompts: + messages = [ + {"role": "user", "content": prompt_item}, + ] + + if hasattr(pipe.tokenizer, "build_chat_prompt"): + chat_prompt = pipe.tokenizer.build_chat_prompt( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + else: + # Fallback for tokenizers like Qwen2TokenizerFast + chat_prompt = pipe.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + + chat_prompts.append(chat_prompt) + + text_inputs = pipe.tokenizer( + chat_prompts, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = pipe.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embedding_list = [] + + for i in range(len(prompt_embeds)): + embedding_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embedding_list + + @tlab_trainer.job_wrapper(wandb_project_name="TLab_Training", manual_logging=True) def train_diffusion_lora(): # Extract parameters from tlab_trainer @@ -325,6 +436,11 @@ def train_diffusion_lora(): eval_prompt = None args["eval_prompt"] = None args["eval_steps"] = 0 + if args.get("model_architecture", "").strip() == "ZImagePipeline": + print("Disabling evaluation for ZImagePipeline as it is not supported.") + eval_prompt = None + args["eval_prompt"] = None + args["eval_steps"] = 0 elif eval_prompt and eval_steps <= 0: print("Warning: eval_steps is set to 0, evaluation will not be performed.") eval_prompt = None @@ -333,7 +449,7 @@ def train_diffusion_lora(): if eval_prompt: eval_images_dir = storage.join(workspace_dir, "temp", f"eval_images_{job_id}") - storage.makedirs(eval_images_dir, exist_ok=True) + asyncio.run(storage.makedirs(eval_images_dir, exist_ok=True)) print(f"Evaluation images will be saved to: {eval_images_dir}") # Add eval images directory to job data @@ -352,43 +468,114 @@ def train_diffusion_lora(): model_architecture = args.get("model_architecture") - # Load pipeline to auto-detect architecture and get correct components - print(f"Loading pipeline to detect model architecture: {pretrained_model_name_or_path}") + # Detect architecture based on multiple indicators + is_sdxl = "StableDiffusionXLPipeline" in model_architecture + + is_sd3 = "StableDiffusion3Pipeline" in model_architecture + + is_flux = "FluxPipeline" in model_architecture + + is_zimage = "ZImagePipeline" in model_architecture + + # Some model cards expose generic architecture strings. Fallback to model id/path hints for Z-Image. + if not is_zimage: + model_name_hint = str(args.get("model_name") or "") + model_path_hint = str(args.get("model_path") or "") + model_hint = f"{model_name_hint} {model_path_hint}".lower() + if "z-image" in model_hint or "zimage" in model_hint: + is_zimage = True + print("Detected Z-Image model from model name/path hint; forcing ZImagePipeline training path.") + + print(f"Architecture detection - SDXL: {is_sdxl}, SD3: {is_sd3}, Flux: {is_flux}, ZImage: {is_zimage}") + + # Mixed Precision + weight_dtype = torch.float32 + mixed_precision = args.get("mixed_precision", None) + if is_zimage and (mixed_precision is None or mixed_precision == "" or mixed_precision == "no"): + weight_dtype = torch.bfloat16 + elif mixed_precision == "fp16": + weight_dtype = torch.float16 + elif mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + pipeline_kwargs = { "torch_dtype": torch.float16, "safety_checker": None, "requires_safety_checker": False, } - temp_pipeline = AutoPipelineForText2Image.from_pretrained(pretrained_model_name_or_path, **pipeline_kwargs) + pipe = None + if is_zimage: + if not diffsynth_available or ZImagePipeline is None or FlowMatchSFTLoss is None: + raise RuntimeError( + "Z-Image training dependencies are unavailable. " + "This is commonly caused by an incompatible xformers build in the plugin venv. " + "Try uninstalling xformers from the diffusion_trainer venv and rerun setup. " + f"Original import error: {diffsynth_import_error}" + ) + + # Ensure the model is downloaded locally if it's not already a directory + if not os.path.isdir(pretrained_model_name_or_path): + from huggingface_hub import snapshot_download - # Extract components from the loaded pipeline - noise_scheduler = temp_pipeline.scheduler - tokenizer = temp_pipeline.tokenizer - text_encoder = temp_pipeline.text_encoder - vae = temp_pipeline.vae + print(f"Downloading Z-Image model {pretrained_model_name_or_path} from Hugging Face...") + pretrained_model_name_or_path = snapshot_download( + repo_id=pretrained_model_name_or_path, + allow_patterns=["*.safetensors", "*.json", "tokenizer/*"], + ) + print(f"Model downloaded to: {pretrained_model_name_or_path}") - # Handle different architectures: FluxPipeline uses 'transformer', others use 'unet' - # We use 'unet' as a unified variable name for the main model component regardless of architecture - if hasattr(temp_pipeline, "transformer"): - # FluxPipeline and other transformer-based models - unet = temp_pipeline.transformer - model_component_name = "transformer" + model_configs, tokenizer_config = build_zimage_model_configs(pretrained_model_name_or_path) + pipe = ZImagePipeline.from_pretrained( + torch_dtype=weight_dtype, + device=device, + model_configs=model_configs, + tokenizer_config=tokenizer_config, + ) + + pipe.scheduler.set_timesteps(int(args.get("num_train_timesteps", 1000)), training=True) + noise_scheduler = pipe.scheduler + tokenizer = pipe.tokenizer + text_encoder = pipe.text_encoder + vae_encoder = pipe.vae_encoder + vae_decoder = pipe.vae_decoder + unet = pipe.dit + model_component_name = "dit" + text_encoder_2 = None + tokenizer_2 = None + vae = None else: - # SD 1.x, SDXL, SD3 and other UNet-based models - unet = temp_pipeline.unet - model_component_name = "unet" + temp_pipeline = AutoPipelineForText2Image.from_pretrained(pretrained_model_name_or_path, **pipeline_kwargs) + + # Extract components from the loaded pipeline + noise_scheduler = temp_pipeline.scheduler + tokenizer = temp_pipeline.tokenizer + text_encoder = temp_pipeline.text_encoder + vae = temp_pipeline.vae + + # Handle different architectures: FluxPipeline uses 'transformer', others use 'unet' + # We use 'unet' as a unified variable name for the main model component regardless of architecture + if hasattr(temp_pipeline, "transformer"): + # FluxPipeline and other transformer-based models + unet = temp_pipeline.transformer + model_component_name = "transformer" + else: + # SD 1.x, SDXL, SD3 and other UNet-based models + unet = temp_pipeline.unet + model_component_name = "unet" - # Handle SDXL case with dual text encoders - text_encoder_2 = getattr(temp_pipeline, "text_encoder_2", None) - tokenizer_2 = getattr(temp_pipeline, "tokenizer_2", None) + # Handle SDXL case with dual text encoders + text_encoder_2 = getattr(temp_pipeline, "text_encoder_2", None) + tokenizer_2 = getattr(temp_pipeline, "tokenizer_2", None) - # Clean up temporary pipeline - del temp_pipeline - torch.cuda.empty_cache() if torch.cuda.is_available() else None + # Clean up temporary pipeline + del temp_pipeline + torch.cuda.empty_cache() if torch.cuda.is_available() else None - print(f"Model components loaded successfully: {pretrained_model_name_or_path}") - print(f"Architecture detected - Model component ({model_component_name}): {type(unet).__name__}") + print(f"Model components loaded successfully: {pretrained_model_name_or_path}") + print(f"Architecture detected - Model component ({model_component_name}): {type(unet).__name__}") if text_encoder_2 is not None: print("Dual text encoder setup detected (likely SDXL)") print(f"Text encoder type: {type(text_encoder).__name__}") @@ -396,7 +583,12 @@ def train_diffusion_lora(): # Freeze parameters unet.requires_grad_(False) - vae.requires_grad_(False) + + if is_zimage: + vae_encoder.requires_grad_(False) + vae_decoder.requires_grad_(False) + else: + vae.requires_grad_(False) text_encoder.requires_grad_(False) if text_encoder_2 is not None: text_encoder_2.requires_grad_(False) @@ -407,27 +599,22 @@ def train_diffusion_lora(): unet.enable_xformers_memory_efficient_attention() if hasattr(vae, "enable_xformers_memory_efficient_attention"): vae.enable_xformers_memory_efficient_attention() + if not is_zimage and hasattr(vae, "enable_xformers_memory_efficient_attention"): + vae.enable_xformers_memory_efficient_attention() print("xFormers memory efficient attention enabled") except Exception as e: print(f"Failed to enable xFormers: {e}") # Enable gradient checkpointing for memory savings if args.get("gradient_checkpointing", False): - unet.enable_gradient_checkpointing() + if hasattr(unet, "enable_gradient_checkpointing"): + unet.enable_gradient_checkpointing() if hasattr(text_encoder, "gradient_checkpointing_enable"): text_encoder.gradient_checkpointing_enable() if text_encoder_2 is not None and hasattr(text_encoder_2, "gradient_checkpointing_enable"): text_encoder_2.gradient_checkpointing_enable() print("Gradient checkpointing enabled") - # Mixed precision - weight_dtype = torch.float32 - mixed_precision = args.get("mixed_precision", None) - if mixed_precision == "fp16": - weight_dtype = torch.float16 - elif mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - # LoRA config - adaptive target modules for different architectures model_type = type(unet).__name__ @@ -439,14 +626,7 @@ def train_diffusion_lora(): f"Has addition_embed_type: {hasattr(unet.config, 'addition_embed_type') if hasattr(unet, 'config') else 'No config'}" ) - # Detect architecture based on multiple indicators - is_sdxl = "StableDiffusionXLPipeline" in model_architecture - - is_sd3 = "StableDiffusion3Pipeline" in model_architecture - - is_flux = "FluxPipeline" in model_architecture - - print(f"Architecture detection - SDXL: {is_sdxl}, SD3: {is_sd3}, Flux: {is_flux}") + print(f"Architecture detection - SDXL: {is_sdxl}, SD3: {is_sd3}, Flux: {is_flux}, ZImage: {is_zimage}") # Define target modules based on detected architecture if is_sdxl: @@ -461,6 +641,10 @@ def train_diffusion_lora(): # Flux uses transformer-based architecture target_modules = ["to_q", "to_k", "to_v", "to_out.0"] architecture_name = "Flux" + elif is_zimage: + # Z-Image DiT uses standard attention projections + target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + architecture_name = "Z-Image" else: # Default SD 1.x targets target_modules = ["to_k", "to_q", "to_v", "to_out.0"] @@ -475,14 +659,21 @@ def train_diffusion_lora(): target_modules=target_modules, ) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") unet.to(device, dtype=weight_dtype) - vae.to(device, dtype=weight_dtype) + if is_zimage: + vae_encoder.to(device, dtype=weight_dtype) + vae_decoder.to(device, dtype=weight_dtype) + else: + vae.to(device, dtype=weight_dtype) text_encoder.to(device, dtype=weight_dtype) if text_encoder_2 is not None: text_encoder_2.to(device, dtype=weight_dtype) - unet.add_adapter(unet_lora_config) + if is_zimage: + unet = get_peft_model(unet, unet_lora_config) + pipe.dit = unet + else: + unet.add_adapter(unet_lora_config) if mixed_precision == "fp16": cast_training_params(unet, dtype=torch.float32) @@ -611,7 +802,7 @@ def generate_eval_image(epoch): train_transforms = transforms.Compose(transform_list) - def tokenize_captions(examples, is_train=True): + def build_captions(examples, is_train=True): captions = [] caption_column = args.get("caption_column", "text") trigger_word = args.get("trigger_word", "").strip() @@ -647,7 +838,10 @@ def tokenize_captions(examples, is_train=True): processed_caption = f"{trigger_word}, {processed_caption}" captions.append(processed_caption) + return captions + def tokenize_captions(examples, is_train=True): + captions = build_captions(examples, is_train=is_train) # Primary tokenizer (always present) inputs = tokenizer( captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" @@ -669,8 +863,41 @@ def tokenize_captions(examples, is_train=True): return result image_column = args.get("image_column", "image") + caption_column = args.get("caption_column", "text") + + if image_column not in dataset.column_names: + raise ValueError(f"Image column '{image_column}' not found in dataset.") + + keep_columns = [image_column] + if caption_column in dataset.column_names: + keep_columns.append(caption_column) + + drop_columns = [col for col in dataset.column_names if col not in keep_columns] + if drop_columns: + dataset = dataset.remove_columns(drop_columns) + + if hasattr(dataset, "reset_format"): + dataset.reset_format() + + dataset = dataset.filter(lambda x: x.get(image_column) is not None) def preprocess_train(examples): + # Filter out examples with None images + images_value = examples.get(image_column) + if images_value is None: + return {} + + valid_indices = [i for i, img in enumerate(images_value) if img is not None] + if not valid_indices: + return {} + + filtered_examples = {} + for key, value in examples.items(): + if isinstance(value, (list, tuple, np.ndarray)): + filtered_examples[key] = [value[i] for i in valid_indices] + + examples = filtered_examples + images = [image.convert("RGB") for image in examples[image_column]] # Enhanced preprocessing for SDXL with proper image metadata tracking @@ -707,13 +934,16 @@ def preprocess_train(examples): examples["crop_coords_top_left"] = crop_coords_top_left examples["target_sizes"] = target_sizes - # Get tokenization results - tokenization_results = tokenize_captions(examples) - examples["input_ids"] = tokenization_results["input_ids"] + if is_zimage: + examples["prompt"] = build_captions(examples, is_train=True) + else: + # Get tokenization results + tokenization_results = tokenize_captions(examples) + examples["input_ids"] = tokenization_results["input_ids"] - # Add second input_ids for SDXL if present - if "input_ids_2" in tokenization_results: - examples["input_ids_2"] = tokenization_results["input_ids_2"] + # Add second input_ids for SDXL if present + if "input_ids_2" in tokenization_results: + examples["input_ids_2"] = tokenization_results["input_ids_2"] return examples @@ -722,9 +952,18 @@ def preprocess_train(examples): def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - input_ids = torch.stack([example["input_ids"] for example in examples]) + batch = {"pixel_values": pixel_values} - batch = {"pixel_values": pixel_values, "input_ids": input_ids} + if is_zimage: + batch["prompt"] = [example["prompt"] for example in examples] + if "original_sizes" in examples[0]: + batch["original_sizes"] = [example["original_sizes"] for example in examples] + batch["crop_coords_top_left"] = [example["crop_coords_top_left"] for example in examples] + batch["target_sizes"] = [example["target_sizes"] for example in examples] + return batch + + input_ids = torch.stack([example["input_ids"] for example in examples]) + batch["input_ids"] = input_ids # Add second input_ids for SDXL if present if "input_ids_2" in examples[0]: @@ -789,180 +1028,207 @@ def collate_fn(examples): for epoch in range(num_train_epochs): unet.train() for step, batch in enumerate(train_dataloader): - # Convert images to latent space - latents = vae.encode(batch["pixel_values"].to(device, dtype=weight_dtype)).latent_dist.sample() - latents = latents * vae.config.scaling_factor - - # Sample noise - noise = torch.randn_like(latents) - if args.get("noise_offset", 0): - noise += args["noise_offset"] * torch.randn( - (latents.shape[0], latents.shape[1], 1, 1), device=latents.device + if is_zimage: + pixel_values = batch["pixel_values"].to(device, dtype=weight_dtype) + input_latents = vae_encoder(pixel_values) + prompt_embeds = encode_prompt_zimage(pipe, batch["prompt"], device) + + loss = FlowMatchSFTLoss( + pipe, + input_latents=input_latents, + prompt_embeds=prompt_embeds, + image_embeds=None, + image_latents=None, + use_gradient_checkpointing=args.get("gradient_checkpointing", False), + use_gradient_checkpointing_offload=False, ) - bsz = latents.shape[0] - timesteps = torch.randint( - 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device - ).long() - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Enhanced text encoding - always use encode_prompt for SDXL - if is_sdxl: - # Always use encode_prompt for SDXL, regardless of text_encoder_2 - prompts = tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=True) - # if tokenizer_2 is not None and "input_ids_2" in batch: - # prompts_2 = tokenizer_2.batch_decode(batch["input_ids_2"], skip_special_tokens=True) - # else: - # prompts_2 = None - - text_encoders = [text_encoder, text_encoder_2] if text_encoder_2 is not None else [text_encoder] - tokenizers = [tokenizer, tokenizer_2] if tokenizer_2 is not None else [tokenizer] - - # Create a temporary pipeline-like object for encode_prompt compatibility - class TempPipeline: - def __init__(self, text_encoder, text_encoder_2, tokenizer, tokenizer_2): - self.text_encoder = text_encoder - self.text_encoder_2 = text_encoder_2 - self.tokenizer = tokenizer - self.tokenizer_2 = tokenizer_2 - - temp_pipe = TempPipeline(text_encoder, text_encoder_2, tokenizer, tokenizer_2) - - encoder_hidden_states, _, pooled_prompt_embeds, _ = encode_prompt( - temp_pipe, - text_encoders, - tokenizers, - prompts, - device, - num_images_per_prompt=1, - do_classifier_free_guidance=False, - ) + print(f"Step {step + 1}/{len(train_dataloader)} - FlowMatchSFT Loss: {loss.item()}") else: - # Standard single text encoder approach - encoder_hidden_states = text_encoder(batch["input_ids"].to(device), return_dict=False)[0] - pooled_prompt_embeds = None - - # For SDXL with dual text encoders, handle dimension compatibility and concatenate - if text_encoder_2 is not None and "input_ids_2" in batch: - encoder_hidden_states_2 = text_encoder_2(batch["input_ids_2"].to(device), return_dict=False)[0] - - # Handle dimension mismatch - ensure both tensors have the same number of dimensions - if encoder_hidden_states.dim() != encoder_hidden_states_2.dim(): - # If one is 2D and the other is 3D, add a dimension to the 2D tensor - if encoder_hidden_states.dim() == 2 and encoder_hidden_states_2.dim() == 3: - encoder_hidden_states = encoder_hidden_states.unsqueeze(1) - elif encoder_hidden_states.dim() == 3 and encoder_hidden_states_2.dim() == 2: - encoder_hidden_states_2 = encoder_hidden_states_2.unsqueeze(1) - - # Ensure sequence lengths match for concatenation - seq_len_1 = ( - encoder_hidden_states.shape[1] - if encoder_hidden_states.dim() == 3 - else encoder_hidden_states.shape[0] + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(device, dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise + noise = torch.randn_like(latents) + if args.get("noise_offset", 0): + noise += args["noise_offset"] * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), device=latents.device ) - seq_len_2 = ( - encoder_hidden_states_2.shape[1] - if encoder_hidden_states_2.dim() == 3 - else encoder_hidden_states_2.shape[0] + + bsz = latents.shape[0] + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ).long() + if hasattr(noise_scheduler, "add_noise"): + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + elif hasattr(noise_scheduler, "scale_noise"): + # Flow-matching schedulers (for example FlowMatchEulerDiscreteScheduler) expose scale_noise. + noisy_latents = noise_scheduler.scale_noise(latents, timesteps, noise) + else: + raise AttributeError( + f"Unsupported scheduler {type(noise_scheduler).__name__}: " + "expected add_noise(...) or scale_noise(...)." ) - if seq_len_1 != seq_len_2: - # Pad the shorter sequence to match the longer one - max_seq_len = max(seq_len_1, seq_len_2) - - if encoder_hidden_states.dim() == 3: - if encoder_hidden_states.shape[1] < max_seq_len: - pad_size = max_seq_len - encoder_hidden_states.shape[1] - padding = torch.zeros( - encoder_hidden_states.shape[0], - pad_size, - encoder_hidden_states.shape[2], - device=encoder_hidden_states.device, - dtype=encoder_hidden_states.dtype, - ) - encoder_hidden_states = torch.cat([encoder_hidden_states, padding], dim=1) - - if encoder_hidden_states_2.shape[1] < max_seq_len: - pad_size = max_seq_len - encoder_hidden_states_2.shape[1] - padding = torch.zeros( - encoder_hidden_states_2.shape[0], - pad_size, - encoder_hidden_states_2.shape[2], - device=encoder_hidden_states_2.device, - dtype=encoder_hidden_states_2.dtype, - ) - encoder_hidden_states_2 = torch.cat([encoder_hidden_states_2, padding], dim=1) - - # Concatenate along the feature dimension (last dimension) - encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_2], dim=-1) - - # Loss target - prediction_type = args.get("prediction_type", None) - if prediction_type is not None: - noise_scheduler.register_to_config(prediction_type=prediction_type) - - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError( - f"Unknown prediction type {noise_scheduler.config.prediction_type}" - ) # Handle SDXL-specific conditioning parameters with proper metadata - unet_kwargs = {"timestep": timesteps, "encoder_hidden_states": encoder_hidden_states, "return_dict": False} - - # SDXL requires additional conditioning kwargs with proper pooled embeddings and time_ids - if is_sdxl: - batch_size = noisy_latents.shape[0] - - # Use proper pooled embeddings if available, otherwise create dummy ones - if pooled_prompt_embeds is not None: - text_embeds = ( - pooled_prompt_embeds.repeat(batch_size, 1) - if pooled_prompt_embeds.shape[0] == 1 - else pooled_prompt_embeds + # Enhanced text encoding - always use encode_prompt for SDXL + if is_sdxl: + # Always use encode_prompt for SDXL, regardless of text_encoder_2 + prompts = tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=True) + # if tokenizer_2 is not None and "input_ids_2" in batch: + # prompts_2 = tokenizer_2.batch_decode(batch["input_ids_2"], skip_special_tokens=True) + # else: + # prompts_2 = None + text_encoders = [text_encoder, text_encoder_2] if text_encoder_2 is not None else [text_encoder] + tokenizers = [tokenizer, tokenizer_2] if tokenizer_2 is not None else [tokenizer] + + # Create a temporary pipeline-like object for encode_prompt compatibility + class TempPipeline: + def __init__(self, text_encoder, text_encoder_2, tokenizer, tokenizer_2): + self.text_encoder = text_encoder + self.text_encoder_2 = text_encoder_2 + self.tokenizer = tokenizer + self.tokenizer_2 = tokenizer_2 + + temp_pipe = TempPipeline(text_encoder, text_encoder_2, tokenizer, tokenizer_2) + + encoder_hidden_states, _, pooled_prompt_embeds, _ = encode_prompt( + temp_pipe, + text_encoders, + tokenizers, + prompts, + device, + num_images_per_prompt=1, + do_classifier_free_guidance=False, ) else: - # Fallback to dummy embeddings for compatibility - text_embeds = torch.zeros(batch_size, 1280, device=device, dtype=weight_dtype) - - # Compute proper time_ids from actual image metadata if available - if "original_sizes" in batch and "crop_coords_top_left" in batch and "target_sizes" in batch: - time_ids_list = [] - for i in range(batch_size): - original_size = batch["original_sizes"][i] - crop_coords = batch["crop_coords_top_left"][i] - target_size = batch["target_sizes"][i] - - # Compute time_ids for this sample - time_ids = compute_time_ids( - original_size, - crop_coords, - target_size, - dtype=weight_dtype, - device=device, - weight_dtype=weight_dtype, + # Standard single text encoder approach + encoder_hidden_states = text_encoder(batch["input_ids"].to(device), return_dict=False)[0] + pooled_prompt_embeds = None + + # For SDXL with dual text encoders, handle dimension compatibility and concatenate + if text_encoder_2 is not None and "input_ids_2" in batch: + encoder_hidden_states_2 = text_encoder_2(batch["input_ids_2"].to(device), return_dict=False)[0] + + # Handle dimension mismatch - ensure both tensors have the same number of dimensions + if encoder_hidden_states.dim() != encoder_hidden_states_2.dim(): + # If one is 2D and the other is 3D, add a dimension to the 2D tensor + if encoder_hidden_states.dim() == 2 and encoder_hidden_states_2.dim() == 3: + encoder_hidden_states = encoder_hidden_states.unsqueeze(1) + elif encoder_hidden_states.dim() == 3 and encoder_hidden_states_2.dim() == 2: + encoder_hidden_states_2 = encoder_hidden_states_2.unsqueeze(1) + + # Ensure sequence lengths match for concatenation + seq_len_1 = ( + encoder_hidden_states.shape[1] + if encoder_hidden_states.dim() == 3 + else encoder_hidden_states.shape[0] + ) + seq_len_2 = ( + encoder_hidden_states_2.shape[1] + if encoder_hidden_states_2.dim() == 3 + else encoder_hidden_states_2.shape[0] ) - time_ids_list.append(time_ids) - time_ids = torch.cat(time_ids_list, dim=0) + if seq_len_1 != seq_len_2: + # Pad the shorter sequence to match the longer one + max_seq_len = max(seq_len_1, seq_len_2) + + if encoder_hidden_states.dim() == 3: + if encoder_hidden_states.shape[1] < max_seq_len: + pad_size = max_seq_len - encoder_hidden_states.shape[1] + padding = torch.zeros( + encoder_hidden_states.shape[0], + pad_size, + encoder_hidden_states.shape[2], + device=encoder_hidden_states.device, + dtype=encoder_hidden_states.dtype, + ) + encoder_hidden_states = torch.cat([encoder_hidden_states, padding], dim=1) + + if encoder_hidden_states_2.shape[1] < max_seq_len: + pad_size = max_seq_len - encoder_hidden_states_2.shape[1] + padding = torch.zeros( + encoder_hidden_states_2.shape[0], + pad_size, + encoder_hidden_states_2.shape[2], + device=encoder_hidden_states_2.device, + dtype=encoder_hidden_states_2.dtype, + ) + encoder_hidden_states_2 = torch.cat([encoder_hidden_states_2, padding], dim=1) + + # Concatenate along the feature dimension (last dimension) + encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_2], dim=-1) + + # Loss target + prediction_type = args.get("prediction_type", None) + if prediction_type is not None: + noise_scheduler.register_to_config(prediction_type=prediction_type) + + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) else: - # Fallback to dummy time_ids for compatibility - resolution = int(args.get("resolution", 512)) - time_ids = torch.tensor( - [[resolution, resolution, 0, 0, resolution, resolution]] * batch_size, - device=device, - dtype=weight_dtype, - ) - - added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids} - unet_kwargs["added_cond_kwargs"] = added_cond_kwargs + raise ValueError( + f"Unknown prediction type {noise_scheduler.config.prediction_type}" + ) # Handle SDXL-specific conditioning parameters with proper metadata + unet_kwargs = { + "timestep": timesteps, + "encoder_hidden_states": encoder_hidden_states, + "return_dict": False, + } + + # SDXL requires additional conditioning kwargs with proper pooled embeddings and time_ids + if is_sdxl: + batch_size = noisy_latents.shape[0] + + # Use proper pooled embeddings if available, otherwise create dummy ones + if pooled_prompt_embeds is not None: + text_embeds = ( + pooled_prompt_embeds.repeat(batch_size, 1) + if pooled_prompt_embeds.shape[0] == 1 + else pooled_prompt_embeds + ) + else: + # Fallback to dummy embeddings for compatibility + text_embeds = torch.zeros(batch_size, 1280, device=device, dtype=weight_dtype) + + # Compute proper time_ids from actual image metadata if available + if "original_sizes" in batch and "crop_coords_top_left" in batch and "target_sizes" in batch: + time_ids_list = [] + for i in range(batch_size): + original_size = batch["original_sizes"][i] + crop_coords = batch["crop_coords_top_left"][i] + target_size = batch["target_sizes"][i] + + # Compute time_ids for this sample + time_ids = compute_time_ids( + original_size, + crop_coords, + target_size, + dtype=weight_dtype, + device=device, + weight_dtype=weight_dtype, + ) + time_ids_list.append(time_ids) + + time_ids = torch.cat(time_ids_list, dim=0) + else: + # Fallback to dummy time_ids for compatibility + resolution = int(args.get("resolution", 512)) + time_ids = torch.tensor( + [[resolution, resolution, 0, 0, resolution, resolution]] * batch_size, + device=device, + dtype=weight_dtype, + ) - model_pred = unet(noisy_latents, **unet_kwargs)[0] + added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids} + unet_kwargs["added_cond_kwargs"] = added_cond_kwargs - # Use improved loss computation with support for different loss types and weighting - loss = compute_loss(model_pred, target, timesteps, noise_scheduler, args) - print(f"Step {step + 1}/{len(train_dataloader)} - Loss: {loss.item()}") + model_pred = unet(noisy_latents, **unet_kwargs)[0] + loss = compute_loss(model_pred, target, timesteps, noise_scheduler, args) + print(f"Step {step + 1}/{len(train_dataloader)} - Loss: {loss.item()}") loss.backward() @@ -1047,7 +1313,7 @@ def __init__(self, text_encoder, text_encoder_2, tokenizer, tokenizer_2): print(f"Saving LoRA weights to {save_directory}") - storage.makedirs(save_directory, exist_ok=True) + asyncio.run(storage.makedirs(save_directory, exist_ok=True)) # Primary method: Use the original working approach that was perfect for SD 1.5 # Try architecture-specific save methods first, then fall back to universal methods @@ -1065,11 +1331,38 @@ def __init__(self, text_encoder, text_encoder_2, tokenizer, tokenizer_2): }, "tlab_trainer_used": True, } - with storage.open(storage.join(save_directory, "tlab_adaptor_info.json"), "w", encoding="utf-8") as f: - json.dump(save_info, f, indent=4) + + async def _write_adaptor_info() -> None: + async with await storage.open( + storage.join(save_directory, "tlab_adaptor_info.json"), + "w", + encoding="utf-8", + ) as f: + await f.write(json.dumps(save_info, indent=4)) + + asyncio.run(_write_adaptor_info()) + + # Method 0: Z-Image PEFT safetensors save + if is_zimage: + try: + from safetensors.torch import save_file + + zimage_lora_state_dict = get_peft_model_state_dict(unet) + save_file(zimage_lora_state_dict, storage.join(save_directory, "pytorch_lora_weights.safetensors")) + print( + f"LoRA weights saved to {save_directory}/pytorch_lora_weights.safetensors using safetensors (Z-Image)" + ) + saved_successfully = True + except ImportError: + zimage_lora_state_dict = get_peft_model_state_dict(unet) + torch.save(zimage_lora_state_dict, storage.join(save_directory, "pytorch_lora_weights.bin")) + print(f"LoRA weights saved to {save_directory}/pytorch_lora_weights.bin using PyTorch format (Z-Image)") + saved_successfully = True + except Exception as e: + print(f"Error saving Z-Image LoRA weights: {e}") # Method 1: Try the original SD 1.x approach that worked perfectly - if not is_sdxl and not is_sd3 and not is_flux: + if not saved_successfully and not is_sdxl and not is_sd3 and not is_flux: try: StableDiffusionPipeline.save_lora_weights( save_directory=save_directory, diff --git a/api/transformerlab/plugins/diffusion_trainer/setup.sh b/api/transformerlab/plugins/diffusion_trainer/setup.sh index ab3ecf689..db0899d62 100644 --- a/api/transformerlab/plugins/diffusion_trainer/setup.sh +++ b/api/transformerlab/plugins/diffusion_trainer/setup.sh @@ -1,6 +1,25 @@ -uv pip install "peft>=0.15.0" +#!/bin/bash -# Only install xformers for non-rocm instances -if ! command -v rocminfo &> /dev/null; then - uv pip install xformers -fi \ No newline at end of file +# Keep plugin deps aligned with the base plugin venv created from api/pyproject.toml. +# This avoids resolver-driven torch/torchvision drift (e.g. missing torchvision::nms). +uv pip install \ + "diffusers" \ + "transformers" \ + "peft>=0.17" \ + diffsynth + +# xformers is ABI-sensitive and frequently mismatches the preinstalled torch build. +# Keep it opt-in to avoid import-time crashes (for example undefined C++ symbols). +# To opt in, set TLAB_ENABLE_XFORMERS=1 before running setup. +if [ "${TLAB_ENABLE_XFORMERS:-0}" = "1" ] && ! command -v rocminfo >/dev/null 2>&1; then + if uv pip install --no-deps xformers; then + if ! python -c "import xformers" >/dev/null 2>&1; then + echo "xformers import test failed; uninstalling incompatible wheel and continuing without xformers." + uv pip uninstall -y xformers || true + fi + else + echo "xformers wheel unavailable for this environment; continuing without it." + fi +else + echo "Skipping xformers install (set TLAB_ENABLE_XFORMERS=1 to opt in)." +fi diff --git a/api/transformerlab/plugins/image_diffusion/diffusion_worker.py b/api/transformerlab/plugins/image_diffusion/diffusion_worker.py index bfa7701b1..72e3603e0 100644 --- a/api/transformerlab/plugins/image_diffusion/diffusion_worker.py +++ b/api/transformerlab/plugins/image_diffusion/diffusion_worker.py @@ -9,6 +9,9 @@ import json import os import sys +import re +from pathlib import Path +import inspect import time import gc from PIL import Image @@ -71,6 +74,179 @@ } +def _is_probable_hf_repo_id(value: str) -> bool: + """Heuristic for Hugging Face repo IDs like `org/name`.""" + if not isinstance(value, str): + return False + candidate = value.strip() + if not candidate: + return False + if os.path.isabs(candidate): + return False + if candidate.startswith("."): + return False + return "/" in candidate and "\\" not in candidate + + +def _extract_hf_repo_from_model_metadata(model_dir: str) -> str | None: + """ + Extract a Hugging Face repo id from local model metadata. + + This helps recover when `model_dir` exists but is missing `model_index.json`. + """ + metadata_path = os.path.join(model_dir, "index.json") + candidates: list[str] = [] + + if os.path.isfile(metadata_path): + try: + with open(metadata_path, "r", encoding="utf-8") as f: + metadata = json.load(f) + if isinstance(metadata, dict): + json_data = metadata.get("json_data", {}) if isinstance(metadata.get("json_data"), dict) else {} + candidates.extend( + [ + json_data.get("huggingface_repo"), + json_data.get("source_id_or_path"), + metadata.get("model_id"), + ] + ) + except Exception as e: + print(f"Warning: Failed to read model metadata at {metadata_path}: {e}") + + model_key = Path(model_dir).name + if model_key: + try: + from lab.model import Model as ModelService + + import asyncio + + model_service = asyncio.run(ModelService.get(model_key)) + model_metadata = asyncio.run(model_service.get_metadata()) + if isinstance(model_metadata, dict): + json_data = ( + model_metadata.get("json_data", {}) if isinstance(model_metadata.get("json_data"), dict) else {} + ) + candidates.extend( + [ + json_data.get("huggingface_repo"), + json_data.get("source_id_or_path"), + model_metadata.get("model_id"), + ] + ) + except Exception: + pass + + seen = set() + for candidate in candidates: + if not isinstance(candidate, str): + continue + candidate = candidate.strip() + if not candidate or candidate in seen: + continue + seen.add(candidate) + if _is_probable_hf_repo_id(candidate): + return candidate + + return None + + +def resolve_diffusion_model_reference(model: str) -> str: + """ + Resolve model reference for diffusers loading. + + If `model` is a local directory but missing `model_index.json`, try falling back + to the original Hugging Face repo id from local metadata. + """ + if not isinstance(model, str): + return model + + model_ref = model.strip() + if not model_ref: + return model + + if not os.path.isdir(model_ref): + return model_ref + + model_index_path = os.path.join(model_ref, "model_index.json") + if os.path.isfile(model_index_path): + return model_ref + + hf_repo = _extract_hf_repo_from_model_metadata(model_ref) + if hf_repo: + print( + f"Local model directory is missing model_index.json at {model_index_path}. " + f"Falling back to Hugging Face repo: {hf_repo}" + ) + return hf_repo + + print( + f"Warning: Local model directory is missing model_index.json at {model_index_path} " + "and no Hugging Face repo metadata could be resolved." + ) + return model_ref + + +def filter_generation_kwargs_for_pipeline(pipe, generation_kwargs: dict) -> dict: + """ + Drop kwargs that are not accepted by `pipe.__call__`. + + Some backends (for example ZImagePipeline) reject common diffusers kwargs like + `cross_attention_kwargs` and callback arguments. + """ + try: + signature = inspect.signature(pipe.__call__) + except (TypeError, ValueError): + return generation_kwargs + + params = signature.parameters + supports_var_kwargs = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in params.values()) + if supports_var_kwargs: + return generation_kwargs + + allowed_keys = {name for name in params if name != "self"} + filtered_kwargs = {} + skipped_keys = [] + + for key, value in generation_kwargs.items(): + if key in allowed_keys: + filtered_kwargs[key] = value + else: + skipped_keys.append(key) + + if skipped_keys: + print( + f"Skipping unsupported generation kwargs for {pipe.__class__.__name__}: {', '.join(sorted(skipped_keys))}" + ) + + return filtered_kwargs + + +def invoke_pipeline_with_safe_kwargs(pipe, generation_kwargs: dict): + """ + Call a pipeline and recover from wrapper signature mismatches. + + Some pipelines expose a permissive `__call__` signature through decorators while the wrapped + implementation is strict. In that case, retry without the unexpected kwarg. + """ + filtered_kwargs = filter_generation_kwargs_for_pipeline(pipe, generation_kwargs) + + while True: + try: + return pipe(**filtered_kwargs) + except TypeError as exc: + message = str(exc) + match = re.search(r"unexpected keyword argument '([^']+)'", message) + if not match: + raise + + unexpected_key = match.group(1) + if unexpected_key not in filtered_kwargs: + raise + + print(f"Retrying generation without unsupported kwarg '{unexpected_key}' for {pipe.__class__.__name__}") + filtered_kwargs = {key: value for key, value in filtered_kwargs.items() if key != unexpected_key} + + def load_controlnet_model(controlnet_id: str, device: str = "cuda") -> ControlNetModel: controlnet_model = ControlNetModel.from_pretrained( controlnet_id, torch_dtype=torch.float16 if device != "cpu" else torch.float32 @@ -79,20 +255,49 @@ def load_controlnet_model(controlnet_id: str, device: str = "cuda") -> ControlNe def latents_to_rgb(latents): - """Convert SDXL latents (4 channels) to RGB tensors (3 channels)""" - weights = ( - (60, -60, 25, -70), - (60, -5, 15, -50), - (60, 10, -5, -35), - ) + """ + Convert latent tensors to an RGB preview image. - weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device)) - biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device) - rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze( - -1 - ) - image_array = rgb_tensor.clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0) + Uses SDXL-style weights for 4-channel latents and a robust normalization fallback for + models with other latent channel counts (for example 16-channel Z-Image latents). + """ + if latents.ndim == 4: + latents = latents[0] + + if latents.ndim != 3: + raise ValueError(f"Expected latents shape [C,H,W] or [B,C,H,W], got {tuple(latents.shape)}") + + channels = latents.shape[0] + if channels == 4: + weights = ( + (60, -60, 25, -70), + (60, -5, 15, -50), + (60, 10, -5, -35), + ) + weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device)) + biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device) + rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze( + -1 + ).unsqueeze(-1) + image_array = rgb_tensor.clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0) + return Image.fromarray(image_array) + + # Fallback for non-SDXL latent layouts: map first channels to RGB and normalize per channel. + rgb_tensor = latents.float() + if channels == 1: + rgb_tensor = rgb_tensor.repeat(3, 1, 1) + elif channels == 2: + rgb_tensor = torch.cat([rgb_tensor, rgb_tensor[:1]], dim=0) + else: + rgb_tensor = rgb_tensor[:3] + + channel_min = rgb_tensor.amin(dim=(1, 2), keepdim=True) + channel_max = rgb_tensor.amax(dim=(1, 2), keepdim=True) + denom = (channel_max - channel_min).clamp(min=1e-6) + rgb_tensor = (rgb_tensor - channel_min) / denom + + image_array = (rgb_tensor * 255.0).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0) return Image.fromarray(image_array) @@ -194,7 +399,8 @@ def is_flux_model(model_path): # Check if model has FLUX components by looking for config from huggingface_hub import model_info - info = model_info(model_path) + resolved_model = resolve_diffusion_model_reference(model_path) + info = model_info(resolved_model) config = getattr(info, "config", {}) diffusers_config = config.get("diffusers", {}) architectures = diffusers_config.get("_class_name", "") @@ -227,6 +433,8 @@ def load_pipeline_with_sharding( print("Loading pipeline with model sharding...") import torch + model_path = resolve_diffusion_model_reference(model_path) + # Flush memory before starting flush_memory() @@ -527,6 +735,8 @@ def load_pipeline_with_device_map( ): """Load pipeline with proper device mapping for multi-GPU""" + model_path = resolve_diffusion_model_reference(model_path) + # Clean up any existing CUDA cache before loading if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -953,7 +1163,7 @@ def main(): print(f"Using pre-generated images from sharding: {len(images)} images") else: # This is a normal pipeline, call it to generate images - result = pipe(**generation_kwargs) + result = invoke_pipeline_with_safe_kwargs(pipe, generation_kwargs) images = result.images except RuntimeError as e: if "illegal memory access" in str(e) or "CUDA error" in str(e): diff --git a/api/transformerlab/plugins/image_diffusion/index.json b/api/transformerlab/plugins/image_diffusion/index.json index 0cbf8d933..1c064c24c 100644 --- a/api/transformerlab/plugins/image_diffusion/index.json +++ b/api/transformerlab/plugins/image_diffusion/index.json @@ -4,7 +4,7 @@ "description": "Generate images in the Diffusion tab using this plugin", "plugin-format": "python", "type": "diffusion", - "version": "0.0.9", + "version": "0.1.1", "git": "", "url": "", "files": ["main.py", "diffusion_worker.py", "setup.sh"], diff --git a/api/transformerlab/plugins/image_diffusion/main.py b/api/transformerlab/plugins/image_diffusion/main.py index 3a5d3d77a..bd69c24f1 100644 --- a/api/transformerlab/plugins/image_diffusion/main.py +++ b/api/transformerlab/plugins/image_diffusion/main.py @@ -1,6 +1,8 @@ from fastapi import HTTPException from pydantic import BaseModel, ValidationError from huggingface_hub import model_info +from pathlib import Path +import inspect import base64 from io import BytesIO import torch @@ -39,6 +41,7 @@ import os import sys import random +import re from werkzeug.utils import secure_filename import json from datetime import datetime @@ -233,20 +236,49 @@ def load_controlnet_model(controlnet_id: str, device: str = "cuda") -> ControlNe def latents_to_rgb(latents): - """Convert SDXL latents (4 channels) to RGB tensors (3 channels)""" - weights = ( - (60, -60, 25, -70), - (60, -5, 15, -50), - (60, 10, -5, -35), - ) + """ + Convert latent tensors to an RGB preview image. - weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device)) - biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device) - rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze( - -1 - ) - image_array = rgb_tensor.clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0) + Uses SDXL-style weights for 4-channel latents and a robust normalization fallback for + models with other latent channel counts (for example 16-channel Z-Image latents). + """ + if latents.ndim == 4: + latents = latents[0] + + if latents.ndim != 3: + raise ValueError(f"Expected latents shape [C,H,W] or [B,C,H,W], got {tuple(latents.shape)}") + + channels = latents.shape[0] + + if channels == 4: + weights = ( + (60, -60, 25, -70), + (60, -5, 15, -50), + (60, 10, -5, -35), + ) + weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device)) + biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device) + rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze( + -1 + ).unsqueeze(-1) + image_array = rgb_tensor.clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0) + return Image.fromarray(image_array) + + # Fallback for non-SDXL latent layouts: map first channels to RGB and normalize per channel. + rgb_tensor = latents.float() + if channels == 1: + rgb_tensor = rgb_tensor.repeat(3, 1, 1) + elif channels == 2: + rgb_tensor = torch.cat([rgb_tensor, rgb_tensor[:1]], dim=0) + else: + rgb_tensor = rgb_tensor[:3] + channel_min = rgb_tensor.amin(dim=(1, 2), keepdim=True) + channel_max = rgb_tensor.amax(dim=(1, 2), keepdim=True) + denom = (channel_max - channel_min).clamp(min=1e-6) + rgb_tensor = (rgb_tensor - channel_min) / denom + + image_array = (rgb_tensor * 255.0).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0) return Image.fromarray(image_array) @@ -322,6 +354,181 @@ def is_zimage_model(model: str) -> bool: return "z-image" in name or "zimage" in name +def _is_probable_hf_repo_id(value: str) -> bool: + """Heuristic for Hugging Face repo IDs like `org/name`.""" + if not isinstance(value, str): + return False + candidate = value.strip() + if not candidate: + return False + if os.path.isabs(candidate): + return False + if candidate.startswith("."): + return False + return "/" in candidate and "\\" not in candidate + + +def _extract_hf_repo_from_model_metadata(model_dir: str) -> str | None: + """ + Extract a Hugging Face repo id from local model metadata. + + This helps recover when `model_dir` exists but is missing `model_index.json`. + """ + # First try direct metadata file lookup in the model directory. + metadata_path = os.path.join(model_dir, "index.json") + candidates: list[str] = [] + + if os.path.isfile(metadata_path): + try: + with open(metadata_path, "r", encoding="utf-8") as f: + metadata = json.load(f) + if isinstance(metadata, dict): + json_data = metadata.get("json_data", {}) if isinstance(metadata.get("json_data"), dict) else {} + candidates.extend( + [ + json_data.get("huggingface_repo"), + json_data.get("source_id_or_path"), + metadata.get("model_id"), + ] + ) + except Exception as e: + print(f"Warning: Failed to read model metadata at {metadata_path}: {e}") + + # Fallback to SDK metadata lookup by directory name. + model_key = Path(model_dir).name + if model_key: + try: + from lab.model import Model as ModelService + + model_service = run_async_from_sync(ModelService.get(model_key)) + model_metadata = run_async_from_sync(model_service.get_metadata()) + if isinstance(model_metadata, dict): + json_data = ( + model_metadata.get("json_data", {}) if isinstance(model_metadata.get("json_data"), dict) else {} + ) + candidates.extend( + [ + json_data.get("huggingface_repo"), + json_data.get("source_id_or_path"), + model_metadata.get("model_id"), + ] + ) + except Exception: + # Best-effort lookup only. + pass + + seen = set() + for candidate in candidates: + if not isinstance(candidate, str): + continue + candidate = candidate.strip() + if not candidate or candidate in seen: + continue + seen.add(candidate) + if _is_probable_hf_repo_id(candidate): + return candidate + + return None + + +def resolve_diffusion_model_reference(model: str) -> str: + """ + Resolve model reference for diffusers loading. + + If `model` is a local directory but missing `model_index.json`, try falling back + to the original Hugging Face repo id from local metadata. + """ + if not isinstance(model, str): + return model + + model_ref = model.strip() + if not model_ref: + return model + + if not os.path.isdir(model_ref): + return model_ref + + model_index_path = os.path.join(model_ref, "model_index.json") + if os.path.isfile(model_index_path): + return model_ref + + hf_repo = _extract_hf_repo_from_model_metadata(model_ref) + if hf_repo: + print( + f"Local model directory is missing model_index.json at {model_index_path}. " + f"Falling back to Hugging Face repo: {hf_repo}" + ) + return hf_repo + + print( + f"Warning: Local model directory is missing model_index.json at {model_index_path} " + "and no Hugging Face repo metadata could be resolved." + ) + return model_ref + + +def filter_generation_kwargs_for_pipeline(pipe, generation_kwargs: dict) -> dict: + """ + Drop kwargs that are not accepted by `pipe.__call__`. + + Some backends (for example ZImagePipeline) reject common diffusers kwargs like + `cross_attention_kwargs` and callback arguments. + """ + try: + signature = inspect.signature(pipe.__call__) + except (TypeError, ValueError): + # If signature introspection fails, keep original kwargs. + return generation_kwargs + + params = signature.parameters + supports_var_kwargs = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in params.values()) + if supports_var_kwargs: + return generation_kwargs + + allowed_keys = {name for name in params if name != "self"} + filtered_kwargs = {} + skipped_keys = [] + + for key, value in generation_kwargs.items(): + if key in allowed_keys: + filtered_kwargs[key] = value + else: + skipped_keys.append(key) + + if skipped_keys: + print( + f"Skipping unsupported generation kwargs for {pipe.__class__.__name__}: {', '.join(sorted(skipped_keys))}" + ) + + return filtered_kwargs + + +def invoke_pipeline_with_safe_kwargs(pipe, generation_kwargs: dict): + """ + Call a pipeline and recover from wrapper signature mismatches. + + Some pipelines expose a permissive `__call__` signature through decorators while the wrapped + implementation is strict. In that case, retry without the unexpected kwarg. + """ + filtered_kwargs = filter_generation_kwargs_for_pipeline(pipe, generation_kwargs) + + while True: + try: + return pipe(**filtered_kwargs) + except TypeError as exc: + message = str(exc) + match = re.search(r"unexpected keyword argument '([^']+)'", message) + if not match: + raise + + unexpected_key = match.group(1) + if unexpected_key not in filtered_kwargs: + raise + + print(f"Retrying generation without unsupported kwarg '{unexpected_key}' for {pipe.__class__.__name__}") + filtered_kwargs = {key: value for key, value in filtered_kwargs.items() if key != unexpected_key} + + def get_pipeline( model: str, adaptor: str = "", @@ -335,8 +542,10 @@ def get_pipeline( # cache_key = get_pipeline_key(model, adaptor, is_img2img, is_inpainting) with _PIPELINES_LOCK: + resolved_model = resolve_diffusion_model_reference(model) + # Detect Z-Image architecture (non-controlnet path) - is_zimage = is_zimage_model(model) + is_zimage = is_zimage_model(resolved_model) # Load appropriate pipeline based on type if is_controlnet: @@ -360,7 +569,7 @@ def get_pipeline( print(f"Loading ControlNet pipeline ({controlnet_id}) for model: {model}") try: - info = model_info(model) + info = model_info(resolved_model) config = getattr(info, "config", {}) diffusers_config = config.get("diffusers", {}) architecture = diffusers_config.get("_class_name", "") @@ -377,7 +586,7 @@ def get_pipeline( print(f"Loaded ControlNet pipeline {controlnet_pipeline} for model {model}") pipe = controlnet_pipeline.from_pretrained( - model, + resolved_model, controlnet=controlnet_model, torch_dtype=torch.float16 if device != "cpu" else torch.float32, safety_checker=None, @@ -386,7 +595,7 @@ def get_pipeline( ) elif is_inpainting: pipe = AutoPipelineForInpainting.from_pretrained( - model, + resolved_model, torch_dtype=torch.float16 if device != "cpu" else torch.float32, safety_checker=None, requires_safety_checker=False, @@ -394,7 +603,7 @@ def get_pipeline( print(f"Loaded inpainting pipeline for model: {model}") elif is_img2img: pipe = AutoPipelineForImage2Image.from_pretrained( - model, + resolved_model, torch_dtype=torch.float16 if device != "cpu" else torch.float32, safety_checker=None, requires_safety_checker=False, @@ -402,14 +611,14 @@ def get_pipeline( print(f"Loaded image-to-image pipeline for model: {model}") elif is_zimage: pipe = DiffusionPipeline.from_pretrained( - model, + resolved_model, torch_dtype=torch.bfloat16 if device != "cpu" else torch.float32, low_cpu_mem_usage=False, ) print(f"Loaded Z-Image pipeline for model: {model} with dtype {pipe.dtype}") else: pipe = AutoPipelineForText2Image.from_pretrained( - model, + resolved_model, torch_dtype=torch.float16 if device != "cpu" else torch.float32, safety_checker=None, requires_safety_checker=False, @@ -627,7 +836,8 @@ def should_use_diffusion_worker(model) -> bool: # Check if model has FLUX components by looking for config from huggingface_hub import model_info - info = model_info(model) + resolved_model = resolve_diffusion_model_reference(model) + info = model_info(resolved_model) config = getattr(info, "config", {}) diffusers_config = config.get("diffusers", {}) architectures = diffusers_config.get("_class_name", "") @@ -1123,8 +1333,7 @@ def unified_callback(pipe, step: int, timestep: int, callback_kwargs: dict): generation_kwargs["callback_on_step_end"] = unified_callback generation_kwargs["callback_on_step_end_tensor_inputs"] = ["latents"] - - result = pipe(**generation_kwargs) + result = invoke_pipeline_with_safe_kwargs(pipe, generation_kwargs) images = result.images # Get all images # Clean up result object to free references