Skip to content

Commit 0652a51

Browse files
author
Andrea
committed
merge: integrate Z-Image i2L (Image-to-LoRA) from upstream PR filipstrand#361
Adds mflux-z-image-i2l CLI command that generates LoRA adapters from reference images using SigLIP2 + DINOv3 + i2L decoder, entirely on-device. Also adds .default. LoRA naming support to ZImageLoRAMapping.
2 parents 0f25a92 + a19d647 commit 0652a51

33 files changed

+1572
-4
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [0.16.4] - 2026-02-15
9+
10+
### 🐛 Bug Fixes
11+
12+
- **Training preview stability**: Always offload optimizer state during preview generation to avoid memory pressure and improve preview reliability.
13+
- **Apple Silicon compile guard**: Narrow the M1/M2 compile fallback so it excludes Max and Ultra variants, preserving expected optimized behavior on those chips.
14+
15+
---
16+
817
## [0.16.3] - 2026-02-14
918

1019
### 🐛 Bug Fixes

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ source-exclude = [
1414

1515
[project]
1616
name = "mflux"
17-
version = "0.16.3"
17+
version = "0.16.4"
1818
description = "MLX native implementations of state-of-the-art generative image models."
1919
readme = "README.md"
2020
keywords = ["flux", "ai", "ml", "transformers", "mlx", "huggingface", "apple-silicon", "diffusers", "qwen", "qwen-image", "seedvr2", "z-image"]
@@ -89,6 +89,7 @@ mflux-generate-qwen-edit = "mflux.models.qwen.cli.qwen_image_edit_generate:main"
8989
mflux-generate-fibo = "mflux.models.fibo.cli.fibo_generate:main"
9090
mflux-generate-z-image = "mflux.models.z_image.cli.z_image_generate:main"
9191
mflux-generate-z-image-turbo = "mflux.models.z_image.cli.z_image_turbo_generate:main"
92+
mflux-z-image-i2l = "mflux.models.z_image.cli.z_image_i2l:main"
9293
mflux-refine-fibo = "mflux.models.fibo_vlm.cli.fibo_refine:main"
9394
mflux-inspire-fibo = "mflux.models.fibo_vlm.cli.fibo_inspire:main"
9495
mflux-concept = "mflux.models.flux.cli.flux_concept:main"

src/mflux/models/common/training/trainer.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from __future__ import annotations
22

3+
import gc
34
import random
5+
import tempfile
46
from pathlib import Path
57

68
import mlx.core as mx
79
from mlx import nn
10+
from mlx.utils import tree_unflatten
811
from PIL import Image as PILImage
912
from tqdm import tqdm
1013

@@ -120,7 +123,7 @@ def train(
120123
)
121124

122125
if training_spec.monitoring is not None and training_state.iterator.num_iterations == 0:
123-
TrainingTrainer._generate_previews(adapter, training_spec, training_state)
126+
TrainingTrainer._generate_previews_with_optimizer_offload(adapter, training_spec, training_state)
124127
validation_batch = training_state.iterator.get_validation_batch()
125128
validation_loss = TrainingTrainer.compute_loss(adapter, training_spec, base_config, validation_batch)
126129
training_state.statistics.append_values(step=training_state.iterator.num_iterations, loss=float(validation_loss)) # fmt: off
@@ -147,7 +150,7 @@ def train(
147150
del validation_loss
148151

149152
if training_state.should_generate_image(training_spec):
150-
TrainingTrainer._generate_previews(adapter, training_spec, training_state)
153+
TrainingTrainer._generate_previews_with_optimizer_offload(adapter, training_spec, training_state)
151154

152155
if training_state.should_save(training_spec):
153156
training_state.save(adapter, training_spec)
@@ -227,3 +230,25 @@ def _generate_previews(
227230
)
228231
)
229232
del image
233+
234+
@staticmethod
235+
def _generate_previews_with_optimizer_offload(
236+
adapter: TrainingAdapter,
237+
training_spec: TrainingSpec,
238+
training_state: TrainingState,
239+
) -> None:
240+
optimizer = training_state.optimizer
241+
with tempfile.TemporaryDirectory() as tmp_dir:
242+
offload_path = Path(tmp_dir) / "optimizer_offload.safetensors"
243+
optimizer.save(offload_path)
244+
optimizer.optimizer.state = []
245+
246+
gc.collect()
247+
mx.clear_cache()
248+
try:
249+
TrainingTrainer._generate_previews(adapter, training_spec, training_state)
250+
finally:
251+
restored_state = tree_unflatten(list(mx.load(str(offload_path)).items()))
252+
optimizer.optimizer.state = restored_state
253+
gc.collect()
254+
mx.clear_cache()
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""CLI entrypoint for Z-Image Image-to-LoRA (i2L).
2+
3+
Usage:
4+
mflux-z-image-i2l --image-path ./style_images --output style_lora.safetensors
5+
mflux-z-image-i2l --image-path img1.jpg img2.jpg --output style_lora.safetensors
6+
"""
7+
8+
import argparse
9+
import sys
10+
from pathlib import Path
11+
12+
from PIL import Image
13+
14+
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tiff", ".tif"}
15+
16+
17+
def _collect_images(paths: list[str]) -> list[Path]:
18+
"""Resolve a mix of files and directories into a sorted list of image paths."""
19+
result = []
20+
for p_str in paths:
21+
p = Path(p_str)
22+
if not p.exists():
23+
print(f"Error: Path not found: {p_str}", file=sys.stderr)
24+
sys.exit(1)
25+
if p.is_dir():
26+
found = sorted(f for f in p.iterdir() if f.suffix.lower() in IMAGE_EXTENSIONS)
27+
if not found:
28+
print(f"Error: No images found in directory: {p_str}", file=sys.stderr)
29+
sys.exit(1)
30+
result.extend(found)
31+
elif p.suffix.lower() in IMAGE_EXTENSIONS:
32+
result.append(p)
33+
else:
34+
print(f"Error: Unsupported file type: {p_str}", file=sys.stderr)
35+
sys.exit(1)
36+
return result
37+
38+
39+
def main():
40+
parser = argparse.ArgumentParser(
41+
description="Generate LoRA weights from style reference images using Z-Image i2L.",
42+
formatter_class=argparse.RawDescriptionHelpFormatter,
43+
epilog="""
44+
Examples:
45+
mflux-z-image-i2l --image-path ./my_style
46+
mflux-z-image-i2l --image-path ./my_style --output my_style.safetensors
47+
mflux-z-image-i2l --image-path img1.jpg img2.jpg img3.jpg img4.jpg
48+
mflux-z-image-i2l --image-path ./style_a ./style_b/photo.png
49+
50+
The generated LoRA can then be used with mflux-generate-z-image-turbo:
51+
mflux-generate-z-image-turbo --prompt "a cat" --lora-paths style.safetensors
52+
""",
53+
)
54+
parser.add_argument(
55+
"--image-path",
56+
"-i",
57+
nargs="+",
58+
required=True,
59+
type=str,
60+
help="Image files or directories containing style reference images.",
61+
)
62+
parser.add_argument(
63+
"--output",
64+
"-o",
65+
type=str,
66+
default="lora.safetensors",
67+
help="Output path for the generated LoRA file. Default: lora.safetensors",
68+
)
69+
70+
args = parser.parse_args()
71+
72+
# Collect image paths from files and directories
73+
image_paths = _collect_images(args.image_path)
74+
75+
# Load images
76+
print(f"Loading {len(image_paths)} image(s)...")
77+
images = []
78+
for p in image_paths:
79+
img = Image.open(p).convert("RGB")
80+
images.append(img)
81+
print(f" {p.name}: {img.size[0]}x{img.size[1]}")
82+
83+
# Import here to avoid slow startup for --help
84+
from mflux.models.z_image.model.z_image_i2l.i2l_pipeline import ZImageI2LPipeline
85+
86+
# Create pipeline and generate LoRA
87+
pipeline = ZImageI2LPipeline.from_pretrained()
88+
pipeline.generate_lora(images=images, output_path=args.output)
89+
90+
91+
if __name__ == "__main__":
92+
main()

src/mflux/models/z_image/model/z_image_i2l/__init__.py

Whitespace-only changes.

src/mflux/models/z_image/model/z_image_i2l/dinov3/__init__.py

Whitespace-only changes.
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import mlx.core as mx
2+
from mlx import nn
3+
from mlx.core.fast import scaled_dot_product_attention
4+
5+
from mflux.models.z_image.model.z_image_i2l.dinov3.dinov3_rope import apply_dinov3_rope
6+
7+
8+
class DINOv3Attention(nn.Module):
9+
"""DINOv3 attention with RoPE on patch tokens.
10+
11+
hidden_size=4096, num_heads=32, head_dim=128.
12+
Bias config: q=False, k=False, v=False, o=True.
13+
"""
14+
15+
def __init__(self):
16+
super().__init__()
17+
self.num_heads = 32
18+
self.head_dim = 128 # 4096 / 32
19+
dim = 4096
20+
self.q_proj = nn.Linear(dim, dim, bias=False)
21+
self.k_proj = nn.Linear(dim, dim, bias=False)
22+
self.v_proj = nn.Linear(dim, dim, bias=False)
23+
self.o_proj = nn.Linear(dim, dim, bias=True)
24+
25+
def __call__(
26+
self,
27+
hidden_states: mx.array,
28+
cos: mx.array,
29+
sin: mx.array,
30+
num_prefix_tokens: int = 5,
31+
) -> mx.array:
32+
B, N, _ = hidden_states.shape
33+
34+
q = self.q_proj(hidden_states).reshape(B, N, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
35+
k = self.k_proj(hidden_states).reshape(B, N, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
36+
v = self.v_proj(hidden_states).reshape(B, N, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
37+
38+
# Apply RoPE only to patch tokens
39+
q, k = apply_dinov3_rope(q, k, cos, sin, num_prefix_tokens=num_prefix_tokens)
40+
41+
scale = 1.0 / mx.sqrt(mx.array(self.head_dim, dtype=q.dtype))
42+
out = scaled_dot_product_attention(q, k, v, scale=scale)
43+
44+
out = out.transpose(0, 2, 1, 3).reshape(B, N, -1)
45+
return self.o_proj(out)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import mlx.core as mx
2+
import mlx.nn as nn
3+
4+
5+
class DINOv3Embeddings(nn.Module):
6+
"""DINOv3 embeddings: CLS token + register tokens + patch embeddings.
7+
8+
image_size=224, patch_size=16, hidden_size=4096, num_register_tokens=4.
9+
Sequence: [CLS, reg0, reg1, reg2, reg3, patch0, patch1, ...]
10+
Total prefix tokens = 5 (1 CLS + 4 registers).
11+
"""
12+
13+
def __init__(self):
14+
super().__init__()
15+
self.hidden_size = 4096
16+
self.patch_size = 16
17+
self.image_size = 224
18+
self.num_register_tokens = 4
19+
20+
self.cls_token = mx.random.normal(shape=(1, 1, self.hidden_size))
21+
self.register_tokens = mx.random.normal(shape=(1, self.num_register_tokens, self.hidden_size))
22+
self.patch_embeddings = nn.Conv2d(
23+
in_channels=3,
24+
out_channels=self.hidden_size,
25+
kernel_size=self.patch_size,
26+
stride=self.patch_size,
27+
bias=True,
28+
)
29+
30+
def __call__(self, pixel_values: mx.array) -> mx.array:
31+
"""
32+
Args:
33+
pixel_values: (B, 3, 224, 224)
34+
Returns:
35+
(B, 1 + 4 + 196, 4096) = (B, 201, 4096)
36+
"""
37+
B = pixel_values.shape[0]
38+
39+
# Patch embedding: (B, C, H, W) -> (B, H', W', hidden) via Conv2d (needs NHWC)
40+
x = mx.transpose(pixel_values, (0, 2, 3, 1)) # (B, H, W, C)
41+
x = self.patch_embeddings(x) # (B, H', W', hidden)
42+
# Flatten spatial: (B, num_patches, hidden)
43+
x = x.reshape(B, -1, self.hidden_size)
44+
45+
# Prepend CLS and register tokens
46+
cls_tokens = mx.broadcast_to(self.cls_token, (B, 1, self.hidden_size))
47+
reg_tokens = mx.broadcast_to(self.register_tokens, (B, self.num_register_tokens, self.hidden_size))
48+
x = mx.concatenate([cls_tokens, reg_tokens, x], axis=1)
49+
50+
return x
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import mlx.core as mx
2+
import mlx.nn as nn
3+
4+
5+
class DINOv3LayerScale(nn.Module):
6+
"""Learnable per-channel scaling, same pattern as DINOv2."""
7+
8+
def __init__(self, dims: int = 4096, init_values: float = 1.0):
9+
super().__init__()
10+
self.gamma = init_values * mx.ones((dims,))
11+
12+
def __call__(self, x: mx.array) -> mx.array:
13+
return x * self.gamma
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import mlx.core as mx
2+
import mlx.nn as nn
3+
4+
5+
class DINOv3GatedMLP(nn.Module):
6+
"""DINOv3 gated MLP with SiLU activation.
7+
8+
hidden_size=4096, intermediate_size=8192. All projections have bias.
9+
Formula: down_proj(silu(gate_proj(x)) * up_proj(x))
10+
"""
11+
12+
def __init__(self):
13+
super().__init__()
14+
self.gate_proj = nn.Linear(4096, 8192, bias=True)
15+
self.up_proj = nn.Linear(4096, 8192, bias=True)
16+
self.down_proj = nn.Linear(8192, 4096, bias=True)
17+
18+
def __call__(self, x: mx.array) -> mx.array:
19+
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))

0 commit comments

Comments
 (0)