Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 37 additions & 18 deletions tests/serving/test_flux2_image_edit_serving.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
"""
CUDA_VISIBLE_DEVICES=4,5,6,7 torchrun --nproc_per_node=4 \
-m cache_dit.serve.serve \
--model-path black-forest-labs/FLUX.2-dev \
--parallel-type ulysses \
--parallel-text-encoder \
--quantize-type float8_wo \
--attn _flash_3 \
--cache \
--compile \
--ulysses-anything
"""

import os
import requests
import base64
Expand All @@ -22,27 +35,19 @@ def call_api(prompt, image_urls=None, name="test", **kwargs):
if image_urls:
payload["image_urls"] = image_urls

try:
response = requests.post(url, json=payload, timeout=300)
response.raise_for_status()

result = response.json()

if "images" not in result or len(result["images"]) == 0:
return None

img_data = base64.b64decode(result["images"][0])
img = Image.open(BytesIO(img_data))
response = requests.post(url, json=payload, timeout=300)
response.raise_for_status()
result = response.json()
assert "images" in result and len(result["images"]) > 0, "No images in response"

filename = f"{name}.png"
img.save(filename)
img_data = base64.b64decode(result["images"][0])
img = Image.open(BytesIO(img_data))

print(f"Saved: {filename}")
return filename
filename = f"{name}.png"
img.save(filename)

except Exception as e:
print(f"Error: {e}")
return None
print(f"Saved: {filename} ({img.size[0]}x{img.size[1]})")
return filename


def test_single():
Expand Down Expand Up @@ -89,8 +94,22 @@ def test_text():
)


def test_text_ulysses_bad_resolution_regression():
filename = call_api(
prompt="A beautiful landscape with mountains and lakes",
name="text_gen_724x1080",
width=724,
height=1080,
num_inference_steps=8,
)
img = Image.open(filename)
assert img.size == (720, 1080)
return filename


if __name__ == "__main__":
test_single()
test_multi()
test_base64()
test_text()
test_text_ulysses_bad_resolution_regression()
Loading