Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
a25e3a1
Add Z Image LoRA fine tuning support
ParamThakkar123 Dec 23, 2025
5c15019
Added existing parameter loading
ParamThakkar123 Dec 23, 2025
2b7c286
Updates
ParamThakkar123 Dec 24, 2025
038552a
Merge branch 'main' of https://github.com/transformerlab/transformerl…
ParamThakkar123 Dec 24, 2025
d571ab8
Merge branch 'main' into add/z-image-ft
josh-janes Dec 24, 2025
e334552
Merge branch 'main' into add/z-image-ft
dadmobile Jan 9, 2026
222a669
Merge branch 'main' into add/z-image-ft
dadmobile Jan 20, 2026
ddd7319
Merge branch 'main' into add/z-image-ft
deep1401 Feb 2, 2026
efada1f
Updated ZImage fine tuning code
ParamThakkar123 Feb 4, 2026
2f2f9ee
Merge branch 'add/z-image-ft' of https://github.com/transformerlab/tr…
ParamThakkar123 Feb 4, 2026
5100330
Updated ZImage fine tuning code
ParamThakkar123 Feb 4, 2026
ef967b0
Updated ZImage fine tuning code
ParamThakkar123 Feb 4, 2026
e3262d2
Merge branch 'main' of https://github.com/transformerlab/transformerl…
ParamThakkar123 Feb 4, 2026
19f23a6
Reformat and rebase
ParamThakkar123 Feb 4, 2026
cc2a21d
Updates
ParamThakkar123 Feb 5, 2026
5c81506
Updates
ParamThakkar123 Feb 5, 2026
3c6c374
Updates
ParamThakkar123 Feb 5, 2026
d6f2822
Merge branch 'main' of https://github.com/transformerlab/transformerl…
ParamThakkar123 Feb 5, 2026
99e483e
Updates
ParamThakkar123 Feb 5, 2026
a2f85e9
Fixed saving lora weights
ParamThakkar123 Feb 6, 2026
d73e250
Formatting
ParamThakkar123 Feb 6, 2026
c1044c7
Merge branch 'main' of https://github.com/transformerlab/transformerl…
ParamThakkar123 Feb 8, 2026
63160e3
ruff
dadmobile Feb 8, 2026
4c9ec6c
Merge branch 'main' into add/z-image-ft
dadmobile Feb 8, 2026
f798019
Merge branch 'main' into add/z-image-ft
ParamThakkar123 Feb 11, 2026
42418c4
Merge branch 'main' into add/z-image-ft
ParamThakkar123 Feb 11, 2026
2d3a814
Merge branch 'main' into add/z-image-ft
ParamThakkar123 Feb 12, 2026
79286da
updates
ParamThakkar123 Feb 12, 2026
6df347d
Merge branch 'main' into add/z-image-ft
ParamThakkar123 Feb 13, 2026
7072c16
Merge branch 'main' into add/z-image-ft
ParamThakkar123 Feb 17, 2026
281f25d
ruff
dadmobile Feb 17, 2026
06b95c7
Merge remote-tracking branch 'origin/main' into add/z-image-ft
dadmobile Feb 17, 2026
008f8ba
Merge branch 'main' of https://github.com/transformerlab/transformerl…
ParamThakkar123 Feb 18, 2026
fae56f9
Merge branch 'add/z-image-ft' of https://github.com/transformerlab/tr…
ParamThakkar123 Feb 18, 2026
1026537
Fixes
ParamThakkar123 Feb 18, 2026
0eacdb5
Fixes
ParamThakkar123 Feb 18, 2026
78f7d24
Updated Peft version
ParamThakkar123 Feb 18, 2026
69687cf
Fixes
ParamThakkar123 Feb 18, 2026
4425840
Fixes
ParamThakkar123 Feb 18, 2026
ddbc9e0
Fixes
ParamThakkar123 Feb 18, 2026
2914c12
Unpin versions
ParamThakkar123 Feb 18, 2026
f1061d5
Bump diffusion plugin version
dadmobile Feb 18, 2026
9950ffd
Merge branch 'main' into add/z-image-ft
ParamThakkar123 Feb 18, 2026
973795c
Updates
ParamThakkar123 Feb 18, 2026
53a2895
Updates
ParamThakkar123 Feb 18, 2026
5a3f72c
Merge branch 'add/z-image-ft' of https://github.com/transformerlab/tr…
ParamThakkar123 Feb 18, 2026
87f465d
Merge branch 'main' of https://github.com/transformerlab/transformerl…
deep1401 Feb 18, 2026
edcb473
ruff
dadmobile Feb 18, 2026
6cc7388
Merge branch 'main' into add/z-image-ft
ParamThakkar123 Feb 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions api/test/api/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,3 +1045,104 @@ def test_get_pipeline_key_whitespace_adaptor():

# Should treat whitespace-only adaptor as no adaptor
assert key == "test-model:: ::txt2img"


