Skip to content

Commit 79f4076

Browse files
committed
refactor(i2l): remove --rank and --lora-scale from CLI and pipeline
These are unnecessary: - --lora-scale: already available via --lora-scales at generation time - --rank: determined naturally by number of images (rank = 4 * N) CLI is now minimal: --image-path and --output only.
1 parent 2c94ef3 commit 79f4076

File tree

2 files changed

+5
-48
lines changed

2 files changed

+5
-48
lines changed

src/mflux/models/z_image/cli/z_image_i2l.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -66,23 +66,6 @@ def main():
6666
default="lora.safetensors",
6767
help="Output path for the generated LoRA file. Default: lora.safetensors",
6868
)
69-
parser.add_argument(
70-
"--lora-scale",
71-
type=float,
72-
default=1.0,
73-
help="Scale factor applied to all LoRA weights before saving. "
74-
"Values > 1.0 amplify the style effect, < 1.0 soften it. Default: 1.0",
75-
)
76-
parser.add_argument(
77-
"--rank",
78-
"-r",
79-
type=int,
80-
default=None,
81-
help="Target LoRA rank. Base rank is 4 per image. With N images, "
82-
"default rank is 4*N (e.g. 16 for 4 images). Set this to override: "
83-
"each image will be repeated ceil(rank / (4*N)) times to reach the "
84-
"target. Valid values: multiples of 4. Typical: 4, 16, 32, 64, 128.",
85-
)
8669

8770
args = parser.parse_args()
8871

@@ -102,7 +85,7 @@ def main():
10285

10386
# Create pipeline and generate LoRA
10487
pipeline = ZImageI2LPipeline.from_pretrained()
105-
pipeline.generate_lora(images=images, output_path=args.output, lora_scale=args.lora_scale, target_rank=args.rank)
88+
pipeline.generate_lora(images=images, output_path=args.output)
10689

10790

10891
if __name__ == "__main__":

src/mflux/models/z_image/model/z_image_i2l/i2l_pipeline.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -158,17 +158,15 @@ def generate_lora(
158158
self,
159159
images: list[Image.Image],
160160
output_path: str | Path = "lora.safetensors",
161-
lora_scale: float = 1.0,
162-
target_rank: int | None = None,
163161
) -> Path:
164162
"""Full pipeline: encode images and save LoRA weights.
165163
164+
Each image contributes rank 4 to the output LoRA. With N images,
165+
the output has effective rank 4*N (merged by concatenation).
166+
166167
Args:
167168
images: List of style reference images.
168169
output_path: Where to save the generated .safetensors file.
169-
lora_scale: Scale factor for LoRA weights.
170-
target_rank: Target LoRA rank (must be multiple of 4).
171-
If None, uses 4 * num_images (natural merge rank).
172170
173171
Returns:
174172
Path to the saved LoRA file.
@@ -182,26 +180,7 @@ def generate_lora(
182180
embeddings = self.encode_images(images)
183181
encode_time = time.time() - t0
184182
print(f" Encoding done in {encode_time:.1f}s")
185-
186-
# Determine repetitions to reach target rank
187-
base_rank = 4
188-
num_images = len(images)
189-
natural_rank = base_rank * num_images
190-
191-
if target_rank is not None:
192-
if target_rank % base_rank != 0:
193-
raise ValueError(f"target_rank must be a multiple of {base_rank}, got {target_rank}")
194-
repeats = max(1, (target_rank + natural_rank - 1) // natural_rank) # ceil division
195-
if repeats > 1:
196-
print(
197-
f" Repeating embeddings {repeats}x to reach rank {natural_rank * repeats} (target: {target_rank})"
198-
)
199-
embeddings = mx.concatenate([embeddings] * repeats, axis=0)
200-
else:
201-
target_rank = natural_rank
202-
203-
effective_rank = base_rank * embeddings.shape[0]
204-
print(f" Output LoRA rank: {effective_rank}")
183+
print(f" Output LoRA rank: {4 * len(images)}")
205184

206185
# Decode to LoRA
207186
print("Generating LoRA weights...")
@@ -211,11 +190,6 @@ def generate_lora(
211190
print(f" Decoding done in {decode_time:.1f}s")
212191
print(f" Generated {len(lora)} LoRA weight tensors")
213192

214-
# Apply scale
215-
if lora_scale != 1.0:
216-
print(f" Applying LoRA scale: {lora_scale}")
217-
lora = {k: v * lora_scale for k, v in lora.items()}
218-
219193
# Save as safetensors (convert MLX bfloat16 -> float32 -> torch bfloat16)
220194
print(f"Saving to {output_path}...")
221195
lora_torch = {}

0 commit comments

Comments
 (0)