Skip to content

Commit 06c568e

Browse files
authored
[SERVING] Add FLUX.2 LoRA tests& support sigmas (#674)
* flux2 turbo lora * flux2 turbo lora
1 parent f3b78f3 commit 06c568e

File tree

3 files changed

+121
-2
lines changed

3 files changed

+121
-2
lines changed

src/cache_dit/serve/api_server.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ class GenerateRequestAPI(BaseModel):
3939
height: int = Field(1024, description="Image/Video height", ge=64, le=4096)
4040
num_inference_steps: int = Field(50, description="Number of inference steps", ge=1, le=200)
4141
guidance_scale: float = Field(7.5, description="Guidance scale", ge=0.0, le=20.0)
42+
sigmas: Optional[List[float]] = Field(
43+
None,
44+
description="Custom sigma schedule (e.g. for turbo inference). Length should typically match num_inference_steps.",
45+
)
4246
seed: Optional[int] = Field(None, description="Random seed")
4347
num_images: int = Field(1, description="Number of images to generate", ge=1, le=4)
4448
image_urls: Optional[List[str]] = Field(
@@ -120,6 +124,7 @@ async def generate(request: GenerateRequestAPI):
120124
height=request.height,
121125
num_inference_steps=request.num_inference_steps,
122126
guidance_scale=request.guidance_scale,
127+
sigmas=request.sigmas,
123128
seed=request.seed,
124129
num_images=request.num_images,
125130
image_urls=request.image_urls,

src/cache_dit/serve/model_manager.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import time
99
import base64
10+
import inspect
1011
import tempfile
1112
import math
1213
import torch
@@ -83,6 +84,7 @@ class GenerateRequest:
8384
height: int = 1024
8485
num_inference_steps: int = 50
8586
guidance_scale: float = 7.5
87+
sigmas: Optional[List[float]] = None
8688
seed: Optional[int] = None
8789
num_images: int = 1
8890
image_urls: Optional[List[str]] = None
@@ -594,6 +596,16 @@ def generate(self, request: GenerateRequest) -> GenerateResponse:
594596
"generator": generator,
595597
}
596598

599+
if request.sigmas is not None:
600+
try:
601+
sig = inspect.signature(self.pipe.__call__)
602+
if "sigmas" in sig.parameters:
603+
pipe_kwargs["sigmas"] = request.sigmas
604+
else:
605+
logger.warning("Pipeline does not support sigmas, ignoring request.sigmas")
606+
except Exception:
607+
pipe_kwargs["sigmas"] = request.sigmas
608+
597609
# Add num_frames for video generation
598610
if is_video_mode:
599611
pipe_kwargs["num_frames"] = request.num_frames
@@ -614,8 +626,6 @@ def generate(self, request: GenerateRequest) -> GenerateResponse:
614626
# Some pipelines (like Flux2Pipeline) don't support negative_prompt
615627
if request.negative_prompt:
616628
try:
617-
import inspect
618-
619629
sig = inspect.signature(self.pipe.__call__)
620630
if "negative_prompt" in sig.parameters:
621631
pipe_kwargs["negative_prompt"] = request.negative_prompt
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""Test FLUX.2 Turbo LoRA model serving.
2+
3+
Server setup:
4+
CUDA_VISIBLE_DEVICES=4,5,6,7 torchrun --nproc_per_node=4 \
5+
-m cache_dit.serve.serve \
6+
--model-path black-forest-labs/FLUX.2-dev \
7+
--lora-path fal/FLUX.2-dev-Turbo \
8+
--lora-name flux.2-turbo-lora.safetensors \
9+
--parallel-type ulysses \
10+
--parallel-text-encoder \
11+
--quantize-type float8_wo \
12+
--attn _flash_3 \
13+
--cache \
14+
--compile \
15+
--ulysses-anything
16+
17+
This test calls /generate with a custom sigma schedule (TURBO_SIGMAS) for 8-step turbo inference.
18+
19+
Reference LoRA: https://huggingface.co/fal/FLUX.2-dev-Turbo
20+
Base model: https://huggingface.co/black-forest-labs/FLUX.2-dev
21+
"""
22+
23+
import os
24+
import requests
25+
import base64
26+
from PIL import Image
27+
from io import BytesIO
28+
29+
30+
# Pre-shifted custom sigmas for 8-step turbo inference
31+
TURBO_SIGMAS = [1.0, 0.6509, 0.4374, 0.2932, 0.1893, 0.1108, 0.0495, 0.00031]
32+
33+
34+
def call_api(prompt, name="flux2_turbo", **kwargs):
35+
host = os.environ.get("CACHE_DIT_HOST", "localhost")
36+
port = int(os.environ.get("CACHE_DIT_PORT", 8000))
37+
url = f"http://{host}:{port}/generate"
38+
39+
payload = {
40+
"prompt": prompt,
41+
"width": kwargs.get("width", 1024),
42+
"height": kwargs.get("height", 1024),
43+
"num_inference_steps": kwargs.get("num_inference_steps", 8),
44+
"guidance_scale": kwargs.get("guidance_scale", 2.5),
45+
"sigmas": kwargs.get("sigmas", TURBO_SIGMAS),
46+
"seed": kwargs.get("seed", 42),
47+
"num_images": kwargs.get("num_images", 1),
48+
}
49+
50+
if "output_format" in kwargs:
51+
payload["output_format"] = kwargs["output_format"]
52+
if "output_dir" in kwargs:
53+
payload["output_dir"] = kwargs["output_dir"]
54+
55+
response = requests.post(url, json=payload, timeout=600)
56+
response.raise_for_status()
57+
result = response.json()
58+
59+
assert "images" in result and result["images"], "No images in response"
60+
61+
if payload.get("output_format", "base64") == "path":
62+
filename = result["images"][0]
63+
assert os.path.exists(filename)
64+
img = Image.open(filename)
65+
print(f"Saved: {filename} ({img.size[0]}x{img.size[1]})")
66+
return filename
67+
68+
img_data = base64.b64decode(result["images"][0])
69+
img = Image.open(BytesIO(img_data))
70+
71+
filename = f"{name}.png"
72+
img.save(filename)
73+
print(f"Saved: {filename} ({img.size[0]}x{img.size[1]})")
74+
return filename
75+
76+
77+
def test_flux2_turbo_lora():
78+
prompt = (
79+
"Industrial product shot of a chrome turbocharger with glowing hot exhaust manifold, "
80+
"engraved text 'FLUX.2 [dev] Turbo by fal' on the compressor housing and 'fal' on the turbine wheel, "
81+
"gradient heat glow from orange to electric blue , studio lighting with dramatic shadows, "
82+
"shallow depth of field, engineering blueprint pattern in background."
83+
)
84+
85+
return call_api(
86+
prompt=prompt,
87+
name="flux2_turbo_lora",
88+
num_inference_steps=8,
89+
guidance_scale=2.5,
90+
sigmas=TURBO_SIGMAS,
91+
width=1024,
92+
height=1024,
93+
seed=42,
94+
)
95+
96+
97+
if __name__ == "__main__":
98+
print("=" * 80)
99+
print("Testing FLUX.2 Turbo LoRA Model Serving")
100+
print("=" * 80)
101+
test_flux2_turbo_lora()
102+
print("=" * 80)
103+
print("Done")
104+
print("=" * 80)

0 commit comments

Comments
 (0)