From a25e3a13b80a07512dbe205bb6776320c611c813 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Tue, 23 Dec 2025 11:32:49 +0530 Subject: [PATCH 01/27] Add Z Image LoRA fine tuning support --- .../plugins/diffusion_trainer/index.json | 49 +++++-------------- .../plugins/diffusion_trainer/main.py | 46 ++++++++++++++--- 2 files changed, 51 insertions(+), 44 deletions(-) diff --git a/api/transformerlab/plugins/diffusion_trainer/index.json b/api/transformerlab/plugins/diffusion_trainer/index.json index c71e3ff5c..35a9fe492 100644 --- a/api/transformerlab/plugins/diffusion_trainer/index.json +++ b/api/transformerlab/plugins/diffusion_trainer/index.json @@ -4,23 +4,18 @@ "description": "A plugin for fine-tuning Stable Diffusion using LoRA adapters.", "plugin-format": "python", "type": "trainer", - "version": "0.1.8", + "version": "0.1.9", "git": "", "url": "", "model_architectures": [ "StableDiffusionPipeline", "StableDiffusionXLPipeline", "StableDiffusion3Pipeline", - "FluxPipeline" - ], - "files": [ - "main.py", - "setup.sh" - ], - "supported_hardware_architectures": [ - "cuda", - "amd" + "FluxPipeline", + "ZImagePipeline" ], + "files": ["main.py", "setup.sh"], + "supported_hardware_architectures": ["cuda", "amd"], "training_template_format": "none", "setup-script": "setup.sh", "parameters": { @@ -88,12 +83,7 @@ "title": "Image Interpolation Mode", "type": "string", "default": "lanczos", - "enum": [ - "lanczos", - "bilinear", - "bicubic", - "nearest" - ] + "enum": ["lanczos", "bilinear", "bicubic", "nearest"] }, "random_flip": { "title": "Random Horizontal Flip", @@ -172,17 +162,12 @@ "title": "Learning Rate", "type": "number", "default": 0.0001, - "minimum": 1e-07 + "minimum": 1e-7 }, "lr_scheduler": { "title": "LR Scheduler", "type": "string", - "enum": [ - "constant", - "linear", - "cosine", - "constant_with_warmup" - ], + "enum": ["constant", "linear", "cosine", "constant_with_warmup"], "default": "constant" }, "lr_warmup_steps": { @@ -208,7 +193,7 @@ "adam_epsilon": { "title": "Adam Epsilon", "type": "number", - "default": 1e-08 + "default": 1e-8 }, "max_grad_norm": { "title": "Max Grad Norm", @@ -218,10 +203,7 @@ "loss_type": { "title": "Loss Type", "type": "string", - "enum": [ - "l2", - "huber" - ], + "enum": ["l2", "huber"], "default": "l2" }, "huber_c": { @@ -232,10 +214,7 @@ "prediction_type": { "title": "Prediction Type", "type": "string", - "enum": [ - "epsilon", - "v_prediction" - ], + "enum": ["epsilon", "v_prediction"], "default": "epsilon" }, "snr_gamma": { @@ -256,11 +235,7 @@ "mixed_precision": { "title": "Mixed Precision", "type": "string", - "enum": [ - "no", - "fp16", - "bf16" - ], + "enum": ["no", "fp16", "bf16"], "default": "no" }, "enable_xformers_memory_efficient_attention": { diff --git a/api/transformerlab/plugins/diffusion_trainer/main.py b/api/transformerlab/plugins/diffusion_trainer/main.py index ee87c94d7..82f7badcf 100644 --- a/api/transformerlab/plugins/diffusion_trainer/main.py +++ b/api/transformerlab/plugins/diffusion_trainer/main.py @@ -8,7 +8,7 @@ import torch.nn.functional as F import torch.utils.checkpoint from peft import LoraConfig -from peft.utils import get_peft_model_state_dict +from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict from torchvision import transforms from diffusers import AutoPipelineForText2Image, StableDiffusionPipeline @@ -29,6 +29,8 @@ from lab.dirs import get_workspace_dir from lab import storage +from safetensors.torch import load_file + workspace_dir = get_workspace_dir() @@ -445,7 +447,9 @@ def train_diffusion_lora(): is_flux = "FluxPipeline" in model_architecture - print(f"Architecture detection - SDXL: {is_sdxl}, SD3: {is_sd3}, Flux: {is_flux}") + is_zimage = "Z-Image-Turbo" in pretrained_model_name_or_path or "Z-Image-Turbo" in model_architecture + + print(f"Architecture detection - SDXL: {is_sdxl}, SD3: {is_sd3}, Flux: {is_flux}, Z-Image: {is_zimage}") # Define target modules based on detected architecture if is_sdxl: @@ -456,10 +460,10 @@ def train_diffusion_lora(): # SD3 uses Multi-Modal DiT architecture target_modules = ["to_q", "to_k", "to_v", "to_out.0"] architecture_name = "SD3" - elif is_flux: + elif is_flux or is_zimage: # Flux uses transformer-based architecture target_modules = ["to_q", "to_k", "to_v", "to_out.0"] - architecture_name = "Flux" + architecture_name = "Flux" if is_flux else "Z-Image-Turbo" else: # Default SD 1.x targets target_modules = ["to_k", "to_q", "to_v", "to_out.0"] @@ -485,6 +489,15 @@ def train_diffusion_lora(): if mixed_precision == "fp16": cast_training_params(unet, dtype=torch.float32) + if is_zimage and args.get("training_adapter"): + adapter_path = args.get("training_adapter") + if adapter_path.endswith(".safetensors"): + state_dict = load_file(adapter_path) + else: + state_dict = torch.load(adapter_path, map_location="cpu") + unet.load_state_dict(state_dict, strict=False) + print(f"Loaded Z-Image Turbo training adapter from {adapter_path}") + lora_layers = filter(lambda p: p.requires_grad, unet.parameters()) # EMA (Exponential Moving Average) for more stable training - Memory optimized for LoRA @@ -541,11 +554,15 @@ def generate_eval_image(epoch): # Replace the model component with our trained version to include LoRA weights if model_component_name == "transformer": - pipeline.transformer = unet + pipeline_component = pipeline.transformer else: - pipeline.unet = unet - pipeline = pipeline.to(device) + pipeline_component = pipeline.unet + pipeline_component.add_adapter(unet_lora_config) + + lora_state_dict = get_peft_model_state_dict(unet) + set_peft_model_state_dict(pipeline_component, lora_state_dict) + pipeline = pipeline.to(device) # Generate image with torch.no_grad(): image = pipeline( @@ -1147,6 +1164,21 @@ def __init__(self, text_encoder, text_encoder_2, tokenizer, tokenizer_2): except Exception as e: print(f"Error with FluxPipeline.save_lora_weights: {e}") + if not saved_successfully and is_zimage: + try: + # Z-Image pipelines may have their own save method + from diffusers import ZImagePipeline + + ZImagePipeline.save_lora_weights( + save_directory=save_directory, + unet_lora_layers=model_lora_state_dict, + safe_serialization=True, + ) + print(f"LoRA weights saved to {save_directory} using ZImagePipeline.save_lora_weights (Z-Image)") + saved_successfully = True + except Exception as e: + print(f"Error with ZImagePipeline.save_lora_weights: {e}") + # Method 5: Try the generic StableDiffusionPipeline method as fallback for all architectures if not saved_successfully: try: From 5c150192ae90d84f402974fc320f698d1875958f Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Tue, 23 Dec 2025 12:20:42 +0530 Subject: [PATCH 02/27] Added existing parameter loading --- .../plugins/diffusion_trainer/index.json | 9 +++++ .../plugins/diffusion_trainer/main.py | 38 ++++++++++++++++--- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/api/transformerlab/plugins/diffusion_trainer/index.json b/api/transformerlab/plugins/diffusion_trainer/index.json index 35a9fe492..4b5eebb81 100644 --- a/api/transformerlab/plugins/diffusion_trainer/index.json +++ b/api/transformerlab/plugins/diffusion_trainer/index.json @@ -289,6 +289,12 @@ "title": "Log to Weights and Biases", "type": "boolean", "default": true + }, + "training_adapter": { + "title": "Z-Image-Turbo Training Adapter Path", + "type": "string", + "default": "", + "ui:help": "Optional local path to a custom de-distillation training adapter (.safetensors or .bin). Leave empty to automatically download and use the recommended ostris v2 adapter when training on Z-Image-Turbo." } }, "parameters_ui": { @@ -298,6 +304,9 @@ "trigger_word": { "ui:help": "Optional trigger word to prepend to all captions during training. Example: 'sks person' or 'ohwx style'" }, + "training_adapter": { + "ui:help": "Leave blank for auto-download of the recommended adapter. Provide a local path if you want to use a custom or offline adapter (e.g., v1 or your own)." + }, "num_train_epochs": { "ui:help": "Total number of training epochs to run." }, diff --git a/api/transformerlab/plugins/diffusion_trainer/main.py b/api/transformerlab/plugins/diffusion_trainer/main.py index 82f7badcf..f4a34c7de 100644 --- a/api/transformerlab/plugins/diffusion_trainer/main.py +++ b/api/transformerlab/plugins/diffusion_trainer/main.py @@ -16,6 +16,7 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import cast_training_params, compute_snr from diffusers.utils import convert_state_dict_to_diffusers +from diffusers import UNet2DConditionModel # Try to import xformers for memory optimization try: @@ -491,12 +492,39 @@ def train_diffusion_lora(): if is_zimage and args.get("training_adapter"): adapter_path = args.get("training_adapter") - if adapter_path.endswith(".safetensors"): - state_dict = load_file(adapter_path) + if adapter_path: + if adapter_path.endswith(".safetensors"): + state_dict = load_file(adapter_path) + else: + state_dict = torch.load(adapter_path, map_location="cpu") + unet.load_state_dict(state_dict, strict=False) + print(f"Loaded Z-Image Turbo training adapter from {adapter_path}") else: - state_dict = torch.load(adapter_path, map_location="cpu") - unet.load_state_dict(state_dict, strict=False) - print(f"Loaded Z-Image Turbo training adapter from {adapter_path}") + adapter_repo = "ostris/zimage_turbo_training_adapter" + adapter_filename = "zimage_turbo_training_adapter_v2.safetensors" + print( + f"No training_adapter provided. Auto-downloading recommended adapter: {adapter_filename} from {adapter_repo}" + ) + + try: + adapter_unet = UNet2DConditionModel.from_pretrained( + adapter_repo, + subfolder="", # Root of repo + filename=adapter_filename, + torch_dtype=weight_dtype, + variant=None, + use_safetensors=True, + low_cpu_mem_usage=True, + ) + # Extract only the LoRA state dict + adapter_state_dict = get_peft_model_state_dict(adapter_unet) + # Apply to our training unet (strict=False to ignore base weights) + unet.load_state_dict(adapter_state_dict, strict=False) + del adapter_unet, adapter_state_dict + cleanup_pipeline() + print("Successfully auto-loaded Z-Image-Turbo training adapter v2") + except Exception as e: + print(f"Failed to auto-download Z-Image Turbo training adapter: {e}") lora_layers = filter(lambda p: p.requires_grad, unet.parameters()) From 2b7c2861383528353a10729b9c1b1087f3953476 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Wed, 24 Dec 2025 11:27:45 +0530 Subject: [PATCH 03/27] Updates --- .../plugins/diffusion_trainer/main.py | 72 +++++-------------- 1 file changed, 19 insertions(+), 53 deletions(-) diff --git a/api/transformerlab/plugins/diffusion_trainer/main.py b/api/transformerlab/plugins/diffusion_trainer/main.py index f4a34c7de..83c85092d 100644 --- a/api/transformerlab/plugins/diffusion_trainer/main.py +++ b/api/transformerlab/plugins/diffusion_trainer/main.py @@ -8,7 +8,7 @@ import torch.nn.functional as F import torch.utils.checkpoint from peft import LoraConfig -from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict from torchvision import transforms from diffusers import AutoPipelineForText2Image, StableDiffusionPipeline @@ -16,7 +16,6 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import cast_training_params, compute_snr from diffusers.utils import convert_state_dict_to_diffusers -from diffusers import UNet2DConditionModel # Try to import xformers for memory optimization try: @@ -30,8 +29,6 @@ from lab.dirs import get_workspace_dir from lab import storage -from safetensors.torch import load_file - workspace_dir = get_workspace_dir() @@ -327,6 +324,11 @@ def train_diffusion_lora(): eval_prompt = None args["eval_prompt"] = None args["eval_steps"] = 0 + if args.get("model_architecture", "").strip() == "ZImagePipeline": + print("Disabling evaluation for ZImagePipeline as it is not supported.") + eval_prompt = None + args["eval_prompt"] = None + args["eval_steps"] = 0 elif eval_prompt and eval_steps <= 0: print("Warning: eval_steps is set to 0, evaluation will not be performed.") eval_prompt = None @@ -448,9 +450,9 @@ def train_diffusion_lora(): is_flux = "FluxPipeline" in model_architecture - is_zimage = "Z-Image-Turbo" in pretrained_model_name_or_path or "Z-Image-Turbo" in model_architecture + is_zimage = "ZImagePipeline" in model_architecture - print(f"Architecture detection - SDXL: {is_sdxl}, SD3: {is_sd3}, Flux: {is_flux}, Z-Image: {is_zimage}") + print(f"Architecture detection - SDXL: {is_sdxl}, SD3: {is_sd3}, Flux: {is_flux}, ZImage: {is_zimage}") # Define target modules based on detected architecture if is_sdxl: @@ -461,10 +463,14 @@ def train_diffusion_lora(): # SD3 uses Multi-Modal DiT architecture target_modules = ["to_q", "to_k", "to_v", "to_out.0"] architecture_name = "SD3" - elif is_flux or is_zimage: + elif is_flux: # Flux uses transformer-based architecture target_modules = ["to_q", "to_k", "to_v", "to_out.0"] - architecture_name = "Flux" if is_flux else "Z-Image-Turbo" + architecture_name = "Flux" + elif is_zimage: + # ZImage uses a modified UNet architecture + target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + architecture_name = "ZImage" else: # Default SD 1.x targets target_modules = ["to_k", "to_q", "to_v", "to_out.0"] @@ -490,42 +496,6 @@ def train_diffusion_lora(): if mixed_precision == "fp16": cast_training_params(unet, dtype=torch.float32) - if is_zimage and args.get("training_adapter"): - adapter_path = args.get("training_adapter") - if adapter_path: - if adapter_path.endswith(".safetensors"): - state_dict = load_file(adapter_path) - else: - state_dict = torch.load(adapter_path, map_location="cpu") - unet.load_state_dict(state_dict, strict=False) - print(f"Loaded Z-Image Turbo training adapter from {adapter_path}") - else: - adapter_repo = "ostris/zimage_turbo_training_adapter" - adapter_filename = "zimage_turbo_training_adapter_v2.safetensors" - print( - f"No training_adapter provided. Auto-downloading recommended adapter: {adapter_filename} from {adapter_repo}" - ) - - try: - adapter_unet = UNet2DConditionModel.from_pretrained( - adapter_repo, - subfolder="", # Root of repo - filename=adapter_filename, - torch_dtype=weight_dtype, - variant=None, - use_safetensors=True, - low_cpu_mem_usage=True, - ) - # Extract only the LoRA state dict - adapter_state_dict = get_peft_model_state_dict(adapter_unet) - # Apply to our training unet (strict=False to ignore base weights) - unet.load_state_dict(adapter_state_dict, strict=False) - del adapter_unet, adapter_state_dict - cleanup_pipeline() - print("Successfully auto-loaded Z-Image-Turbo training adapter v2") - except Exception as e: - print(f"Failed to auto-download Z-Image Turbo training adapter: {e}") - lora_layers = filter(lambda p: p.requires_grad, unet.parameters()) # EMA (Exponential Moving Average) for more stable training - Memory optimized for LoRA @@ -582,15 +552,11 @@ def generate_eval_image(epoch): # Replace the model component with our trained version to include LoRA weights if model_component_name == "transformer": - pipeline_component = pipeline.transformer + pipeline.transformer = unet else: - pipeline_component = pipeline.unet - pipeline_component.add_adapter(unet_lora_config) - - lora_state_dict = get_peft_model_state_dict(unet) - set_peft_model_state_dict(pipeline_component, lora_state_dict) - + pipeline.unet = unet pipeline = pipeline.to(device) + # Generate image with torch.no_grad(): image = pipeline( @@ -1194,7 +1160,7 @@ def __init__(self, text_encoder, text_encoder_2, tokenizer, tokenizer_2): if not saved_successfully and is_zimage: try: - # Z-Image pipelines may have their own save method + # ZImage pipelines may have their own save method from diffusers import ZImagePipeline ZImagePipeline.save_lora_weights( @@ -1202,7 +1168,7 @@ def __init__(self, text_encoder, text_encoder_2, tokenizer, tokenizer_2): unet_lora_layers=model_lora_state_dict, safe_serialization=True, ) - print(f"LoRA weights saved to {save_directory} using ZImagePipeline.save_lora_weights (Z-Image)") + print(f"LoRA weights saved to {save_directory} using ZImagePipeline.save_lora_weights (ZImage)") saved_successfully = True except Exception as e: print(f"Error with ZImagePipeline.save_lora_weights: {e}") From efada1fe07a355a2f04578c6f6b87b1344db6745 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Thu, 5 Feb 2026 01:24:50 +0530 Subject: [PATCH 04/27] Updated ZImage fine tuning code --- .../plugins/diffusion_trainer/main.py | 614 +++++++++++------- .../plugins/diffusion_trainer/setup.sh | 2 +- 2 files changed, 380 insertions(+), 236 deletions(-) diff --git a/api/transformerlab/plugins/diffusion_trainer/main.py b/api/transformerlab/plugins/diffusion_trainer/main.py index 83c85092d..3405d99a2 100644 --- a/api/transformerlab/plugins/diffusion_trainer/main.py +++ b/api/transformerlab/plugins/diffusion_trainer/main.py @@ -7,9 +7,15 @@ import torch import torch.nn.functional as F import torch.utils.checkpoint -from peft import LoraConfig +from peft import LoraConfig, get_peft_model from peft.utils import get_peft_model_state_dict from torchvision import transforms +from diffsynth import ModelConfig +from diffsynth.pipelines.z_image import ZImagePipeline +from diffsynth.diffusion.loss import FlowMatchSFTLoss, TrajectoryImitationLoss +import os +import glob +import accelerate from diffusers import AutoPipelineForText2Image, StableDiffusionPipeline @@ -52,6 +58,36 @@ def cleanup_pipeline(): cleanup_pipeline() +def build_zimage_model_configs(model_id_or_path: str) -> tuple[list[ModelConfig], ModelConfig]: + """Build ModelConfig list + tokenizer config for Z-Image Turbo.""" + transformer_pattern = os.path.join("transformer", "*.safetensors") + text_encoder_pattern = os.path.join("text_encoder", "*.safetensors") + vae_pattern = os.path.join("vae", "vae", "diffusion_pytorch_model.safetensors") + + tokenizer_pattern = "tokenizer/" + + if model_id_or_path and os.path.isdir(model_id_or_path): + transformer_paths = glob.glob(os.path.join(model_id_or_path, transformer_pattern)) + text_encoder_paths = glob.glob(os.path.join(model_id_or_path, text_encoder_pattern)) + vae_paths = glob.glob(os.path.join(model_id_or_path, vae_pattern)) + + model_configs = [ + ModelConfig(path=transformer_paths), + ModelConfig(path=text_encoder_paths), + ModelConfig(path=vae_paths), + ] + + tokenizer_config = ModelConfig(path=os.path.join(model_id_or_path, tokenizer_pattern)) + else: + model_configs = [ + ModelConfig(model_id=model_id_or_path, subfolder="transformer"), + ModelConfig(model_id=model_id_or_path, subfolder="text_encoder"), + ModelConfig(model_id=model_id_or_path, subfolder="vae/vae"), + ] + + tokenizer_config = ModelConfig(model_id=model_id_or_path, subfolder="tokenizer") + + return model_configs, tokenizer_config def compute_loss_weighting(args, timesteps, noise_scheduler): """ @@ -302,6 +338,49 @@ def encode_prompt_sdxl( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds +def encode_prompt_zimage(pipe, prompts, device, max_sequence_length: int = 512): + """Encode prompts using Z-Image tokenizer/chat template.""" + if isinstance(prompts, str): + prompts = [prompts] + + chat_prompts = [] + for prompt_item in prompts: + messages = [ + {"role": "user", "content": prompt_item}, + ] + + chat_prompt = pipe.tokenizer.build_chat_prompt( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + + chat_prompts.append(chat_prompt) + + text_inputs = pipe.tokenizer( + chat_prompts, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = pipe.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embedding_list = [] + + for i in range(len(prompt_embeds)): + embedding_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embedding_list @tlab_trainer.job_wrapper(wandb_project_name="TLab_Training", manual_logging=True) def train_diffusion_lora(): @@ -355,44 +434,89 @@ def train_diffusion_lora(): variant = args.get("variant", None) model_architecture = args.get("model_architecture") - # Load pipeline to auto-detect architecture and get correct components print(f"Loading pipeline to detect model architecture: {pretrained_model_name_or_path}") + + # Detect architecture based on multiple indicators + is_sdxl = "StableDiffusionXLPipeline" in model_architecture + + is_sd3 = "StableDiffusion3Pipeline" in model_architecture + + is_flux = "FluxPipeline" in model_architecture + + is_zimage = "ZImagePipeline" in model_architecture + + print(f"Architecture detection - SDXL: {is_sdxl}, SD3: {is_sd3}, Flux: {is_flux}, ZImage: {is_zimage}") + + # Mixed Precision + weight_dtype = torch.float32 + mixed_precision = args.get("mixed_precision", None) + if is_zimage and (mixed_precision is None or mixed_precision == "" or mixed_precision == "no"): + weight_dtype = torch.bfloat16 + elif mixed_precision == "fp16": + weight_dtype = torch.float16 + elif mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + pipeline_kwargs = { "torch_dtype": torch.float16, "safety_checker": None, "requires_safety_checker": False, } - temp_pipeline = AutoPipelineForText2Image.from_pretrained(pretrained_model_name_or_path, **pipeline_kwargs) - - # Extract components from the loaded pipeline - noise_scheduler = temp_pipeline.scheduler - tokenizer = temp_pipeline.tokenizer - text_encoder = temp_pipeline.text_encoder - vae = temp_pipeline.vae + pipe = None + if is_zimage: + model_configs, tokenizer_config = build_zimage_model_configs(pretrained_model_name_or_path) + pipe = ZImagePipeline.from_pretrained( + torch_dtype=weight_dtype, + device=device, + model_configs=model_configs, + tokenizer_config=tokenizer_config, + ) - # Handle different architectures: FluxPipeline uses 'transformer', others use 'unet' - # We use 'unet' as a unified variable name for the main model component regardless of architecture - if hasattr(temp_pipeline, "transformer"): - # FluxPipeline and other transformer-based models - unet = temp_pipeline.transformer - model_component_name = "transformer" + pipe.scheduler.set_timesteps(int(args.get("num_train_timesteps", 1000)), device=device) + noise_scheduler = pipe.scheduler + tokenizer = pipe.tokenizer + text_encoder = pipe.text_encoder + vae_encoder = pipe.vae_encoder + vae_decoder = pipe.vae_decoder + unet = pipe.dit + model_component_name = "dit" + text_encoder_2 = None + tokenizer_2 = None + vae = None else: - # SD 1.x, SDXL, SD3 and other UNet-based models - unet = temp_pipeline.unet - model_component_name = "unet" + temp_pipeline = AutoPipelineForText2Image.from_pretrained(pretrained_model_name_or_path, **pipeline_kwargs) + + # Extract components from the loaded pipeline + noise_scheduler = temp_pipeline.scheduler + tokenizer = temp_pipeline.tokenizer + text_encoder = temp_pipeline.text_encoder + vae = temp_pipeline.vae + + # Handle different architectures: FluxPipeline uses 'transformer', others use 'unet' + # We use 'unet' as a unified variable name for the main model component regardless of architecture + if hasattr(temp_pipeline, "transformer"): + # FluxPipeline and other transformer-based models + unet = temp_pipeline.transformer + model_component_name = "transformer" + else: + # SD 1.x, SDXL, SD3 and other UNet-based models + unet = temp_pipeline.unet + model_component_name = "unet" - # Handle SDXL case with dual text encoders - text_encoder_2 = getattr(temp_pipeline, "text_encoder_2", None) - tokenizer_2 = getattr(temp_pipeline, "tokenizer_2", None) + # Handle SDXL case with dual text encoders + text_encoder_2 = getattr(temp_pipeline, "text_encoder_2", None) + tokenizer_2 = getattr(temp_pipeline, "tokenizer_2", None) - # Clean up temporary pipeline - del temp_pipeline - torch.cuda.empty_cache() if torch.cuda.is_available() else None + # Clean up temporary pipeline + del temp_pipeline + torch.cuda.empty_cache() if torch.cuda.is_available() else None - print(f"Model components loaded successfully: {pretrained_model_name_or_path}") - print(f"Architecture detected - Model component ({model_component_name}): {type(unet).__name__}") + print(f"Model components loaded successfully: {pretrained_model_name_or_path}") + print(f"Architecture detected - Model component ({model_component_name}): {type(unet).__name__}") if text_encoder_2 is not None: print("Dual text encoder setup detected (likely SDXL)") print(f"Text encoder type: {type(text_encoder).__name__}") @@ -400,7 +524,12 @@ def train_diffusion_lora(): # Freeze parameters unet.requires_grad_(False) - vae.requires_grad_(False) + + if is_zimage: + vae_encoder.requires_grad_(False) + vae_decoder.requires_grad_(False) + else: + vae.requires_grad_(False) text_encoder.requires_grad_(False) if text_encoder_2 is not None: text_encoder_2.requires_grad_(False) @@ -411,6 +540,8 @@ def train_diffusion_lora(): unet.enable_xformers_memory_efficient_attention() if hasattr(vae, "enable_xformers_memory_efficient_attention"): vae.enable_xformers_memory_efficient_attention() + if not is_zimage and hasattr(vae, "enable_xformers_memory_efficient_attention"): + vae.enable_xformers_memory_efficient_attention() print("xFormers memory efficient attention enabled") except Exception as e: print(f"Failed to enable xFormers: {e}") @@ -424,14 +555,6 @@ def train_diffusion_lora(): text_encoder_2.gradient_checkpointing_enable() print("Gradient checkpointing enabled") - # Mixed precision - weight_dtype = torch.float32 - mixed_precision = args.get("mixed_precision", None) - if mixed_precision == "fp16": - weight_dtype = torch.float16 - elif mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - # LoRA config - adaptive target modules for different architectures model_type = type(unet).__name__ @@ -443,15 +566,6 @@ def train_diffusion_lora(): f"Has addition_embed_type: {hasattr(unet.config, 'addition_embed_type') if hasattr(unet, 'config') else 'No config'}" ) - # Detect architecture based on multiple indicators - is_sdxl = "StableDiffusionXLPipeline" in model_architecture - - is_sd3 = "StableDiffusion3Pipeline" in model_architecture - - is_flux = "FluxPipeline" in model_architecture - - is_zimage = "ZImagePipeline" in model_architecture - print(f"Architecture detection - SDXL: {is_sdxl}, SD3: {is_sd3}, Flux: {is_flux}, ZImage: {is_zimage}") # Define target modules based on detected architecture @@ -468,9 +582,9 @@ def train_diffusion_lora(): target_modules = ["to_q", "to_k", "to_v", "to_out.0"] architecture_name = "Flux" elif is_zimage: - # ZImage uses a modified UNet architecture + # Z-Image DiT uses standard attention projections target_modules = ["to_q", "to_k", "to_v", "to_out.0"] - architecture_name = "ZImage" + architecture_name = "Z-Image" else: # Default SD 1.x targets target_modules = ["to_k", "to_q", "to_v", "to_out.0"] @@ -485,14 +599,21 @@ def train_diffusion_lora(): target_modules=target_modules, ) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") unet.to(device, dtype=weight_dtype) - vae.to(device, dtype=weight_dtype) + if is_zimage: + vae_encoder.to(device, dtype=weight_dtype) + vae_decoder.to(device, dtype=weight_dtype) + else: + vae.to(device, dtype=weight_dtype) text_encoder.to(device, dtype=weight_dtype) if text_encoder_2 is not None: text_encoder_2.to(device, dtype=weight_dtype) - unet.add_adapter(unet_lora_config) + if is_zimage: + unet = get_peft_model(unet, unet_lora_config) + pipe.dit = unet + else: + unet.add_adapter(unet_lora_config) if mixed_precision == "fp16": cast_training_params(unet, dtype=torch.float32) @@ -621,7 +742,7 @@ def generate_eval_image(epoch): train_transforms = transforms.Compose(transform_list) - def tokenize_captions(examples, is_train=True): + def build_captions(examples, is_train=True): captions = [] caption_column = args.get("caption_column", "text") trigger_word = args.get("trigger_word", "").strip() @@ -658,6 +779,8 @@ def tokenize_captions(examples, is_train=True): captions.append(processed_caption) + def tokenize_captions(examples, is_train=True): + captions = build_captions(examples, is_train=is_train) # Primary tokenizer (always present) inputs = tokenizer( captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" @@ -717,13 +840,16 @@ def preprocess_train(examples): examples["crop_coords_top_left"] = crop_coords_top_left examples["target_sizes"] = target_sizes - # Get tokenization results - tokenization_results = tokenize_captions(examples) - examples["input_ids"] = tokenization_results["input_ids"] + if is_zimage: + examples["prompt"] = build_captions(examples, is_train=True) + else: + # Get tokenization results + tokenization_results = tokenize_captions(examples) + examples["input_ids"] = tokenization_results["input_ids"] - # Add second input_ids for SDXL if present - if "input_ids_2" in tokenization_results: - examples["input_ids_2"] = tokenization_results["input_ids_2"] + # Add second input_ids for SDXL if present + if "input_ids_2" in tokenization_results: + examples["input_ids_2"] = tokenization_results["input_ids_2"] return examples @@ -799,181 +925,195 @@ def collate_fn(examples): for epoch in range(num_train_epochs): unet.train() for step, batch in enumerate(train_dataloader): - # Convert images to latent space - latents = vae.encode(batch["pixel_values"].to(device, dtype=weight_dtype)).latent_dist.sample() - latents = latents * vae.config.scaling_factor - - # Sample noise - noise = torch.randn_like(latents) - if args.get("noise_offset", 0): - noise += args["noise_offset"] * torch.randn( - (latents.shape[0], latents.shape[1], 1, 1), device=latents.device + if is_zimage: + pixel_values = batch["pixel_values"].to(device, dtype=weight_dtype) + input_latents = vae_encoder(pixel_values) + prompt_embeds = encode_prompt_zimage(pipe, batch["prompt"], device) + + loss = FlowMatchSFTLoss( + pipe, + input_latents=input_latents, + prompt_embeds=prompt_embeds, + image_embeds=None, + image_latents=None, + use_gradient_checkpointing=args.get("gradient_checkpointing", False), + use_gradient_checkpointing_offload=False, ) - bsz = latents.shape[0] - timesteps = torch.randint( - 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device - ).long() - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Enhanced text encoding - always use encode_prompt for SDXL - if is_sdxl: - # Always use encode_prompt for SDXL, regardless of text_encoder_2 - prompts = tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=True) - # if tokenizer_2 is not None and "input_ids_2" in batch: - # prompts_2 = tokenizer_2.batch_decode(batch["input_ids_2"], skip_special_tokens=True) - # else: - # prompts_2 = None - - text_encoders = [text_encoder, text_encoder_2] if text_encoder_2 is not None else [text_encoder] - tokenizers = [tokenizer, tokenizer_2] if tokenizer_2 is not None else [tokenizer] - - # Create a temporary pipeline-like object for encode_prompt compatibility - class TempPipeline: - def __init__(self, text_encoder, text_encoder_2, tokenizer, tokenizer_2): - self.text_encoder = text_encoder - self.text_encoder_2 = text_encoder_2 - self.tokenizer = tokenizer - self.tokenizer_2 = tokenizer_2 - - temp_pipe = TempPipeline(text_encoder, text_encoder_2, tokenizer, tokenizer_2) - - encoder_hidden_states, _, pooled_prompt_embeds, _ = encode_prompt( - temp_pipe, - text_encoders, - tokenizers, - prompts, - device, - num_images_per_prompt=1, - do_classifier_free_guidance=False, - ) + print(f"Step {step + 1}/{len(train_dataloader)} - FlowMatchSFT Loss: {loss.item()}") else: - # Standard single text encoder approach - encoder_hidden_states = text_encoder(batch["input_ids"].to(device), return_dict=False)[0] - pooled_prompt_embeds = None - - # For SDXL with dual text encoders, handle dimension compatibility and concatenate - if text_encoder_2 is not None and "input_ids_2" in batch: - encoder_hidden_states_2 = text_encoder_2(batch["input_ids_2"].to(device), return_dict=False)[0] - - # Handle dimension mismatch - ensure both tensors have the same number of dimensions - if encoder_hidden_states.dim() != encoder_hidden_states_2.dim(): - # If one is 2D and the other is 3D, add a dimension to the 2D tensor - if encoder_hidden_states.dim() == 2 and encoder_hidden_states_2.dim() == 3: - encoder_hidden_states = encoder_hidden_states.unsqueeze(1) - elif encoder_hidden_states.dim() == 3 and encoder_hidden_states_2.dim() == 2: - encoder_hidden_states_2 = encoder_hidden_states_2.unsqueeze(1) - - # Ensure sequence lengths match for concatenation - seq_len_1 = ( - encoder_hidden_states.shape[1] - if encoder_hidden_states.dim() == 3 - else encoder_hidden_states.shape[0] - ) - seq_len_2 = ( - encoder_hidden_states_2.shape[1] - if encoder_hidden_states_2.dim() == 3 - else encoder_hidden_states_2.shape[0] + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(device, dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise + noise = torch.randn_like(latents) + if args.get("noise_offset", 0): + noise += args["noise_offset"] * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), device=latents.device ) - if seq_len_1 != seq_len_2: - # Pad the shorter sequence to match the longer one - max_seq_len = max(seq_len_1, seq_len_2) - - if encoder_hidden_states.dim() == 3: - if encoder_hidden_states.shape[1] < max_seq_len: - pad_size = max_seq_len - encoder_hidden_states.shape[1] - padding = torch.zeros( - encoder_hidden_states.shape[0], - pad_size, - encoder_hidden_states.shape[2], - device=encoder_hidden_states.device, - dtype=encoder_hidden_states.dtype, - ) - encoder_hidden_states = torch.cat([encoder_hidden_states, padding], dim=1) - - if encoder_hidden_states_2.shape[1] < max_seq_len: - pad_size = max_seq_len - encoder_hidden_states_2.shape[1] - padding = torch.zeros( - encoder_hidden_states_2.shape[0], - pad_size, - encoder_hidden_states_2.shape[2], - device=encoder_hidden_states_2.device, - dtype=encoder_hidden_states_2.dtype, - ) - encoder_hidden_states_2 = torch.cat([encoder_hidden_states_2, padding], dim=1) - - # Concatenate along the feature dimension (last dimension) - encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_2], dim=-1) - - # Loss target - prediction_type = args.get("prediction_type", None) - if prediction_type is not None: - noise_scheduler.register_to_config(prediction_type=prediction_type) - - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError( - f"Unknown prediction type {noise_scheduler.config.prediction_type}" - ) # Handle SDXL-specific conditioning parameters with proper metadata - unet_kwargs = {"timestep": timesteps, "encoder_hidden_states": encoder_hidden_states, "return_dict": False} - - # SDXL requires additional conditioning kwargs with proper pooled embeddings and time_ids - if is_sdxl: - batch_size = noisy_latents.shape[0] - - # Use proper pooled embeddings if available, otherwise create dummy ones - if pooled_prompt_embeds is not None: - text_embeds = ( - pooled_prompt_embeds.repeat(batch_size, 1) - if pooled_prompt_embeds.shape[0] == 1 - else pooled_prompt_embeds + bsz = latents.shape[0] + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ).long() + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Enhanced text encoding - always use encode_prompt for SDXL + if is_sdxl: + # Always use encode_prompt for SDXL, regardless of text_encoder_2 + prompts = tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=True) + # if tokenizer_2 is not None and "input_ids_2" in batch: + # prompts_2 = tokenizer_2.batch_decode(batch["input_ids_2"], skip_special_tokens=True) + # else: + # prompts_2 = None + text_encoders = [text_encoder, text_encoder_2] if text_encoder_2 is not None else [text_encoder] + tokenizers = [tokenizer, tokenizer_2] if tokenizer_2 is not None else [tokenizer] + + # Create a temporary pipeline-like object for encode_prompt compatibility + class TempPipeline: + def __init__(self, text_encoder, text_encoder_2, tokenizer, tokenizer_2): + self.text_encoder = text_encoder + self.text_encoder_2 = text_encoder_2 + self.tokenizer = tokenizer + self.tokenizer_2 = tokenizer_2 + + temp_pipe = TempPipeline(text_encoder, text_encoder_2, tokenizer, tokenizer_2) + + encoder_hidden_states, _, pooled_prompt_embeds, _ = encode_prompt( + temp_pipe, + text_encoders, + tokenizers, + prompts, + device, + num_images_per_prompt=1, + do_classifier_free_guidance=False, ) else: - # Fallback to dummy embeddings for compatibility - text_embeds = torch.zeros(batch_size, 1280, device=device, dtype=weight_dtype) - - # Compute proper time_ids from actual image metadata if available - if "original_sizes" in batch and "crop_coords_top_left" in batch and "target_sizes" in batch: - time_ids_list = [] - for i in range(batch_size): - original_size = batch["original_sizes"][i] - crop_coords = batch["crop_coords_top_left"][i] - target_size = batch["target_sizes"][i] - - # Compute time_ids for this sample - time_ids = compute_time_ids( - original_size, - crop_coords, - target_size, - dtype=weight_dtype, - device=device, - weight_dtype=weight_dtype, + # Standard single text encoder approach + encoder_hidden_states = text_encoder(batch["input_ids"].to(device), return_dict=False)[0] + pooled_prompt_embeds = None + + # For SDXL with dual text encoders, handle dimension compatibility and concatenate + if text_encoder_2 is not None and "input_ids_2" in batch: + encoder_hidden_states_2 = text_encoder_2(batch["input_ids_2"].to(device), return_dict=False)[0] + + # Handle dimension mismatch - ensure both tensors have the same number of dimensions + if encoder_hidden_states.dim() != encoder_hidden_states_2.dim(): + # If one is 2D and the other is 3D, add a dimension to the 2D tensor + if encoder_hidden_states.dim() == 2 and encoder_hidden_states_2.dim() == 3: + encoder_hidden_states = encoder_hidden_states.unsqueeze(1) + elif encoder_hidden_states.dim() == 3 and encoder_hidden_states_2.dim() == 2: + encoder_hidden_states_2 = encoder_hidden_states_2.unsqueeze(1) + + # Ensure sequence lengths match for concatenation + seq_len_1 = ( + encoder_hidden_states.shape[1] + if encoder_hidden_states.dim() == 3 + else encoder_hidden_states.shape[0] + ) + seq_len_2 = ( + encoder_hidden_states_2.shape[1] + if encoder_hidden_states_2.dim() == 3 + else encoder_hidden_states_2.shape[0] ) - time_ids_list.append(time_ids) - time_ids = torch.cat(time_ids_list, dim=0) + if seq_len_1 != seq_len_2: + # Pad the shorter sequence to match the longer one + max_seq_len = max(seq_len_1, seq_len_2) + + if encoder_hidden_states.dim() == 3: + if encoder_hidden_states.shape[1] < max_seq_len: + pad_size = max_seq_len - encoder_hidden_states.shape[1] + padding = torch.zeros( + encoder_hidden_states.shape[0], + pad_size, + encoder_hidden_states.shape[2], + device=encoder_hidden_states.device, + dtype=encoder_hidden_states.dtype, + ) + encoder_hidden_states = torch.cat([encoder_hidden_states, padding], dim=1) + + if encoder_hidden_states_2.shape[1] < max_seq_len: + pad_size = max_seq_len - encoder_hidden_states_2.shape[1] + padding = torch.zeros( + encoder_hidden_states_2.shape[0], + pad_size, + encoder_hidden_states_2.shape[2], + device=encoder_hidden_states_2.device, + dtype=encoder_hidden_states_2.dtype, + ) + encoder_hidden_states_2 = torch.cat([encoder_hidden_states_2, padding], dim=1) + + # Concatenate along the feature dimension (last dimension) + encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_2], dim=-1) + + # Loss target + prediction_type = args.get("prediction_type", None) + if prediction_type is not None: + noise_scheduler.register_to_config(prediction_type=prediction_type) + + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) else: - # Fallback to dummy time_ids for compatibility - resolution = int(args.get("resolution", 512)) - time_ids = torch.tensor( - [[resolution, resolution, 0, 0, resolution, resolution]] * batch_size, - device=device, - dtype=weight_dtype, - ) - - added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids} - unet_kwargs["added_cond_kwargs"] = added_cond_kwargs - - model_pred = unet(noisy_latents, **unet_kwargs)[0] + raise ValueError( + f"Unknown prediction type {noise_scheduler.config.prediction_type}" + ) # Handle SDXL-specific conditioning parameters with proper metadata + unet_kwargs = {"timestep": timesteps, "encoder_hidden_states": encoder_hidden_states, "return_dict": False} + + # SDXL requires additional conditioning kwargs with proper pooled embeddings and time_ids + if is_sdxl: + batch_size = noisy_latents.shape[0] + + # Use proper pooled embeddings if available, otherwise create dummy ones + if pooled_prompt_embeds is not None: + text_embeds = ( + pooled_prompt_embeds.repeat(batch_size, 1) + if pooled_prompt_embeds.shape[0] == 1 + else pooled_prompt_embeds + ) + else: + # Fallback to dummy embeddings for compatibility + text_embeds = torch.zeros(batch_size, 1280, device=device, dtype=weight_dtype) + + # Compute proper time_ids from actual image metadata if available + if "original_sizes" in batch and "crop_coords_top_left" in batch and "target_sizes" in batch: + time_ids_list = [] + for i in range(batch_size): + original_size = batch["original_sizes"][i] + crop_coords = batch["crop_coords_top_left"][i] + target_size = batch["target_sizes"][i] + + # Compute time_ids for this sample + time_ids = compute_time_ids( + original_size, + crop_coords, + target_size, + dtype=weight_dtype, + device=device, + weight_dtype=weight_dtype, + ) + time_ids_list.append(time_ids) + + time_ids = torch.cat(time_ids_list, dim=0) + else: + # Fallback to dummy time_ids for compatibility + resolution = int(args.get("resolution", 512)) + time_ids = torch.tensor( + [[resolution, resolution, 0, 0, resolution, resolution]] * batch_size, + device=device, + dtype=weight_dtype, + ) - # Use improved loss computation with support for different loss types and weighting - loss = compute_loss(model_pred, target, timesteps, noise_scheduler, args) - print(f"Step {step + 1}/{len(train_dataloader)} - Loss: {loss.item()}") + added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids} + unet_kwargs["added_cond_kwargs"] = added_cond_kwargs + model_pred = unet(noisy_latents, **unet_kwargs)[0] + loss = compute_loss(model_pred, target, timesteps, noise_scheduler, args) + print(f"Step {step + 1}/{len(train_dataloader)} - Loss: {loss.item()}") + loss.backward() if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(train_dataloader): @@ -1078,8 +1218,27 @@ def __init__(self, text_encoder, text_encoder_2, tokenizer, tokenizer_2): with storage.open(storage.join(save_directory, "tlab_adaptor_info.json"), "w", encoding="utf-8") as f: json.dump(save_info, f, indent=4) + # Method 0: Z-Image PEFT safetensors save + if is_zimage: + try: + from safetensors.torch import save_file + + zimage_lora_state_dict = get_peft_model_state_dict(unet) + save_file(zimage_lora_state_dict, storage.join(save_directory, "pytorch_lora_weights.safetensors")) + print(f"LoRA weights saved to {save_directory}/pytorch_lora_weights.safetensors using safetensors (Z-Image)") + saved_successfully = True + except ImportError: + zimage_lora_state_dict = get_peft_model_state_dict(unet) + torch.save(zimage_lora_state_dict, storage.join(save_directory, "pytorch_lora_weights.bin")) + print( + f"LoRA weights saved to {save_directory}/pytorch_lora_weights.bin using PyTorch format (Z-Image)" + ) + saved_successfully = True + except Exception as e: + print(f"Error saving Z-Image LoRA weights: {e}") + # Method 1: Try the original SD 1.x approach that worked perfectly - if not is_sdxl and not is_sd3 and not is_flux: + if not saved_successfully and not is_sdxl and not is_sd3 and not is_flux: try: StableDiffusionPipeline.save_lora_weights( save_directory=save_directory, @@ -1158,21 +1317,6 @@ def __init__(self, text_encoder, text_encoder_2, tokenizer, tokenizer_2): except Exception as e: print(f"Error with FluxPipeline.save_lora_weights: {e}") - if not saved_successfully and is_zimage: - try: - # ZImage pipelines may have their own save method - from diffusers import ZImagePipeline - - ZImagePipeline.save_lora_weights( - save_directory=save_directory, - unet_lora_layers=model_lora_state_dict, - safe_serialization=True, - ) - print(f"LoRA weights saved to {save_directory} using ZImagePipeline.save_lora_weights (ZImage)") - saved_successfully = True - except Exception as e: - print(f"Error with ZImagePipeline.save_lora_weights: {e}") - # Method 5: Try the generic StableDiffusionPipeline method as fallback for all architectures if not saved_successfully: try: diff --git a/api/transformerlab/plugins/diffusion_trainer/setup.sh b/api/transformerlab/plugins/diffusion_trainer/setup.sh index ab3ecf689..d277e5508 100644 --- a/api/transformerlab/plugins/diffusion_trainer/setup.sh +++ b/api/transformerlab/plugins/diffusion_trainer/setup.sh @@ -1,4 +1,4 @@ -uv pip install "peft>=0.15.0" +uv pip install "peft>=0.15.0" diffsynth # Only install xformers for non-rocm instances if ! command -v rocminfo &> /dev/null; then From 5100330eabc53a63bd545b1557d249ce047253e7 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Thu, 5 Feb 2026 01:26:28 +0530 Subject: [PATCH 05/27] Updated ZImage fine tuning code --- api/transformerlab/plugins/diffusion_trainer/index.json | 9 --------- 1 file changed, 9 deletions(-) diff --git a/api/transformerlab/plugins/diffusion_trainer/index.json b/api/transformerlab/plugins/diffusion_trainer/index.json index 903e0fe78..20728c454 100644 --- a/api/transformerlab/plugins/diffusion_trainer/index.json +++ b/api/transformerlab/plugins/diffusion_trainer/index.json @@ -289,12 +289,6 @@ "title": "Log to Weights and Biases", "type": "boolean", "default": true - }, - "training_adapter": { - "title": "Z-Image-Turbo Training Adapter Path", - "type": "string", - "default": "", - "ui:help": "Optional local path to a custom de-distillation training adapter (.safetensors or .bin). Leave empty to automatically download and use the recommended ostris v2 adapter when training on Z-Image-Turbo." } }, "parameters_ui": { @@ -304,9 +298,6 @@ "trigger_word": { "ui:help": "Optional trigger word to prepend to all captions during training. Example: 'sks person' or 'ohwx style'" }, - "training_adapter": { - "ui:help": "Leave blank for auto-download of the recommended adapter. Provide a local path if you want to use a custom or offline adapter (e.g., v1 or your own)." - }, "num_train_epochs": { "ui:help": "Total number of training epochs to run." }, From ef967b06c9d185be24209e894dcd88795b168822 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Thu, 5 Feb 2026 01:27:53 +0530 Subject: [PATCH 06/27] Updated ZImage fine tuning code --- api/transformerlab/plugins/diffusion_trainer/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/transformerlab/plugins/diffusion_trainer/main.py b/api/transformerlab/plugins/diffusion_trainer/main.py index 6e0b0d809..7a21a365a 100644 --- a/api/transformerlab/plugins/diffusion_trainer/main.py +++ b/api/transformerlab/plugins/diffusion_trainer/main.py @@ -435,6 +435,7 @@ def train_diffusion_lora(): variant = args.get("variant", None) model_architecture = args.get("model_architecture") + # Load pipeline to auto-detect architecture and get correct components print(f"Loading pipeline to detect model architecture: {pretrained_model_name_or_path}") From 19f23a6e6711749a4f49611905b1d09e23f22d53 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Thu, 5 Feb 2026 01:29:10 +0530 Subject: [PATCH 07/27] Reformat and rebase --- .../plugins/diffusion_trainer/main.py | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/api/transformerlab/plugins/diffusion_trainer/main.py b/api/transformerlab/plugins/diffusion_trainer/main.py index 7a21a365a..9346f3c9c 100644 --- a/api/transformerlab/plugins/diffusion_trainer/main.py +++ b/api/transformerlab/plugins/diffusion_trainer/main.py @@ -59,6 +59,7 @@ def cleanup_pipeline(): cleanup_pipeline() + def build_zimage_model_configs(model_id_or_path: str) -> tuple[list[ModelConfig], ModelConfig]: """Build ModelConfig list + tokenizer config for Z-Image Turbo.""" transformer_pattern = os.path.join("transformer", "*.safetensors") @@ -90,6 +91,7 @@ def build_zimage_model_configs(model_id_or_path: str) -> tuple[list[ModelConfig] return model_configs, tokenizer_config + def compute_loss_weighting(args, timesteps, noise_scheduler): """ Compute loss weighting for improved training stability. @@ -339,6 +341,7 @@ def encode_prompt_sdxl( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + def encode_prompt_zimage(pipe, prompts, device, max_sequence_length: int = 512): """Encode prompts using Z-Image tokenizer/chat template.""" if isinstance(prompts, str): @@ -383,6 +386,7 @@ def encode_prompt_zimage(pipe, prompts, device, max_sequence_length: int = 512): return embedding_list + @tlab_trainer.job_wrapper(wandb_project_name="TLab_Training", manual_logging=True) def train_diffusion_lora(): # Extract parameters from tlab_trainer @@ -435,7 +439,7 @@ def train_diffusion_lora(): variant = args.get("variant", None) model_architecture = args.get("model_architecture") - + # Load pipeline to auto-detect architecture and get correct components print(f"Loading pipeline to detect model architecture: {pretrained_model_name_or_path}") @@ -488,7 +492,7 @@ def train_diffusion_lora(): model_component_name = "dit" text_encoder_2 = None tokenizer_2 = None - vae = None + vae = None else: temp_pipeline = AutoPipelineForText2Image.from_pretrained(pretrained_model_name_or_path, **pipeline_kwargs) @@ -1063,7 +1067,11 @@ def __init__(self, text_encoder, text_encoder_2, tokenizer, tokenizer_2): raise ValueError( f"Unknown prediction type {noise_scheduler.config.prediction_type}" ) # Handle SDXL-specific conditioning parameters with proper metadata - unet_kwargs = {"timestep": timesteps, "encoder_hidden_states": encoder_hidden_states, "return_dict": False} + unet_kwargs = { + "timestep": timesteps, + "encoder_hidden_states": encoder_hidden_states, + "return_dict": False, + } # SDXL requires additional conditioning kwargs with proper pooled embeddings and time_ids if is_sdxl: @@ -1115,7 +1123,7 @@ def __init__(self, text_encoder, text_encoder_2, tokenizer, tokenizer_2): model_pred = unet(noisy_latents, **unet_kwargs)[0] loss = compute_loss(model_pred, target, timesteps, noise_scheduler, args) print(f"Step {step + 1}/{len(train_dataloader)} - Loss: {loss.item()}") - + loss.backward() if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(train_dataloader): @@ -1227,14 +1235,14 @@ def __init__(self, text_encoder, text_encoder_2, tokenizer, tokenizer_2): zimage_lora_state_dict = get_peft_model_state_dict(unet) save_file(zimage_lora_state_dict, storage.join(save_directory, "pytorch_lora_weights.safetensors")) - print(f"LoRA weights saved to {save_directory}/pytorch_lora_weights.safetensors using safetensors (Z-Image)") + print( + f"LoRA weights saved to {save_directory}/pytorch_lora_weights.safetensors using safetensors (Z-Image)" + ) saved_successfully = True except ImportError: zimage_lora_state_dict = get_peft_model_state_dict(unet) torch.save(zimage_lora_state_dict, storage.join(save_directory, "pytorch_lora_weights.bin")) - print( - f"LoRA weights saved to {save_directory}/pytorch_lora_weights.bin using PyTorch format (Z-Image)" - ) + print(f"LoRA weights saved to {save_directory}/pytorch_lora_weights.bin using PyTorch format (Z-Image)") saved_successfully = True except Exception as e: print(f"Error saving Z-Image LoRA weights: {e}") From cc2a21db2715e934f2522ee7369268463c364cfd Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Thu, 5 Feb 2026 21:28:48 +0530 Subject: [PATCH 08/27] Updates --- api/transformerlab/plugins/diffusion_trainer/main.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/api/transformerlab/plugins/diffusion_trainer/main.py b/api/transformerlab/plugins/diffusion_trainer/main.py index 9346f3c9c..0528ae81a 100644 --- a/api/transformerlab/plugins/diffusion_trainer/main.py +++ b/api/transformerlab/plugins/diffusion_trainer/main.py @@ -82,16 +82,17 @@ def build_zimage_model_configs(model_id_or_path: str) -> tuple[list[ModelConfig] tokenizer_config = ModelConfig(path=os.path.join(model_id_or_path, tokenizer_pattern)) else: model_configs = [ - ModelConfig(model_id=model_id_or_path, subfolder="transformer"), - ModelConfig(model_id=model_id_or_path, subfolder="text_encoder"), - ModelConfig(model_id=model_id_or_path, subfolder="vae/vae"), + ModelConfig(model_id=model_id_or_path, origin_file_pattern=transformer_pattern), + ModelConfig(model_id=model_id_or_path, origin_file_pattern=text_encoder_pattern), + ModelConfig(model_id=model_id_or_path, origin_file_pattern=vae_pattern), ] - tokenizer_config = ModelConfig(model_id=model_id_or_path, subfolder="tokenizer") + tokenizer_config = ModelConfig( + model_id=model_id_or_path, origin_file_pattern=os.path.join(tokenizer_pattern, "*") + ) return model_configs, tokenizer_config - def compute_loss_weighting(args, timesteps, noise_scheduler): """ Compute loss weighting for improved training stability. From 5c815064128afd61697b6d01d865aeae3dce414f Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Thu, 5 Feb 2026 23:40:52 +0530 Subject: [PATCH 09/27] Updates --- .../plugins/diffusion_trainer/main.py | 90 ++++++++++++++++--- 1 file changed, 76 insertions(+), 14 deletions(-) diff --git a/api/transformerlab/plugins/diffusion_trainer/main.py b/api/transformerlab/plugins/diffusion_trainer/main.py index 0528ae81a..191973790 100644 --- a/api/transformerlab/plugins/diffusion_trainer/main.py +++ b/api/transformerlab/plugins/diffusion_trainer/main.py @@ -64,7 +64,7 @@ def build_zimage_model_configs(model_id_or_path: str) -> tuple[list[ModelConfig] """Build ModelConfig list + tokenizer config for Z-Image Turbo.""" transformer_pattern = os.path.join("transformer", "*.safetensors") text_encoder_pattern = os.path.join("text_encoder", "*.safetensors") - vae_pattern = os.path.join("vae", "vae", "diffusion_pytorch_model.safetensors") + vae_pattern = os.path.join("vae", "diffusion_pytorch_model.safetensors") tokenizer_pattern = "tokenizer/" @@ -84,6 +84,7 @@ def build_zimage_model_configs(model_id_or_path: str) -> tuple[list[ModelConfig] model_configs = [ ModelConfig(model_id=model_id_or_path, origin_file_pattern=transformer_pattern), ModelConfig(model_id=model_id_or_path, origin_file_pattern=text_encoder_pattern), + # Fix: Use the corrected pattern for remote loading ModelConfig(model_id=model_id_or_path, origin_file_pattern=vae_pattern), ] @@ -93,6 +94,7 @@ def build_zimage_model_configs(model_id_or_path: str) -> tuple[list[ModelConfig] return model_configs, tokenizer_config + def compute_loss_weighting(args, timesteps, noise_scheduler): """ Compute loss weighting for improved training stability. @@ -354,12 +356,20 @@ def encode_prompt_zimage(pipe, prompts, device, max_sequence_length: int = 512): {"role": "user", "content": prompt_item}, ] - chat_prompt = pipe.tokenizer.build_chat_prompt( - messages, - tokenize=False, - add_generation_prompt=True, - enable_thinking=True, - ) + if hasattr(pipe.tokenizer, "build_chat_prompt"): + chat_prompt = pipe.tokenizer.build_chat_prompt( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + else: + # Fallback for tokenizers like Qwen2TokenizerFast + chat_prompt = pipe.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) chat_prompts.append(chat_prompt) @@ -441,9 +451,6 @@ def train_diffusion_lora(): model_architecture = args.get("model_architecture") - # Load pipeline to auto-detect architecture and get correct components - print(f"Loading pipeline to detect model architecture: {pretrained_model_name_or_path}") - # Detect architecture based on multiple indicators is_sdxl = "StableDiffusionXLPipeline" in model_architecture @@ -475,6 +482,17 @@ def train_diffusion_lora(): pipe = None if is_zimage: + # Ensure the model is downloaded locally if it's not already a directory + if not os.path.isdir(pretrained_model_name_or_path): + from huggingface_hub import snapshot_download + + print(f"Downloading Z-Image model {pretrained_model_name_or_path} from Hugging Face...") + pretrained_model_name_or_path = snapshot_download( + repo_id=pretrained_model_name_or_path, + allow_patterns=["*.safetensors", "*.json", "tokenizer/*"], + ) + print(f"Model downloaded to: {pretrained_model_name_or_path}") + model_configs, tokenizer_config = build_zimage_model_configs(pretrained_model_name_or_path) pipe = ZImagePipeline.from_pretrained( torch_dtype=weight_dtype, @@ -483,7 +501,7 @@ def train_diffusion_lora(): tokenizer_config=tokenizer_config, ) - pipe.scheduler.set_timesteps(int(args.get("num_train_timesteps", 1000)), device=device) + pipe.scheduler.set_timesteps(int(args.get("num_train_timesteps", 1000)), training=True) noise_scheduler = pipe.scheduler tokenizer = pipe.tokenizer text_encoder = pipe.text_encoder @@ -555,7 +573,8 @@ def train_diffusion_lora(): # Enable gradient checkpointing for memory savings if args.get("gradient_checkpointing", False): - unet.enable_gradient_checkpointing() + if hasattr(unet, "enable_gradient_checkpointing"): + unet.enable_gradient_checkpointing() if hasattr(text_encoder, "gradient_checkpointing_enable"): text_encoder.gradient_checkpointing_enable() if text_encoder_2 is not None and hasattr(text_encoder_2, "gradient_checkpointing_enable"): @@ -785,6 +804,7 @@ def build_captions(examples, is_train=True): processed_caption = f"{trigger_word}, {processed_caption}" captions.append(processed_caption) + return captions def tokenize_captions(examples, is_train=True): captions = build_captions(examples, is_train=is_train) @@ -809,8 +829,41 @@ def tokenize_captions(examples, is_train=True): return result image_column = args.get("image_column", "image") + caption_column = args.get("caption_column", "text") + + if image_column not in dataset.column_names: + raise ValueError(f"Image column '{image_column}' not found in dataset.") + + keep_columns = [image_column] + if caption_column in dataset.column_names: + keep_columns.append(caption_column) + + drop_columns = [col for col in dataset.column_names if col not in keep_columns] + if drop_columns: + dataset = dataset.remove_columns(drop_columns) + + if hasattr(dataset, "reset_format"): + dataset.reset_format() + + dataset = dataset.filter(lambda x: x.get(image_column) is not None) def preprocess_train(examples): + # Filter out examples with None images + images_value = examples.get(image_column) + if images_value is None: + return {} + + valid_indices = [i for i, img in enumerate(images_value) if img is not None] + if not valid_indices: + return {} + + filtered_examples = {} + for key, value in examples.items(): + if isinstance(value, (list, tuple, np.ndarray)): + filtered_examples[key] = [value[i] for i in valid_indices] + + examples = filtered_examples + images = [image.convert("RGB") for image in examples[image_column]] # Enhanced preprocessing for SDXL with proper image metadata tracking @@ -865,9 +918,18 @@ def preprocess_train(examples): def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - input_ids = torch.stack([example["input_ids"] for example in examples]) + batch = {"pixel_values": pixel_values} + + if is_zimage: + batch["prompt"] = [example["prompt"] for example in examples] + if "original_sizes" in examples[0]: + batch["original_sizes"] = [example["original_sizes"] for example in examples] + batch["crop_coords_top_left"] = [example["crop_coords_top_left"] for example in examples] + batch["target_sizes"] = [example["target_sizes"] for example in examples] + return batch - batch = {"pixel_values": pixel_values, "input_ids": input_ids} + input_ids = torch.stack([example["input_ids"] for example in examples]) + batch["input_ids"] = input_ids # Add second input_ids for SDXL if present if "input_ids_2" in examples[0]: From 3c6c3743226642e1471cd318461890bc7c4d4afa Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Thu, 5 Feb 2026 23:41:50 +0530 Subject: [PATCH 10/27] Updates --- api/transformerlab/plugins/diffusion_trainer/main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/api/transformerlab/plugins/diffusion_trainer/main.py b/api/transformerlab/plugins/diffusion_trainer/main.py index 191973790..b7834002c 100644 --- a/api/transformerlab/plugins/diffusion_trainer/main.py +++ b/api/transformerlab/plugins/diffusion_trainer/main.py @@ -833,7 +833,7 @@ def tokenize_captions(examples, is_train=True): if image_column not in dataset.column_names: raise ValueError(f"Image column '{image_column}' not found in dataset.") - + keep_columns = [image_column] if caption_column in dataset.column_names: keep_columns.append(caption_column) @@ -852,11 +852,11 @@ def preprocess_train(examples): images_value = examples.get(image_column) if images_value is None: return {} - + valid_indices = [i for i, img in enumerate(images_value) if img is not None] if not valid_indices: return {} - + filtered_examples = {} for key, value in examples.items(): if isinstance(value, (list, tuple, np.ndarray)): From 99e483e8ff2b76bf145612ad912a248925aaae2b Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Thu, 5 Feb 2026 23:45:15 +0530 Subject: [PATCH 11/27] Updates --- api/transformerlab/plugins/diffusion_trainer/setup.sh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/api/transformerlab/plugins/diffusion_trainer/setup.sh b/api/transformerlab/plugins/diffusion_trainer/setup.sh index d277e5508..4cd1602fd 100644 --- a/api/transformerlab/plugins/diffusion_trainer/setup.sh +++ b/api/transformerlab/plugins/diffusion_trainer/setup.sh @@ -1,3 +1,9 @@ +#!/bin/bash + +# Install compatible torch and torchvision first to avoid version conflicts +uv pip install torch torchvision + +# Install PEFT and diffsynth uv pip install "peft>=0.15.0" diffsynth # Only install xformers for non-rocm instances From a2f85e9e9f7d7692fab3909c4b5f288677b12d37 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Fri, 6 Feb 2026 20:44:50 +0530 Subject: [PATCH 12/27] Fixed saving lora weights --- .../plugins/diffusion_trainer/main.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/api/transformerlab/plugins/diffusion_trainer/main.py b/api/transformerlab/plugins/diffusion_trainer/main.py index b7834002c..db6dcf4d4 100644 --- a/api/transformerlab/plugins/diffusion_trainer/main.py +++ b/api/transformerlab/plugins/diffusion_trainer/main.py @@ -432,7 +432,7 @@ def train_diffusion_lora(): if eval_prompt: eval_images_dir = storage.join(workspace_dir, "temp", f"eval_images_{job_id}") - storage.makedirs(eval_images_dir, exist_ok=True) + asyncio.run(storage.makedirs(eval_images_dir, exist_ok=True)) print(f"Evaluation images will be saved to: {eval_images_dir}") # Add eval images directory to job data @@ -1270,7 +1270,7 @@ def __init__(self, text_encoder, text_encoder_2, tokenizer, tokenizer_2): print(f"Saving LoRA weights to {save_directory}") - storage.makedirs(save_directory, exist_ok=True) + asyncio.run(storage.makedirs(save_directory, exist_ok=True)) # Primary method: Use the original working approach that was perfect for SD 1.5 # Try architecture-specific save methods first, then fall back to universal methods @@ -1288,8 +1288,15 @@ def __init__(self, text_encoder, text_encoder_2, tokenizer, tokenizer_2): }, "tlab_trainer_used": True, } - with storage.open(storage.join(save_directory, "tlab_adaptor_info.json"), "w", encoding="utf-8") as f: - json.dump(save_info, f, indent=4) + async def _write_adaptor_info() -> None: + async with await storage.open( + storage.join(save_directory, "tlab_adaptor_info.json"), + "w", + encoding="utf-8", + ) as f: + await f.write(json.dumps(save_info, indent=4)) + + asyncio.run(_write_adaptor_info()) # Method 0: Z-Image PEFT safetensors save if is_zimage: From d73e250dcec4fa40b958938e636ddbc19b879047 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Fri, 6 Feb 2026 20:45:08 +0530 Subject: [PATCH 13/27] Formatting --- api/transformerlab/plugins/diffusion_trainer/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/transformerlab/plugins/diffusion_trainer/main.py b/api/transformerlab/plugins/diffusion_trainer/main.py index db6dcf4d4..8582a3e30 100644 --- a/api/transformerlab/plugins/diffusion_trainer/main.py +++ b/api/transformerlab/plugins/diffusion_trainer/main.py @@ -1288,6 +1288,7 @@ def __init__(self, text_encoder, text_encoder_2, tokenizer, tokenizer_2): }, "tlab_trainer_used": True, } + async def _write_adaptor_info() -> None: async with await storage.open( storage.join(save_directory, "tlab_adaptor_info.json"), From 63160e35116e4f64b34e24811b51a1c52d3a3321 Mon Sep 17 00:00:00 2001 From: Tony Salomone Date: Sun, 8 Feb 2026 13:29:40 -0500 Subject: [PATCH 14/27] ruff --- api/transformerlab/plugins/diffusion_trainer/main.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/api/transformerlab/plugins/diffusion_trainer/main.py b/api/transformerlab/plugins/diffusion_trainer/main.py index 8582a3e30..c41451d02 100644 --- a/api/transformerlab/plugins/diffusion_trainer/main.py +++ b/api/transformerlab/plugins/diffusion_trainer/main.py @@ -13,10 +13,9 @@ from torchvision import transforms from diffsynth import ModelConfig from diffsynth.pipelines.z_image import ZImagePipeline -from diffsynth.diffusion.loss import FlowMatchSFTLoss, TrajectoryImitationLoss +from diffsynth.diffusion.loss import FlowMatchSFTLoss import os import glob -import accelerate from diffusers import AutoPipelineForText2Image, StableDiffusionPipeline From 79286da2033d6fbdb7267ce243f972325aba37a9 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Thu, 12 Feb 2026 20:09:41 +0530 Subject: [PATCH 15/27] updates --- .../plugins/diffusion_trainer/main.py | 22 ++++++++++++++++++- .../plugins/diffusion_trainer/setup.sh | 2 +- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/api/transformerlab/plugins/diffusion_trainer/main.py b/api/transformerlab/plugins/diffusion_trainer/main.py index c41451d02..5226c9422 100644 --- a/api/transformerlab/plugins/diffusion_trainer/main.py +++ b/api/transformerlab/plugins/diffusion_trainer/main.py @@ -459,6 +459,17 @@ def train_diffusion_lora(): is_zimage = "ZImagePipeline" in model_architecture + # Some model cards expose generic architecture strings. Fallback to model id/path hints for Z-Image. + if not is_zimage: + model_name_hint = str(args.get("model_name") or "") + model_path_hint = str(args.get("model_path") or "") + model_hint = f"{model_name_hint} {model_path_hint}".lower() + if "z-image" in model_hint or "zimage" in model_hint: + is_zimage = True + print( + "Detected Z-Image model from model name/path hint; forcing ZImagePipeline training path." + ) + print(f"Architecture detection - SDXL: {is_sdxl}, SD3: {is_sd3}, Flux: {is_flux}, ZImage: {is_zimage}") # Mixed Precision @@ -1025,7 +1036,16 @@ def collate_fn(examples): timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device ).long() - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + if hasattr(noise_scheduler, "add_noise"): + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + elif hasattr(noise_scheduler, "scale_noise"): + # Flow-matching schedulers (for example FlowMatchEulerDiscreteScheduler) expose scale_noise. + noisy_latents = noise_scheduler.scale_noise(latents, timesteps, noise) + else: + raise AttributeError( + f"Unsupported scheduler {type(noise_scheduler).__name__}: " + "expected add_noise(...) or scale_noise(...)." + ) # Enhanced text encoding - always use encode_prompt for SDXL if is_sdxl: diff --git a/api/transformerlab/plugins/diffusion_trainer/setup.sh b/api/transformerlab/plugins/diffusion_trainer/setup.sh index 4cd1602fd..dc9566913 100644 --- a/api/transformerlab/plugins/diffusion_trainer/setup.sh +++ b/api/transformerlab/plugins/diffusion_trainer/setup.sh @@ -1,7 +1,7 @@ #!/bin/bash # Install compatible torch and torchvision first to avoid version conflicts -uv pip install torch torchvision +uv pip install torch torchvision diffusers transformers --extra-index-url https://download.pytorch.org/whl/cu118 # Install PEFT and diffsynth uv pip install "peft>=0.15.0" diffsynth From 281f25d50ab4f998bd481c9a8a2a06a4d9b925bc Mon Sep 17 00:00:00 2001 From: Tony Salomone Date: Tue, 17 Feb 2026 16:44:32 -0500 Subject: [PATCH 16/27] ruff --- api/transformerlab/plugins/diffusion_trainer/main.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/api/transformerlab/plugins/diffusion_trainer/main.py b/api/transformerlab/plugins/diffusion_trainer/main.py index 5226c9422..f77a6268e 100644 --- a/api/transformerlab/plugins/diffusion_trainer/main.py +++ b/api/transformerlab/plugins/diffusion_trainer/main.py @@ -466,9 +466,7 @@ def train_diffusion_lora(): model_hint = f"{model_name_hint} {model_path_hint}".lower() if "z-image" in model_hint or "zimage" in model_hint: is_zimage = True - print( - "Detected Z-Image model from model name/path hint; forcing ZImagePipeline training path." - ) + print("Detected Z-Image model from model name/path hint; forcing ZImagePipeline training path.") print(f"Architecture detection - SDXL: {is_sdxl}, SD3: {is_sd3}, Flux: {is_flux}, ZImage: {is_zimage}") From 1026537bcdd088f380039a1c63f231502e08eda7 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Wed, 18 Feb 2026 13:57:49 +0530 Subject: [PATCH 17/27] Fixes --- .../plugin_sdk/plugin_harness.py | 107 +++++++++++++++++- .../plugins/diffusion_trainer/setup.sh | 6 +- 2 files changed, 104 insertions(+), 9 deletions(-) diff --git a/api/transformerlab/plugin_sdk/plugin_harness.py b/api/transformerlab/plugin_sdk/plugin_harness.py index 1697eb66b..afdc8b9bc 100644 --- a/api/transformerlab/plugin_sdk/plugin_harness.py +++ b/api/transformerlab/plugin_sdk/plugin_harness.py @@ -12,7 +12,48 @@ import sys import argparse import traceback -import asyncio +import sqlite3 +from typing import Optional + + +def get_db_config_value(key: str, team_id: Optional[str] = None, user_id: Optional[str] = None) -> Optional[str]: + """ + Read config values directly from sqlite without importing transformerlab.plugin. + This keeps harness startup independent from heavy ML dependencies. + """ + from lab import HOME_DIR + + db_path = f"{HOME_DIR}/llmlab.sqlite3" + db = sqlite3.connect(db_path, isolation_level=None) + db.execute("PRAGMA busy_timeout=30000") + try: + # Priority 1: user-specific config (requires both user_id and team_id) + if user_id and team_id: + cursor = db.execute( + "SELECT value FROM config WHERE key = ? AND user_id = ? AND team_id = ?", (key, user_id, team_id) + ) + row = cursor.fetchone() + cursor.close() + if row is not None: + return row[0] + + # Priority 2: team-wide config + if team_id: + cursor = db.execute( + "SELECT value FROM config WHERE key = ? AND user_id IS NULL AND team_id = ?", (key, team_id) + ) + row = cursor.fetchone() + cursor.close() + if row is not None: + return row[0] + + # Priority 3: global config + cursor = db.execute("SELECT value FROM config WHERE key = ? AND user_id IS NULL AND team_id IS NULL", (key,)) + row = cursor.fetchone() + cursor.close() + return row[0] if row is not None else None + finally: + db.close() parser = argparse.ArgumentParser() @@ -20,14 +61,63 @@ args, unknown = parser.parse_known_args() -def set_config_env_vars(env_var: str, target_env_var: str = None, user_id: str = None, team_id: str = None): - try: - from transformerlab.plugin import get_db_config_value +def configure_plugin_runtime_library_paths(plugin_dir: str) -> None: + """ + Prefer CUDA/NCCL libraries from the plugin venv over system-wide libraries. + This reduces CUDA symbol mismatches caused by stale host NCCL installs. + """ + if os.name == "nt": + return + + venv_path = os.path.join(plugin_dir, "venv") + if not os.path.isdir(venv_path): + return + + pyver = f"python{sys.version_info.major}.{sys.version_info.minor}" + site_packages = os.path.join(venv_path, "lib", pyver, "site-packages") + + candidate_paths: list[str] = [] + torch_lib = os.path.join(site_packages, "torch", "lib") + if os.path.isdir(torch_lib): + candidate_paths.append(torch_lib) + + nvidia_root = os.path.join(site_packages, "nvidia") + if os.path.isdir(nvidia_root): + for pkg_name in os.listdir(nvidia_root): + lib_dir = os.path.join(nvidia_root, pkg_name, "lib") + if os.path.isdir(lib_dir): + candidate_paths.append(lib_dir) - value = asyncio.run(get_db_config_value(env_var, user_id=user_id, team_id=team_id)) + if not candidate_paths: + return + + existing_paths = [p for p in os.environ.get("LD_LIBRARY_PATH", "").split(os.pathsep) if p] + candidate_norm = {os.path.normpath(c) for c in candidate_paths} + + merged = list(candidate_paths) + for path in existing_paths: + if os.path.normpath(path) not in candidate_norm: + merged.append(path) + + if merged != existing_paths: + os.environ["LD_LIBRARY_PATH"] = os.pathsep.join(merged) + print("Configured LD_LIBRARY_PATH for plugin runtime libraries") + + +configure_plugin_runtime_library_paths(args.plugin_dir) + + +def set_config_env_vars( + env_var: str, + target_env_var: Optional[str] = None, + user_id: Optional[str] = None, + team_id: Optional[str] = None, +) -> None: + try: + value = get_db_config_value(env_var, user_id=user_id, team_id=team_id) if value: os.environ[target_env_var] = value - print(f"Set {target_env_var} from {'user' if user_id else 'team'} config: {value}") + print(f"Set {target_env_var} from {'user' if user_id else 'team'} config") except Exception as e: print(f"Warning: Could not set {target_env_var} from {'user' if user_id else 'team'} config: {e}") @@ -69,6 +159,11 @@ def set_config_env_vars(env_var: str, target_env_var: str = None, user_id: str = except ImportError as e: print(f"Error executing plugin: {e}") traceback.print_exc() + if "ncclCommShrink" in str(e): + print( + "Detected CUDA/NCCL mismatch while importing torch. " + "Reinstall the plugin venv with a torch build matching this machine's CUDA runtime." + ) # if e is a ModuleNotFoundError, the plugin is missing a required package if isinstance(e, ModuleNotFoundError): diff --git a/api/transformerlab/plugins/diffusion_trainer/setup.sh b/api/transformerlab/plugins/diffusion_trainer/setup.sh index dc9566913..535be64d9 100644 --- a/api/transformerlab/plugins/diffusion_trainer/setup.sh +++ b/api/transformerlab/plugins/diffusion_trainer/setup.sh @@ -1,7 +1,7 @@ #!/bin/bash -# Install compatible torch and torchvision first to avoid version conflicts -uv pip install torch torchvision diffusers transformers --extra-index-url https://download.pytorch.org/whl/cu118 +# Keep torch stack from base plugin venv setup to avoid CUDA/NCCL mismatches. +uv pip install diffusers transformers # Install PEFT and diffsynth uv pip install "peft>=0.15.0" diffsynth @@ -9,4 +9,4 @@ uv pip install "peft>=0.15.0" diffsynth # Only install xformers for non-rocm instances if ! command -v rocminfo &> /dev/null; then uv pip install xformers -fi \ No newline at end of file +fi From 0eacdb5a8ad2ec2fa909af62f8dc3c24a5486f65 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Wed, 18 Feb 2026 14:05:28 +0530 Subject: [PATCH 18/27] Fixes --- .../plugins/diffusion_trainer/setup.sh | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/api/transformerlab/plugins/diffusion_trainer/setup.sh b/api/transformerlab/plugins/diffusion_trainer/setup.sh index 535be64d9..e949ff7e2 100644 --- a/api/transformerlab/plugins/diffusion_trainer/setup.sh +++ b/api/transformerlab/plugins/diffusion_trainer/setup.sh @@ -1,12 +1,17 @@ #!/bin/bash -# Keep torch stack from base plugin venv setup to avoid CUDA/NCCL mismatches. -uv pip install diffusers transformers +# Keep plugin deps aligned with the base plugin venv created from api/pyproject.toml. +# This avoids resolver-driven torch/torchvision drift (e.g. missing torchvision::nms). +uv pip install \ + "diffusers==0.36.0" \ + "transformers==4.57.1" \ + "peft==0.15.2" \ + diffsynth -# Install PEFT and diffsynth -uv pip install "peft>=0.15.0" diffsynth - -# Only install xformers for non-rocm instances -if ! command -v rocminfo &> /dev/null; then - uv pip install xformers +# Only install xformers for non-ROCm instances. +# Use --no-deps so xformers cannot modify the preinstalled torch stack. +if ! command -v rocminfo >/dev/null 2>&1; then + if ! uv pip install --no-deps xformers; then + echo "xformers wheel unavailable for this environment; continuing without it." + fi fi From 78f7d24ad6c9dc1449d31c4a1e96c35d86aa7ded Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Wed, 18 Feb 2026 14:09:28 +0530 Subject: [PATCH 19/27] Updated Peft version --- api/transformerlab/plugins/diffusion_trainer/setup.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/transformerlab/plugins/diffusion_trainer/setup.sh b/api/transformerlab/plugins/diffusion_trainer/setup.sh index e949ff7e2..098a3943c 100644 --- a/api/transformerlab/plugins/diffusion_trainer/setup.sh +++ b/api/transformerlab/plugins/diffusion_trainer/setup.sh @@ -5,7 +5,7 @@ uv pip install \ "diffusers==0.36.0" \ "transformers==4.57.1" \ - "peft==0.15.2" \ + "peft>=0.17" \ diffsynth # Only install xformers for non-ROCm instances. From 69687cf840320b7dd6ecd5678d513fedf6e2f2dc Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Wed, 18 Feb 2026 16:33:39 +0530 Subject: [PATCH 20/27] Fixes --- api/test/api/test_diffusion.py | 44 ++++++ .../image_diffusion/diffusion_worker.py | 120 +++++++++++++++- .../plugins/image_diffusion/main.py | 133 ++++++++++++++++-- 3 files changed, 288 insertions(+), 9 deletions(-) diff --git a/api/test/api/test_diffusion.py b/api/test/api/test_diffusion.py index 81ac282fb..a624ae0b8 100644 --- a/api/test/api/test_diffusion.py +++ b/api/test/api/test_diffusion.py @@ -1045,3 +1045,47 @@ def test_get_pipeline_key_whitespace_adaptor(): # Should treat whitespace-only adaptor as no adaptor assert key == "test-model:: ::txt2img" + + +def test_resolve_diffusion_model_reference_non_directory(): + """Non-directory model refs should pass through unchanged.""" + main = pytest.importorskip("transformerlab.plugins.image_diffusion.main") + + with patch("transformerlab.plugins.image_diffusion.main.os.path.isdir", return_value=False): + resolved = main.resolve_diffusion_model_reference("Tongyi-MAI/Z-Image-Turbo") + + assert resolved == "Tongyi-MAI/Z-Image-Turbo" + + +def test_resolve_diffusion_model_reference_prefers_local_complete_dir(): + """Local directory with model_index.json should stay as local path.""" + main = pytest.importorskip("transformerlab.plugins.image_diffusion.main") + local_dir = "/tmp/models/Tongyi-MAI_Z-Image-Turbo" + + with ( + patch("transformerlab.plugins.image_diffusion.main.os.path.isdir", return_value=True), + patch("transformerlab.plugins.image_diffusion.main.os.path.isfile", return_value=True), + patch("transformerlab.plugins.image_diffusion.main._extract_hf_repo_from_model_metadata") as mock_extract, + ): + resolved = main.resolve_diffusion_model_reference(local_dir) + + mock_extract.assert_not_called() + assert resolved == local_dir + + +def test_resolve_diffusion_model_reference_falls_back_to_hf_repo(): + """Incomplete local directory should fall back to Hugging Face repo id from metadata.""" + main = pytest.importorskip("transformerlab.plugins.image_diffusion.main") + local_dir = "/tmp/models/Tongyi-MAI_Z-Image-Turbo" + + with ( + patch("transformerlab.plugins.image_diffusion.main.os.path.isdir", return_value=True), + patch("transformerlab.plugins.image_diffusion.main.os.path.isfile", return_value=False), + patch( + "transformerlab.plugins.image_diffusion.main._extract_hf_repo_from_model_metadata", + return_value="Tongyi-MAI/Z-Image-Turbo", + ), + ): + resolved = main.resolve_diffusion_model_reference(local_dir) + + assert resolved == "Tongyi-MAI/Z-Image-Turbo" diff --git a/api/transformerlab/plugins/image_diffusion/diffusion_worker.py b/api/transformerlab/plugins/image_diffusion/diffusion_worker.py index bfa7701b1..830ce0835 100644 --- a/api/transformerlab/plugins/image_diffusion/diffusion_worker.py +++ b/api/transformerlab/plugins/image_diffusion/diffusion_worker.py @@ -9,6 +9,7 @@ import json import os import sys +from pathlib import Path import time import gc from PIL import Image @@ -71,6 +72,118 @@ } +def _is_probable_hf_repo_id(value: str) -> bool: + """Heuristic for Hugging Face repo IDs like `org/name`.""" + if not isinstance(value, str): + return False + candidate = value.strip() + if not candidate: + return False + if os.path.isabs(candidate): + return False + if candidate.startswith("."): + return False + return "/" in candidate and "\\" not in candidate + + +def _extract_hf_repo_from_model_metadata(model_dir: str) -> str | None: + """ + Extract a Hugging Face repo id from local model metadata. + + This helps recover when `model_dir` exists but is missing `model_index.json`. + """ + metadata_path = os.path.join(model_dir, "index.json") + candidates: list[str] = [] + + if os.path.isfile(metadata_path): + try: + with open(metadata_path, "r", encoding="utf-8") as f: + metadata = json.load(f) + if isinstance(metadata, dict): + json_data = metadata.get("json_data", {}) if isinstance(metadata.get("json_data"), dict) else {} + candidates.extend( + [ + json_data.get("huggingface_repo"), + json_data.get("source_id_or_path"), + metadata.get("model_id"), + ] + ) + except Exception as e: + print(f"Warning: Failed to read model metadata at {metadata_path}: {e}") + + model_key = Path(model_dir).name + if model_key: + try: + from lab.model import Model as ModelService + + import asyncio + + model_service = asyncio.run(ModelService.get(model_key)) + model_metadata = asyncio.run(model_service.get_metadata()) + if isinstance(model_metadata, dict): + json_data = ( + model_metadata.get("json_data", {}) if isinstance(model_metadata.get("json_data"), dict) else {} + ) + candidates.extend( + [ + json_data.get("huggingface_repo"), + json_data.get("source_id_or_path"), + model_metadata.get("model_id"), + ] + ) + except Exception: + pass + + seen = set() + for candidate in candidates: + if not isinstance(candidate, str): + continue + candidate = candidate.strip() + if not candidate or candidate in seen: + continue + seen.add(candidate) + if _is_probable_hf_repo_id(candidate): + return candidate + + return None + + +def resolve_diffusion_model_reference(model: str) -> str: + """ + Resolve model reference for diffusers loading. + + If `model` is a local directory but missing `model_index.json`, try falling back + to the original Hugging Face repo id from local metadata. + """ + if not isinstance(model, str): + return model + + model_ref = model.strip() + if not model_ref: + return model + + if not os.path.isdir(model_ref): + return model_ref + + model_index_path = os.path.join(model_ref, "model_index.json") + if os.path.isfile(model_index_path): + return model_ref + + hf_repo = _extract_hf_repo_from_model_metadata(model_ref) + if hf_repo: + print( + f"Local model directory is missing model_index.json at {model_index_path}. " + f"Falling back to Hugging Face repo: {hf_repo}" + ) + return hf_repo + + print( + f"Warning: Local model directory is missing model_index.json at {model_index_path} " + "and no Hugging Face repo metadata could be resolved." + ) + return model_ref + + def load_controlnet_model(controlnet_id: str, device: str = "cuda") -> ControlNetModel: controlnet_model = ControlNetModel.from_pretrained( controlnet_id, torch_dtype=torch.float16 if device != "cpu" else torch.float32 @@ -194,7 +307,8 @@ def is_flux_model(model_path): # Check if model has FLUX components by looking for config from huggingface_hub import model_info - info = model_info(model_path) + resolved_model = resolve_diffusion_model_reference(model_path) + info = model_info(resolved_model) config = getattr(info, "config", {}) diffusers_config = config.get("diffusers", {}) architectures = diffusers_config.get("_class_name", "") @@ -227,6 +341,8 @@ def load_pipeline_with_sharding( print("Loading pipeline with model sharding...") import torch + model_path = resolve_diffusion_model_reference(model_path) + # Flush memory before starting flush_memory() @@ -527,6 +643,8 @@ def load_pipeline_with_device_map( ): """Load pipeline with proper device mapping for multi-GPU""" + model_path = resolve_diffusion_model_reference(model_path) + # Clean up any existing CUDA cache before loading if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/api/transformerlab/plugins/image_diffusion/main.py b/api/transformerlab/plugins/image_diffusion/main.py index 3a5d3d77a..8b393702d 100644 --- a/api/transformerlab/plugins/image_diffusion/main.py +++ b/api/transformerlab/plugins/image_diffusion/main.py @@ -1,6 +1,7 @@ from fastapi import HTTPException from pydantic import BaseModel, ValidationError from huggingface_hub import model_info +from pathlib import Path import base64 from io import BytesIO import torch @@ -322,6 +323,119 @@ def is_zimage_model(model: str) -> bool: return "z-image" in name or "zimage" in name +def _is_probable_hf_repo_id(value: str) -> bool: + """Heuristic for Hugging Face repo IDs like `org/name`.""" + if not isinstance(value, str): + return False + candidate = value.strip() + if not candidate: + return False + if os.path.isabs(candidate): + return False + if candidate.startswith("."): + return False + return "/" in candidate and "\\" not in candidate + + +def _extract_hf_repo_from_model_metadata(model_dir: str) -> str | None: + """ + Extract a Hugging Face repo id from local model metadata. + + This helps recover when `model_dir` exists but is missing `model_index.json`. + """ + # First try direct metadata file lookup in the model directory. + metadata_path = os.path.join(model_dir, "index.json") + candidates: list[str] = [] + + if os.path.isfile(metadata_path): + try: + with open(metadata_path, "r", encoding="utf-8") as f: + metadata = json.load(f) + if isinstance(metadata, dict): + json_data = metadata.get("json_data", {}) if isinstance(metadata.get("json_data"), dict) else {} + candidates.extend( + [ + json_data.get("huggingface_repo"), + json_data.get("source_id_or_path"), + metadata.get("model_id"), + ] + ) + except Exception as e: + print(f"Warning: Failed to read model metadata at {metadata_path}: {e}") + + # Fallback to SDK metadata lookup by directory name. + model_key = Path(model_dir).name + if model_key: + try: + from lab.model import Model as ModelService + + model_service = run_async_from_sync(ModelService.get(model_key)) + model_metadata = run_async_from_sync(model_service.get_metadata()) + if isinstance(model_metadata, dict): + json_data = ( + model_metadata.get("json_data", {}) if isinstance(model_metadata.get("json_data"), dict) else {} + ) + candidates.extend( + [ + json_data.get("huggingface_repo"), + json_data.get("source_id_or_path"), + model_metadata.get("model_id"), + ] + ) + except Exception: + # Best-effort lookup only. + pass + + seen = set() + for candidate in candidates: + if not isinstance(candidate, str): + continue + candidate = candidate.strip() + if not candidate or candidate in seen: + continue + seen.add(candidate) + if _is_probable_hf_repo_id(candidate): + return candidate + + return None + + +def resolve_diffusion_model_reference(model: str) -> str: + """ + Resolve model reference for diffusers loading. + + If `model` is a local directory but missing `model_index.json`, try falling back + to the original Hugging Face repo id from local metadata. + """ + if not isinstance(model, str): + return model + + model_ref = model.strip() + if not model_ref: + return model + + if not os.path.isdir(model_ref): + return model_ref + + model_index_path = os.path.join(model_ref, "model_index.json") + if os.path.isfile(model_index_path): + return model_ref + + hf_repo = _extract_hf_repo_from_model_metadata(model_ref) + if hf_repo: + print( + f"Local model directory is missing model_index.json at {model_index_path}. " + f"Falling back to Hugging Face repo: {hf_repo}" + ) + return hf_repo + + print( + f"Warning: Local model directory is missing model_index.json at {model_index_path} " + "and no Hugging Face repo metadata could be resolved." + ) + return model_ref + + def get_pipeline( model: str, adaptor: str = "", @@ -335,8 +449,10 @@ def get_pipeline( # cache_key = get_pipeline_key(model, adaptor, is_img2img, is_inpainting) with _PIPELINES_LOCK: + resolved_model = resolve_diffusion_model_reference(model) + # Detect Z-Image architecture (non-controlnet path) - is_zimage = is_zimage_model(model) + is_zimage = is_zimage_model(resolved_model) # Load appropriate pipeline based on type if is_controlnet: @@ -360,7 +476,7 @@ def get_pipeline( print(f"Loading ControlNet pipeline ({controlnet_id}) for model: {model}") try: - info = model_info(model) + info = model_info(resolved_model) config = getattr(info, "config", {}) diffusers_config = config.get("diffusers", {}) architecture = diffusers_config.get("_class_name", "") @@ -377,7 +493,7 @@ def get_pipeline( print(f"Loaded ControlNet pipeline {controlnet_pipeline} for model {model}") pipe = controlnet_pipeline.from_pretrained( - model, + resolved_model, controlnet=controlnet_model, torch_dtype=torch.float16 if device != "cpu" else torch.float32, safety_checker=None, @@ -386,7 +502,7 @@ def get_pipeline( ) elif is_inpainting: pipe = AutoPipelineForInpainting.from_pretrained( - model, + resolved_model, torch_dtype=torch.float16 if device != "cpu" else torch.float32, safety_checker=None, requires_safety_checker=False, @@ -394,7 +510,7 @@ def get_pipeline( print(f"Loaded inpainting pipeline for model: {model}") elif is_img2img: pipe = AutoPipelineForImage2Image.from_pretrained( - model, + resolved_model, torch_dtype=torch.float16 if device != "cpu" else torch.float32, safety_checker=None, requires_safety_checker=False, @@ -402,14 +518,14 @@ def get_pipeline( print(f"Loaded image-to-image pipeline for model: {model}") elif is_zimage: pipe = DiffusionPipeline.from_pretrained( - model, + resolved_model, torch_dtype=torch.bfloat16 if device != "cpu" else torch.float32, low_cpu_mem_usage=False, ) print(f"Loaded Z-Image pipeline for model: {model} with dtype {pipe.dtype}") else: pipe = AutoPipelineForText2Image.from_pretrained( - model, + resolved_model, torch_dtype=torch.float16 if device != "cpu" else torch.float32, safety_checker=None, requires_safety_checker=False, @@ -627,7 +743,8 @@ def should_use_diffusion_worker(model) -> bool: # Check if model has FLUX components by looking for config from huggingface_hub import model_info - info = model_info(model) + resolved_model = resolve_diffusion_model_reference(model) + info = model_info(resolved_model) config = getattr(info, "config", {}) diffusers_config = config.get("diffusers", {}) architectures = diffusers_config.get("_class_name", "") From 4425840939f1bba8ad1fc17937dd551b31492d14 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Wed, 18 Feb 2026 16:40:01 +0530 Subject: [PATCH 21/27] Fixes --- api/test/api/test_diffusion.py | 21 ++++++++++ .../image_diffusion/diffusion_worker.py | 37 ++++++++++++++++++ .../plugins/image_diffusion/main.py | 38 +++++++++++++++++++ 3 files changed, 96 insertions(+) diff --git a/api/test/api/test_diffusion.py b/api/test/api/test_diffusion.py index a624ae0b8..d5f8680b6 100644 --- a/api/test/api/test_diffusion.py +++ b/api/test/api/test_diffusion.py @@ -1089,3 +1089,24 @@ def test_resolve_diffusion_model_reference_falls_back_to_hf_repo(): resolved = main.resolve_diffusion_model_reference(local_dir) assert resolved == "Tongyi-MAI/Z-Image-Turbo" + + +def test_filter_generation_kwargs_for_pipeline_drops_unsupported(): + """Unsupported kwargs should be removed when pipeline call signature is strict.""" + main = pytest.importorskip("transformerlab.plugins.image_diffusion.main") + + class StrictPipeline: + def __call__(self, prompt, guidance_scale): + return {"prompt": prompt, "guidance_scale": guidance_scale} + + pipe = StrictPipeline() + kwargs = { + "prompt": "test", + "guidance_scale": 7.5, + "cross_attention_kwargs": {"scale": 1.0}, + "callback_on_step_end": lambda *_: None, + } + + filtered = main.filter_generation_kwargs_for_pipeline(pipe, kwargs) + + assert filtered == {"prompt": "test", "guidance_scale": 7.5} diff --git a/api/transformerlab/plugins/image_diffusion/diffusion_worker.py b/api/transformerlab/plugins/image_diffusion/diffusion_worker.py index 830ce0835..ef7611970 100644 --- a/api/transformerlab/plugins/image_diffusion/diffusion_worker.py +++ b/api/transformerlab/plugins/image_diffusion/diffusion_worker.py @@ -10,6 +10,7 @@ import os import sys from pathlib import Path +import inspect import time import gc from PIL import Image @@ -184,6 +185,41 @@ def resolve_diffusion_model_reference(model: str) -> str: return model_ref +def filter_generation_kwargs_for_pipeline(pipe, generation_kwargs: dict) -> dict: + """ + Drop kwargs that are not accepted by `pipe.__call__`. + + Some backends (for example ZImagePipeline) reject common diffusers kwargs like + `cross_attention_kwargs` and callback arguments. + """ + try: + signature = inspect.signature(pipe.__call__) + except (TypeError, ValueError): + return generation_kwargs + + params = signature.parameters + supports_var_kwargs = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in params.values()) + if supports_var_kwargs: + return generation_kwargs + + allowed_keys = {name for name in params if name != "self"} + filtered_kwargs = {} + skipped_keys = [] + + for key, value in generation_kwargs.items(): + if key in allowed_keys: + filtered_kwargs[key] = value + else: + skipped_keys.append(key) + + if skipped_keys: + print( + f"Skipping unsupported generation kwargs for {pipe.__class__.__name__}: {', '.join(sorted(skipped_keys))}" + ) + + return filtered_kwargs + + def load_controlnet_model(controlnet_id: str, device: str = "cuda") -> ControlNetModel: controlnet_model = ControlNetModel.from_pretrained( controlnet_id, torch_dtype=torch.float16 if device != "cpu" else torch.float32 @@ -1015,6 +1051,7 @@ def main(): generation_kwargs["callback_on_step_end"] = decode_callback generation_kwargs["callback_on_step_end_tensor_inputs"] = ["latents"] print("Enabled intermediate image saving") + generation_kwargs = filter_generation_kwargs_for_pipeline(pipe, generation_kwargs) # Generate images print("Starting image generation...") diff --git a/api/transformerlab/plugins/image_diffusion/main.py b/api/transformerlab/plugins/image_diffusion/main.py index 8b393702d..729ad9a9d 100644 --- a/api/transformerlab/plugins/image_diffusion/main.py +++ b/api/transformerlab/plugins/image_diffusion/main.py @@ -2,6 +2,7 @@ from pydantic import BaseModel, ValidationError from huggingface_hub import model_info from pathlib import Path +import inspect import base64 from io import BytesIO import torch @@ -436,6 +437,42 @@ def resolve_diffusion_model_reference(model: str) -> str: return model_ref +def filter_generation_kwargs_for_pipeline(pipe, generation_kwargs: dict) -> dict: + """ + Drop kwargs that are not accepted by `pipe.__call__`. + + Some backends (for example ZImagePipeline) reject common diffusers kwargs like + `cross_attention_kwargs` and callback arguments. + """ + try: + signature = inspect.signature(pipe.__call__) + except (TypeError, ValueError): + # If signature introspection fails, keep original kwargs. + return generation_kwargs + + params = signature.parameters + supports_var_kwargs = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in params.values()) + if supports_var_kwargs: + return generation_kwargs + + allowed_keys = {name for name in params if name != "self"} + filtered_kwargs = {} + skipped_keys = [] + + for key, value in generation_kwargs.items(): + if key in allowed_keys: + filtered_kwargs[key] = value + else: + skipped_keys.append(key) + + if skipped_keys: + print( + f"Skipping unsupported generation kwargs for {pipe.__class__.__name__}: {', '.join(sorted(skipped_keys))}" + ) + + return filtered_kwargs + + def get_pipeline( model: str, adaptor: str = "", @@ -1240,6 +1277,7 @@ def unified_callback(pipe, step: int, timestep: int, callback_kwargs: dict): generation_kwargs["callback_on_step_end"] = unified_callback generation_kwargs["callback_on_step_end_tensor_inputs"] = ["latents"] + generation_kwargs = filter_generation_kwargs_for_pipeline(pipe, generation_kwargs) result = pipe(**generation_kwargs) images = result.images # Get all images From ddbc9e0f8fd87b64fae2c3fa201de52a73bd4506 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Wed, 18 Feb 2026 19:31:54 +0530 Subject: [PATCH 22/27] Fixes --- api/transformerlab/plugin_sdk/plugin_harness.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/api/transformerlab/plugin_sdk/plugin_harness.py b/api/transformerlab/plugin_sdk/plugin_harness.py index afdc8b9bc..456a65b63 100644 --- a/api/transformerlab/plugin_sdk/plugin_harness.py +++ b/api/transformerlab/plugin_sdk/plugin_harness.py @@ -113,13 +113,14 @@ def set_config_env_vars( user_id: Optional[str] = None, team_id: Optional[str] = None, ) -> None: + target_key = target_env_var or env_var try: value = get_db_config_value(env_var, user_id=user_id, team_id=team_id) if value: - os.environ[target_env_var] = value - print(f"Set {target_env_var} from {'user' if user_id else 'team'} config") + os.environ[target_key] = value + print(f"Set {target_key} from {'user' if user_id else 'team'} config") except Exception as e: - print(f"Warning: Could not set {target_env_var} from {'user' if user_id else 'team'} config: {e}") + print(f"Warning: Could not set {target_key} from {'user' if user_id else 'team'} config: {e}") # Set organization context from environment variable if provided From 2914c12e7ed15c8eeb5a624c5f9fb1d4a6bc73cd Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Wed, 18 Feb 2026 19:37:20 +0530 Subject: [PATCH 23/27] Unpin versions --- api/transformerlab/plugins/diffusion_trainer/setup.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/transformerlab/plugins/diffusion_trainer/setup.sh b/api/transformerlab/plugins/diffusion_trainer/setup.sh index 098a3943c..4e3752c70 100644 --- a/api/transformerlab/plugins/diffusion_trainer/setup.sh +++ b/api/transformerlab/plugins/diffusion_trainer/setup.sh @@ -3,8 +3,8 @@ # Keep plugin deps aligned with the base plugin venv created from api/pyproject.toml. # This avoids resolver-driven torch/torchvision drift (e.g. missing torchvision::nms). uv pip install \ - "diffusers==0.36.0" \ - "transformers==4.57.1" \ + "diffusers" \ + "transformers" \ "peft>=0.17" \ diffsynth From f1061d57c77c902281da194a1aa5da3ab4839ef5 Mon Sep 17 00:00:00 2001 From: Tony Salomone Date: Wed, 18 Feb 2026 09:32:14 -0500 Subject: [PATCH 24/27] Bump diffusion plugin version --- api/transformerlab/plugins/image_diffusion/index.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/transformerlab/plugins/image_diffusion/index.json b/api/transformerlab/plugins/image_diffusion/index.json index 0cbf8d933..1c064c24c 100644 --- a/api/transformerlab/plugins/image_diffusion/index.json +++ b/api/transformerlab/plugins/image_diffusion/index.json @@ -4,7 +4,7 @@ "description": "Generate images in the Diffusion tab using this plugin", "plugin-format": "python", "type": "diffusion", - "version": "0.0.9", + "version": "0.1.1", "git": "", "url": "", "files": ["main.py", "diffusion_worker.py", "setup.sh"], From 973795c47acf5252c91d7c75177a3b2438b22159 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Wed, 18 Feb 2026 20:12:52 +0530 Subject: [PATCH 25/27] Updates --- api/test/api/test_diffusion.py | 26 +++++++++++++++ .../image_diffusion/diffusion_worker.py | 32 ++++++++++++++++-- .../plugins/image_diffusion/main.py | 33 +++++++++++++++++-- 3 files changed, 86 insertions(+), 5 deletions(-) diff --git a/api/test/api/test_diffusion.py b/api/test/api/test_diffusion.py index d5f8680b6..e6ce07bfd 100644 --- a/api/test/api/test_diffusion.py +++ b/api/test/api/test_diffusion.py @@ -1110,3 +1110,29 @@ def __call__(self, prompt, guidance_scale): filtered = main.filter_generation_kwargs_for_pipeline(pipe, kwargs) assert filtered == {"prompt": "test", "guidance_scale": 7.5} + + +def test_invoke_pipeline_with_safe_kwargs_retries_on_unexpected_keyword(): + """Retry logic should remove unsupported kwargs when strict call signatures are wrapped.""" + main = pytest.importorskip("transformerlab.plugins.image_diffusion.main") + + class WrappedStrictPipeline: + def __call__(self, *args, **kwargs): + if "cross_attention_kwargs" in kwargs: + raise TypeError( + "ZImagePipeline.__call__() got an unexpected keyword argument 'cross_attention_kwargs'" + ) + return kwargs + + pipe = WrappedStrictPipeline() + kwargs = { + "prompt": "test", + "guidance_scale": 7.5, + "cross_attention_kwargs": {"scale": 1.0}, + } + + result = main.invoke_pipeline_with_safe_kwargs(pipe, kwargs) + + assert result["prompt"] == "test" + assert result["guidance_scale"] == 7.5 + assert "cross_attention_kwargs" not in result diff --git a/api/transformerlab/plugins/image_diffusion/diffusion_worker.py b/api/transformerlab/plugins/image_diffusion/diffusion_worker.py index ef7611970..20540a857 100644 --- a/api/transformerlab/plugins/image_diffusion/diffusion_worker.py +++ b/api/transformerlab/plugins/image_diffusion/diffusion_worker.py @@ -9,6 +9,7 @@ import json import os import sys +import re from pathlib import Path import inspect import time @@ -220,6 +221,34 @@ def filter_generation_kwargs_for_pipeline(pipe, generation_kwargs: dict) -> dict return filtered_kwargs +def invoke_pipeline_with_safe_kwargs(pipe, generation_kwargs: dict): + """ + Call a pipeline and recover from wrapper signature mismatches. + + Some pipelines expose a permissive `__call__` signature through decorators while the wrapped + implementation is strict. In that case, retry without the unexpected kwarg. + """ + filtered_kwargs = filter_generation_kwargs_for_pipeline(pipe, generation_kwargs) + + while True: + try: + return pipe(**filtered_kwargs) + except TypeError as exc: + message = str(exc) + match = re.search(r"unexpected keyword argument '([^']+)'", message) + if not match: + raise + + unexpected_key = match.group(1) + if unexpected_key not in filtered_kwargs: + raise + + print( + f"Retrying generation without unsupported kwarg '{unexpected_key}' for {pipe.__class__.__name__}" + ) + filtered_kwargs = {key: value for key, value in filtered_kwargs.items() if key != unexpected_key} + + def load_controlnet_model(controlnet_id: str, device: str = "cuda") -> ControlNetModel: controlnet_model = ControlNetModel.from_pretrained( controlnet_id, torch_dtype=torch.float16 if device != "cpu" else torch.float32 @@ -1051,7 +1080,6 @@ def main(): generation_kwargs["callback_on_step_end"] = decode_callback generation_kwargs["callback_on_step_end_tensor_inputs"] = ["latents"] print("Enabled intermediate image saving") - generation_kwargs = filter_generation_kwargs_for_pipeline(pipe, generation_kwargs) # Generate images print("Starting image generation...") @@ -1108,7 +1136,7 @@ def main(): print(f"Using pre-generated images from sharding: {len(images)} images") else: # This is a normal pipeline, call it to generate images - result = pipe(**generation_kwargs) + result = invoke_pipeline_with_safe_kwargs(pipe, generation_kwargs) images = result.images except RuntimeError as e: if "illegal memory access" in str(e) or "CUDA error" in str(e): diff --git a/api/transformerlab/plugins/image_diffusion/main.py b/api/transformerlab/plugins/image_diffusion/main.py index 729ad9a9d..106f27a04 100644 --- a/api/transformerlab/plugins/image_diffusion/main.py +++ b/api/transformerlab/plugins/image_diffusion/main.py @@ -41,6 +41,7 @@ import os import sys import random +import re from werkzeug.utils import secure_filename import json from datetime import datetime @@ -473,6 +474,34 @@ def filter_generation_kwargs_for_pipeline(pipe, generation_kwargs: dict) -> dict return filtered_kwargs +def invoke_pipeline_with_safe_kwargs(pipe, generation_kwargs: dict): + """ + Call a pipeline and recover from wrapper signature mismatches. + + Some pipelines expose a permissive `__call__` signature through decorators while the wrapped + implementation is strict. In that case, retry without the unexpected kwarg. + """ + filtered_kwargs = filter_generation_kwargs_for_pipeline(pipe, generation_kwargs) + + while True: + try: + return pipe(**filtered_kwargs) + except TypeError as exc: + message = str(exc) + match = re.search(r"unexpected keyword argument '([^']+)'", message) + if not match: + raise + + unexpected_key = match.group(1) + if unexpected_key not in filtered_kwargs: + raise + + print( + f"Retrying generation without unsupported kwarg '{unexpected_key}' for {pipe.__class__.__name__}" + ) + filtered_kwargs = {key: value for key, value in filtered_kwargs.items() if key != unexpected_key} + + def get_pipeline( model: str, adaptor: str = "", @@ -1277,9 +1306,7 @@ def unified_callback(pipe, step: int, timestep: int, callback_kwargs: dict): generation_kwargs["callback_on_step_end"] = unified_callback generation_kwargs["callback_on_step_end_tensor_inputs"] = ["latents"] - generation_kwargs = filter_generation_kwargs_for_pipeline(pipe, generation_kwargs) - - result = pipe(**generation_kwargs) + result = invoke_pipeline_with_safe_kwargs(pipe, generation_kwargs) images = result.images # Get all images # Clean up result object to free references From 53a28953a2bdcf6384ebee77357f70aa2e7d999d Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Wed, 18 Feb 2026 21:06:14 +0530 Subject: [PATCH 26/27] Updates --- api/test/api/test_diffusion.py | 12 +++++ .../plugins/diffusion_trainer/main.py | 44 +++++++++++---- .../plugins/diffusion_trainer/setup.sh | 16 ++++-- .../image_diffusion/diffusion_worker.py | 53 ++++++++++++++----- .../plugins/image_diffusion/main.py | 53 ++++++++++++++----- 5 files changed, 141 insertions(+), 37 deletions(-) diff --git a/api/test/api/test_diffusion.py b/api/test/api/test_diffusion.py index e6ce07bfd..326eec04d 100644 --- a/api/test/api/test_diffusion.py +++ b/api/test/api/test_diffusion.py @@ -1136,3 +1136,15 @@ def __call__(self, *args, **kwargs): assert result["prompt"] == "test" assert result["guidance_scale"] == 7.5 assert "cross_attention_kwargs" not in result + + +def test_latents_to_rgb_supports_non_sdxl_channel_counts(): + """Intermediate preview conversion should not fail for non-4-channel latents.""" + main = pytest.importorskip("transformerlab.plugins.image_diffusion.main") + torch = pytest.importorskip("torch") + + latents = torch.randn(16, 8, 8) + preview = main.latents_to_rgb(latents) + + assert preview.mode == "RGB" + assert preview.size == (8, 8) diff --git a/api/transformerlab/plugins/diffusion_trainer/main.py b/api/transformerlab/plugins/diffusion_trainer/main.py index f77a6268e..e81f2f517 100644 --- a/api/transformerlab/plugins/diffusion_trainer/main.py +++ b/api/transformerlab/plugins/diffusion_trainer/main.py @@ -3,6 +3,7 @@ import json import gc import asyncio +from typing import Any import numpy as np import torch @@ -11,9 +12,6 @@ from peft import LoraConfig, get_peft_model from peft.utils import get_peft_model_state_dict from torchvision import transforms -from diffsynth import ModelConfig -from diffsynth.pipelines.z_image import ZImagePipeline -from diffsynth.diffusion.loss import FlowMatchSFTLoss import os import glob @@ -22,18 +20,35 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import cast_training_params, compute_snr from diffusers.utils import convert_state_dict_to_diffusers +from transformerlab.sdk.v1.train import tlab_trainer +from lab.dirs import get_workspace_dir +from lab import storage + +# diffsynth is only required for Z-Image training. Keep it optional so non-ZImage +# workflows still run when optional low-level deps (for example xformers) are broken. +ModelConfig = None +ZImagePipeline = None +FlowMatchSFTLoss = None +diffsynth_available = False +diffsynth_import_error = None +try: + from diffsynth import ModelConfig + from diffsynth.pipelines.z_image import ZImagePipeline + from diffsynth.diffusion.loss import FlowMatchSFTLoss + + diffsynth_available = True +except Exception as e: + diffsynth_import_error = e + print(f"Warning: diffsynth is unavailable; Z-Image training will be disabled. Error: {e}") # Try to import xformers for memory optimization try: import xformers # noqa: F401 xformers_available = True -except ImportError: +except Exception as e: xformers_available = False - -from transformerlab.sdk.v1.train import tlab_trainer -from lab.dirs import get_workspace_dir -from lab import storage + print(f"Warning: xFormers is unavailable; continuing without it. Error: {e}") workspace_dir = asyncio.run(get_workspace_dir()) @@ -59,8 +74,11 @@ def cleanup_pipeline(): cleanup_pipeline() -def build_zimage_model_configs(model_id_or_path: str) -> tuple[list[ModelConfig], ModelConfig]: +def build_zimage_model_configs(model_id_or_path: str) -> tuple[list[Any], Any]: """Build ModelConfig list + tokenizer config for Z-Image Turbo.""" + if ModelConfig is None: + raise RuntimeError("diffsynth ModelConfig is unavailable.") + transformer_pattern = os.path.join("transformer", "*.safetensors") text_encoder_pattern = os.path.join("text_encoder", "*.safetensors") vae_pattern = os.path.join("vae", "diffusion_pytorch_model.safetensors") @@ -490,6 +508,14 @@ def train_diffusion_lora(): pipe = None if is_zimage: + if not diffsynth_available or ZImagePipeline is None or FlowMatchSFTLoss is None: + raise RuntimeError( + "Z-Image training dependencies are unavailable. " + "This is commonly caused by an incompatible xformers build in the plugin venv. " + "Try uninstalling xformers from the diffusion_trainer venv and rerun setup. " + f"Original import error: {diffsynth_import_error}" + ) + # Ensure the model is downloaded locally if it's not already a directory if not os.path.isdir(pretrained_model_name_or_path): from huggingface_hub import snapshot_download diff --git a/api/transformerlab/plugins/diffusion_trainer/setup.sh b/api/transformerlab/plugins/diffusion_trainer/setup.sh index 4e3752c70..db0899d62 100644 --- a/api/transformerlab/plugins/diffusion_trainer/setup.sh +++ b/api/transformerlab/plugins/diffusion_trainer/setup.sh @@ -8,10 +8,18 @@ uv pip install \ "peft>=0.17" \ diffsynth -# Only install xformers for non-ROCm instances. -# Use --no-deps so xformers cannot modify the preinstalled torch stack. -if ! command -v rocminfo >/dev/null 2>&1; then - if ! uv pip install --no-deps xformers; then +# xformers is ABI-sensitive and frequently mismatches the preinstalled torch build. +# Keep it opt-in to avoid import-time crashes (for example undefined C++ symbols). +# To opt in, set TLAB_ENABLE_XFORMERS=1 before running setup. +if [ "${TLAB_ENABLE_XFORMERS:-0}" = "1" ] && ! command -v rocminfo >/dev/null 2>&1; then + if uv pip install --no-deps xformers; then + if ! python -c "import xformers" >/dev/null 2>&1; then + echo "xformers import test failed; uninstalling incompatible wheel and continuing without xformers." + uv pip uninstall -y xformers || true + fi + else echo "xformers wheel unavailable for this environment; continuing without it." fi +else + echo "Skipping xformers install (set TLAB_ENABLE_XFORMERS=1 to opt in)." fi diff --git a/api/transformerlab/plugins/image_diffusion/diffusion_worker.py b/api/transformerlab/plugins/image_diffusion/diffusion_worker.py index 20540a857..259cc1de5 100644 --- a/api/transformerlab/plugins/image_diffusion/diffusion_worker.py +++ b/api/transformerlab/plugins/image_diffusion/diffusion_worker.py @@ -257,20 +257,49 @@ def load_controlnet_model(controlnet_id: str, device: str = "cuda") -> ControlNe def latents_to_rgb(latents): - """Convert SDXL latents (4 channels) to RGB tensors (3 channels)""" - weights = ( - (60, -60, 25, -70), - (60, -5, 15, -50), - (60, 10, -5, -35), - ) + """ + Convert latent tensors to an RGB preview image. - weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device)) - biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device) - rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze( - -1 - ) - image_array = rgb_tensor.clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0) + Uses SDXL-style weights for 4-channel latents and a robust normalization fallback for + models with other latent channel counts (for example 16-channel Z-Image latents). + """ + if latents.ndim == 4: + latents = latents[0] + + if latents.ndim != 3: + raise ValueError(f"Expected latents shape [C,H,W] or [B,C,H,W], got {tuple(latents.shape)}") + + channels = latents.shape[0] + + if channels == 4: + weights = ( + (60, -60, 25, -70), + (60, -5, 15, -50), + (60, 10, -5, -35), + ) + weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device)) + biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device) + rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze( + -1 + ).unsqueeze(-1) + image_array = rgb_tensor.clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0) + return Image.fromarray(image_array) + + # Fallback for non-SDXL latent layouts: map first channels to RGB and normalize per channel. + rgb_tensor = latents.float() + if channels == 1: + rgb_tensor = rgb_tensor.repeat(3, 1, 1) + elif channels == 2: + rgb_tensor = torch.cat([rgb_tensor, rgb_tensor[:1]], dim=0) + else: + rgb_tensor = rgb_tensor[:3] + + channel_min = rgb_tensor.amin(dim=(1, 2), keepdim=True) + channel_max = rgb_tensor.amax(dim=(1, 2), keepdim=True) + denom = (channel_max - channel_min).clamp(min=1e-6) + rgb_tensor = (rgb_tensor - channel_min) / denom + image_array = (rgb_tensor * 255.0).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0) return Image.fromarray(image_array) diff --git a/api/transformerlab/plugins/image_diffusion/main.py b/api/transformerlab/plugins/image_diffusion/main.py index 106f27a04..b27658c65 100644 --- a/api/transformerlab/plugins/image_diffusion/main.py +++ b/api/transformerlab/plugins/image_diffusion/main.py @@ -236,20 +236,49 @@ def load_controlnet_model(controlnet_id: str, device: str = "cuda") -> ControlNe def latents_to_rgb(latents): - """Convert SDXL latents (4 channels) to RGB tensors (3 channels)""" - weights = ( - (60, -60, 25, -70), - (60, -5, 15, -50), - (60, 10, -5, -35), - ) + """ + Convert latent tensors to an RGB preview image. - weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device)) - biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device) - rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze( - -1 - ) - image_array = rgb_tensor.clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0) + Uses SDXL-style weights for 4-channel latents and a robust normalization fallback for + models with other latent channel counts (for example 16-channel Z-Image latents). + """ + if latents.ndim == 4: + latents = latents[0] + + if latents.ndim != 3: + raise ValueError(f"Expected latents shape [C,H,W] or [B,C,H,W], got {tuple(latents.shape)}") + + channels = latents.shape[0] + + if channels == 4: + weights = ( + (60, -60, 25, -70), + (60, -5, 15, -50), + (60, 10, -5, -35), + ) + weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device)) + biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device) + rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze( + -1 + ).unsqueeze(-1) + image_array = rgb_tensor.clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0) + return Image.fromarray(image_array) + + # Fallback for non-SDXL latent layouts: map first channels to RGB and normalize per channel. + rgb_tensor = latents.float() + if channels == 1: + rgb_tensor = rgb_tensor.repeat(3, 1, 1) + elif channels == 2: + rgb_tensor = torch.cat([rgb_tensor, rgb_tensor[:1]], dim=0) + else: + rgb_tensor = rgb_tensor[:3] + + channel_min = rgb_tensor.amin(dim=(1, 2), keepdim=True) + channel_max = rgb_tensor.amax(dim=(1, 2), keepdim=True) + denom = (channel_max - channel_min).clamp(min=1e-6) + rgb_tensor = (rgb_tensor - channel_min) / denom + image_array = (rgb_tensor * 255.0).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0) return Image.fromarray(image_array) From edcb473c3cbb6d9114d4566a5fe57f3a031e76f0 Mon Sep 17 00:00:00 2001 From: Tony Salomone Date: Wed, 18 Feb 2026 11:52:19 -0500 Subject: [PATCH 27/27] ruff --- api/test/api/test_diffusion.py | 4 +--- .../plugins/image_diffusion/diffusion_worker.py | 4 +--- api/transformerlab/plugins/image_diffusion/main.py | 4 +--- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/api/test/api/test_diffusion.py b/api/test/api/test_diffusion.py index 326eec04d..aa702d3bc 100644 --- a/api/test/api/test_diffusion.py +++ b/api/test/api/test_diffusion.py @@ -1119,9 +1119,7 @@ def test_invoke_pipeline_with_safe_kwargs_retries_on_unexpected_keyword(): class WrappedStrictPipeline: def __call__(self, *args, **kwargs): if "cross_attention_kwargs" in kwargs: - raise TypeError( - "ZImagePipeline.__call__() got an unexpected keyword argument 'cross_attention_kwargs'" - ) + raise TypeError("ZImagePipeline.__call__() got an unexpected keyword argument 'cross_attention_kwargs'") return kwargs pipe = WrappedStrictPipeline() diff --git a/api/transformerlab/plugins/image_diffusion/diffusion_worker.py b/api/transformerlab/plugins/image_diffusion/diffusion_worker.py index 259cc1de5..72e3603e0 100644 --- a/api/transformerlab/plugins/image_diffusion/diffusion_worker.py +++ b/api/transformerlab/plugins/image_diffusion/diffusion_worker.py @@ -243,9 +243,7 @@ def invoke_pipeline_with_safe_kwargs(pipe, generation_kwargs: dict): if unexpected_key not in filtered_kwargs: raise - print( - f"Retrying generation without unsupported kwarg '{unexpected_key}' for {pipe.__class__.__name__}" - ) + print(f"Retrying generation without unsupported kwarg '{unexpected_key}' for {pipe.__class__.__name__}") filtered_kwargs = {key: value for key, value in filtered_kwargs.items() if key != unexpected_key} diff --git a/api/transformerlab/plugins/image_diffusion/main.py b/api/transformerlab/plugins/image_diffusion/main.py index b27658c65..bd69c24f1 100644 --- a/api/transformerlab/plugins/image_diffusion/main.py +++ b/api/transformerlab/plugins/image_diffusion/main.py @@ -525,9 +525,7 @@ def invoke_pipeline_with_safe_kwargs(pipe, generation_kwargs: dict): if unexpected_key not in filtered_kwargs: raise - print( - f"Retrying generation without unsupported kwarg '{unexpected_key}' for {pipe.__class__.__name__}" - ) + print(f"Retrying generation without unsupported kwarg '{unexpected_key}' for {pipe.__class__.__name__}") filtered_kwargs = {key: value for key, value in filtered_kwargs.items() if key != unexpected_key}