def test_resolve_diffusion_model_reference_non_directory():
"""Non-directory model refs should pass through unchanged."""
main = pytest.importorskip("transformerlab.plugins.image_diffusion.main")

with patch("transformerlab.plugins.image_diffusion.main.os.path.isdir", return_value=False):
resolved = main.resolve_diffusion_model_reference("Tongyi-MAI/Z-Image-Turbo")

assert resolved == "Tongyi-MAI/Z-Image-Turbo"


def test_resolve_diffusion_model_reference_prefers_local_complete_dir():
"""Local directory with model_index.json should stay as local path."""
main = pytest.importorskip("transformerlab.plugins.image_diffusion.main")
local_dir = "/tmp/models/Tongyi-MAI_Z-Image-Turbo"

with (
patch("transformerlab.plugins.image_diffusion.main.os.path.isdir", return_value=True),
patch("transformerlab.plugins.image_diffusion.main.os.path.isfile", return_value=True),
patch("transformerlab.plugins.image_diffusion.main._extract_hf_repo_from_model_metadata") as mock_extract,
):
resolved = main.resolve_diffusion_model_reference(local_dir)

mock_extract.assert_not_called()
assert resolved == local_dir


def test_resolve_diffusion_model_reference_falls_back_to_hf_repo():
"""Incomplete local directory should fall back to Hugging Face repo id from metadata."""
main = pytest.importorskip("transformerlab.plugins.image_diffusion.main")
local_dir = "/tmp/models/Tongyi-MAI_Z-Image-Turbo"

with (
patch("transformerlab.plugins.image_diffusion.main.os.path.isdir", return_value=True),
patch("transformerlab.plugins.image_diffusion.main.os.path.isfile", return_value=False),
patch(
"transformerlab.plugins.image_diffusion.main._extract_hf_repo_from_model_metadata",
return_value="Tongyi-MAI/Z-Image-Turbo",
),
):
resolved = main.resolve_diffusion_model_reference(local_dir)

assert resolved == "Tongyi-MAI/Z-Image-Turbo"


def test_filter_generation_kwargs_for_pipeline_drops_unsupported():
"""Unsupported kwargs should be removed when pipeline call signature is strict."""
main = pytest.importorskip("transformerlab.plugins.image_diffusion.main")

class StrictPipeline:
def __call__(self, prompt, guidance_scale):
return {"prompt": prompt, "guidance_scale": guidance_scale}

pipe = StrictPipeline()
kwargs = {
"prompt": "test",
"guidance_scale": 7.5,
"cross_attention_kwargs": {"scale": 1.0},
"callback_on_step_end": lambda *_: None,
}

filtered = main.filter_generation_kwargs_for_pipeline(pipe, kwargs)

assert filtered == {"prompt": "test", "guidance_scale": 7.5}


def test_invoke_pipeline_with_safe_kwargs_retries_on_unexpected_keyword():
"""Retry logic should remove unsupported kwargs when strict call signatures are wrapped."""
main = pytest.importorskip("transformerlab.plugins.image_diffusion.main")

class WrappedStrictPipeline:
def __call__(self, *args, **kwargs):
if "cross_attention_kwargs" in kwargs:
raise TypeError("ZImagePipeline.__call__() got an unexpected keyword argument 'cross_attention_kwargs'")
return kwargs

pipe = WrappedStrictPipeline()
kwargs = {
"prompt": "test",
"guidance_scale": 7.5,
"cross_attention_kwargs": {"scale": 1.0},
}

result = main.invoke_pipeline_with_safe_kwargs(pipe, kwargs)

assert result["prompt"] == "test"
assert result["guidance_scale"] == 7.5
assert "cross_attention_kwargs" not in result


def test_latents_to_rgb_supports_non_sdxl_channel_counts():
"""Intermediate preview conversion should not fail for non-4-channel latents."""
main = pytest.importorskip("transformerlab.plugins.image_diffusion.main")
torch = pytest.importorskip("torch")

latents = torch.randn(16, 8, 8)
preview = main.latents_to_rgb(latents)

assert preview.mode == "RGB"
assert preview.size == (8, 8)
112 changes: 104 additions & 8 deletions api/transformerlab/plugin_sdk/plugin_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,115 @@
import sys
import argparse
import traceback
import asyncio
import sqlite3
from typing import Optional


def get_db_config_value(key: str, team_id: Optional[str] = None, user_id: Optional[str] = None) -> Optional[str]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove the import from transformerlab.plugin and add the function here directly? Was there an issue?

"""
Read config values directly from sqlite without importing transformerlab.plugin.
This keeps harness startup independent from heavy ML dependencies.
"""
from lab import HOME_DIR

db_path = f"{HOME_DIR}/llmlab.sqlite3"
db = sqlite3.connect(db_path, isolation_level=None)
db.execute("PRAGMA busy_timeout=30000")
try:
# Priority 1: user-specific config (requires both user_id and team_id)
if user_id and team_id:
cursor = db.execute(
"SELECT value FROM config WHERE key = ? AND user_id = ? AND team_id = ?", (key, user_id, team_id)
)
row = cursor.fetchone()
cursor.close()
if row is not None:
return row[0]

# Priority 2: team-wide config
if team_id:
cursor = db.execute(
"SELECT value FROM config WHERE key = ? AND user_id IS NULL AND team_id = ?", (key, team_id)
)
row = cursor.fetchone()
cursor.close()
if row is not None:
return row[0]

# Priority 3: global config
cursor = db.execute("SELECT value FROM config WHERE key = ? AND user_id IS NULL AND team_id IS NULL", (key,))
row = cursor.fetchone()
cursor.close()
return row[0] if row is not None else None
finally:
db.close()


parser = argparse.ArgumentParser()
parser.add_argument("--plugin_dir", type=str, required=True)
args, unknown = parser.parse_known_args()


def set_config_env_vars(env_var: str, target_env_var: str = None, user_id: str = None, team_id: str = None):
try:
from transformerlab.plugin import get_db_config_value
def configure_plugin_runtime_library_paths(plugin_dir: str) -> None:
"""
Prefer CUDA/NCCL libraries from the plugin venv over system-wide libraries.
This reduces CUDA symbol mismatches caused by stale host NCCL installs.
"""
if os.name == "nt":
return

venv_path = os.path.join(plugin_dir, "venv")
if not os.path.isdir(venv_path):
return

pyver = f"python{sys.version_info.major}.{sys.version_info.minor}"
site_packages = os.path.join(venv_path, "lib", pyver, "site-packages")

candidate_paths: list[str] = []
torch_lib = os.path.join(site_packages, "torch", "lib")
if os.path.isdir(torch_lib):
candidate_paths.append(torch_lib)

nvidia_root = os.path.join(site_packages, "nvidia")
if os.path.isdir(nvidia_root):
for pkg_name in os.listdir(nvidia_root):
lib_dir = os.path.join(nvidia_root, pkg_name, "lib")
if os.path.isdir(lib_dir):
candidate_paths.append(lib_dir)

value = asyncio.run(get_db_config_value(env_var, user_id=user_id, team_id=team_id))
if not candidate_paths:
return

existing_paths = [p for p in os.environ.get("LD_LIBRARY_PATH", "").split(os.pathsep) if p]
candidate_norm = {os.path.normpath(c) for c in candidate_paths}

merged = list(candidate_paths)
for path in existing_paths:
if os.path.normpath(path) not in candidate_norm:
merged.append(path)

if merged != existing_paths:
os.environ["LD_LIBRARY_PATH"] = os.pathsep.join(merged)
print("Configured LD_LIBRARY_PATH for plugin runtime libraries")


configure_plugin_runtime_library_paths(args.plugin_dir)


def set_config_env_vars(
env_var: str,
target_env_var: Optional[str] = None,
user_id: Optional[str] = None,
team_id: Optional[str] = None,
) -> None:
target_key = target_env_var or env_var
try:
value = get_db_config_value(env_var, user_id=user_id, team_id=team_id)
if value:
os.environ[target_env_var] = value
print(f"Set {target_env_var} from {'user' if user_id else 'team'} config: {value}")
os.environ[target_key] = value
print(f"Set {target_key} from {'user' if user_id else 'team'} config")
except Exception as e:
print(f"Warning: Could not set {target_env_var} from {'user' if user_id else 'team'} config: {e}")
print(f"Warning: Could not set {target_key} from {'user' if user_id else 'team'} config: {e}")


# Set organization context from environment variable if provided
Expand Down Expand Up @@ -69,6 +160,11 @@ def set_config_env_vars(env_var: str, target_env_var: str = None, user_id: str =
except ImportError as e:
print(f"Error executing plugin: {e}")
traceback.print_exc()
if "ncclCommShrink" in str(e):
print(
"Detected CUDA/NCCL mismatch while importing torch. "
"Reinstall the plugin venv with a torch build matching this machine's CUDA runtime."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should never face this issue since we do the base install right?

)

# if e is a ModuleNotFoundError, the plugin is missing a required package
if isinstance(e, ModuleNotFoundError):
Expand Down
5 changes: 3 additions & 2 deletions api/transformerlab/plugins/diffusion_trainer/index.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
"description": "A plugin for fine-tuning Stable Diffusion using LoRA adapters.",
"plugin-format": "python",
"type": "trainer",
"version": "0.1.10",
"version": "0.1.11",
"git": "",
"url": "",
"model_architectures": [
"StableDiffusionPipeline",
"StableDiffusionXLPipeline",
"StableDiffusion3Pipeline",
"FluxPipeline"
"FluxPipeline",
"ZImagePipeline"
],
"files": ["main.py", "setup.sh"],
"supported_hardware_architectures": ["cuda", "amd"],
Expand Down
Loading
Loading